Merge pull request #1711 from nspcc-dev/compiler/inline
Allow to inline internal wrappers in compiler
This commit is contained in:
commit
4d0681d898
10 changed files with 483 additions and 51 deletions
|
@ -55,14 +55,14 @@ func (c *codegen) traverseGlobals() (int, int, int) {
|
||||||
switch n := node.(type) {
|
switch n := node.(type) {
|
||||||
case *ast.FuncDecl:
|
case *ast.FuncDecl:
|
||||||
if isInitFunc(n) {
|
if isInitFunc(n) {
|
||||||
c, _ := countLocals(n)
|
num, _ := c.countLocals(n)
|
||||||
if c > initLocals {
|
if num > initLocals {
|
||||||
initLocals = c
|
initLocals = num
|
||||||
}
|
}
|
||||||
} else if isDeployFunc(n) {
|
} else if isDeployFunc(n) {
|
||||||
c, _ := countLocals(n)
|
num, _ := c.countLocals(n)
|
||||||
if c > deployLocals {
|
if num > deployLocals {
|
||||||
deployLocals = c
|
deployLocals = num
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return !hasDefer
|
return !hasDefer
|
||||||
|
@ -175,10 +175,16 @@ func (f funcUsage) funcUsed(name string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// lastStmtIsReturn checks if last statement of the declaration was return statement..
|
// lastStmtIsReturn checks if last statement of the declaration was return statement..
|
||||||
func lastStmtIsReturn(decl *ast.FuncDecl) (b bool) {
|
func lastStmtIsReturn(body *ast.BlockStmt) (b bool) {
|
||||||
if l := len(decl.Body.List); l != 0 {
|
if l := len(body.List); l != 0 {
|
||||||
_, ok := decl.Body.List[l-1].(*ast.ReturnStmt)
|
switch inner := body.List[l-1].(type) {
|
||||||
return ok
|
case *ast.BlockStmt:
|
||||||
|
return lastStmtIsReturn(inner)
|
||||||
|
case *ast.ReturnStmt:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -298,3 +304,11 @@ func canConvert(s string) bool {
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// canInline returns true if function is to be inlined.
|
||||||
|
// Currently there is a static list of function which are inlined,
|
||||||
|
// this may change in future.
|
||||||
|
func canInline(s string) bool {
|
||||||
|
return isNativeHelpersPath(s) ||
|
||||||
|
strings.HasPrefix(s, "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline")
|
||||||
|
}
|
||||||
|
|
|
@ -31,6 +31,8 @@ type codegen struct {
|
||||||
|
|
||||||
// Type information.
|
// Type information.
|
||||||
typeInfo *types.Info
|
typeInfo *types.Info
|
||||||
|
// pkgInfoInline is stack of type information for packages containing inline functions.
|
||||||
|
pkgInfoInline []*loader.PackageInfo
|
||||||
|
|
||||||
// A mapping of func identifiers with their scope.
|
// A mapping of func identifiers with their scope.
|
||||||
funcs map[string]*funcScope
|
funcs map[string]*funcScope
|
||||||
|
@ -192,20 +194,21 @@ func (c *codegen) emitStoreStructField(i int) {
|
||||||
|
|
||||||
// getVarIndex returns variable type and position in corresponding slot,
|
// getVarIndex returns variable type and position in corresponding slot,
|
||||||
// according to current scope.
|
// according to current scope.
|
||||||
func (c *codegen) getVarIndex(pkg string, name string) (varType, int) {
|
func (c *codegen) getVarIndex(pkg string, name string) *varInfo {
|
||||||
if pkg == "" {
|
if pkg == "" {
|
||||||
if c.scope != nil {
|
if c.scope != nil {
|
||||||
vt, val := c.scope.vars.getVarIndex(name)
|
vi := c.scope.vars.getVarInfo(name)
|
||||||
if val >= 0 {
|
if vi != nil {
|
||||||
return vt, val
|
return vi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if i, ok := c.globals[c.getIdentName(pkg, name)]; ok {
|
if i, ok := c.globals[c.getIdentName(pkg, name)]; ok {
|
||||||
return varGlobal, i
|
return &varInfo{refType: varGlobal, index: i}
|
||||||
}
|
}
|
||||||
|
|
||||||
return varLocal, c.scope.newVariable(varLocal, name)
|
c.scope.newVariable(varLocal, name)
|
||||||
|
return c.scope.vars.getVarInfo(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) {
|
func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) {
|
||||||
|
@ -223,8 +226,15 @@ func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) {
|
||||||
|
|
||||||
// emitLoadVar loads specified variable to the evaluation stack.
|
// emitLoadVar loads specified variable to the evaluation stack.
|
||||||
func (c *codegen) emitLoadVar(pkg string, name string) {
|
func (c *codegen) emitLoadVar(pkg string, name string) {
|
||||||
t, i := c.getVarIndex(pkg, name)
|
vi := c.getVarIndex(pkg, name)
|
||||||
c.emitLoadByIndex(t, i)
|
if vi.tv.Value != nil {
|
||||||
|
c.emitLoadConst(vi.tv)
|
||||||
|
return
|
||||||
|
} else if vi.index == unspecifiedVarIndex {
|
||||||
|
emit.Opcodes(c.prog.BinWriter, opcode.PUSHNULL)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.emitLoadByIndex(vi.refType, vi.index)
|
||||||
}
|
}
|
||||||
|
|
||||||
// emitLoadByIndex loads specified variable type with index i.
|
// emitLoadByIndex loads specified variable type with index i.
|
||||||
|
@ -243,8 +253,8 @@ func (c *codegen) emitStoreVar(pkg string, name string) {
|
||||||
emit.Opcodes(c.prog.BinWriter, opcode.DROP)
|
emit.Opcodes(c.prog.BinWriter, opcode.DROP)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t, i := c.getVarIndex(pkg, name)
|
vi := c.getVarIndex(pkg, name)
|
||||||
c.emitStoreByIndex(t, i)
|
c.emitStoreByIndex(vi.refType, vi.index)
|
||||||
}
|
}
|
||||||
|
|
||||||
// emitLoadByIndex stores top value in the specified variable type with index i.
|
// emitLoadByIndex stores top value in the specified variable type with index i.
|
||||||
|
@ -320,7 +330,7 @@ func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package, seenBefore b
|
||||||
case *ast.FuncDecl:
|
case *ast.FuncDecl:
|
||||||
if isInitFunc(n) {
|
if isInitFunc(n) {
|
||||||
if seenBefore {
|
if seenBefore {
|
||||||
cnt, _ := countLocals(n)
|
cnt, _ := c.countLocals(n)
|
||||||
c.clearSlots(cnt)
|
c.clearSlots(cnt)
|
||||||
seenBefore = true
|
seenBefore = true
|
||||||
}
|
}
|
||||||
|
@ -352,7 +362,7 @@ func (c *codegen) convertDeployFuncs() {
|
||||||
case *ast.FuncDecl:
|
case *ast.FuncDecl:
|
||||||
if isDeployFunc(n) {
|
if isDeployFunc(n) {
|
||||||
if seenBefore {
|
if seenBefore {
|
||||||
cnt, _ := countLocals(n)
|
cnt, _ := c.countLocals(n)
|
||||||
c.clearSlots(cnt)
|
c.clearSlots(cnt)
|
||||||
}
|
}
|
||||||
c.convertFuncDecl(f, n, pkg)
|
c.convertFuncDecl(f, n, pkg)
|
||||||
|
@ -398,7 +408,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
|
||||||
// All globals copied into the scope of the function need to be added
|
// All globals copied into the scope of the function need to be added
|
||||||
// to the stack size of the function.
|
// to the stack size of the function.
|
||||||
if !isInit && !isDeploy {
|
if !isInit && !isDeploy {
|
||||||
sizeLoc := f.countLocals()
|
sizeLoc := c.countLocalsWithDefer(f)
|
||||||
if sizeLoc > 255 {
|
if sizeLoc > 255 {
|
||||||
c.prog.Err = errors.New("maximum of 255 local variables is allowed")
|
c.prog.Err = errors.New("maximum of 255 local variables is allowed")
|
||||||
}
|
}
|
||||||
|
@ -440,7 +450,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
|
||||||
// If we have reached the end of the function without encountering `return` statement,
|
// If we have reached the end of the function without encountering `return` statement,
|
||||||
// we should clean alt.stack manually.
|
// we should clean alt.stack manually.
|
||||||
// This can be the case with void and named-return functions.
|
// This can be the case with void and named-return functions.
|
||||||
if !isInit && !isDeploy && !lastStmtIsReturn(decl) {
|
if !isInit && !isDeploy && !lastStmtIsReturn(decl.Body) {
|
||||||
c.saveSequencePoint(decl.Body)
|
c.saveSequencePoint(decl.Body)
|
||||||
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
||||||
}
|
}
|
||||||
|
@ -623,7 +633,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
c.processDefers()
|
c.processDefers()
|
||||||
|
|
||||||
c.saveSequencePoint(n)
|
c.saveSequencePoint(n)
|
||||||
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
if len(c.pkgInfoInline) == 0 {
|
||||||
|
emit.Opcodes(c.prog.BinWriter, opcode.RET)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case *ast.IfStmt:
|
case *ast.IfStmt:
|
||||||
|
@ -800,7 +812,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
|
|
||||||
switch fun := n.Fun.(type) {
|
switch fun := n.Fun.(type) {
|
||||||
case *ast.Ident:
|
case *ast.Ident:
|
||||||
f, ok = c.funcs[c.getIdentName("", fun.Name)]
|
var pkgName string
|
||||||
|
if len(c.pkgInfoInline) != 0 {
|
||||||
|
pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path()
|
||||||
|
}
|
||||||
|
f, ok = c.funcs[c.getIdentName(pkgName, fun.Name)]
|
||||||
|
|
||||||
isBuiltin = isGoBuiltin(fun.Name)
|
isBuiltin = isGoBuiltin(fun.Name)
|
||||||
if !ok && !isBuiltin {
|
if !ok && !isBuiltin {
|
||||||
name = fun.Name
|
name = fun.Name
|
||||||
|
@ -809,6 +826,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
if fun.Obj != nil && fun.Obj.Kind == ast.Var {
|
if fun.Obj != nil && fun.Obj.Kind == ast.Var {
|
||||||
isFunc = true
|
isFunc = true
|
||||||
}
|
}
|
||||||
|
if ok && canInline(f.pkg.Path()) {
|
||||||
|
c.inlineCall(f, n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
case *ast.SelectorExpr:
|
case *ast.SelectorExpr:
|
||||||
// If this is a method call we need to walk the AST to load the struct locally.
|
// If this is a method call we need to walk the AST to load the struct locally.
|
||||||
// Otherwise this is a function call from a imported package and we can call it
|
// Otherwise this is a function call from a imported package and we can call it
|
||||||
|
@ -824,6 +845,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
if ok {
|
if ok {
|
||||||
f.selector = fun.X.(*ast.Ident)
|
f.selector = fun.X.(*ast.Ident)
|
||||||
isBuiltin = isCustomBuiltin(f)
|
isBuiltin = isCustomBuiltin(f)
|
||||||
|
if canInline(f.pkg.Path()) {
|
||||||
|
c.inlineCall(f, n)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
typ := c.typeOf(fun)
|
typ := c.typeOf(fun)
|
||||||
if _, ok := typ.(*types.Signature); ok {
|
if _, ok := typ.(*types.Signature); ok {
|
||||||
|
@ -867,10 +892,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
typ, ok := c.typeOf(n.Fun).(*types.Signature)
|
typ, ok := c.typeOf(n.Fun).(*types.Signature)
|
||||||
if ok && typ.Variadic() && !n.Ellipsis.IsValid() {
|
if ok && typ.Variadic() && !n.Ellipsis.IsValid() {
|
||||||
// pack variadic args into an array only if last argument is not of form `...`
|
// pack variadic args into an array only if last argument is not of form `...`
|
||||||
varSize := len(n.Args) - typ.Params().Len() + 1
|
varSize := c.packVarArgs(n, typ)
|
||||||
c.emitReverse(varSize)
|
|
||||||
emit.Int(c.prog.BinWriter, int64(varSize))
|
|
||||||
emit.Opcodes(c.prog.BinWriter, opcode.PACK)
|
|
||||||
numArgs -= varSize - 1
|
numArgs -= varSize - 1
|
||||||
}
|
}
|
||||||
c.emitReverse(numArgs)
|
c.emitReverse(numArgs)
|
||||||
|
@ -1207,6 +1229,16 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// packVarArgs packs variadic arguments into an array
|
||||||
|
// and returns amount of arguments packed.
|
||||||
|
func (c *codegen) packVarArgs(n *ast.CallExpr, typ *types.Signature) int {
|
||||||
|
varSize := len(n.Args) - typ.Params().Len() + 1
|
||||||
|
c.emitReverse(varSize)
|
||||||
|
emit.Int(c.prog.BinWriter, int64(varSize))
|
||||||
|
emit.Opcodes(c.prog.BinWriter, opcode.PACK)
|
||||||
|
return varSize
|
||||||
|
}
|
||||||
|
|
||||||
// processDefers emits code for `defer` statements.
|
// processDefers emits code for `defer` statements.
|
||||||
// TRY-related opcodes handle exception as follows:
|
// TRY-related opcodes handle exception as follows:
|
||||||
// 1. CATCH block is executed only if exception has occurred.
|
// 1. CATCH block is executed only if exception has occurred.
|
||||||
|
@ -1919,7 +1951,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
|
||||||
// of bytecode space.
|
// of bytecode space.
|
||||||
name := c.getFuncNameFromDecl(pkg.Path(), n)
|
name := c.getFuncNameFromDecl(pkg.Path(), n)
|
||||||
if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) &&
|
if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) &&
|
||||||
(!isInteropPath(pkg.Path()) || isNativeHelpersPath(pkg.Path())) {
|
(!isInteropPath(pkg.Path()) && !canInline(pkg.Path())) {
|
||||||
c.convertFuncDecl(f, n, pkg)
|
c.convertFuncDecl(f, n, pkg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1970,7 +2002,8 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
|
||||||
for _, decl := range f.Decls {
|
for _, decl := range f.Decls {
|
||||||
switch n := decl.(type) {
|
switch n := decl.(type) {
|
||||||
case *ast.FuncDecl:
|
case *ast.FuncDecl:
|
||||||
c.newFunc(n)
|
fs := c.newFunc(n)
|
||||||
|
fs.file = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,8 @@ type funcScope struct {
|
||||||
// Package where the function is defined.
|
// Package where the function is defined.
|
||||||
pkg *types.Package
|
pkg *types.Package
|
||||||
|
|
||||||
|
file *ast.File
|
||||||
|
|
||||||
// Program label of the scope
|
// Program label of the scope
|
||||||
label uint16
|
label uint16
|
||||||
|
|
||||||
|
@ -100,11 +102,47 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func countLocals(decl *ast.FuncDecl) (int, bool) {
|
func (c *codegen) countLocals(decl *ast.FuncDecl) (int, bool) {
|
||||||
|
return c.countLocalsInline(decl, nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *codegen) countLocalsInline(decl *ast.FuncDecl, pkg *types.Package, f *funcScope) (int, bool) {
|
||||||
|
oldMap := c.importMap
|
||||||
|
if pkg != nil {
|
||||||
|
c.fillImportMap(f.file, pkg)
|
||||||
|
}
|
||||||
|
|
||||||
size := 0
|
size := 0
|
||||||
hasDefer := false
|
hasDefer := false
|
||||||
ast.Inspect(decl, func(n ast.Node) bool {
|
ast.Inspect(decl, func(n ast.Node) bool {
|
||||||
switch n := n.(type) {
|
switch n := n.(type) {
|
||||||
|
case *ast.CallExpr:
|
||||||
|
var name string
|
||||||
|
switch fun := n.Fun.(type) {
|
||||||
|
case *ast.Ident:
|
||||||
|
var pkgName string
|
||||||
|
if pkg != nil {
|
||||||
|
pkgName = pkg.Path()
|
||||||
|
}
|
||||||
|
name = c.getIdentName(pkgName, fun.Name)
|
||||||
|
case *ast.SelectorExpr:
|
||||||
|
name, _ = c.getFuncNameFromSelector(fun)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if inner, ok := c.funcs[name]; ok && canInline(name) {
|
||||||
|
for i := range n.Args {
|
||||||
|
switch n.Args[i].(type) {
|
||||||
|
case *ast.Ident:
|
||||||
|
case *ast.BasicLit:
|
||||||
|
default:
|
||||||
|
size++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
innerSz, _ := c.countLocalsInline(inner.decl, inner.pkg, inner)
|
||||||
|
size += innerSz
|
||||||
|
}
|
||||||
|
return false
|
||||||
case *ast.FuncType:
|
case *ast.FuncType:
|
||||||
num := n.Results.NumFields()
|
num := n.Results.NumFields()
|
||||||
if num != 0 && len(n.Results.List[0].Names) != 0 {
|
if num != 0 && len(n.Results.List[0].Names) != 0 {
|
||||||
|
@ -117,7 +155,11 @@ func countLocals(decl *ast.FuncDecl) (int, bool) {
|
||||||
case *ast.DeferStmt:
|
case *ast.DeferStmt:
|
||||||
hasDefer = true
|
hasDefer = true
|
||||||
return false
|
return false
|
||||||
case *ast.ReturnStmt, *ast.IfStmt:
|
case *ast.ReturnStmt:
|
||||||
|
if pkg == nil {
|
||||||
|
size++
|
||||||
|
}
|
||||||
|
case *ast.IfStmt:
|
||||||
size++
|
size++
|
||||||
// This handles the inline GenDecl like "var x = 2"
|
// This handles the inline GenDecl like "var x = 2"
|
||||||
case *ast.ValueSpec:
|
case *ast.ValueSpec:
|
||||||
|
@ -134,13 +176,16 @@ func countLocals(decl *ast.FuncDecl) (int, bool) {
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
if pkg != nil {
|
||||||
|
c.importMap = oldMap
|
||||||
|
}
|
||||||
return size, hasDefer
|
return size, hasDefer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *funcScope) countLocals() int {
|
func (c *codegen) countLocalsWithDefer(f *funcScope) int {
|
||||||
size, hasDefer := countLocals(c.decl)
|
size, hasDefer := c.countLocals(f.decl)
|
||||||
if hasDefer {
|
if hasDefer {
|
||||||
c.finallyProcessedIndex = size
|
f.finallyProcessedIndex = size
|
||||||
size++
|
size++
|
||||||
}
|
}
|
||||||
return size
|
return size
|
||||||
|
@ -154,12 +199,6 @@ func (c *funcScope) countArgs() int {
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *funcScope) stackSize() int64 {
|
|
||||||
size := c.countLocals()
|
|
||||||
numArgs := c.countArgs()
|
|
||||||
return int64(size + numArgs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newVariable creates a new local variable or argument in the scope of the function.
|
// newVariable creates a new local variable or argument in the scope of the function.
|
||||||
func (c *funcScope) newVariable(t varType, name string) int {
|
func (c *funcScope) newVariable(t varType, name string) int {
|
||||||
return c.vars.newVariable(t, name)
|
return c.vars.newVariable(t, name)
|
||||||
|
|
94
pkg/compiler/inline.go
Normal file
94
pkg/compiler/inline.go
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
package compiler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go/ast"
|
||||||
|
"go/types"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// inlineCall inlines call of n for function represented by f.
|
||||||
|
// Call `f(a,b)` for definition `func f(x,y int)` is translated to block:
|
||||||
|
// {
|
||||||
|
// x := a
|
||||||
|
// y := b
|
||||||
|
// <inline body of f directly>
|
||||||
|
// }
|
||||||
|
func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
|
||||||
|
pkg := c.buildInfo.program.Package(f.pkg.Path())
|
||||||
|
sig := c.typeOf(n.Fun).(*types.Signature)
|
||||||
|
|
||||||
|
// Arguments need to be walked with the current scope,
|
||||||
|
// while stored in the new.
|
||||||
|
oldScope := c.scope.vars.locals
|
||||||
|
c.scope.vars.newScope()
|
||||||
|
newScope := c.scope.vars.locals
|
||||||
|
defer c.scope.vars.dropScope()
|
||||||
|
|
||||||
|
hasVarArgs := !n.Ellipsis.IsValid()
|
||||||
|
needPack := sig.Variadic() && hasVarArgs
|
||||||
|
for i := range n.Args {
|
||||||
|
c.scope.vars.locals = oldScope
|
||||||
|
// true if normal arg or var arg is `slice...`
|
||||||
|
needStore := i < sig.Params().Len()-1 || !sig.Variadic() || !hasVarArgs
|
||||||
|
if !needStore {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
name := sig.Params().At(i).Name()
|
||||||
|
if tv := c.typeAndValueOf(n.Args[i]); tv.Value != nil {
|
||||||
|
c.scope.vars.locals = newScope
|
||||||
|
c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, tv)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if arg, ok := n.Args[i].(*ast.Ident); ok {
|
||||||
|
// When function argument is variable or const, we may avoid
|
||||||
|
// introducing additional variables for parameters.
|
||||||
|
// This is done by providing additional alias to variable.
|
||||||
|
if vi := c.scope.vars.getVarInfo(arg.Name); vi != nil {
|
||||||
|
c.scope.vars.locals = newScope
|
||||||
|
c.scope.vars.addAlias(name, vi.refType, vi.index, vi.tv)
|
||||||
|
continue
|
||||||
|
} else if arg.Name == "nil" {
|
||||||
|
c.scope.vars.locals = newScope
|
||||||
|
c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, types.TypeAndValue{})
|
||||||
|
continue
|
||||||
|
} else if index, ok := c.globals[c.getIdentName("", arg.Name)]; ok {
|
||||||
|
c.scope.vars.locals = newScope
|
||||||
|
c.scope.vars.addAlias(name, varGlobal, index, types.TypeAndValue{})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast.Walk(c, n.Args[i])
|
||||||
|
c.scope.vars.locals = newScope
|
||||||
|
c.scope.newLocal(name)
|
||||||
|
c.emitStoreVar("", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if needPack {
|
||||||
|
// traverse variadic args and pack them
|
||||||
|
// if they are provided directly i.e. without `...`
|
||||||
|
c.scope.vars.locals = oldScope
|
||||||
|
for i := sig.Params().Len() - 1; i < len(n.Args); i++ {
|
||||||
|
ast.Walk(c, n.Args[i])
|
||||||
|
}
|
||||||
|
c.scope.vars.locals = newScope
|
||||||
|
c.packVarArgs(n, sig)
|
||||||
|
name := sig.Params().At(sig.Params().Len() - 1).Name()
|
||||||
|
c.scope.newLocal(name)
|
||||||
|
c.emitStoreVar("", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.pkgInfoInline = append(c.pkgInfoInline, pkg)
|
||||||
|
oldMap := c.importMap
|
||||||
|
c.fillImportMap(f.file, pkg.Pkg)
|
||||||
|
ast.Inspect(f.decl, c.scope.analyzeVoidCalls)
|
||||||
|
ast.Walk(c, f.decl.Body)
|
||||||
|
if c.scope.voidCalls[n] {
|
||||||
|
for i := 0; i < f.decl.Type.Results.NumFields(); i++ {
|
||||||
|
emit.Opcodes(c.prog.BinWriter, opcode.DROP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.importMap = oldMap
|
||||||
|
c.pkgInfoInline = c.pkgInfoInline[:len(c.pkgInfoInline)-1]
|
||||||
|
}
|
163
pkg/compiler/inline_test.go
Normal file
163
pkg/compiler/inline_test.go
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
package compiler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/compiler"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int) {
|
||||||
|
v := vmAndCompile(t, src)
|
||||||
|
ctx := v.Context()
|
||||||
|
actualCall := 0
|
||||||
|
actualInitSlot := 0
|
||||||
|
|
||||||
|
for op, _, err := ctx.Next(); ; op, _, err = ctx.Next() {
|
||||||
|
require.NoError(t, err)
|
||||||
|
switch op {
|
||||||
|
case opcode.CALL, opcode.CALLL:
|
||||||
|
actualCall++
|
||||||
|
case opcode.INITSLOT:
|
||||||
|
actualInitSlot++
|
||||||
|
}
|
||||||
|
if ctx.IP() == ctx.LenInstr() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, expectedCall, actualCall)
|
||||||
|
require.Equal(t, expectedInitSlot, actualInitSlot)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInline(t *testing.T) {
|
||||||
|
srcTmpl := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||||
|
// local alias
|
||||||
|
func sum(a, b int) int {
|
||||||
|
return 42
|
||||||
|
}
|
||||||
|
var Num = 1
|
||||||
|
func Main() int {
|
||||||
|
%s
|
||||||
|
}`
|
||||||
|
t.Run("no return", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `inline.NoArgsNoReturn()
|
||||||
|
return 1`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(1))
|
||||||
|
})
|
||||||
|
t.Run("has return, dropped", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `inline.NoArgsReturn1()
|
||||||
|
return 2`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(2))
|
||||||
|
})
|
||||||
|
t.Run("drop twice", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `inline.DropInsideInline()
|
||||||
|
return 42`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(42))
|
||||||
|
})
|
||||||
|
t.Run("no args return 1", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.NoArgsReturn1()`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(1))
|
||||||
|
})
|
||||||
|
t.Run("sum", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.Sum(1, 2)`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(3))
|
||||||
|
})
|
||||||
|
t.Run("sum squared (nested inline)", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.SumSquared(1, 2)`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(9))
|
||||||
|
})
|
||||||
|
t.Run("inline function in inline function parameter", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), inline.Sum(3, 4))`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(9+3+4))
|
||||||
|
})
|
||||||
|
t.Run("global name clash", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.GetSumSameName()`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(42))
|
||||||
|
})
|
||||||
|
t.Run("local name clash", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), sum(3, 4))`)
|
||||||
|
checkCallCount(t, src, 1, 2)
|
||||||
|
eval(t, src, big.NewInt(51))
|
||||||
|
})
|
||||||
|
t.Run("var args, empty", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11)`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(11))
|
||||||
|
})
|
||||||
|
t.Run("var args, direct", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11, 14, 17)`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(42))
|
||||||
|
})
|
||||||
|
t.Run("var args, array", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `arr := []int{14, 17}
|
||||||
|
return inline.VarSum(11, arr...)`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(42))
|
||||||
|
})
|
||||||
|
t.Run("globals", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, `return inline.Concat(Num)`)
|
||||||
|
checkCallCount(t, src, 0, 1)
|
||||||
|
eval(t, src, big.NewInt(221))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineConversion(t *testing.T) {
|
||||||
|
src1 := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||||
|
var _ = inline.A
|
||||||
|
func Main() int {
|
||||||
|
a := 2
|
||||||
|
return inline.SumSquared(1, a)
|
||||||
|
}`
|
||||||
|
b1, err := compiler.Compile("foo.go", strings.NewReader(src1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
src2 := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||||
|
var _ = inline.A
|
||||||
|
func Main() int {
|
||||||
|
a := 2
|
||||||
|
{
|
||||||
|
return (1 + a) * (1 + a)
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
b2, err := compiler.Compile("foo.go", strings.NewReader(src2))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, b2, b1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineConversionQualified(t *testing.T) {
|
||||||
|
src1 := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||||
|
var A = 1
|
||||||
|
func Main() int {
|
||||||
|
return inline.Concat(A)
|
||||||
|
}`
|
||||||
|
b1, err := compiler.Compile("foo.go", strings.NewReader(src1))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
src2 := `package foo
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||||
|
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/b"
|
||||||
|
var A = 1
|
||||||
|
func Main() int {
|
||||||
|
return A * 100 + b.A * 10 + inline.A
|
||||||
|
}`
|
||||||
|
b2, err := compiler.Compile("foo.go", strings.NewReader(src2))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, b2, b1)
|
||||||
|
}
|
7
pkg/compiler/testdata/inline/a/a.go
vendored
Normal file
7
pkg/compiler/testdata/inline/a/a.go
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package a
|
||||||
|
|
||||||
|
var A = 29
|
||||||
|
|
||||||
|
func GetA() int {
|
||||||
|
return A
|
||||||
|
}
|
7
pkg/compiler/testdata/inline/b/b.go
vendored
Normal file
7
pkg/compiler/testdata/inline/b/b.go
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package b
|
||||||
|
|
||||||
|
var A = 12
|
||||||
|
|
||||||
|
func GetA() int {
|
||||||
|
return A
|
||||||
|
}
|
44
pkg/compiler/testdata/inline/inline.go
vendored
Normal file
44
pkg/compiler/testdata/inline/inline.go
vendored
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package inline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/a"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/b"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NoArgsNoReturn() {}
|
||||||
|
func NoArgsReturn1() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
func Sum(a, b int) int {
|
||||||
|
return a + b
|
||||||
|
}
|
||||||
|
func sum(x, y int) int {
|
||||||
|
return x + y
|
||||||
|
}
|
||||||
|
func SumSquared(a, b int) int {
|
||||||
|
return sum(a, b) * (a + b)
|
||||||
|
}
|
||||||
|
|
||||||
|
var A = 1
|
||||||
|
|
||||||
|
func GetSumSameName() int {
|
||||||
|
return a.GetA() + b.GetA() + A
|
||||||
|
}
|
||||||
|
|
||||||
|
func DropInsideInline() int {
|
||||||
|
sum(1, 2)
|
||||||
|
sum(3, 4)
|
||||||
|
return 7
|
||||||
|
}
|
||||||
|
|
||||||
|
func VarSum(a int, b ...int) int {
|
||||||
|
sum := a
|
||||||
|
for i := range b {
|
||||||
|
sum += b[i]
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
func Concat(n int) int {
|
||||||
|
return n*100 + b.A*10 + A
|
||||||
|
}
|
|
@ -8,6 +8,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue {
|
func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue {
|
||||||
|
for i := len(c.pkgInfoInline) - 1; i >= 0; i-- {
|
||||||
|
if tv, ok := c.pkgInfoInline[i].Types[e]; ok {
|
||||||
|
return tv
|
||||||
|
}
|
||||||
|
}
|
||||||
return c.typeInfo.Types[e]
|
return c.typeInfo.Types[e]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,24 @@
|
||||||
package compiler
|
package compiler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go/types"
|
||||||
|
)
|
||||||
|
|
||||||
type varScope struct {
|
type varScope struct {
|
||||||
localsCnt int
|
localsCnt int
|
||||||
argCnt int
|
argCnt int
|
||||||
arguments map[string]int
|
arguments map[string]int
|
||||||
locals []map[string]int
|
locals []map[string]varInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type varInfo struct {
|
||||||
|
refType varType
|
||||||
|
index int
|
||||||
|
tv types.TypeAndValue
|
||||||
|
}
|
||||||
|
|
||||||
|
const unspecifiedVarIndex = -1
|
||||||
|
|
||||||
func newVarScope() varScope {
|
func newVarScope() varScope {
|
||||||
return varScope{
|
return varScope{
|
||||||
arguments: make(map[string]int),
|
arguments: make(map[string]int),
|
||||||
|
@ -14,23 +26,34 @@ func newVarScope() varScope {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *varScope) newScope() {
|
func (c *varScope) newScope() {
|
||||||
c.locals = append(c.locals, map[string]int{})
|
c.locals = append(c.locals, map[string]varInfo{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *varScope) dropScope() {
|
func (c *varScope) dropScope() {
|
||||||
c.locals = c.locals[:len(c.locals)-1]
|
c.locals = c.locals[:len(c.locals)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *varScope) getVarIndex(name string) (varType, int) {
|
func (c *varScope) addAlias(name string, vt varType, index int, tv types.TypeAndValue) {
|
||||||
|
c.locals[len(c.locals)-1][name] = varInfo{
|
||||||
|
refType: vt,
|
||||||
|
index: index,
|
||||||
|
tv: tv,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *varScope) getVarInfo(name string) *varInfo {
|
||||||
for i := len(c.locals) - 1; i >= 0; i-- {
|
for i := len(c.locals) - 1; i >= 0; i-- {
|
||||||
if i, ok := c.locals[i][name]; ok {
|
if vi, ok := c.locals[i][name]; ok {
|
||||||
return varLocal, i
|
return &vi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if i, ok := c.arguments[name]; ok {
|
if i, ok := c.arguments[name]; ok {
|
||||||
return varArgument, i
|
return &varInfo{
|
||||||
|
refType: varArgument,
|
||||||
|
index: i,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return 0, -1
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// newVariable creates a new local variable or argument in the scope of the function.
|
// newVariable creates a new local variable or argument in the scope of the function.
|
||||||
|
@ -56,7 +79,10 @@ func (c *varScope) newVariable(t varType, name string) int {
|
||||||
func (c *varScope) newLocal(name string) int {
|
func (c *varScope) newLocal(name string) int {
|
||||||
idx := len(c.locals) - 1
|
idx := len(c.locals) - 1
|
||||||
m := c.locals[idx]
|
m := c.locals[idx]
|
||||||
m[name] = c.localsCnt
|
m[name] = varInfo{
|
||||||
|
refType: varLocal,
|
||||||
|
index: c.localsCnt,
|
||||||
|
}
|
||||||
c.localsCnt++
|
c.localsCnt++
|
||||||
c.locals[idx] = m
|
c.locals[idx] = m
|
||||||
return c.localsCnt - 1
|
return c.localsCnt - 1
|
||||||
|
|
Loading…
Reference in a new issue