diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 1c1dab8bd..899d9e35f 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -8,6 +8,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "golang.org/x/tools/go/loader" ) var ( @@ -35,21 +36,54 @@ func (c *codegen) getIdentName(pkg string, name string) string { } // traverseGlobals visits and initializes global variables. -// and returns number of variables initialized. -func (c *codegen) traverseGlobals() int { +// and returns number of variables initialized and +// true if any init functions were encountered. +func (c *codegen) traverseGlobals() (int, bool) { var n int + var hasInit bool c.ForEachFile(func(f *ast.File, _ *types.Package) { n += countGlobals(f) + if !hasInit { + ast.Inspect(f, func(node ast.Node) bool { + n, ok := node.(*ast.FuncDecl) + if ok { + if isInitFunc(n) { + hasInit = true + } + return false + } + return true + }) + } }) - if n != 0 { + if n != 0 || hasInit { if n > 255 { c.prog.BinWriter.Err = errors.New("too many global variables") - return 0 + return 0, hasInit } - emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)}) - c.ForEachFile(c.convertGlobals) + if n != 0 { + emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)}) + } + c.ForEachPackage(func(pkg *loader.PackageInfo) { + if n > 0 { + for _, f := range pkg.Files { + c.fillImportMap(f, pkg.Pkg) + c.convertGlobals(f, pkg.Pkg) + } + } + if hasInit { + for _, f := range pkg.Files { + c.fillImportMap(f, pkg.Pkg) + c.convertInitFuncs(f, pkg.Pkg) + } + } + // because we reuse `convertFuncDecl` for init funcs, + // we need to cleare scope, so that global variables + // encountered after will be recognized as globals. + c.scope = nil + }) } - return n + return n, hasInit } // countGlobals counts the global variables in the program to add @@ -103,6 +137,32 @@ func lastStmtIsReturn(decl *ast.FuncDecl) (b bool) { return false } +// analyzePkgOrder sets the order in which packages should be processed. +// From Go spec: +// A package with no imports is initialized by assigning initial values to all its package-level variables +// followed by calling all init functions in the order they appear in the source, possibly in multiple files, +// as presented to the compiler. If a package has imports, the imported packages are initialized before +// initializing the package itself. If multiple packages import a package, the imported package +// will be initialized only once. The importing of packages, by construction, guarantees +// that there can be no cyclic initialization dependencies. +func (c *codegen) analyzePkgOrder() { + seen := make(map[string]bool) + info := c.buildInfo.program.Package(c.buildInfo.initialPackage) + c.visitPkg(info.Pkg, seen) +} + +func (c *codegen) visitPkg(pkg *types.Package, seen map[string]bool) { + pkgPath := pkg.Path() + if seen[pkgPath] { + return + } + for _, imp := range pkg.Imports() { + c.visitPkg(imp, seen) + } + seen[pkgPath] = true + c.packages = append(c.packages, pkgPath) +} + func (c *codegen) analyzeFuncUsage() funcUsage { usage := funcUsage{} diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index fc34b7b04..384c19197 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -74,6 +74,9 @@ type codegen struct { // mainPkg is a main package metadata. mainPkg *loader.PackageInfo + // packages contains packages in the order they were loaded. + packages []string + // Label table for recording jump destinations. l []int } @@ -275,24 +278,48 @@ func (c *codegen) convertGlobals(f *ast.File, _ *types.Package) { }) } +func isInitFunc(decl *ast.FuncDecl) bool { + return decl.Name.Name == "init" && decl.Recv == nil && + decl.Type.Params.NumFields() == 0 && + decl.Type.Results.NumFields() == 0 +} + +func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) { + ast.Inspect(f, func(node ast.Node) bool { + switch n := node.(type) { + case *ast.FuncDecl: + if isInitFunc(n) { + c.convertFuncDecl(f, n, pkg) + } + case *ast.GenDecl: + return false + } + return true + }) +} + func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) { var ( f *funcScope ok, isLambda bool ) - - f, ok = c.funcs[c.getFuncNameFromDecl("", decl)] - if ok { - // If this function is a syscall or builtin we will not convert it to bytecode. - if isSyscall(f) || isCustomBuiltin(f) { - return - } - c.setLabel(f.label) - } else if f, ok = c.lambda[c.getIdentName("", decl.Name.Name)]; ok { - isLambda = ok - c.setLabel(f.label) + isInit := isInitFunc(decl) + if isInit { + f = c.newFuncScope(decl, c.newLabel()) } else { - f = c.newFunc(decl) + f, ok = c.funcs[c.getFuncNameFromDecl("", decl)] + if ok { + // If this function is a syscall or builtin we will not convert it to bytecode. + if isSyscall(f) || isCustomBuiltin(f) { + return + } + c.setLabel(f.label) + } else if f, ok = c.lambda[c.getIdentName("", decl.Name.Name)]; ok { + isLambda = ok + c.setLabel(f.label) + } else { + f = c.newFunc(decl) + } } f.rng.Start = uint16(c.prog.Len()) @@ -342,7 +369,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types. // If we have reached the end of the function without encountering `return` statement, // we should clean alt.stack manually. // This can be the case with void and named-return functions. - if !lastStmtIsReturn(decl) { + if !isInit && !lastStmtIsReturn(decl) { c.saveSequencePoint(decl.Body) emit.Opcode(c.prog.BinWriter, opcode.RET) } @@ -1462,13 +1489,14 @@ func (c *codegen) newLambda(u uint16, lit *ast.FuncLit) { func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { c.mainPkg = pkg + c.analyzePkgOrder() funUsage := c.analyzeFuncUsage() // Bring all imported functions into scope. c.ForEachFile(c.resolveFuncDecls) - n := c.traverseGlobals() - if n > 0 { + n, hasInit := c.traverseGlobals() + if n > 0 || hasInit { emit.Opcode(c.prog.BinWriter, opcode.RET) c.initEndOffset = c.prog.Len() } @@ -1488,7 +1516,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { // Don't convert the function if it's not used. This will save a lot // of bytecode space. name := c.getFuncNameFromDecl(pkg.Path(), n) - if funUsage.funcUsed(name) && !isInteropPath(pkg.Path()) { + if !isInitFunc(n) && funUsage.funcUsed(name) && !isInteropPath(pkg.Path()) { c.convertFuncDecl(f, n, pkg) } } diff --git a/pkg/compiler/compiler.go b/pkg/compiler/compiler.go index 76d48358f..aaefbc039 100644 --- a/pkg/compiler/compiler.go +++ b/pkg/compiler/compiler.go @@ -47,16 +47,25 @@ type buildInfo struct { program *loader.Program } -// ForEachFile executes fn on each file used in current program. -func (c *codegen) ForEachFile(fn func(*ast.File, *types.Package)) { - for _, pkg := range c.buildInfo.program.AllPackages { +// ForEachPackage executes fn on each package used in the current program +// in the order they should be initialized. +func (c *codegen) ForEachPackage(fn func(*loader.PackageInfo)) { + for i := range c.packages { + pkg := c.buildInfo.program.Package(c.packages[i]) c.typeInfo = &pkg.Info c.currPkg = pkg.Pkg + fn(pkg) + } +} + +// ForEachFile executes fn on each file used in current program. +func (c *codegen) ForEachFile(fn func(*ast.File, *types.Package)) { + c.ForEachPackage(func(pkg *loader.PackageInfo) { for _, f := range pkg.Files { c.fillImportMap(f, pkg.Pkg) fn(f, pkg.Pkg) } - } + }) } // fillImportMap fills import map for f. diff --git a/pkg/compiler/init_test.go b/pkg/compiler/init_test.go new file mode 100644 index 000000000..dd74639ad --- /dev/null +++ b/pkg/compiler/init_test.go @@ -0,0 +1,107 @@ +package compiler_test + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestInit(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + src := `package foo + var a int + func init() { + a = 42 + } + func Main() int { + return a + }` + eval(t, src, big.NewInt(42)) + }) + t.Run("Multi", func(t *testing.T) { + src := `package foo + var m = map[int]int{} + var a = 2 + func init() { + m[1] = 11 + } + func init() { + a = 1 + m[3] = 30 + } + func Main() int { + return m[1] + m[3] + a + }` + eval(t, src, big.NewInt(42)) + }) + t.Run("WithCall", func(t *testing.T) { + src := `package foo + var m = map[int]int{} + func init() { + initMap(m) + } + func initMap(m map[int]int) { + m[11] = 42 + } + func Main() int { + return m[11] + }` + eval(t, src, big.NewInt(42)) + }) + t.Run("InvalidSignature", func(t *testing.T) { + src := `package foo + type Foo int + var a int + func (f Foo) init() { + a = 2 + } + func Main() int { + return a + }` + eval(t, src, big.NewInt(0)) + }) +} + +func TestImportOrder(t *testing.T) { + t.Run("1,2", func(t *testing.T) { + src := `package foo + import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg1" + import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg2" + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3" + func Main() int { return pkg3.A }` + v := vmAndCompile(t, src) + v.PrintOps() + eval(t, src, big.NewInt(2)) + }) + t.Run("2,1", func(t *testing.T) { + src := `package foo + import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg2" + import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg1" + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3" + func Main() int { return pkg3.A }` + eval(t, src, big.NewInt(1)) + }) + t.Run("InitializeOnce", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3" + var A = pkg3.A + func Main() int { return A }` + eval(t, src, big.NewInt(3)) + }) +} + +func TestInitWithNoGlobals(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop/runtime" + func init() { + runtime.Notify("called in '_initialize'") + } + func Main() int { + return 42 + }` + v, s := vmAndCompileInterop(t, src) + require.NoError(t, v.Run()) + assertResult(t, v, big.NewInt(42)) + require.True(t, len(s.events) == 1) +} diff --git a/pkg/compiler/testdata/pkg1/pkg1.go b/pkg/compiler/testdata/pkg1/pkg1.go new file mode 100644 index 000000000..910eb0e45 --- /dev/null +++ b/pkg/compiler/testdata/pkg1/pkg1.go @@ -0,0 +1,7 @@ +package pkg1 + +import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3" + +func init() { + pkg3.A = 1 +} diff --git a/pkg/compiler/testdata/pkg2/pkg2.go b/pkg/compiler/testdata/pkg2/pkg2.go new file mode 100644 index 000000000..e95cda011 --- /dev/null +++ b/pkg/compiler/testdata/pkg2/pkg2.go @@ -0,0 +1,9 @@ +package pkg2 + +import ( + "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3" +) + +func init() { + pkg3.A = 2 +} diff --git a/pkg/compiler/testdata/pkg3/pkg3.go b/pkg/compiler/testdata/pkg3/pkg3.go new file mode 100644 index 000000000..bf7c49a9e --- /dev/null +++ b/pkg/compiler/testdata/pkg3/pkg3.go @@ -0,0 +1,7 @@ +package pkg3 + +var A int + +func init() { + A = 3 +}