forked from TrueCloudLab/neoneo-go
compiler: allow to use local variables in init()
Because body of multiple `init()` functions constitute single method in contract, we initialize slot with maximum amount of locals encounterered in any of `init()` functions and clear them before emitting body of each instance of `init()`.
This commit is contained in:
parent
2d9ef9219a
commit
6701e8cda0
4 changed files with 58 additions and 27 deletions
|
@ -37,20 +37,24 @@ func (c *codegen) getIdentName(pkg string, name string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// traverseGlobals visits and initializes global variables.
|
// traverseGlobals visits and initializes global variables.
|
||||||
// and returns number of variables initialized and
|
// and returns number of variables initialized.
|
||||||
// true if any init functions were encountered.
|
// Second return value is -1 if no `init()` functions were encountered
|
||||||
func (c *codegen) traverseGlobals() (int, bool) {
|
// and number of maximum amount of locals in any of init functions otherwise.
|
||||||
|
func (c *codegen) traverseGlobals() (int, int) {
|
||||||
var hasDefer bool
|
var hasDefer bool
|
||||||
var n int
|
var n int
|
||||||
var hasInit bool
|
initLocals := -1
|
||||||
c.ForEachFile(func(f *ast.File, _ *types.Package) {
|
c.ForEachFile(func(f *ast.File, _ *types.Package) {
|
||||||
n += countGlobals(f)
|
n += countGlobals(f)
|
||||||
if !hasInit || !hasDefer {
|
if initLocals == -1 || !hasDefer {
|
||||||
ast.Inspect(f, func(node ast.Node) bool {
|
ast.Inspect(f, func(node ast.Node) bool {
|
||||||
switch n := node.(type) {
|
switch n := node.(type) {
|
||||||
case *ast.FuncDecl:
|
case *ast.FuncDecl:
|
||||||
if isInitFunc(n) {
|
if isInitFunc(n) {
|
||||||
hasInit = true
|
c, _ := countLocals(n)
|
||||||
|
if c > initLocals {
|
||||||
|
initLocals = c
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return !hasDefer
|
return !hasDefer
|
||||||
case *ast.DeferStmt:
|
case *ast.DeferStmt:
|
||||||
|
@ -64,14 +68,18 @@ func (c *codegen) traverseGlobals() (int, bool) {
|
||||||
if hasDefer {
|
if hasDefer {
|
||||||
n++
|
n++
|
||||||
}
|
}
|
||||||
if n != 0 || hasInit {
|
if n != 0 || initLocals > -1 {
|
||||||
if n > 255 {
|
if n > 255 {
|
||||||
c.prog.BinWriter.Err = errors.New("too many global variables")
|
c.prog.BinWriter.Err = errors.New("too many global variables")
|
||||||
return 0, hasInit
|
return 0, initLocals
|
||||||
}
|
}
|
||||||
if n != 0 {
|
if n != 0 {
|
||||||
emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)})
|
emit.Instruction(c.prog.BinWriter, opcode.INITSSLOT, []byte{byte(n)})
|
||||||
}
|
}
|
||||||
|
if initLocals > 0 {
|
||||||
|
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(initLocals), 0})
|
||||||
|
}
|
||||||
|
seenBefore := false
|
||||||
c.ForEachPackage(func(pkg *loader.PackageInfo) {
|
c.ForEachPackage(func(pkg *loader.PackageInfo) {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
for _, f := range pkg.Files {
|
for _, f := range pkg.Files {
|
||||||
|
@ -79,10 +87,10 @@ func (c *codegen) traverseGlobals() (int, bool) {
|
||||||
c.convertGlobals(f, pkg.Pkg)
|
c.convertGlobals(f, pkg.Pkg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasInit {
|
if initLocals > -1 {
|
||||||
for _, f := range pkg.Files {
|
for _, f := range pkg.Files {
|
||||||
c.fillImportMap(f, pkg.Pkg)
|
c.fillImportMap(f, pkg.Pkg)
|
||||||
c.convertInitFuncs(f, pkg.Pkg)
|
seenBefore = c.convertInitFuncs(f, pkg.Pkg, seenBefore) || seenBefore
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// because we reuse `convertFuncDecl` for init funcs,
|
// because we reuse `convertFuncDecl` for init funcs,
|
||||||
|
@ -96,7 +104,7 @@ func (c *codegen) traverseGlobals() (int, bool) {
|
||||||
c.globals["<exception>"] = c.exceptionIndex
|
c.globals["<exception>"] = c.exceptionIndex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return n, hasInit
|
return n, initLocals
|
||||||
}
|
}
|
||||||
|
|
||||||
// countGlobals counts the global variables in the program to add
|
// countGlobals counts the global variables in the program to add
|
||||||
|
|
|
@ -301,11 +301,23 @@ func isInitFunc(decl *ast.FuncDecl) bool {
|
||||||
decl.Type.Results.NumFields() == 0
|
decl.Type.Results.NumFields() == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) {
|
func (c *codegen) clearSlots(n int) {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
emit.Opcodes(c.prog.BinWriter, opcode.PUSHNULL)
|
||||||
|
c.emitStoreByIndex(varLocal, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package, seenBefore bool) bool {
|
||||||
ast.Inspect(f, func(node ast.Node) bool {
|
ast.Inspect(f, func(node ast.Node) bool {
|
||||||
switch n := node.(type) {
|
switch n := node.(type) {
|
||||||
case *ast.FuncDecl:
|
case *ast.FuncDecl:
|
||||||
if isInitFunc(n) {
|
if isInitFunc(n) {
|
||||||
|
if seenBefore {
|
||||||
|
cnt, _ := countLocals(n)
|
||||||
|
c.clearSlots(cnt)
|
||||||
|
seenBefore = true
|
||||||
|
}
|
||||||
c.convertFuncDecl(f, n, pkg)
|
c.convertFuncDecl(f, n, pkg)
|
||||||
}
|
}
|
||||||
case *ast.GenDecl:
|
case *ast.GenDecl:
|
||||||
|
@ -313,6 +325,7 @@ func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package) {
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
return seenBefore
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) {
|
func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.Package) {
|
||||||
|
@ -345,16 +358,18 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
|
||||||
|
|
||||||
// All globals copied into the scope of the function need to be added
|
// All globals copied into the scope of the function need to be added
|
||||||
// to the stack size of the function.
|
// to the stack size of the function.
|
||||||
sizeLoc := f.countLocals()
|
if !isInit {
|
||||||
if sizeLoc > 255 {
|
sizeLoc := f.countLocals()
|
||||||
c.prog.Err = errors.New("maximum of 255 local variables is allowed")
|
if sizeLoc > 255 {
|
||||||
}
|
c.prog.Err = errors.New("maximum of 255 local variables is allowed")
|
||||||
sizeArg := f.countArgs()
|
}
|
||||||
if sizeArg > 255 {
|
sizeArg := f.countArgs()
|
||||||
c.prog.Err = errors.New("maximum of 255 local variables is allowed")
|
if sizeArg > 255 {
|
||||||
}
|
c.prog.Err = errors.New("maximum of 255 local variables is allowed")
|
||||||
if sizeLoc != 0 || sizeArg != 0 {
|
}
|
||||||
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)})
|
if sizeLoc != 0 || sizeArg != 0 {
|
||||||
|
emit.Instruction(c.prog.BinWriter, opcode.INITSLOT, []byte{byte(sizeLoc), byte(sizeArg)})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.vars.newScope()
|
f.vars.newScope()
|
||||||
|
@ -1777,7 +1792,8 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
||||||
// Bring all imported functions into scope.
|
// Bring all imported functions into scope.
|
||||||
c.ForEachFile(c.resolveFuncDecls)
|
c.ForEachFile(c.resolveFuncDecls)
|
||||||
|
|
||||||
n, hasInit := c.traverseGlobals()
|
n, initLocals := c.traverseGlobals()
|
||||||
|
hasInit := initLocals > -1
|
||||||
if n > 0 || hasInit {
|
if n > 0 || hasInit {
|
||||||
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
||||||
c.initEndOffset = c.prog.Len()
|
c.initEndOffset = c.prog.Len()
|
||||||
|
|
|
@ -152,10 +152,10 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *funcScope) countLocals() int {
|
func countLocals(decl *ast.FuncDecl) (int, bool) {
|
||||||
size := 0
|
size := 0
|
||||||
hasDefer := false
|
hasDefer := false
|
||||||
ast.Inspect(c.decl, func(n ast.Node) bool {
|
ast.Inspect(decl, func(n ast.Node) bool {
|
||||||
switch n := n.(type) {
|
switch n := n.(type) {
|
||||||
case *ast.FuncType:
|
case *ast.FuncType:
|
||||||
num := n.Results.NumFields()
|
num := n.Results.NumFields()
|
||||||
|
@ -186,6 +186,11 @@ func (c *funcScope) countLocals() int {
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
return size, hasDefer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *funcScope) countLocals() int {
|
||||||
|
size, hasDefer := countLocals(c.decl)
|
||||||
if hasDefer {
|
if hasDefer {
|
||||||
c.finallyProcessedIndex = size
|
c.finallyProcessedIndex = size
|
||||||
size++
|
size++
|
||||||
|
|
|
@ -24,11 +24,13 @@ func TestInit(t *testing.T) {
|
||||||
var m = map[int]int{}
|
var m = map[int]int{}
|
||||||
var a = 2
|
var a = 2
|
||||||
func init() {
|
func init() {
|
||||||
m[1] = 11
|
b := 11
|
||||||
|
m[1] = b
|
||||||
}
|
}
|
||||||
func init() {
|
func init() {
|
||||||
a = 1
|
a = 1
|
||||||
m[3] = 30
|
var b int
|
||||||
|
m[3] = 30 + b
|
||||||
}
|
}
|
||||||
func Main() int {
|
func Main() int {
|
||||||
return m[1] + m[3] + a
|
return m[1] + m[3] + a
|
||||||
|
|
Loading…
Reference in a new issue