Merge pull request #1261 from nspcc-dev/feature/init

compiler: allow to use `init` function
This commit is contained in:
Roman Khimov 2020-08-05 19:04:51 +03:00 committed by GitHub
commit fced917b71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 254 additions and 27 deletions

View file

@ -8,6 +8,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"golang.org/x/tools/go/loader"
) )
var ( var (
@ -35,21 +36,54 @@ func (c *codegen) getIdentName(pkg string, name string) string {
} }
// traverseGlobals visits and initializes global variables. // traverseGlobals visits and initializes global variables.
// and returns number of variables initialized. // and returns number of variables initialized and
func (c *codegen) traverseGlobals() int { // true if any init functions were encountered.
func (c *codegen) traverseGlobals() (int, bool) {
var n int var n int
var hasInit bool
c.ForEachFile(func(f *ast.File, _ *types.Package) { c.ForEachFile(func(f *ast.File, _ *types.Package) {
n += countGlobals(f) n += countGlobals(f)
if !hasInit {
ast.Inspect(f, func(node ast.Node) bool {
n, ok := node.(*ast.FuncDecl)
if ok {
if isInitFunc(n) {
hasInit = true
}
return false
}
return true
}) })
if n != 0 { }
})
if n != 0 || hasInit {
if n > 255 { if n > 255 {
c.prog.BinWriter.Err = errors.New("too many global variables") c.prog.BinWriter.Err = errors.New("too many global variables")
return 0 return 0, hasInit
} }
if n != 0 {
emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)}) emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)})
c.ForEachFile(c.convertGlobals)
} }
return n c.ForEachPackage(func(pkg *loader.PackageInfo) {
if n > 0 {
for _, f := range pkg.Files {
c.fillImportMap(f, pkg.Pkg)
c.convertGlobals(f, pkg.Pkg)
}
}
if hasInit {
for _, f := range pkg.Files {
c.fillImportMap(f, pkg.Pkg)
c.convertInitFuncs(f, pkg.Pkg)
}
}
// because we reuse `convertFuncDecl` for init funcs,
// we need to cleare scope, so that global variables
// encountered after will be recognized as globals.
c.scope = nil
})
}
return n, hasInit
} }
// countGlobals counts the global variables in the program to add // countGlobals counts the global variables in the program to add
@ -103,6 +137,32 @@ func lastStmtIsReturn(decl *ast.FuncDecl) (b bool) {
return false return false
} }
// analyzePkgOrder sets the order in which packages should be processed.
// From Go spec:
// A package with no imports is initialized by assigning initial values to all its package-level variables
// followed by calling all init functions in the order they appear in the source, possibly in multiple files,
// as presented to the compiler. If a package has imports, the imported packages are initialized before
// initializing the package itself. If multiple packages import a package, the imported package
// will be initialized only once. The importing of packages, by construction, guarantees
// that there can be no cyclic initialization dependencies.
func (c *codegen) analyzePkgOrder() {
seen := make(map[string]bool)
info := c.buildInfo.program.Package(c.buildInfo.initialPackage)
c.visitPkg(info.Pkg, seen)
}
func (c *codegen) visitPkg(pkg *types.Package, seen map[string]bool) {
pkgPath := pkg.Path()
if seen[pkgPath] {
return
}
for _, imp := range pkg.Imports() {
c.visitPkg(imp, seen)
}
seen[pkgPath] = true
c.packages = append(c.packages, pkgPath)
}
func (c *codegen) analyzeFuncUsage() funcUsage { func (c *codegen) analyzeFuncUsage() funcUsage {
usage := funcUsage{} usage := funcUsage{}

View file

@ -74,6 +74,9 @@ type codegen struct {
// mainPkg is a main package metadata. // mainPkg is a main package metadata.
mainPkg *loader.PackageInfo mainPkg *loader.PackageInfo
// packages contains packages in the order they were loaded.
packages []string
// Label table for recording jump destinations. // Label table for recording jump destinations.
l []int l []int
} }
@ -275,12 +278,35 @@ func (c *codegen) convertGlobals(f *ast.File, _ *types.Package) {
}) })
} }
func isInitFunc(decl *ast.FuncDecl) bool {
return decl.Name.Name == "init" && decl.Recv == nil &&
decl.Type.Params.NumFields() == 0 &&
decl.Type.Results.NumFields() == 0
}
func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) {
ast.Inspect(f, func(node ast.Node) bool {
switch n := node.(type) {
case *ast.FuncDecl:
if isInitFunc(n) {
c.convertFuncDecl(f, n, pkg)
}
case *ast.GenDecl:
return false
}
return true
})
}
func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) { func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) {
var ( var (
f *funcScope f *funcScope
ok, isLambda bool ok, isLambda bool
) )
isInit := isInitFunc(decl)
if isInit {
f = c.newFuncScope(decl, c.newLabel())
} else {
f, ok = c.funcs[c.getFuncNameFromDecl("", decl)] f, ok = c.funcs[c.getFuncNameFromDecl("", decl)]
if ok { if ok {
// If this function is a syscall or builtin we will not convert it to bytecode. // If this function is a syscall or builtin we will not convert it to bytecode.
@ -294,6 +320,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
} else { } else {
f = c.newFunc(decl) f = c.newFunc(decl)
} }
}
f.rng.Start = uint16(c.prog.Len()) f.rng.Start = uint16(c.prog.Len())
c.scope = f c.scope = f
@ -342,7 +369,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
// If we have reached the end of the function without encountering `return` statement, // If we have reached the end of the function without encountering `return` statement,
// we should clean alt.stack manually. // we should clean alt.stack manually.
// This can be the case with void and named-return functions. // This can be the case with void and named-return functions.
if !lastStmtIsReturn(decl) { if !isInit && !lastStmtIsReturn(decl) {
c.saveSequencePoint(decl.Body) c.saveSequencePoint(decl.Body)
emit.Opcode(c.prog.BinWriter, opcode.RET) emit.Opcode(c.prog.BinWriter, opcode.RET)
} }
@ -1462,13 +1489,14 @@ func (c *codegen) newLambda(u uint16, lit *ast.FuncLit) {
func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
c.mainPkg = pkg c.mainPkg = pkg
c.analyzePkgOrder()
funUsage := c.analyzeFuncUsage() funUsage := c.analyzeFuncUsage()
// Bring all imported functions into scope. // Bring all imported functions into scope.
c.ForEachFile(c.resolveFuncDecls) c.ForEachFile(c.resolveFuncDecls)
n := c.traverseGlobals() n, hasInit := c.traverseGlobals()
if n > 0 { if n > 0 || hasInit {
emit.Opcode(c.prog.BinWriter, opcode.RET) emit.Opcode(c.prog.BinWriter, opcode.RET)
c.initEndOffset = c.prog.Len() c.initEndOffset = c.prog.Len()
} }
@ -1488,7 +1516,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
// Don't convert the function if it's not used. This will save a lot // Don't convert the function if it's not used. This will save a lot
// of bytecode space. // of bytecode space.
name := c.getFuncNameFromDecl(pkg.Path(), n) name := c.getFuncNameFromDecl(pkg.Path(), n)
if funUsage.funcUsed(name) && !isInteropPath(pkg.Path()) { if !isInitFunc(n) && funUsage.funcUsed(name) && !isInteropPath(pkg.Path()) {
c.convertFuncDecl(f, n, pkg) c.convertFuncDecl(f, n, pkg)
} }
} }

View file

@ -47,16 +47,25 @@ type buildInfo struct {
program *loader.Program program *loader.Program
} }
// ForEachFile executes fn on each file used in current program. // ForEachPackage executes fn on each package used in the current program
func (c *codegen) ForEachFile(fn func(*ast.File, *types.Package)) { // in the order they should be initialized.
for _, pkg := range c.buildInfo.program.AllPackages { func (c *codegen) ForEachPackage(fn func(*loader.PackageInfo)) {
for i := range c.packages {
pkg := c.buildInfo.program.Package(c.packages[i])
c.typeInfo = &pkg.Info c.typeInfo = &pkg.Info
c.currPkg = pkg.Pkg c.currPkg = pkg.Pkg
fn(pkg)
}
}
// ForEachFile executes fn on each file used in current program.
func (c *codegen) ForEachFile(fn func(*ast.File, *types.Package)) {
c.ForEachPackage(func(pkg *loader.PackageInfo) {
for _, f := range pkg.Files { for _, f := range pkg.Files {
c.fillImportMap(f, pkg.Pkg) c.fillImportMap(f, pkg.Pkg)
fn(f, pkg.Pkg) fn(f, pkg.Pkg)
} }
} })
} }
// fillImportMap fills import map for f. // fillImportMap fills import map for f.

107
pkg/compiler/init_test.go Normal file
View file

@ -0,0 +1,107 @@
package compiler_test
import (
"math/big"
"testing"
"github.com/stretchr/testify/require"
)
func TestInit(t *testing.T) {
t.Run("Simple", func(t *testing.T) {
src := `package foo
var a int
func init() {
a = 42
}
func Main() int {
return a
}`
eval(t, src, big.NewInt(42))
})
t.Run("Multi", func(t *testing.T) {
src := `package foo
var m = map[int]int{}
var a = 2
func init() {
m[1] = 11
}
func init() {
a = 1
m[3] = 30
}
func Main() int {
return m[1] + m[3] + a
}`
eval(t, src, big.NewInt(42))
})
t.Run("WithCall", func(t *testing.T) {
src := `package foo
var m = map[int]int{}
func init() {
initMap(m)
}
func initMap(m map[int]int) {
m[11] = 42
}
func Main() int {
return m[11]
}`
eval(t, src, big.NewInt(42))
})
t.Run("InvalidSignature", func(t *testing.T) {
src := `package foo
type Foo int
var a int
func (f Foo) init() {
a = 2
}
func Main() int {
return a
}`
eval(t, src, big.NewInt(0))
})
}
func TestImportOrder(t *testing.T) {
t.Run("1,2", func(t *testing.T) {
src := `package foo
import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg1"
import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg2"
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
func Main() int { return pkg3.A }`
v := vmAndCompile(t, src)
v.PrintOps()
eval(t, src, big.NewInt(2))
})
t.Run("2,1", func(t *testing.T) {
src := `package foo
import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg2"
import _ "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg1"
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
func Main() int { return pkg3.A }`
eval(t, src, big.NewInt(1))
})
t.Run("InitializeOnce", func(t *testing.T) {
src := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
var A = pkg3.A
func Main() int { return A }`
eval(t, src, big.NewInt(3))
})
}
func TestInitWithNoGlobals(t *testing.T) {
src := `package foo
import "github.com/nspcc-dev/neo-go/pkg/interop/runtime"
func init() {
runtime.Notify("called in '_initialize'")
}
func Main() int {
return 42
}`
v, s := vmAndCompileInterop(t, src)
require.NoError(t, v.Run())
assertResult(t, v, big.NewInt(42))
require.True(t, len(s.events) == 1)
}

7
pkg/compiler/testdata/pkg1/pkg1.go vendored Normal file
View file

@ -0,0 +1,7 @@
package pkg1
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
func init() {
pkg3.A = 1
}

9
pkg/compiler/testdata/pkg2/pkg2.go vendored Normal file
View file

@ -0,0 +1,9 @@
package pkg2
import (
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/pkg3"
)
func init() {
pkg3.A = 2
}

7
pkg/compiler/testdata/pkg3/pkg3.go vendored Normal file
View file

@ -0,0 +1,7 @@
package pkg3
var A int
func init() {
A = 3
}