From 685d44dbc15938b165e074c4c63927c28b45e3e1 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 24 Jul 2020 13:40:54 +0300 Subject: [PATCH] *: support `_initialize` method in contracts Invoke `_initialize` method on every call if present. In NEO3 there is no entrypoint and methods are invoked by offset, thus `Main` function is no longer required. We still have special `Main` method in tests to simplify them. --- pkg/compiler/analysis.go | 15 ++++++++--- pkg/compiler/codegen.go | 34 +++++++++++------------- pkg/compiler/debug.go | 36 ++++++++++++++++---------- pkg/compiler/global_test.go | 22 ++++++++++++++++ pkg/compiler/interop_test.go | 22 +++++++++++++++- pkg/compiler/vm_test.go | 28 ++++++++++++++++++-- pkg/core/interop_system.go | 6 +++++ pkg/core/interop_system_test.go | 30 +++++++++++++++++++++ pkg/smartcontract/manifest/manifest.go | 3 +++ 9 files changed, 156 insertions(+), 40 deletions(-) diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 621825a61..0033e427b 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -27,16 +27,23 @@ func (c *codegen) newGlobal(name string) { } // traverseGlobals visits and initializes global variables. -func (c *codegen) traverseGlobals(f ast.Node) { - n := countGlobals(f) +// and returns number of variables initialized. +func (c *codegen) traverseGlobals(fs ...*ast.File) int { + var n int + for _, f := range fs { + n += countGlobals(f) + } if n != 0 { if n > 255 { c.prog.BinWriter.Err = errors.New("too many global variables") - return + return 0 } emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)}) + for _, f := range fs { + c.convertGlobals(f) + } } - c.convertGlobals(f) + return n } // countGlobals counts the global variables in the program to add diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 2dbcbc900..5b96f5f94 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -21,9 +21,6 @@ import ( "golang.org/x/tools/go/loader" ) -// The identifier of the entry function. Default set to Main. -const mainIdent = "Main" - type codegen struct { // Information about the program with all its dependencies. buildInfo *buildInfo @@ -62,6 +59,12 @@ type codegen struct { // to a text span in the source file. sequencePoints map[string][]DebugSeqPoint + // initEndOffset specifies the end of the initialization method. + initEndOffset int + + // mainPkg is a main package metadata. + mainPkg *loader.PackageInfo + // Label table for recording jump destinations. l []int } @@ -1412,13 +1415,6 @@ func (c *codegen) newLambda(u uint16, lit *ast.FuncLit) { } func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { - // Resolve the entrypoint of the program. - main, mainFile := resolveEntryPoint(mainIdent, pkg) - if main == nil { - c.prog.Err = fmt.Errorf("could not find func main. Did you forget to declare it? ") - return c.prog.Err - } - funUsage := analyzeFuncUsage(pkg, info.program.AllPackages) // Bring all imported functions into scope. @@ -1428,10 +1424,12 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { } } - c.traverseGlobals(mainFile) - - // convert the entry point first. - c.convertFuncDecl(mainFile, main, pkg.Pkg) + c.mainPkg = pkg + n := c.traverseGlobals(pkg.Files...) + if n > 0 { + emit.Opcode(c.prog.BinWriter, opcode.RET) + c.initEndOffset = c.prog.Len() + } // sort map keys to generate code deterministically. keys := make([]*types.Package, 0, len(info.program.AllPackages)) @@ -1451,7 +1449,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { case *ast.FuncDecl: // Don't convert the function if it's not used. This will save a lot // of bytecode space. - if n.Name.Name != mainIdent && funUsage.funcUsed(n.Name.Name) { + if funUsage.funcUsed(n.Name.Name) { c.convertFuncDecl(f, n, k) } } @@ -1497,10 +1495,8 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) { for _, decl := range f.Decls { switch n := decl.(type) { case *ast.FuncDecl: - if n.Name.Name != mainIdent { - c.newFunc(n) - c.funcs[n.Name.Name].pkg = pkg - } + c.newFunc(n) + c.funcs[n.Name.Name].pkg = pkg } } } diff --git a/pkg/compiler/debug.go b/pkg/compiler/debug.go index ebf5a3e61..01b1aed64 100644 --- a/pkg/compiler/debug.go +++ b/pkg/compiler/debug.go @@ -17,6 +17,7 @@ import ( // DebugInfo represents smart-contract debug information. type DebugInfo struct { + MainPkg string `json:"-"` Hash util.Uint160 `json:"hash"` Documents []string `json:"documents"` Methods []MethodDebugInfo `json:"methods"` @@ -102,8 +103,24 @@ func (c *codegen) saveSequencePoint(n ast.Node) { func (c *codegen) emitDebugInfo(contract []byte) *DebugInfo { d := &DebugInfo{ - Hash: hash.Hash160(contract), - Events: []EventDebugInfo{}, + MainPkg: c.mainPkg.Pkg.Name(), + Hash: hash.Hash160(contract), + Events: []EventDebugInfo{}, + } + if c.initEndOffset > 0 { + d.Methods = append(d.Methods, MethodDebugInfo{ + ID: manifest.MethodInit, + Name: DebugMethodName{ + Name: manifest.MethodInit, + Namespace: c.mainPkg.Pkg.Name(), + }, + IsExported: true, + Range: DebugRange{ + Start: 0, + End: uint16(c.initEndOffset), + }, + ReturnType: "Void", + }) } for name, scope := range c.funcs { m := c.methodInfoFromScope(name, scope) @@ -341,22 +358,13 @@ func parsePairJSON(data []byte, sep string) (string, string, error) { // ConvertToManifest converts contract to the manifest.Manifest struct for debugger. // Note: manifest is taken from the external source, however it can be generated ad-hoc. See #1038. func (di *DebugInfo) ConvertToManifest(fs smartcontract.PropertyState) (*manifest.Manifest, error) { - var ( - mainNamespace string - err error - ) - for _, method := range di.Methods { - if method.Name.Name == mainIdent { - mainNamespace = method.Name.Namespace - break - } - } - if mainNamespace == "" { + var err error + if di.MainPkg == "" { return nil, errors.New("no Main method was found") } methods := make([]manifest.Method, 0) for _, method := range di.Methods { - if method.IsExported && method.Name.Namespace == mainNamespace { + if method.IsExported && method.Name.Namespace == di.MainPkg { mMethod, err := method.ToManifestMethod() if err != nil { return nil, err diff --git a/pkg/compiler/global_test.go b/pkg/compiler/global_test.go index 8b98761d3..359e8ef69 100644 --- a/pkg/compiler/global_test.go +++ b/pkg/compiler/global_test.go @@ -3,7 +3,12 @@ package compiler_test import ( "fmt" "math/big" + "strings" "testing" + + "github.com/nspcc-dev/neo-go/pkg/compiler" + "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/stretchr/testify/require" ) func TestChangeGlobal(t *testing.T) { @@ -105,3 +110,20 @@ func TestArgumentLocal(t *testing.T) { eval(t, src, big.NewInt(40)) }) } + +func TestContractWithNoMain(t *testing.T) { + src := `package foo + var someGlobal int = 1 + func Add3(a int) int { + someLocal := 2 + return someGlobal + someLocal + a + }` + b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src)) + require.NoError(t, err) + v := vm.New() + invokeMethod(t, "Add3", b, v, di) + v.Estack().PushVal(39) + require.NoError(t, v.Run()) + require.Equal(t, 1, v.Estack().Len()) + require.Equal(t, big.NewInt(42), v.PopResult()) +} diff --git a/pkg/compiler/interop_test.go b/pkg/compiler/interop_test.go index 8f7326637..e0cceeb59 100644 --- a/pkg/compiler/interop_test.go +++ b/pkg/compiler/interop_test.go @@ -63,9 +63,10 @@ func TestFromAddress(t *testing.T) { } func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM { - b, err := compiler.Compile(strings.NewReader(src)) + b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src)) require.NoError(t, err) v := core.SpawnVM(ic) + invokeMethod(t, testMainIdent, b, v, di) v.LoadScriptWithFlags(b, smartcontract.All) return v } @@ -73,12 +74,16 @@ func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM { func TestAppCall(t *testing.T) { srcInner := ` package foo + var a int = 3 func Main(a []byte, b []byte) []byte { panic("Main was called") } func Append(a []byte, b []byte) []byte { return append(a, b...) } + func Add3(n int) int { + return a + n + } ` inner, di, err := compiler.CompileWithDebugInfo(strings.NewReader(srcInner)) @@ -147,6 +152,21 @@ func TestAppCall(t *testing.T) { assertResult(t, v, []byte{1, 2, 3, 4}) }) + + t.Run("InitializedGlobals", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop/engine" + func Main() int { + var addr = []byte(` + fmt.Sprintf("%#v", string(ih.BytesBE())) + `) + result := engine.AppCall(addr, "add3", 39) + return result.(int) + }` + + v := spawnVM(t, ic, src) + require.NoError(t, v.Run()) + + assertResult(t, v, big.NewInt(42)) + }) } func getAppCallScript(h string) string { diff --git a/pkg/compiler/vm_test.go b/pkg/compiler/vm_test.go index 89d0d40c0..f2af0d5b8 100644 --- a/pkg/compiler/vm_test.go +++ b/pkg/compiler/vm_test.go @@ -7,6 +7,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/compiler" "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -20,6 +22,9 @@ type testCase struct { result interface{} } +// testMainIdent is a method invoked in tests by default. +const testMainIdent = "Main" + func runTestCases(t *testing.T, tcases []testCase) { for _, tcase := range tcases { t.Run(tcase.name, func(t *testing.T) { eval(t, tcase.src, tcase.result) }) @@ -65,12 +70,31 @@ func vmAndCompileInterop(t *testing.T, src string) (*vm.VM, *storagePlugin) { storePlugin := newStoragePlugin() vm.RegisterInteropGetter(storePlugin.getInterop) - b, err := compiler.Compile(strings.NewReader(src)) + b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src)) require.NoError(t, err) - vm.Load(b) + invokeMethod(t, testMainIdent, b, vm, di) return vm, storePlugin } +func invokeMethod(t *testing.T, method string, script []byte, v *vm.VM, di *compiler.DebugInfo) { + mainOffset := -1 + initOffset := -1 + for i := range di.Methods { + switch di.Methods[i].ID { + case method: + mainOffset = int(di.Methods[i].Range.Start) + case manifest.MethodInit: + initOffset = int(di.Methods[i].Range.Start) + } + } + require.True(t, mainOffset >= 0) + v.LoadScriptWithFlags(script, smartcontract.All) + v.Jump(v.Context(), mainOffset) + if initOffset >= 0 { + v.Call(v.Context(), initOffset) + } +} + type storagePlugin struct { mem map[string][]byte interops map[uint32]vm.InteropFunc diff --git a/pkg/core/interop_system.go b/pkg/core/interop_system.go index 698d391a0..b7aaa116a 100644 --- a/pkg/core/interop_system.go +++ b/pkg/core/interop_system.go @@ -17,6 +17,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -536,6 +537,11 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac v.Jump(v.Context(), md.Offset) } + md = cs.Manifest.ABI.GetMethod(manifest.MethodInit) + if md != nil { + v.Call(v.Context(), md.Offset) + } + return nil } diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index d3b33add2..33f9414b1 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -330,6 +330,8 @@ func getTestContractState() *state.Contract { byte(opcode.ADD), byte(opcode.RET), byte(opcode.PUSH7), byte(opcode.RET), byte(opcode.DROP), byte(opcode.RET), + byte(opcode.INITSSLOT), 1, byte(opcode.PUSH3), byte(opcode.STSFLD0), byte(opcode.RET), + byte(opcode.LDSFLD0), byte(opcode.ADD), byte(opcode.RET), } h := hash.Hash160(script) m := manifest.NewManifest(h) @@ -354,6 +356,19 @@ func getTestContractState() *state.Contract { Offset: 5, ReturnType: smartcontract.VoidType, }, + { + Name: manifest.MethodInit, + Offset: 7, + ReturnType: smartcontract.VoidType, + }, + { + Name: "add3", + Offset: 12, + Parameters: []manifest.Parameter{ + manifest.NewParameter("addend", smartcontract.IntegerType), + }, + ReturnType: smartcontract.IntegerType, + }, } return &state.Contract{ Script: script, @@ -382,6 +397,7 @@ func TestContractCall(t *testing.T) { perm := manifest.NewPermission(manifest.PermissionHash, h) perm.Methods.Add("add") perm.Methods.Add("drop") + perm.Methods.Add("add3") m.Permissions = append(m.Permissions, *perm) require.NoError(t, ic.DAO.PutContractState(&state.Contract{ @@ -441,6 +457,20 @@ func TestContractCall(t *testing.T) { require.NoError(t, contractCall(ic, v)) require.Error(t, v.Run()) }) + + t.Run("CallInitialize", func(t *testing.T) { + t.Run("Directly", runInvalid(stackitem.NewArray([]stackitem.Item{}), "_initialize", h.BytesBE())) + + initVM(v) + v.Estack().PushVal(stackitem.NewArray([]stackitem.Item{stackitem.Make(5)})) + v.Estack().PushVal("add3") + v.Estack().PushVal(h.BytesBE()) + require.NoError(t, contractCall(ic, v)) + require.NoError(t, v.Run()) + require.Equal(t, 2, v.Estack().Len()) + require.Equal(t, big.NewInt(8), v.Estack().Pop().Value()) + require.Equal(t, big.NewInt(42), v.Estack().Pop().Value()) + }) } func TestContractCreate(t *testing.T) { diff --git a/pkg/smartcontract/manifest/manifest.go b/pkg/smartcontract/manifest/manifest.go index d733c90c9..d3db02bba 100644 --- a/pkg/smartcontract/manifest/manifest.go +++ b/pkg/smartcontract/manifest/manifest.go @@ -11,6 +11,9 @@ import ( // MaxManifestSize is a max length for a valid contract manifest. const MaxManifestSize = 2048 +// MethodInit is a name for default initialization method. +const MethodInit = "_initialize" + // ABI represents a contract application binary interface. type ABI struct { Hash util.Uint160 `json:"hash"`