Merge pull request #1261 from nspcc-dev/feature/init
compiler: allow to use `init` function
This commit is contained in:
commit
fced917b71
7 changed files with 254 additions and 27 deletions
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
|
"golang.org/x/tools/go/loader"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -35,21 +36,54 @@ func (c *codegen) getIdentName(pkg string, name string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// traverseGlobals visits and initializes global variables.
|
// traverseGlobals visits and initializes global variables.
|
||||||
// and returns number of variables initialized.
|
// and returns number of variables initialized and
|
||||||
func (c *codegen) traverseGlobals() int {
|
// true if any init functions were encountered.
|
||||||
|
func (c *codegen) traverseGlobals() (int, bool) {
|
||||||
var n int
|
var n int
|
||||||
|
var hasInit bool
|
||||||
c.ForEachFile(func(f *ast.File, _ *types.Package) {
|
c.ForEachFile(func(f *ast.File, _ *types.Package) {
|
||||||
n += countGlobals(f)
|
n += countGlobals(f)
|
||||||
|
if !hasInit {
|
||||||
|
ast.Inspect(f, func(node ast.Node) bool {
|
||||||
|
n, ok := node.(*ast.FuncDecl)
|
||||||
|
if ok {
|
||||||
|
if isInitFunc(n) {
|
||||||
|
hasInit = true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
if n != 0 {
|
if n != 0 || hasInit {
|
||||||
if n > 255 {
|
if n > 255 {
|
||||||
c.prog.BinWriter.Err = errors.New("too many global variables")
|
c.prog.BinWriter.Err = errors.New("too many global variables")
|
||||||
return 0
|
return 0, hasInit
|
||||||
}
|
}
|
||||||
emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)})
|
if n != 0 {
|
||||||
c.ForEachFile(c.convertGlobals)
|
emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)})
|
||||||
|
}
|
||||||
|
c.ForEachPackage(func(pkg *loader.PackageInfo) {
|
||||||
|
if n > 0 {
|
||||||
|
for _, f := range pkg.Files {
|
||||||
|
c.fillImportMap(f, pkg.Pkg)
|
||||||
|
c.convertGlobals(f, pkg.Pkg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasInit {
|
||||||
|
for _, f := range pkg.Files {
|
||||||
|
c.fillImportMap(f, pkg.Pkg)
|
||||||
|
c.convertInitFuncs(f, pkg.Pkg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// because we reuse `convertFuncDecl` for init funcs,
|
||||||
|
// we need to cleare scope, so that global variables
|
||||||
|
// encountered after will be recognized as globals.
|
||||||
|
c.scope = nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return n
|
return n, hasInit
|
||||||
}
|
}
|
||||||
|
|
||||||
// countGlobals counts the global variables in the program to add
|
// countGlobals counts the global variables in the program to add
|
||||||
|
@ -103,6 +137,32 @@ func lastStmtIsReturn(decl *ast.FuncDecl) (b bool) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// analyzePkgOrder sets the order in which packages should be processed.
|
||||||
|
// From Go spec:
|
||||||
|
// A package with no imports is initialized by assigning initial values to all its package-level variables
|
||||||
|
// followed by calling all init functions in the order they appear in the source, possibly in multiple files,
|
||||||
|
// as presented to the compiler. If a package has imports, the imported packages are initialized before
|
||||||
|
// initializing the package itself. If multiple packages import a package, the imported package
|
||||||
|
// will be initialized only once. The importing of packages, by construction, guarantees
|
||||||
|
// that there can be no cyclic initialization dependencies.
|
||||||
|
func (c *codegen) analyzePkgOrder() {
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
info := c.buildInfo.program.Package(c.buildInfo.initialPackage)
|
||||||
|
c.visitPkg(info.Pkg, seen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codegen) visitPkg(pkg *types.Package, seen map[string]bool) {
|
||||||
|
pkgPath := pkg.Path()
|
||||||
|
if seen[pkgPath] {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, imp := range pkg.Imports() {
|
||||||
|
c.visitPkg(imp, seen)
|
||||||
|
}
|
||||||
|
seen[pkgPath] = true
|
||||||
|
c.packages = append(c.packages, pkgPath)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *codegen) analyzeFuncUsage() funcUsage {
|
func (c *codegen) analyzeFuncUsage() funcUsage {
|
||||||
usage := funcUsage{}
|
usage := funcUsage{}
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,9 @@ type codegen struct {
|
||||||
// mainPkg is a main package metadata.
|
// mainPkg is a main package metadata.
|
||||||
mainPkg *loader.PackageInfo
|
mainPkg *loader.PackageInfo
|
||||||
|
|
||||||
|
// packages contains packages in the order they were loaded.
|
||||||
|
packages []string
|
||||||
|
|
||||||
// Label table for recording jump destinations.
|
// Label table for recording jump destinations.
|
||||||
l []int
|
l []int
|
||||||
}
|
}
|
||||||
|
@ -275,24 +278,48 @@ func (c *codegen) convertGlobals(f *ast.File, _ *types.Package) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isInitFunc(decl *ast.FuncDecl) bool {
|
||||||
|
return decl.Name.Name == "init" && decl.Recv == nil &&
|
||||||
|
decl.Type.Params.NumFields() == 0 &&
|
||||||
|
decl.Type.Results.NumFields() == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) {
|
||||||
|
ast.Inspect(f, func(node ast.Node) bool {
|
||||||
|
switch n := node.(type) {
|
||||||
|
case *ast.FuncDecl:
|
||||||
|
if isInitFunc(n) {
|
||||||
|
c.convertFuncDecl(f, n, pkg)
|
||||||
|
}
|
||||||
|
case *ast.GenDecl:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) {
|
func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) {
|
||||||
var (
|
var (
|
||||||
f *funcScope
|
f *funcScope
|
||||||
ok, isLambda bool
|
ok, isLambda bool
|
||||||
)
|
)
|
||||||
|
isInit := isInitFunc(decl)
|
||||||
f, ok = c.funcs[c.getFuncNameFromDecl("", decl)]
|
if isInit {
|
||||||
if ok {
|
f = c.newFuncScope(decl, c.newLabel())
|
||||||
// If this function is a syscall or builtin we will not convert it to bytecode.
|
|
||||||
if isSyscall(f) || isCustomBuiltin(f) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.setLabel(f.label)
|
|
||||||
} else if f, ok = c.lambda[c.getIdentName("", decl.Name.Name)]; ok {
|
|
||||||
isLambda = ok
|
|
||||||
c.setLabel(f.label)
|
|
||||||
} else {
|
} else {
|
||||||
f = c.newFunc(decl)
|
f, ok = c.funcs[c.getFuncNameFromDecl("", decl)]
|
||||||
|
if ok {
|
||||||
|
// If this function is a syscall or builtin we will not convert it to bytecode.
|
||||||
|
if isSyscall(f) || isCustomBuiltin(f) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.setLabel(f.label)
|
||||||
|
} else if f, ok = c.lambda[c.getIdentName("", decl.Name.Name)]; ok {
|
||||||
|
isLambda = ok
|
||||||
|
c.setLabel(f.label)
|
||||||
|
} else {
|
||||||
|
f = c.newFunc(decl)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.rng.Start = uint16(c.prog.Len())
|
f.rng.Start = uint16(c.prog.Len())
|
||||||
|
@ -342,7 +369,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
|
||||||
// If we have reached the end of the function without encountering `return` statement,
|
// If we have reached the end of the function without encountering `return` statement,
|
||||||
// we should clean alt.stack manually.
|
// we should clean alt.stack manually.
|
||||||
// This can be the case with void and named-return functions.
|
// This can be the case with void and named-return functions.
|
||||||
if !lastStmtIsReturn(decl) {
|
if !isInit && !lastStmtIsReturn(decl) {
|
||||||
c.saveSequencePoint(decl.Body)
|
c.saveSequencePoint(decl.Body)
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.RET)
|
emit.Opcode(c.prog.BinWriter, opcode.RET)
|
||||||
}
|
}
|
||||||
|
@ -1462,13 +1489,14 @@ func (c *codegen) newLambda(u uint16, lit *ast.FuncLit) {
|
||||||
|
|
||||||
func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
||||||
c.mainPkg = pkg
|
c.mainPkg = pkg
|
||||||
|
c.analyzePkgOrder()
|
||||||
funUsage := c.analyzeFuncUsage()
|
funUsage := c.analyzeFuncUsage()
|
||||||
|
|
||||||
// Bring all imported functions into scope.
|
// Bring all imported functions into scope.
|
||||||
c.ForEachFile(c.resolveFuncDecls)
|
c.ForEachFile(c.resolveFuncDecls)
|
||||||
|
|
||||||
n := c.traverseGlobals()
|
n, hasInit := c.traverseGlobals()
|
||||||
if n > 0 {
|
if n > 0 || hasInit {
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.RET)
|
emit.Opcode(c.prog.BinWriter, opcode.RET)
|
||||||
c.initEndOffset = c.prog.Len()
|
c.initEndOffset = c.prog.Len()
|
||||||
}
|
}
|
||||||
|
@ -1488,7 +1516,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
||||||
// Don't convert the function if it's not used. This will save a lot
|
// Don't convert the function if it's not used. This will save a lot
|
||||||
// of bytecode space.
|
// of bytecode space.
|
||||||
name := c.getFuncNameFromDecl(pkg.Path(), n)
|
name := c.getFuncNameFromDecl(pkg.Path(), n)
|
||||||
if funUsage.funcUsed(name) && !isInteropPath(pkg.Path()) {
|
if !isInitFunc(n) && funUsage.funcUsed(name) && !isInteropPath(pkg.Path()) {
|
||||||
c.convertFuncDecl(f, n, pkg)
|
c.convertFuncDecl(f, n, pkg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,16 +47,25 @@ type buildInfo struct {
|
||||||
program *loader.Program
|
program *loader.Program
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForEachFile executes fn on each file used in current program.
|
// ForEachPackage executes fn on each package used in the current program
|
||||||
func (c *codegen) ForEachFile(fn func(*ast.File, *types.Package)) {
|
// in the order they should be initialized.
|
||||||
for _, pkg := range c.buildInfo.program.AllPackages {
|
func (c *codegen) ForEachPackage(fn func(*loader.PackageInfo)) {
|
||||||
|
for i := range c.packages {
|
||||||
|
pkg := c.buildInfo.program.Package(c.packages[i])
|
||||||
c.typeInfo = &pkg.Info
|
c.typeInfo = &pkg.Info
|
||||||
c.currPkg = pkg.Pkg
|
c.currPkg = pkg.Pkg
|
||||||
|
fn(pkg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForEachFile executes fn on each file used in current program.
|
||||||
|
func (c *codegen) ForEachFile(fn func(*ast.File, *types.Package)) {
|
||||||
|
c.ForEachPackage(func(pkg *loader.PackageInfo) {
|
||||||
for _, f := range pkg.Files {
|
for _, f := range pkg.Files {
|
||||||
c.fillImportMap(f, pkg.Pkg)
|
c.fillImportMap(f, pkg.Pkg)
|
||||||
fn(f, pkg.Pkg)
|
fn(f, pkg.Pkg)
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// fillImportMap fills import map for f.
|
// fillImportMap fills import map for f.
|
||||||
|
|
107
pkg/compiler/init_test.go
Normal file
107
pkg/compiler/init_test.go
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
package compiler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/big"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInit(t *testing.T) {
|
||||||
|
t.Run("Simple", func(t *testing.T) {
|
||||||
|
src := `package foo
|
||||||
|
var a int
|
||||||
|
func init() {
|
||||||
|
a = 42
|
||||||
|
}
|
||||||
|
func Main() int {
|
||||||
|
return a
|
||||||
|
}`
|
||||||
|
eval(t, src, big.NewInt(42))
|
||||||
|
})
|
||||||
|
t.Run("Multi", func(t *testing.T) {
|
||||||
|
src := `package foo
|
||||||
|
var m = map[int]int{}
|
||||||
|
var a = 2
|
||||||
|
func init() {
|
||||||
|
m[1] = 11
|
||||||
|
}
|
||||||
|
func init() {
|
||||||
|
a = 1
|
||||||
|
m[3] = 30
|
||||||
|
}
|
||||||
|
func Main() int {
|
||||||
|
return m[1] + m[3] + a
|
||||||
|
}`
|
||||||
|
eval(t, src, big.NewInt(42))
|
||||||
|
})
|
||||||
|
t.Run("WithCall", func(t *testing.T) {
|
||||||
|
src := `package foo
|
||||||
|
var m = map[int]int{}
|
||||||
|
func init() {
|
||||||
|
initMap(m)
|
||||||
|
}
|
||||||
|
func initMap(m map[int]int) {
|
||||||
|
m[11] = 42
|
||||||
|
}
|
||||||
|
func Main() int {
|
||||||
|
return m[11]
|
||||||
|
}`
|
||||||
|
eval(t, src, big.NewInt(42))
|
||||||
|
})
|
||||||
|
t.Run("InvalidSignature", func(t *testing.T) {
|
||||||
|
src := `package foo
|
||||||
|
type Foo int
|
||||||
|
var a int
|
||||||
|
func (f Foo) init() {
|
||||||
|
a = 2
|
||||||
|
}
|
||||||
|
func Main() int {
|
||||||
|
return a
|
||||||
|
}`
|
||||||
|
eval(t, src, big.NewInt(0))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestImportOrder(t *testing.T) {
|
||||||
|
t.Run("1,2", func(t *testing.T) {
|
||||||
|
src := `package foo
|
||||||
|
import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg1"
|
||||||
|
import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg2"
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
|
||||||
|
func Main() int { return pkg3.A }`
|
||||||
|
v := vmAndCompile(t, src)
|
||||||
|
v.PrintOps()
|
||||||
|
eval(t, src, big.NewInt(2))
|
||||||
|
})
|
||||||
|
t.Run("2,1", func(t *testing.T) {
|
||||||
|
src := `package foo
|
||||||
|
import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg2"
|
||||||
|
import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg1"
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
|
||||||
|
func Main() int { return pkg3.A }`
|
||||||
|
eval(t, src, big.NewInt(1))
|
||||||
|
})
|
||||||
|
t.Run("InitializeOnce", func(t *testing.T) {
|
||||||
|
src := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
|
||||||
|
var A = pkg3.A
|
||||||
|
func Main() int { return A }`
|
||||||
|
eval(t, src, big.NewInt(3))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitWithNoGlobals(t *testing.T) {
|
||||||
|
src := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/interop/runtime"
|
||||||
|
func init() {
|
||||||
|
runtime.Notify("called in '_initialize'")
|
||||||
|
}
|
||||||
|
func Main() int {
|
||||||
|
return 42
|
||||||
|
}`
|
||||||
|
v, s := vmAndCompileInterop(t, src)
|
||||||
|
require.NoError(t, v.Run())
|
||||||
|
assertResult(t, v, big.NewInt(42))
|
||||||
|
require.True(t, len(s.events) == 1)
|
||||||
|
}
|
7
pkg/compiler/testdata/pkg1/pkg1.go
vendored
Normal file
7
pkg/compiler/testdata/pkg1/pkg1.go
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package pkg1
|
||||||
|
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
pkg3.A = 1
|
||||||
|
}
|
9
pkg/compiler/testdata/pkg2/pkg2.go
vendored
Normal file
9
pkg/compiler/testdata/pkg2/pkg2.go
vendored
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package pkg2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
pkg3.A = 2
|
||||||
|
}
|
7
pkg/compiler/testdata/pkg3/pkg3.go
vendored
Normal file
7
pkg/compiler/testdata/pkg3/pkg3.go
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package pkg3
|
||||||
|
|
||||||
|
var A int
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
A = 3
|
||||||
|
}
|
Loading…
Reference in a new issue