diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 37ff255e8..955ded585 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -4,6 +4,7 @@ import ( "errors" "go/ast" "go/types" + "strings" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" @@ -11,12 +12,12 @@ import ( ) var ( - // Go language builtin functions and custom builtin utility functions. - builtinFuncs = []string{ - "len", "append", "SHA256", - "AppCall", + // Go language builtin functions. + goBuiltins = []string{"len", "append", "panic"} + // Custom builtin utility functions. + customBuiltins = []string{ + "SHA256", "AppCall", "FromAddress", "Equals", - "panic", "ToBool", "ToByteArray", "ToInteger", } ) @@ -134,20 +135,21 @@ func analyzeFuncUsage(pkgs map[*types.Package]*loader.PackageInfo) funcUsage { return usage } -func isBuiltin(expr ast.Expr) bool { - var name string +func isGoBuiltin(name string) bool { + for i := range goBuiltins { + if name == goBuiltins[i] { + return true + } + } + return false +} - switch t := expr.(type) { - case *ast.Ident: - name = t.Name - case *ast.SelectorExpr: - name = t.Sel.Name - default: +func isCustomBuiltin(f *funcScope) bool { + if !isInteropPath(f.pkg.Path()) { return false } - - for _, n := range builtinFuncs { - if name == n { + for _, n := range customBuiltins { + if f.name == n { return true } } @@ -155,9 +157,13 @@ func isBuiltin(expr ast.Expr) bool { } func isSyscall(fun *funcScope) bool { - if fun.selector == nil { + if fun.selector == nil || fun.pkg == nil || !isInteropPath(fun.pkg.Path()) { return false } _, ok := syscalls[fun.selector.Name][fun.name] return ok } + +func isInteropPath(s string) bool { + return strings.HasPrefix(s, "github.com/nspcc-dev/neo-go/pkg/interop") +} diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 29127d02e..492e3cdef 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -281,8 +281,8 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl) { f, ok = c.funcs[decl.Name.Name] if ok { - // If this function is a syscall we will not convert it to bytecode. - if isSyscall(f) { + // If this function is a syscall or builtin we will not convert it to bytecode. + if isSyscall(f) || isCustomBuiltin(f) { return } c.setLabel(f.label) @@ -691,12 +691,13 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { ok bool name string numArgs = len(n.Args) - isBuiltin = isBuiltin(n.Fun) + isBuiltin bool ) switch fun := n.Fun.(type) { case *ast.Ident: f, ok = c.funcs[fun.Name] + isBuiltin = isGoBuiltin(fun.Name) if !ok && !isBuiltin { name = fun.Name } @@ -717,6 +718,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { c.prog.Err = fmt.Errorf("could not resolve function %s", fun.Sel.Name) return nil } + isBuiltin = isCustomBuiltin(f) case *ast.ArrayType: // For now we will assume that there are only byte slice conversions. // E.g. []byte("foobar") or []byte(scriptHash). @@ -1309,7 +1311,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { // Bring all imported functions into scope. for _, pkg := range info.program.AllPackages { for _, f := range pkg.Files { - c.resolveFuncDecls(f) + c.resolveFuncDecls(f, pkg.Pkg) } } @@ -1378,12 +1380,13 @@ func CodeGen(info *buildInfo) ([]byte, *DebugInfo, error) { return buf, c.emitDebugInfo(), nil } -func (c *codegen) resolveFuncDecls(f *ast.File) { +func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) { for _, decl := range f.Decls { switch n := decl.(type) { case *ast.FuncDecl: if n.Name.Name != mainIdent { c.newFunc(n) + c.funcs[n.Name.Name].pkg = pkg } } } diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index fde6ceadf..5c202c8aa 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -3,6 +3,7 @@ package compiler import ( "go/ast" "go/token" + "go/types" ) // A funcScope represents the scope within the function context. @@ -18,6 +19,9 @@ type funcScope struct { // The declaration of the function in the AST. Nil if this scope is not a function. decl *ast.FuncDecl + // Package where the function is defined. + pkg *types.Package + // Program label of the scope label uint16 diff --git a/pkg/compiler/interop_test.go b/pkg/compiler/interop_test.go index efdb19654..a56ea338b 100644 --- a/pkg/compiler/interop_test.go +++ b/pkg/compiler/interop_test.go @@ -2,6 +2,7 @@ package compiler_test import ( "fmt" + "math/big" "strings" "testing" @@ -15,6 +16,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) @@ -137,3 +139,49 @@ func getAppCallScript(h string) string { } ` } + +func TestBuiltinDoesNotCompile(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop/util" + func Main() bool { + a := 1 + b := 2 + return util.Equals(a, b) + }` + + v := vmAndCompile(t, src) + ctx := v.Context() + retCount := 0 + for op, _, err := ctx.Next(); err == nil; op, _, err = ctx.Next() { + if ctx.IP() > len(ctx.Program()) { + break + } + if op == opcode.RET { + retCount++ + } + } + require.Equal(t, 1, retCount) +} + +func TestInteropPackage(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/block" + func Main() int { + b := block.Block{} + a := block.GetTransactionCount(b) + return a + }` + eval(t, src, big.NewInt(42)) +} + +func TestBuiltinPackage(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/util" + func Main() int { + if util.Equals(1, 2) { // always returns true + return 1 + } + return 2 + }` + eval(t, src, big.NewInt(1)) +} diff --git a/pkg/compiler/testdata/block/block.go b/pkg/compiler/testdata/block/block.go new file mode 100644 index 000000000..5e5975fdc --- /dev/null +++ b/pkg/compiler/testdata/block/block.go @@ -0,0 +1,9 @@ +package block + +// Block is opaque type. +type Block struct{} + +// GetTransactionCount is a mirror of `GetTransactionCount` interop. +func GetTransactionCount(b Block) int { + return 42 +} diff --git a/pkg/compiler/testdata/util/equals.go b/pkg/compiler/testdata/util/equals.go new file mode 100644 index 000000000..391fb74db --- /dev/null +++ b/pkg/compiler/testdata/util/equals.go @@ -0,0 +1,6 @@ +package util + +// Equals is a mirror of `Equals` builtin. +func Equals(a, b interface{}) bool { + return true +}