diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 0a976b73f..5be47dc70 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -746,12 +747,27 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.CompositeLit: - switch typ := c.typeOf(n).Underlying().(type) { + t := c.typeOf(n) + switch typ := t.Underlying().(type) { case *types.Struct: c.convertStruct(n, false) case *types.Map: c.convertMap(n) default: + if tn, ok := t.(*types.Named); ok && isInteropPath(tn.String()) { + st, _ := scAndVMInteropTypeFromExpr(tn) + expectedLen := -1 + switch st { + case smartcontract.Hash160Type: + expectedLen = 20 + case smartcontract.Hash256Type: + expectedLen = 32 + } + if expectedLen != -1 && expectedLen != len(n.Elts) { + c.prog.Err = fmt.Errorf("%s type must have size %d", tn.Obj().Name(), expectedLen) + return nil + } + } ln := len(n.Elts) // ByteArrays needs a different approach than normal arrays. if isByteSlice(typ) { diff --git a/pkg/compiler/debug.go b/pkg/compiler/debug.go index 1676a0a3e..d82fc688e 100644 --- a/pkg/compiler/debug.go +++ b/pkg/compiler/debug.go @@ -228,6 +228,29 @@ func (c *codegen) scAndVMReturnTypeFromScope(scope *funcScope) (smartcontract.Pa } } +func scAndVMInteropTypeFromExpr(named *types.Named) (smartcontract.ParamType, stackitem.Type) { + name := named.Obj().Name() + pkg := named.Obj().Pkg().Name() + switch pkg { + case "blockchain", "contract": + return smartcontract.ArrayType, stackitem.ArrayT // Block, Transaction, Contract + case "interop": + if name != "Interface" { + switch name { + case "Hash160": + return smartcontract.Hash160Type, stackitem.ByteArrayT + case "Hash256": + return smartcontract.Hash256Type, stackitem.ByteArrayT + case "PublicKey": + return smartcontract.PublicKeyType, stackitem.ByteArrayT + case "Signature": + return smartcontract.SignatureType, stackitem.ByteArrayT + } + } + } + return smartcontract.InteropInterfaceType, stackitem.InteropT +} + func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr) (smartcontract.ParamType, stackitem.Type) { t := c.typeOf(typ) if c.typeOf(typ) == nil { @@ -235,26 +258,7 @@ func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr) (smartcontract.ParamType, st } if named, ok := t.(*types.Named); ok { if isInteropPath(named.String()) { - name := named.Obj().Name() - pkg := named.Obj().Pkg().Name() - switch pkg { - case "blockchain", "contract": - return smartcontract.ArrayType, stackitem.ArrayT // Block, Transaction, Contract - case "interop": - if name != "Interface" { - switch name { - case "Hash160": - return smartcontract.Hash160Type, stackitem.ByteArrayT - case "Hash256": - return smartcontract.Hash256Type, stackitem.ByteArrayT - case "PublicKey": - return smartcontract.PublicKeyType, stackitem.ByteArrayT - case "Signature": - return smartcontract.SignatureType, stackitem.ByteArrayT - } - } - } - return smartcontract.InteropInterfaceType, stackitem.InteropT + return scAndVMInteropTypeFromExpr(named) } } switch t := t.Underlying().(type) { diff --git a/pkg/compiler/interop_test.go b/pkg/compiler/interop_test.go index dfdf7e8f6..6aa603aaa 100644 --- a/pkg/compiler/interop_test.go +++ b/pkg/compiler/interop_test.go @@ -24,6 +24,42 @@ import ( "go.uber.org/zap/zaptest" ) +func TestTypeConstantSize(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop" + var a %T // type declaration is always ok + func Main() interface{} { + return %#v + }` + + t.Run("Hash160", func(t *testing.T) { + t.Run("good", func(t *testing.T) { + a := make(cinterop.Hash160, 20) + src := fmt.Sprintf(src, a, a) + eval(t, src, []byte(a)) + }) + t.Run("bad", func(t *testing.T) { + a := make(cinterop.Hash160, 19) + src := fmt.Sprintf(src, a, a) + _, err := compiler.Compile("foo.go", strings.NewReader(src)) + require.Error(t, err) + }) + }) + t.Run("Hash256", func(t *testing.T) { + t.Run("good", func(t *testing.T) { + a := make(cinterop.Hash256, 32) + src := fmt.Sprintf(src, a, a) + eval(t, src, []byte(a)) + }) + t.Run("bad", func(t *testing.T) { + a := make(cinterop.Hash256, 31) + src := fmt.Sprintf(src, a, a) + _, err := compiler.Compile("foo.go", strings.NewReader(src)) + require.Error(t, err) + }) + }) +} + func TestFromAddress(t *testing.T) { as1 := "NQRLhCpAru9BjGsMwk67vdMwmzKMRgsnnN" addr1, err := address.StringToUint160(as1)