diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 1c1dab8bd..4905490c3 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -48,6 +48,7 @@ func (c *codegen) traverseGlobals() int { } emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)}) c.ForEachFile(c.convertGlobals) + c.ForEachFile(c.convertInitFuncs) } return n } diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index fc34b7b04..59bcd9020 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -275,24 +275,48 @@ func (c *codegen) convertGlobals(f *ast.File, _ *types.Package) { }) } +func isInitFunc(decl *ast.FuncDecl) bool { + return decl.Name.Name == "init" && decl.Recv == nil && + decl.Type.Params.NumFields() == 0 && + decl.Type.Results.NumFields() == 0 +} + +func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) { + ast.Inspect(f, func(node ast.Node) bool { + switch n := node.(type) { + case *ast.FuncDecl: + if isInitFunc(n) { + c.convertFuncDecl(f, n, pkg) + } + case *ast.GenDecl: + return false + } + return true + }) +} + func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) { var ( f *funcScope ok, isLambda bool ) - - f, ok = c.funcs[c.getFuncNameFromDecl("", decl)] - if ok { - // If this function is a syscall or builtin we will not convert it to bytecode. - if isSyscall(f) || isCustomBuiltin(f) { - return - } - c.setLabel(f.label) - } else if f, ok = c.lambda[c.getIdentName("", decl.Name.Name)]; ok { - isLambda = ok - c.setLabel(f.label) + isInit := isInitFunc(decl) + if isInit { + f = c.newFuncScope(decl, c.newLabel()) } else { - f = c.newFunc(decl) + f, ok = c.funcs[c.getFuncNameFromDecl("", decl)] + if ok { + // If this function is a syscall or builtin we will not convert it to bytecode. + if isSyscall(f) || isCustomBuiltin(f) { + return + } + c.setLabel(f.label) + } else if f, ok = c.lambda[c.getIdentName("", decl.Name.Name)]; ok { + isLambda = ok + c.setLabel(f.label) + } else { + f = c.newFunc(decl) + } } f.rng.Start = uint16(c.prog.Len()) @@ -342,7 +366,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types. // If we have reached the end of the function without encountering `return` statement, // we should clean alt.stack manually. // This can be the case with void and named-return functions. - if !lastStmtIsReturn(decl) { + if !isInit && !lastStmtIsReturn(decl) { c.saveSequencePoint(decl.Body) emit.Opcode(c.prog.BinWriter, opcode.RET) } @@ -1488,7 +1512,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { // Don't convert the function if it's not used. This will save a lot // of bytecode space. name := c.getFuncNameFromDecl(pkg.Path(), n) - if funUsage.funcUsed(name) && !isInteropPath(pkg.Path()) { + if !isInitFunc(n) && funUsage.funcUsed(name) && !isInteropPath(pkg.Path()) { c.convertFuncDecl(f, n, pkg) } } diff --git a/pkg/compiler/init_test.go b/pkg/compiler/init_test.go new file mode 100644 index 000000000..878e87ab4 --- /dev/null +++ b/pkg/compiler/init_test.go @@ -0,0 +1,62 @@ +package compiler_test + +import ( + "math/big" + "testing" +) + +func TestInit(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + src := `package foo + var a int + func init() { + a = 42 + } + func Main() int { + return a + }` + eval(t, src, big.NewInt(42)) + }) + t.Run("Multi", func(t *testing.T) { + src := `package foo + var m = map[int]int{} + var a = 2 + func init() { + m[1] = 11 + } + func init() { + a = 1 + m[3] = 30 + } + func Main() int { + return m[1] + m[3] + a + }` + eval(t, src, big.NewInt(42)) + }) + t.Run("WithCall", func(t *testing.T) { + src := `package foo + var m = map[int]int{} + func init() { + initMap(m) + } + func initMap(m map[int]int) { + m[11] = 42 + } + func Main() int { + return m[11] + }` + eval(t, src, big.NewInt(42)) + }) + t.Run("InvalidSignature", func(t *testing.T) { + src := `package foo + type Foo int + var a int + func (f Foo) init() { + a = 2 + } + func Main() int { + return a + }` + eval(t, src, big.NewInt(0)) + }) +}