forked from TrueCloudLab/neoneo-go
*: support _initialize
method in contracts
Invoke `_initialize` method on every call if present. In NEO3 there is no entrypoint and methods are invoked by offset, thus `Main` function is no longer required. We still have special `Main` method in tests to simplify them.
This commit is contained in:
parent
466af55dea
commit
685d44dbc1
9 changed files with 156 additions and 40 deletions
|
@ -27,17 +27,24 @@ func (c *codegen) newGlobal(name string) {
|
|||
}
|
||||
|
||||
// traverseGlobals visits and initializes global variables.
|
||||
func (c *codegen) traverseGlobals(f ast.Node) {
|
||||
n := countGlobals(f)
|
||||
// and returns number of variables initialized.
|
||||
func (c *codegen) traverseGlobals(fs ...*ast.File) int {
|
||||
var n int
|
||||
for _, f := range fs {
|
||||
n += countGlobals(f)
|
||||
}
|
||||
if n != 0 {
|
||||
if n > 255 {
|
||||
c.prog.BinWriter.Err = errors.New("too many global variables")
|
||||
return
|
||||
return 0
|
||||
}
|
||||
emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)})
|
||||
}
|
||||
for _, f := range fs {
|
||||
c.convertGlobals(f)
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// countGlobals counts the global variables in the program to add
|
||||
// them with the stack size of the function.
|
||||
|
|
|
@ -21,9 +21,6 @@ import (
|
|||
"golang.org/x/tools/go/loader"
|
||||
)
|
||||
|
||||
// The identifier of the entry function. Default set to Main.
|
||||
const mainIdent = "Main"
|
||||
|
||||
type codegen struct {
|
||||
// Information about the program with all its dependencies.
|
||||
buildInfo *buildInfo
|
||||
|
@ -62,6 +59,12 @@ type codegen struct {
|
|||
// to a text span in the source file.
|
||||
sequencePoints map[string][]DebugSeqPoint
|
||||
|
||||
// initEndOffset specifies the end of the initialization method.
|
||||
initEndOffset int
|
||||
|
||||
// mainPkg is a main package metadata.
|
||||
mainPkg *loader.PackageInfo
|
||||
|
||||
// Label table for recording jump destinations.
|
||||
l []int
|
||||
}
|
||||
|
@ -1412,13 +1415,6 @@ func (c *codegen) newLambda(u uint16, lit *ast.FuncLit) {
|
|||
}
|
||||
|
||||
func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
||||
// Resolve the entrypoint of the program.
|
||||
main, mainFile := resolveEntryPoint(mainIdent, pkg)
|
||||
if main == nil {
|
||||
c.prog.Err = fmt.Errorf("could not find func main. Did you forget to declare it? ")
|
||||
return c.prog.Err
|
||||
}
|
||||
|
||||
funUsage := analyzeFuncUsage(pkg, info.program.AllPackages)
|
||||
|
||||
// Bring all imported functions into scope.
|
||||
|
@ -1428,10 +1424,12 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
|||
}
|
||||
}
|
||||
|
||||
c.traverseGlobals(mainFile)
|
||||
|
||||
// convert the entry point first.
|
||||
c.convertFuncDecl(mainFile, main, pkg.Pkg)
|
||||
c.mainPkg = pkg
|
||||
n := c.traverseGlobals(pkg.Files...)
|
||||
if n > 0 {
|
||||
emit.Opcode(c.prog.BinWriter, opcode.RET)
|
||||
c.initEndOffset = c.prog.Len()
|
||||
}
|
||||
|
||||
// sort map keys to generate code deterministically.
|
||||
keys := make([]*types.Package, 0, len(info.program.AllPackages))
|
||||
|
@ -1451,7 +1449,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
|||
case *ast.FuncDecl:
|
||||
// Don't convert the function if it's not used. This will save a lot
|
||||
// of bytecode space.
|
||||
if n.Name.Name != mainIdent && funUsage.funcUsed(n.Name.Name) {
|
||||
if funUsage.funcUsed(n.Name.Name) {
|
||||
c.convertFuncDecl(f, n, k)
|
||||
}
|
||||
}
|
||||
|
@ -1497,13 +1495,11 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
|
|||
for _, decl := range f.Decls {
|
||||
switch n := decl.(type) {
|
||||
case *ast.FuncDecl:
|
||||
if n.Name.Name != mainIdent {
|
||||
c.newFunc(n)
|
||||
c.funcs[n.Name.Name].pkg = pkg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *codegen) writeJumps(b []byte) error {
|
||||
ctx := vm.NewContext(b)
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
|
||||
// DebugInfo represents smart-contract debug information.
|
||||
type DebugInfo struct {
|
||||
MainPkg string `json:"-"`
|
||||
Hash util.Uint160 `json:"hash"`
|
||||
Documents []string `json:"documents"`
|
||||
Methods []MethodDebugInfo `json:"methods"`
|
||||
|
@ -102,9 +103,25 @@ func (c *codegen) saveSequencePoint(n ast.Node) {
|
|||
|
||||
func (c *codegen) emitDebugInfo(contract []byte) *DebugInfo {
|
||||
d := &DebugInfo{
|
||||
MainPkg: c.mainPkg.Pkg.Name(),
|
||||
Hash: hash.Hash160(contract),
|
||||
Events: []EventDebugInfo{},
|
||||
}
|
||||
if c.initEndOffset > 0 {
|
||||
d.Methods = append(d.Methods, MethodDebugInfo{
|
||||
ID: manifest.MethodInit,
|
||||
Name: DebugMethodName{
|
||||
Name: manifest.MethodInit,
|
||||
Namespace: c.mainPkg.Pkg.Name(),
|
||||
},
|
||||
IsExported: true,
|
||||
Range: DebugRange{
|
||||
Start: 0,
|
||||
End: uint16(c.initEndOffset),
|
||||
},
|
||||
ReturnType: "Void",
|
||||
})
|
||||
}
|
||||
for name, scope := range c.funcs {
|
||||
m := c.methodInfoFromScope(name, scope)
|
||||
if m.Range.Start == m.Range.End {
|
||||
|
@ -341,22 +358,13 @@ func parsePairJSON(data []byte, sep string) (string, string, error) {
|
|||
// ConvertToManifest converts contract to the manifest.Manifest struct for debugger.
|
||||
// Note: manifest is taken from the external source, however it can be generated ad-hoc. See #1038.
|
||||
func (di *DebugInfo) ConvertToManifest(fs smartcontract.PropertyState) (*manifest.Manifest, error) {
|
||||
var (
|
||||
mainNamespace string
|
||||
err error
|
||||
)
|
||||
for _, method := range di.Methods {
|
||||
if method.Name.Name == mainIdent {
|
||||
mainNamespace = method.Name.Namespace
|
||||
break
|
||||
}
|
||||
}
|
||||
if mainNamespace == "" {
|
||||
var err error
|
||||
if di.MainPkg == "" {
|
||||
return nil, errors.New("no Main method was found")
|
||||
}
|
||||
methods := make([]manifest.Method, 0)
|
||||
for _, method := range di.Methods {
|
||||
if method.IsExported && method.Name.Namespace == mainNamespace {
|
||||
if method.IsExported && method.Name.Namespace == di.MainPkg {
|
||||
mMethod, err := method.ToManifestMethod()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -3,7 +3,12 @@ package compiler_test
|
|||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/compiler"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChangeGlobal(t *testing.T) {
|
||||
|
@ -105,3 +110,20 @@ func TestArgumentLocal(t *testing.T) {
|
|||
eval(t, src, big.NewInt(40))
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractWithNoMain(t *testing.T) {
|
||||
src := `package foo
|
||||
var someGlobal int = 1
|
||||
func Add3(a int) int {
|
||||
someLocal := 2
|
||||
return someGlobal + someLocal + a
|
||||
}`
|
||||
b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src))
|
||||
require.NoError(t, err)
|
||||
v := vm.New()
|
||||
invokeMethod(t, "Add3", b, v, di)
|
||||
v.Estack().PushVal(39)
|
||||
require.NoError(t, v.Run())
|
||||
require.Equal(t, 1, v.Estack().Len())
|
||||
require.Equal(t, big.NewInt(42), v.PopResult())
|
||||
}
|
||||
|
|
|
@ -63,9 +63,10 @@ func TestFromAddress(t *testing.T) {
|
|||
}
|
||||
|
||||
func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM {
|
||||
b, err := compiler.Compile(strings.NewReader(src))
|
||||
b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src))
|
||||
require.NoError(t, err)
|
||||
v := core.SpawnVM(ic)
|
||||
invokeMethod(t, testMainIdent, b, v, di)
|
||||
v.LoadScriptWithFlags(b, smartcontract.All)
|
||||
return v
|
||||
}
|
||||
|
@ -73,12 +74,16 @@ func spawnVM(t *testing.T, ic *interop.Context, src string) *vm.VM {
|
|||
func TestAppCall(t *testing.T) {
|
||||
srcInner := `
|
||||
package foo
|
||||
var a int = 3
|
||||
func Main(a []byte, b []byte) []byte {
|
||||
panic("Main was called")
|
||||
}
|
||||
func Append(a []byte, b []byte) []byte {
|
||||
return append(a, b...)
|
||||
}
|
||||
func Add3(n int) int {
|
||||
return a + n
|
||||
}
|
||||
`
|
||||
|
||||
inner, di, err := compiler.CompileWithDebugInfo(strings.NewReader(srcInner))
|
||||
|
@ -147,6 +152,21 @@ func TestAppCall(t *testing.T) {
|
|||
|
||||
assertResult(t, v, []byte{1, 2, 3, 4})
|
||||
})
|
||||
|
||||
t.Run("InitializedGlobals", func(t *testing.T) {
|
||||
src := `package foo
|
||||
import "github.com/nspcc-dev/neo-go/pkg/interop/engine"
|
||||
func Main() int {
|
||||
var addr = []byte(` + fmt.Sprintf("%#v", string(ih.BytesBE())) + `)
|
||||
result := engine.AppCall(addr, "add3", 39)
|
||||
return result.(int)
|
||||
}`
|
||||
|
||||
v := spawnVM(t, ic, src)
|
||||
require.NoError(t, v.Run())
|
||||
|
||||
assertResult(t, v, big.NewInt(42))
|
||||
})
|
||||
}
|
||||
|
||||
func getAppCallScript(h string) string {
|
||||
|
|
|
@ -7,6 +7,8 @@ import (
|
|||
|
||||
"github.com/nspcc-dev/neo-go/pkg/compiler"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
|
@ -20,6 +22,9 @@ type testCase struct {
|
|||
result interface{}
|
||||
}
|
||||
|
||||
// testMainIdent is a method invoked in tests by default.
|
||||
const testMainIdent = "Main"
|
||||
|
||||
func runTestCases(t *testing.T, tcases []testCase) {
|
||||
for _, tcase := range tcases {
|
||||
t.Run(tcase.name, func(t *testing.T) { eval(t, tcase.src, tcase.result) })
|
||||
|
@ -65,12 +70,31 @@ func vmAndCompileInterop(t *testing.T, src string) (*vm.VM, *storagePlugin) {
|
|||
storePlugin := newStoragePlugin()
|
||||
vm.RegisterInteropGetter(storePlugin.getInterop)
|
||||
|
||||
b, err := compiler.Compile(strings.NewReader(src))
|
||||
b, di, err := compiler.CompileWithDebugInfo(strings.NewReader(src))
|
||||
require.NoError(t, err)
|
||||
vm.Load(b)
|
||||
invokeMethod(t, testMainIdent, b, vm, di)
|
||||
return vm, storePlugin
|
||||
}
|
||||
|
||||
func invokeMethod(t *testing.T, method string, script []byte, v *vm.VM, di *compiler.DebugInfo) {
|
||||
mainOffset := -1
|
||||
initOffset := -1
|
||||
for i := range di.Methods {
|
||||
switch di.Methods[i].ID {
|
||||
case method:
|
||||
mainOffset = int(di.Methods[i].Range.Start)
|
||||
case manifest.MethodInit:
|
||||
initOffset = int(di.Methods[i].Range.Start)
|
||||
}
|
||||
}
|
||||
require.True(t, mainOffset >= 0)
|
||||
v.LoadScriptWithFlags(script, smartcontract.All)
|
||||
v.Jump(v.Context(), mainOffset)
|
||||
if initOffset >= 0 {
|
||||
v.Call(v.Context(), initOffset)
|
||||
}
|
||||
}
|
||||
|
||||
type storagePlugin struct {
|
||||
mem map[string][]byte
|
||||
interops map[uint32]vm.InteropFunc
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
|
@ -536,6 +537,11 @@ func contractCallExInternal(ic *interop.Context, v *vm.VM, h []byte, method stac
|
|||
v.Jump(v.Context(), md.Offset)
|
||||
}
|
||||
|
||||
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
|
||||
if md != nil {
|
||||
v.Call(v.Context(), md.Offset)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -330,6 +330,8 @@ func getTestContractState() *state.Contract {
|
|||
byte(opcode.ADD), byte(opcode.RET),
|
||||
byte(opcode.PUSH7), byte(opcode.RET),
|
||||
byte(opcode.DROP), byte(opcode.RET),
|
||||
byte(opcode.INITSSLOT), 1, byte(opcode.PUSH3), byte(opcode.STSFLD0), byte(opcode.RET),
|
||||
byte(opcode.LDSFLD0), byte(opcode.ADD), byte(opcode.RET),
|
||||
}
|
||||
h := hash.Hash160(script)
|
||||
m := manifest.NewManifest(h)
|
||||
|
@ -354,6 +356,19 @@ func getTestContractState() *state.Contract {
|
|||
Offset: 5,
|
||||
ReturnType: smartcontract.VoidType,
|
||||
},
|
||||
{
|
||||
Name: manifest.MethodInit,
|
||||
Offset: 7,
|
||||
ReturnType: smartcontract.VoidType,
|
||||
},
|
||||
{
|
||||
Name: "add3",
|
||||
Offset: 12,
|
||||
Parameters: []manifest.Parameter{
|
||||
manifest.NewParameter("addend", smartcontract.IntegerType),
|
||||
},
|
||||
ReturnType: smartcontract.IntegerType,
|
||||
},
|
||||
}
|
||||
return &state.Contract{
|
||||
Script: script,
|
||||
|
@ -382,6 +397,7 @@ func TestContractCall(t *testing.T) {
|
|||
perm := manifest.NewPermission(manifest.PermissionHash, h)
|
||||
perm.Methods.Add("add")
|
||||
perm.Methods.Add("drop")
|
||||
perm.Methods.Add("add3")
|
||||
m.Permissions = append(m.Permissions, *perm)
|
||||
|
||||
require.NoError(t, ic.DAO.PutContractState(&state.Contract{
|
||||
|
@ -441,6 +457,20 @@ func TestContractCall(t *testing.T) {
|
|||
require.NoError(t, contractCall(ic, v))
|
||||
require.Error(t, v.Run())
|
||||
})
|
||||
|
||||
t.Run("CallInitialize", func(t *testing.T) {
|
||||
t.Run("Directly", runInvalid(stackitem.NewArray([]stackitem.Item{}), "_initialize", h.BytesBE()))
|
||||
|
||||
initVM(v)
|
||||
v.Estack().PushVal(stackitem.NewArray([]stackitem.Item{stackitem.Make(5)}))
|
||||
v.Estack().PushVal("add3")
|
||||
v.Estack().PushVal(h.BytesBE())
|
||||
require.NoError(t, contractCall(ic, v))
|
||||
require.NoError(t, v.Run())
|
||||
require.Equal(t, 2, v.Estack().Len())
|
||||
require.Equal(t, big.NewInt(8), v.Estack().Pop().Value())
|
||||
require.Equal(t, big.NewInt(42), v.Estack().Pop().Value())
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractCreate(t *testing.T) {
|
||||
|
|
|
@ -11,6 +11,9 @@ import (
|
|||
// MaxManifestSize is a max length for a valid contract manifest.
|
||||
const MaxManifestSize = 2048
|
||||
|
||||
// MethodInit is a name for default initialization method.
|
||||
const MethodInit = "_initialize"
|
||||
|
||||
// ABI represents a contract application binary interface.
|
||||
type ABI struct {
|
||||
Hash util.Uint160 `json:"hash"`
|
||||
|
|
Loading…
Reference in a new issue