*: support _initialize method in contracts

Invoke `_initialize` method on every call if present.
In NEO3 there is no entrypoint and methods are invoked by offset,
thus `Main` function is no longer required.
We still have special `Main` method in tests to simplify them.
This commit is contained in:
Evgenii Stratonikov 2020-07-24 13:40:54 +03:00
parent 466af55dea
commit 685d44dbc1
9 changed files with 156 additions and 40 deletions

View file

@ -27,17 +27,24 @@ func (c *codegen) newGlobal(name string) {
} }
// traverseGlobals visits and initializes global variables. // traverseGlobals visits and initializes global variables.
func (c *codegen) traverseGlobals(f ast.Node) { // and returns number of variables initialized.
n := countGlobals(f) func (c *codegen) traverseGlobals(fs ...*ast.File) int {
var n int
for _, f := range fs {
n += countGlobals(f)
}
if n != 0 { if n != 0 {
if n > 255 { if n > 255 {
c.prog.BinWriter.Err = errors.New("too many global variables") c.prog.BinWriter.Err = errors.New("too many global variables")
return return 0
} }
emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)}) emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)})
} for _, f := range fs {
c.convertGlobals(f) c.convertGlobals(f)
} }
}
return n
}
// countGlobals counts the global variables in the program to add // countGlobals counts the global variables in the program to add
// them with the stack size of the function. // them with the stack size of the function.

View file

@ -21,9 +21,6 @@ import (
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
) )
// The identifier of the entry function. Default set to Main.
const mainIdent = "Main"
type codegen struct { type codegen struct {
// Information about the program with all its dependencies. // Information about the program with all its dependencies.
buildInfo *buildInfo buildInfo *buildInfo
@ -62,6 +59,12 @@ type codegen struct {
// to a text span in the source file. // to a text span in the source file.
sequencePoints map[string][]DebugSeqPoint sequencePoints map[string][]DebugSeqPoint
// initEndOffset specifies the end of the initialization method.
initEndOffset int
// mainPkg is a main package metadata.
mainPkg *loader.PackageInfo
// Label table for recording jump destinations. // Label table for recording jump destinations.
l []int l []int
} }
@ -1412,13 +1415,6 @@ func (c *codegen) newLambda(u uint16, lit *ast.FuncLit) {
} }
func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
// Resolve the entrypoint of the program.
main, mainFile := resolveEntryPoint(mainIdent, pkg)
if main == nil {
c.prog.Err = fmt.Errorf("could not find func main. Did you forget to declare it? ")
return c.prog.Err
}
funUsage := analyzeFuncUsage(pkg, info.program.AllPackages) funUsage := analyzeFuncUsage(pkg, info.program.AllPackages)
// Bring all imported functions into scope. // Bring all imported functions into scope.
@ -1428,10 +1424,12 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
} }
} }
c.traverseGlobals(mainFile) c.mainPkg = pkg
n := c.traverseGlobals(pkg.Files...)
// convert the entry point first. if n > 0 {
c.convertFuncDecl(mainFile, main, pkg.Pkg) emit.Opcode(c.prog.BinWriter, opcode.RET)
c.initEndOffset = c.prog.Len()
}
// sort map keys to generate code deterministically. // sort map keys to generate code deterministically.
keys := make([]*types.Package, 0, len(info.program.AllPackages)) keys := make([]*types.Package, 0, len(info.program.AllPackages))
@ -1451,7 +1449,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
case *ast.FuncDecl: case *ast.FuncDecl:
// Don't convert the function if it's not used. This will save a lot // Don't convert the function if it's not used. This will save a lot
// of bytecode space. // of bytecode space.
if n.Name.Name != mainIdent && funUsage.funcUsed(n.Name.Name) { if funUsage.funcUsed(n.Name.Name) {
c.convertFuncDecl(f, n, k) c.convertFuncDecl(f, n, k)
} }
} }
@ -1497,13 +1495,11 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
for _, decl := range f.Decls { for _, decl := range f.Decls {
switch n := decl.(type) { switch n := decl.(type) {
case *ast.FuncDecl: case *ast.FuncDecl:
if n.Name.Name != mainIdent {
c.newFunc(n) c.newFunc(n)
c.funcs[n.Name.Name].pkg = pkg c.funcs[n.Name.Name].pkg = pkg
} }
} }
} }
}
func (c *codegen) writeJumps(b []byte) error { func (c *codegen) writeJumps(b []byte) error {
ctx := vm.NewContext(b) ctx := vm.NewContext(b)

View file

@ -17,6 +17,7 @@ import (
// DebugInfo represents smart-contract debug information. // DebugInfo represents smart-contract debug information.
type DebugInfo struct { type DebugInfo struct {
MainPkg string `json:"-"`
Hash util.Uint160 `json:"hash"` Hash util.Uint160 `json:"hash"`
Documents []string `json:"documents"` Documents []string `json:"documents"`
Methods []MethodDebugInfo `json:"methods"` Methods []MethodDebugInfo `json:"methods"`
@ -102,9 +103,25 @@ func (c *codegen) saveSequencePoint(n ast.Node) {
func (c *codegen) emitDebugInfo(contract []byte) *DebugInfo { func (c *codegen) emitDebugInfo(contract []byte) *DebugInfo {
d := &DebugInfo{ d := &DebugInfo{
MainPkg: c.mainPkg.Pkg.Name(),
Hash: hash.Hash160(contract), Hash: hash.Hash160(contract),
Events: []EventDebugInfo{}, Events: []EventDebugInfo{},
} }
if c.initEndOffset > 0 {
d.Methods = append(d.Methods, MethodDebugInfo{
ID: manifest.MethodInit,
Name: DebugMethodName{
Name: manifest.MethodInit,
Namespace: c.mainPkg.Pkg.Name(),
},
IsExported: true,
Range: DebugRange{
Start: 0,
End: uint16(c.initEndOffset),
},
ReturnType: "Void",
})
}
for name, scope := range c.funcs { for name, scope := range c.funcs {
m := c.methodInfoFromScope(name, scope) m := c.methodInfoFromScope(name, scope)
if m.Range.Start == m.Range.End { if m.Range.Start == m.Range.End {
@ -341,22 +358,13 @@ func parsePairJSON(data []byte, sep string) (string, string, error) {
// ConvertToManifest converts contract to the manifest.Manifest struct for debugger. // ConvertToManifest converts contract to the manifest.Manifest struct for debugger.
// Note: manifest is taken from the external source, however it can be generated ad-hoc. See #1038. // Note: manifest is taken from the external source, however it can be generated ad-hoc. See #1038.
func (di *DebugInfo) ConvertToManifest(fs smartcontract.PropertyState) (*manifest.Manifest, error) { func (di *DebugInfo) ConvertToManifest(fs smartcontract.PropertyState) (*manifest.Manifest, error) {
var ( var err error
mainNamespace string if di.MainPkg == "" {
err error
)
for _, method := range di.Methods {
if method.Name.Name == mainIdent {
mainNamespace = method.Name.Namespace
break
}
}
if mainNamespace == "" {
return nil, errors.New("no Main method was found") return nil, errors.New("no Main method was found")
} }
methods := make([]manifest.Method, 0) methods := make([]manifest.Method, 0)
for _, method := range di.Methods { for _, method := range di.Methods {
if method.IsExported && method.Name.Namespace == mainNamespace { if method.IsExported && method.Name.Namespace == di.MainPkg {
mMethod, err := method.ToManifestMethod() mMethod, err := method.ToManifestMethod()
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -3,7 +3,12 @@ package compiler_test
import ( import (
"fmt" "fmt"
"math/big" "math/big"
"strings"
"testing" "testing"
"github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/stretchr/testify/require"
) )
func TestChangeGlobal(t *testing.T) { func TestChangeGlobal(t *testing.T) {
@ -105,3 +110,20 @@ func TestArgumentLocal(t *testing.T) {
eval(t, src, big.NewInt(40)) eval(t, src, big.NewInt(40))
}) })
} }
func TestContractWithNoMain(t *testing.T) {
src := `package foo
var someGlobal int = 1
func Add3(a int) int {
someLocal := 2
return someGlobal + someLocal + a
}`
b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src))
require.NoError(t, err)
v := vm.New()
invokeMethod(t, "Add3", b, v, di)
v.Estack().PushVal(39)
require.NoError(t, v.Run())
require.Equal(t, 1, v.Estack().Len())
require.Equal(t, big.NewInt(42), v.PopResult())
}

View file

@ -63,9 +63,10 @@ func TestFromAddress(t *testing.T) {
} }
func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM { func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM {
b, err := compiler.Compile(strings.NewReader(src)) b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src))
require.NoError(t, err) require.NoError(t, err)
v := core.SpawnVM(ic) v := core.SpawnVM(ic)
invokeMethod(t, testMainIdent, b, v, di)
v.LoadScriptWithFlags(b, smartcontract.All) v.LoadScriptWithFlags(b, smartcontract.All)
return v return v
} }
@ -73,12 +74,16 @@ func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM {
func TestAppCall(t *testing.T) { func TestAppCall(t *testing.T) {
srcInner := ` srcInner := `
package foo package foo
var a int = 3
func Main(a []byte, b []byte) []byte { func Main(a []byte, b []byte) []byte {
panic("Main was called") panic("Main was called")
} }
func Append(a []byte, b []byte) []byte { func Append(a []byte, b []byte) []byte {
return append(a, b...) return append(a, b...)
} }
func Add3(n int) int {
return a + n
}
` `
inner, di, err := compiler.CompileWithDebugInfo(strings.NewReader(srcInner)) inner, di, err := compiler.CompileWithDebugInfo(strings.NewReader(srcInner))
@ -147,6 +152,21 @@ func TestAppCall(t *testing.T) {
assertResult(t, v, []byte{1, 2, 3, 4}) assertResult(t, v, []byte{1, 2, 3, 4})
}) })
t.Run("InitializedGlobals", func(t *testing.T) {
src := `package foo
import "github.com/nspcc-dev/neo-go/pkg/interop/engine"
func Main() int {
var addr = []byte(` + fmt.Sprintf("%#v", string(ih.BytesBE())) + `)
result := engine.AppCall(addr, "add3", 39)
return result.(int)
}`
v := spawnVM(t, ic, src)
require.NoError(t, v.Run())
assertResult(t, v, big.NewInt(42))
})
} }
func getAppCallScript(h string) string { func getAppCallScript(h string) string {

View file

@ -7,6 +7,8 @@ import (
"github.com/nspcc-dev/neo-go/pkg/compiler" "github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -20,6 +22,9 @@ type testCase struct {
result interface{} result interface{}
} }
// testMainIdent is a method invoked in tests by default.
const testMainIdent = "Main"
func runTestCases(t *testing.T, tcases []testCase) { func runTestCases(t *testing.T, tcases []testCase) {
for _, tcase := range tcases { for _, tcase := range tcases {
t.Run(tcase.name, func(t *testing.T) { eval(t, tcase.src, tcase.result) }) t.Run(tcase.name, func(t *testing.T) { eval(t, tcase.src, tcase.result) })
@ -65,12 +70,31 @@ func vmAndCompileInterop(t *testing.T, src string) (*vm.VM, *storagePlugin) {
storePlugin := newStoragePlugin() storePlugin := newStoragePlugin()
vm.RegisterInteropGetter(storePlugin.getInterop) vm.RegisterInteropGetter(storePlugin.getInterop)
b, err := compiler.Compile(strings.NewReader(src)) b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src))
require.NoError(t, err) require.NoError(t, err)
vm.Load(b) invokeMethod(t, testMainIdent, b, vm, di)
return vm, storePlugin return vm, storePlugin
} }
func invokeMethod(t *testing.T, method string, script []byte, v *vm.VM, di *compiler.DebugInfo) {
mainOffset := -1
initOffset := -1
for i := range di.Methods {
switch di.Methods[i].ID {
case method:
mainOffset = int(di.Methods[i].Range.Start)
case manifest.MethodInit:
initOffset = int(di.Methods[i].Range.Start)
}
}
require.True(t, mainOffset >= 0)
v.LoadScriptWithFlags(script, smartcontract.All)
v.Jump(v.Context(), mainOffset)
if initOffset >= 0 {
v.Call(v.Context(), initOffset)
}
}
type storagePlugin struct { type storagePlugin struct {
mem map[string][]byte mem map[string][]byte
interops map[uint32]vm.InteropFunc interops map[uint32]vm.InteropFunc

View file

@ -17,6 +17,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -536,6 +537,11 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
v.Jump(v.Context(), md.Offset) v.Jump(v.Context(), md.Offset)
} }
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
if md != nil {
v.Call(v.Context(), md.Offset)
}
return nil return nil
} }

View file

@ -330,6 +330,8 @@ func getTestContractState() *state.Contract {
byte(opcode.ADD), byte(opcode.RET), byte(opcode.ADD), byte(opcode.RET),
byte(opcode.PUSH7), byte(opcode.RET), byte(opcode.PUSH7), byte(opcode.RET),
byte(opcode.DROP), byte(opcode.RET), byte(opcode.DROP), byte(opcode.RET),
byte(opcode.INITSSLOT), 1, byte(opcode.PUSH3), byte(opcode.STSFLD0), byte(opcode.RET),
byte(opcode.LDSFLD0), byte(opcode.ADD), byte(opcode.RET),
} }
h := hash.Hash160(script) h := hash.Hash160(script)
m := manifest.NewManifest(h) m := manifest.NewManifest(h)
@ -354,6 +356,19 @@ func getTestContractState() *state.Contract {
Offset: 5, Offset: 5,
ReturnType: smartcontract.VoidType, ReturnType: smartcontract.VoidType,
}, },
{
Name: manifest.MethodInit,
Offset: 7,
ReturnType: smartcontract.VoidType,
},
{
Name: "add3",
Offset: 12,
Parameters: []manifest.Parameter{
manifest.NewParameter("addend", smartcontract.IntegerType),
},
ReturnType: smartcontract.IntegerType,
},
} }
return &state.Contract{ return &state.Contract{
Script: script, Script: script,
@ -382,6 +397,7 @@ func TestContractCall(t *testing.T) {
perm := manifest.NewPermission(manifest.PermissionHash, h) perm := manifest.NewPermission(manifest.PermissionHash, h)
perm.Methods.Add("add") perm.Methods.Add("add")
perm.Methods.Add("drop") perm.Methods.Add("drop")
perm.Methods.Add("add3")
m.Permissions = append(m.Permissions, *perm) m.Permissions = append(m.Permissions, *perm)
require.NoError(t, ic.DAO.PutContractState(&state.Contract{ require.NoError(t, ic.DAO.PutContractState(&state.Contract{
@ -441,6 +457,20 @@ func TestContractCall(t *testing.T) {
require.NoError(t, contractCall(ic, v)) require.NoError(t, contractCall(ic, v))
require.Error(t, v.Run()) require.Error(t, v.Run())
}) })
t.Run("CallInitialize", func(t *testing.T) {
t.Run("Directly", runInvalid(stackitem.NewArray([]stackitem.Item{}), "_initialize", h.BytesBE()))
initVM(v)
v.Estack().PushVal(stackitem.NewArray([]stackitem.Item{stackitem.Make(5)}))
v.Estack().PushVal("add3")
v.Estack().PushVal(h.BytesBE())
require.NoError(t, contractCall(ic, v))
require.NoError(t, v.Run())
require.Equal(t, 2, v.Estack().Len())
require.Equal(t, big.NewInt(8), v.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), v.Estack().Pop().Value())
})
} }
func TestContractCreate(t *testing.T) { func TestContractCreate(t *testing.T) {

View file

@ -11,6 +11,9 @@ import (
// MaxManifestSize is a max length for a valid contract manifest. // MaxManifestSize is a max length for a valid contract manifest.
const MaxManifestSize = 2048 const MaxManifestSize = 2048
// MethodInit is a name for default initialization method.
const MethodInit = "_initialize"
// ABI represents a contract application binary interface. // ABI represents a contract application binary interface.
type ABI struct { type ABI struct {
Hash util.Uint160 `json:"hash"` Hash util.Uint160 `json:"hash"`