compiler: enforce Hash160 and Hash256 size in literals

Can be useful to prevent small typos.
This commit is contained in:
Evgenii Stratonikov 2020-12-10 14:08:42 +03:00
parent d828096cbf
commit ff4880249d
3 changed files with 78 additions and 21 deletions

View file

@ -14,6 +14,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
@ -746,12 +747,27 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil
case *ast.CompositeLit:
switch typ := c.typeOf(n).Underlying().(type) {
t := c.typeOf(n)
switch typ := t.Underlying().(type) {
case *types.Struct:
c.convertStruct(n, false)
case *types.Map:
c.convertMap(n)
default:
if tn, ok := t.(*types.Named); ok && isInteropPath(tn.String()) {
st, _ := scAndVMInteropTypeFromExpr(tn)
expectedLen := -1
switch st {
case smartcontract.Hash160Type:
expectedLen = 20
case smartcontract.Hash256Type:
expectedLen = 32
}
if expectedLen != -1 && expectedLen != len(n.Elts) {
c.prog.Err = fmt.Errorf("%s type must have size %d", tn.Obj().Name(), expectedLen)
return nil
}
}
ln := len(n.Elts)
// ByteArrays needs a different approach than normal arrays.
if isByteSlice(typ) {

View file

@ -228,6 +228,29 @@ func (c *codegen) scAndVMReturnTypeFromScope(scope *funcScope) (smartcontract.Pa
}
}
func scAndVMInteropTypeFromExpr(named *types.Named) (smartcontract.ParamType, stackitem.Type) {
name := named.Obj().Name()
pkg := named.Obj().Pkg().Name()
switch pkg {
case "blockchain", "contract":
return smartcontract.ArrayType, stackitem.ArrayT // Block, Transaction, Contract
case "interop":
if name != "Interface" {
switch name {
case "Hash160":
return smartcontract.Hash160Type, stackitem.ByteArrayT
case "Hash256":
return smartcontract.Hash256Type, stackitem.ByteArrayT
case "PublicKey":
return smartcontract.PublicKeyType, stackitem.ByteArrayT
case "Signature":
return smartcontract.SignatureType, stackitem.ByteArrayT
}
}
}
return smartcontract.InteropInterfaceType, stackitem.InteropT
}
func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr) (smartcontract.ParamType, stackitem.Type) {
t := c.typeOf(typ)
if c.typeOf(typ) == nil {
@ -235,26 +258,7 @@ func (c *codegen) scAndVMTypeFromExpr(typ ast.Expr) (smartcontract.ParamType, st
}
if named, ok := t.(*types.Named); ok {
if isInteropPath(named.String()) {
name := named.Obj().Name()
pkg := named.Obj().Pkg().Name()
switch pkg {
case "blockchain", "contract":
return smartcontract.ArrayType, stackitem.ArrayT // Block, Transaction, Contract
case "interop":
if name != "Interface" {
switch name {
case "Hash160":
return smartcontract.Hash160Type, stackitem.ByteArrayT
case "Hash256":
return smartcontract.Hash256Type, stackitem.ByteArrayT
case "PublicKey":
return smartcontract.PublicKeyType, stackitem.ByteArrayT
case "Signature":
return smartcontract.SignatureType, stackitem.ByteArrayT
}
}
}
return smartcontract.InteropInterfaceType, stackitem.InteropT
return scAndVMInteropTypeFromExpr(named)
}
}
switch t := t.Underlying().(type) {

View file

@ -15,6 +15,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
cinterop "github.com/nspcc-dev/neo-go/pkg/interop"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/vm"
@ -23,6 +24,42 @@ import (
"go.uber.org/zap/zaptest"
)
func TestTypeConstantSize(t *testing.T) {
src := `package foo
import "github.com/nspcc-dev/neo-go/pkg/interop"
var a %T // type declaration is always ok
func Main() interface{} {
return %#v
}`
t.Run("Hash160", func(t *testing.T) {
t.Run("good", func(t *testing.T) {
a := make(cinterop.Hash160, 20)
src := fmt.Sprintf(src, a, a)
eval(t, src, []byte(a))
})
t.Run("bad", func(t *testing.T) {
a := make(cinterop.Hash160, 19)
src := fmt.Sprintf(src, a, a)
_, err := compiler.Compile("foo.go", strings.NewReader(src))
require.Error(t, err)
})
})
t.Run("Hash256", func(t *testing.T) {
t.Run("good", func(t *testing.T) {
a := make(cinterop.Hash256, 32)
src := fmt.Sprintf(src, a, a)
eval(t, src, []byte(a))
})
t.Run("bad", func(t *testing.T) {
a := make(cinterop.Hash256, 31)
src := fmt.Sprintf(src, a, a)
_, err := compiler.Compile("foo.go", strings.NewReader(src))
require.Error(t, err)
})
})
}
func TestFromAddress(t *testing.T) {
as1 := "NQRLhCpAru9BjGsMwk67vdMwmzKMRgsnnN"
addr1, err := address.StringToUint160(as1)