compiler: allow to use local variables in init()

Because body of multiple `init()` functions constitute single
method in contract, we initialize slot with maximum amount of
locals encounterered in any of `init()` functions and clear them
before emitting body of each instance of `init()`.
This commit is contained in:
Evgenii Stratonikov 2020-10-06 18:25:40 +03:00
parent 2d9ef9219a
commit 6701e8cda0
4 changed files with 58 additions and 27 deletions

View file

@ -37,20 +37,24 @@ func (c *codegen) getIdentName(pkg string, name string) string {
} }
// traverseGlobals visits and initializes global variables. // traverseGlobals visits and initializes global variables.
// and returns number of variables initialized and // and returns number of variables initialized.
// true if any init functions were encountered. // Second return value is -1 if no `init()` functions were encountered
func (c *codegen) traverseGlobals() (int, bool) { // and number of maximum amount of locals in any of init functions otherwise.
func (c *codegen) traverseGlobals() (int, int) {
var hasDefer bool var hasDefer bool
var n int var n int
var hasInit bool initLocals := -1
c.ForEachFile(func(f *ast.File, _ *types.Package) { c.ForEachFile(func(f *ast.File, _ *types.Package) {
n += countGlobals(f) n += countGlobals(f)
if !hasInit || !hasDefer { if initLocals == -1 || !hasDefer {
ast.Inspect(f, func(node ast.Node) bool { ast.Inspect(f, func(node ast.Node) bool {
switch n := node.(type) { switch n := node.(type) {
case *ast.FuncDecl: case *ast.FuncDecl:
if isInitFunc(n) { if isInitFunc(n) {
hasInit = true c, _ := countLocals(n)
if c > initLocals {
initLocals = c
}
} }
return !hasDefer return !hasDefer
case *ast.DeferStmt: case *ast.DeferStmt:
@ -64,14 +68,18 @@ func (c *codegen) traverseGlobals() (int, bool) {
if hasDefer { if hasDefer {
n++ n++
} }
if n != 0 || hasInit { if n != 0 || initLocals > -1 {
if n > 255 { if n > 255 {
c.prog.BinWriter.Err = errors.New("too many global variables") c.prog.BinWriter.Err = errors.New("too many global variables")
return 0, hasInit return 0, initLocals
} }
if n != 0 { if n != 0 {
emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)}) emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)})
} }
if initLocals > 0 {
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(initLocals), 0})
}
seenBefore := false
c.ForEachPackage(func(pkg *loader.PackageInfo) { c.ForEachPackage(func(pkg *loader.PackageInfo) {
if n > 0 { if n > 0 {
for _, f := range pkg.Files { for _, f := range pkg.Files {
@ -79,10 +87,10 @@ func (c *codegen) traverseGlobals() (int, bool) {
c.convertGlobals(f, pkg.Pkg) c.convertGlobals(f, pkg.Pkg)
} }
} }
if hasInit { if initLocals > -1 {
for _, f := range pkg.Files { for _, f := range pkg.Files {
c.fillImportMap(f, pkg.Pkg) c.fillImportMap(f, pkg.Pkg)
c.convertInitFuncs(f, pkg.Pkg) seenBefore = c.convertInitFuncs(f, pkg.Pkg, seenBefore) || seenBefore
} }
} }
// because we reuse `convertFuncDecl` for init funcs, // because we reuse `convertFuncDecl` for init funcs,
@ -96,7 +104,7 @@ func (c *codegen) traverseGlobals() (int, bool) {
c.globals["<exception>"] = c.exceptionIndex c.globals["<exception>"] = c.exceptionIndex
} }
} }
return n, hasInit return n, initLocals
} }
// countGlobals counts the global variables in the program to add // countGlobals counts the global variables in the program to add

View file

@ -301,11 +301,23 @@ func isInitFunc(decl *ast.FuncDecl) bool {
decl.Type.Results.NumFields() == 0 decl.Type.Results.NumFields() == 0
} }
func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) { func (c *codegen) clearSlots(n int) {
for i := 0; i < n; i++ {
emit.Opcodes(c.prog.BinWriter, opcode.PUSHNULL)
c.emitStoreByIndex(varLocal, i)
}
}
func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package, seenBefore bool) bool {
ast.Inspect(f, func(node ast.Node) bool { ast.Inspect(f, func(node ast.Node) bool {
switch n := node.(type) { switch n := node.(type) {
case *ast.FuncDecl: case *ast.FuncDecl:
if isInitFunc(n) { if isInitFunc(n) {
if seenBefore {
cnt, _ := countLocals(n)
c.clearSlots(cnt)
seenBefore = true
}
c.convertFuncDecl(f, n, pkg) c.convertFuncDecl(f, n, pkg)
} }
case *ast.GenDecl: case *ast.GenDecl:
@ -313,6 +325,7 @@ func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) {
} }
return true return true
}) })
return seenBefore
} }
func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) { func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) {
@ -345,16 +358,18 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
// All globals copied into the scope of the function need to be added // All globals copied into the scope of the function need to be added
// to the stack size of the function. // to the stack size of the function.
sizeLoc := f.countLocals() if !isInit {
if sizeLoc > 255 { sizeLoc := f.countLocals()
c.prog.Err = errors.New("maximum of 255 local variables is allowed") if sizeLoc > 255 {
} c.prog.Err = errors.New("maximum of 255 local variables is allowed")
sizeArg := f.countArgs() }
if sizeArg > 255 { sizeArg := f.countArgs()
c.prog.Err = errors.New("maximum of 255 local variables is allowed") if sizeArg > 255 {
} c.prog.Err = errors.New("maximum of 255 local variables is allowed")
if sizeLoc != 0 || sizeArg != 0 { }
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)}) if sizeLoc != 0 || sizeArg != 0 {
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)})
}
} }
f.vars.newScope() f.vars.newScope()
@ -1777,7 +1792,8 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
// Bring all imported functions into scope. // Bring all imported functions into scope.
c.ForEachFile(c.resolveFuncDecls) c.ForEachFile(c.resolveFuncDecls)
n, hasInit := c.traverseGlobals() n, initLocals := c.traverseGlobals()
hasInit := initLocals > -1
if n > 0 || hasInit { if n > 0 || hasInit {
emit.Opcodes(c.prog.BinWriter, opcode.RET) emit.Opcodes(c.prog.BinWriter, opcode.RET)
c.initEndOffset = c.prog.Len() c.initEndOffset = c.prog.Len()

View file

@ -152,10 +152,10 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
return true return true
} }
func (c *funcScope) countLocals() int { func countLocals(decl *ast.FuncDecl) (int, bool) {
size := 0 size := 0
hasDefer := false hasDefer := false
ast.Inspect(c.decl, func(n ast.Node) bool { ast.Inspect(decl, func(n ast.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *ast.FuncType: case *ast.FuncType:
num := n.Results.NumFields() num := n.Results.NumFields()
@ -186,6 +186,11 @@ func (c *funcScope) countLocals() int {
} }
return true return true
}) })
return size, hasDefer
}
func (c *funcScope) countLocals() int {
size, hasDefer := countLocals(c.decl)
if hasDefer { if hasDefer {
c.finallyProcessedIndex = size c.finallyProcessedIndex = size
size++ size++

View file

@ -24,11 +24,13 @@ func TestInit(t *testing.T) {
var m = map[int]int{} var m = map[int]int{}
var a = 2 var a = 2
func init() { func init() {
m[1] = 11 b := 11
m[1] = b
} }
func init() { func init() {
a = 1 a = 1
m[3] = 30 var b int
m[3] = 30 + b
} }
func Main() int { func Main() int {
return m[1] + m[3] + a return m[1] + m[3] + a