From 6701e8cda03e9d59f19e2696c2652de8988b6556 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Tue, 6 Oct 2020 18:25:40 +0300 Subject: [PATCH] compiler: allow to use local variables in `init()` Because body of multiple `init()` functions constitute single method in contract, we initialize slot with maximum amount of locals encounterered in any of `init()` functions and clear them before emitting body of each instance of `init()`. --- pkg/compiler/analysis.go | 30 +++++++++++++++++----------- pkg/compiler/codegen.go | 40 ++++++++++++++++++++++++++------------ pkg/compiler/func_scope.go | 9 +++++++-- pkg/compiler/init_test.go | 6 ++++-- 4 files changed, 58 insertions(+), 27 deletions(-) diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index 0dcc3d262..66bbfcda0 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -37,20 +37,24 @@ func (c *codegen) getIdentName(pkg string, name string) string { } // traverseGlobals visits and initializes global variables. -// and returns number of variables initialized and -// true if any init functions were encountered. -func (c *codegen) traverseGlobals() (int, bool) { +// and returns number of variables initialized. +// Second return value is -1 if no `init()` functions were encountered +// and number of maximum amount of locals in any of init functions otherwise. +func (c *codegen) traverseGlobals() (int, int) { var hasDefer bool var n int - var hasInit bool + initLocals := -1 c.ForEachFile(func(f *ast.File, _ *types.Package) { n += countGlobals(f) - if !hasInit || !hasDefer { + if initLocals == -1 || !hasDefer { ast.Inspect(f, func(node ast.Node) bool { switch n := node.(type) { case *ast.FuncDecl: if isInitFunc(n) { - hasInit = true + c, _ := countLocals(n) + if c > initLocals { + initLocals = c + } } return !hasDefer case *ast.DeferStmt: @@ -64,14 +68,18 @@ func (c *codegen) traverseGlobals() (int, bool) { if hasDefer { n++ } - if n != 0 || hasInit { + if n != 0 || initLocals > -1 { if n > 255 { c.prog.BinWriter.Err = errors.New("too many global variables") - return 0, hasInit + return 0, initLocals } if n != 0 { emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)}) } + if initLocals > 0 { + emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(initLocals), 0}) + } + seenBefore := false c.ForEachPackage(func(pkg *loader.PackageInfo) { if n > 0 { for _, f := range pkg.Files { @@ -79,10 +87,10 @@ func (c *codegen) traverseGlobals() (int, bool) { c.convertGlobals(f, pkg.Pkg) } } - if hasInit { + if initLocals > -1 { for _, f := range pkg.Files { c.fillImportMap(f, pkg.Pkg) - c.convertInitFuncs(f, pkg.Pkg) + seenBefore = c.convertInitFuncs(f, pkg.Pkg, seenBefore) || seenBefore } } // because we reuse `convertFuncDecl` for init funcs, @@ -96,7 +104,7 @@ func (c *codegen) traverseGlobals() (int, bool) { c.globals[""] = c.exceptionIndex } } - return n, hasInit + return n, initLocals } // countGlobals counts the global variables in the program to add diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 7152aa171..da2f971b1 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -301,11 +301,23 @@ func isInitFunc(decl *ast.FuncDecl) bool { decl.Type.Results.NumFields() == 0 } -func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) { +func (c *codegen) clearSlots(n int) { + for i := 0; i < n; i++ { + emit.Opcodes(c.prog.BinWriter, opcode.PUSHNULL) + c.emitStoreByIndex(varLocal, i) + } +} + +func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package, seenBefore bool) bool { ast.Inspect(f, func(node ast.Node) bool { switch n := node.(type) { case *ast.FuncDecl: if isInitFunc(n) { + if seenBefore { + cnt, _ := countLocals(n) + c.clearSlots(cnt) + seenBefore = true + } c.convertFuncDecl(f, n, pkg) } case *ast.GenDecl: @@ -313,6 +325,7 @@ func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) { } return true }) + return seenBefore } func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) { @@ -345,16 +358,18 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types. // All globals copied into the scope of the function need to be added // to the stack size of the function. - sizeLoc := f.countLocals() - if sizeLoc > 255 { - c.prog.Err = errors.New("maximum of 255 local variables is allowed") - } - sizeArg := f.countArgs() - if sizeArg > 255 { - c.prog.Err = errors.New("maximum of 255 local variables is allowed") - } - if sizeLoc != 0 || sizeArg != 0 { - emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)}) + if !isInit { + sizeLoc := f.countLocals() + if sizeLoc > 255 { + c.prog.Err = errors.New("maximum of 255 local variables is allowed") + } + sizeArg := f.countArgs() + if sizeArg > 255 { + c.prog.Err = errors.New("maximum of 255 local variables is allowed") + } + if sizeLoc != 0 || sizeArg != 0 { + emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)}) + } } f.vars.newScope() @@ -1777,7 +1792,8 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error { // Bring all imported functions into scope. c.ForEachFile(c.resolveFuncDecls) - n, hasInit := c.traverseGlobals() + n, initLocals := c.traverseGlobals() + hasInit := initLocals > -1 if n > 0 || hasInit { emit.Opcodes(c.prog.BinWriter, opcode.RET) c.initEndOffset = c.prog.Len() diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index 41f3db76c..3e977c42e 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -152,10 +152,10 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool { return true } -func (c *funcScope) countLocals() int { +func countLocals(decl *ast.FuncDecl) (int, bool) { size := 0 hasDefer := false - ast.Inspect(c.decl, func(n ast.Node) bool { + ast.Inspect(decl, func(n ast.Node) bool { switch n := n.(type) { case *ast.FuncType: num := n.Results.NumFields() @@ -186,6 +186,11 @@ func (c *funcScope) countLocals() int { } return true }) + return size, hasDefer +} + +func (c *funcScope) countLocals() int { + size, hasDefer := countLocals(c.decl) if hasDefer { c.finallyProcessedIndex = size size++ diff --git a/pkg/compiler/init_test.go b/pkg/compiler/init_test.go index 696e119f2..45e1ad68c 100644 --- a/pkg/compiler/init_test.go +++ b/pkg/compiler/init_test.go @@ -24,11 +24,13 @@ func TestInit(t *testing.T) { var m = map[int]int{} var a = 2 func init() { - m[1] = 11 + b := 11 + m[1] = b } func init() { a = 1 - m[3] = 30 + var b int + m[3] = 30 + b } func Main() int { return m[1] + m[3] + a