compiler: enforce Hash160 and Hash256 size in literals

Can be useful to prevent small typos.
This commit is contained in:
Evgenii Stratonikov 2020-12-10 14:08:42 +03:00
parent d828096cbf
commit ff4880249d
3 changed files with 78 additions and 21 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
@ -746,12 +747,27 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
case *ast.CompositeLit: case *ast.CompositeLit:
switch typ := c.typeOf(n).Underlying().(type) { t := c.typeOf(n)
switch typ := t.Underlying().(type) {
case *types.Struct: case *types.Struct:
c.convertStruct(n, false) c.convertStruct(n, false)
case *types.Map: case *types.Map:
c.convertMap(n) c.convertMap(n)
default: default:
if tn, ok := t.(*types.Named); ok && isInteropPath(tn.String()) {
st, _ := scAndVMInteropTypeFromExpr(tn)
expectedLen := -1
switch st {
case smartcontract.Hash160Type:
expectedLen = 20
case smartcontract.Hash256Type:
expectedLen = 32
}
if expectedLen != -1 && expectedLen != len(n.Elts) {
c.prog.Err = fmt.Errorf("%s type must have size %d", tn.Obj().Name(), expectedLen)
return nil
}
}
ln := len(n.Elts) ln := len(n.Elts)
// ByteArrays needs a different approach than normal arrays. // ByteArrays needs a different approach than normal arrays.
if isByteSlice(typ) { if isByteSlice(typ) {

View file

@ -228,13 +228,7 @@ func (c *codegen) scAndVMReturnTypeFromScope(scope *funcScope) (smartcontract.Pa
} }
} }
func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr) (smartcontract.ParamType, stackitem.Type) { func scAndVMInteropTypeFromExpr(named *types.Named) (smartcontract.ParamType, stackitem.Type) {
t := c.typeOf(typ)
if c.typeOf(typ) == nil {
return smartcontract.AnyType, stackitem.AnyT
}
if named, ok := t.(*types.Named); ok {
if isInteropPath(named.String()) {
name := named.Obj().Name() name := named.Obj().Name()
pkg := named.Obj().Pkg().Name() pkg := named.Obj().Pkg().Name()
switch pkg { switch pkg {
@ -255,6 +249,16 @@ func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr) (smartcontract.ParamType, st
} }
} }
return smartcontract.InteropInterfaceType, stackitem.InteropT return smartcontract.InteropInterfaceType, stackitem.InteropT
}
func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr) (smartcontract.ParamType, stackitem.Type) {
t := c.typeOf(typ)
if c.typeOf(typ) == nil {
return smartcontract.AnyType, stackitem.AnyT
}
if named, ok := t.(*types.Named); ok {
if isInteropPath(named.String()) {
return scAndVMInteropTypeFromExpr(named)
} }
} }
switch t := t.Underlying().(type) { switch t := t.Underlying().(type) {

View file

@ -15,6 +15,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/encoding/address"
cinterop "github.com/nspcc-dev/neo-go/pkg/interop"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
@ -23,6 +24,42 @@ import (
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
) )
func TestTypeConstantSize(t *testing.T) {
src := `package foo
import "github.com/nspcc-dev/neo-go/pkg/interop"
var a %T // type declaration is always ok
func Main() interface{} {
return %#v
}`
t.Run("Hash160", func(t *testing.T) {
t.Run("good", func(t *testing.T) {
a := make(cinterop.Hash160, 20)
src := fmt.Sprintf(src, a, a)
eval(t, src, []byte(a))
})
t.Run("bad", func(t *testing.T) {
a := make(cinterop.Hash160, 19)
src := fmt.Sprintf(src, a, a)
_, err := compiler.Compile("foo.go", strings.NewReader(src))
require.Error(t, err)
})
})
t.Run("Hash256", func(t *testing.T) {
t.Run("good", func(t *testing.T) {
a := make(cinterop.Hash256, 32)
src := fmt.Sprintf(src, a, a)
eval(t, src, []byte(a))
})
t.Run("bad", func(t *testing.T) {
a := make(cinterop.Hash256, 31)
src := fmt.Sprintf(src, a, a)
_, err := compiler.Compile("foo.go", strings.NewReader(src))
require.Error(t, err)
})
})
}
func TestFromAddress(t *testing.T) { func TestFromAddress(t *testing.T) {
as1 := "NQRLhCpAru9BjGsMwk67vdMwmzKMRgsnnN" as1 := "NQRLhCpAru9BjGsMwk67vdMwmzKMRgsnnN"
addr1, err := address.StringToUint160(as1) addr1, err := address.StringToUint160(as1)