Merge pull request #1711 from nspcc-dev/compiler/inline

Allow to inline internal wrappers in compiler
This commit is contained in:
Roman Khimov 2021-02-15 19:05:30 +03:00 committed by GitHub
commit 4d0681d898
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 483 additions and 51 deletions

View file

@ -55,14 +55,14 @@ func (c *codegen) traverseGlobals() (int, int, int) {
switch n := node.(type) { switch n := node.(type) {
case *ast.FuncDecl: case *ast.FuncDecl:
if isInitFunc(n) { if isInitFunc(n) {
c, _ := countLocals(n) num, _ := c.countLocals(n)
if c > initLocals { if num > initLocals {
initLocals = c initLocals = num
} }
} else if isDeployFunc(n) { } else if isDeployFunc(n) {
c, _ := countLocals(n) num, _ := c.countLocals(n)
if c > deployLocals { if num > deployLocals {
deployLocals = c deployLocals = num
} }
} }
return !hasDefer return !hasDefer
@ -175,10 +175,16 @@ func (f funcUsage) funcUsed(name string) bool {
} }
// lastStmtIsReturn checks if last statement of the declaration was return statement.. // lastStmtIsReturn checks if last statement of the declaration was return statement..
func lastStmtIsReturn(decl *ast.FuncDecl) (b bool) { func lastStmtIsReturn(body *ast.BlockStmt) (b bool) {
if l := len(decl.Body.List); l != 0 { if l := len(body.List); l != 0 {
_, ok := decl.Body.List[l-1].(*ast.ReturnStmt) switch inner := body.List[l-1].(type) {
return ok case *ast.BlockStmt:
return lastStmtIsReturn(inner)
case *ast.ReturnStmt:
return true
default:
return false
}
} }
return false return false
} }
@ -298,3 +304,11 @@ func canConvert(s string) bool {
} }
return true return true
} }
// canInline returns true if function is to be inlined.
// Currently there is a static list of function which are inlined,
// this may change in future.
func canInline(s string) bool {
return isNativeHelpersPath(s) ||
strings.HasPrefix(s, "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline")
}

View file

@ -31,6 +31,8 @@ type codegen struct {
// Type information. // Type information.
typeInfo *types.Info typeInfo *types.Info
// pkgInfoInline is stack of type information for packages containing inline functions.
pkgInfoInline []*loader.PackageInfo
// A mapping of func identifiers with their scope. // A mapping of func identifiers with their scope.
funcs map[string]*funcScope funcs map[string]*funcScope
@ -192,20 +194,21 @@ func (c *codegen) emitStoreStructField(i int) {
// getVarIndex returns variable type and position in corresponding slot, // getVarIndex returns variable type and position in corresponding slot,
// according to current scope. // according to current scope.
func (c *codegen) getVarIndex(pkg string, name string) (varType, int) { func (c *codegen) getVarIndex(pkg string, name string) *varInfo {
if pkg == "" { if pkg == "" {
if c.scope != nil { if c.scope != nil {
vt, val := c.scope.vars.getVarIndex(name) vi := c.scope.vars.getVarInfo(name)
if val >= 0 { if vi != nil {
return vt, val return vi
} }
} }
} }
if i, ok := c.globals[c.getIdentName(pkg, name)]; ok { if i, ok := c.globals[c.getIdentName(pkg, name)]; ok {
return varGlobal, i return &varInfo{refType: varGlobal, index: i}
} }
return varLocal, c.scope.newVariable(varLocal, name) c.scope.newVariable(varLocal, name)
return c.scope.vars.getVarInfo(name)
} }
func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) { func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) {
@ -223,8 +226,15 @@ func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) {
// emitLoadVar loads specified variable to the evaluation stack. // emitLoadVar loads specified variable to the evaluation stack.
func (c *codegen) emitLoadVar(pkg string, name string) { func (c *codegen) emitLoadVar(pkg string, name string) {
t, i := c.getVarIndex(pkg, name) vi := c.getVarIndex(pkg, name)
c.emitLoadByIndex(t, i) if vi.tv.Value != nil {
c.emitLoadConst(vi.tv)
return
} else if vi.index == unspecifiedVarIndex {
emit.Opcodes(c.prog.BinWriter, opcode.PUSHNULL)
return
}
c.emitLoadByIndex(vi.refType, vi.index)
} }
// emitLoadByIndex loads specified variable type with index i. // emitLoadByIndex loads specified variable type with index i.
@ -243,8 +253,8 @@ func (c *codegen) emitStoreVar(pkg string, name string) {
emit.Opcodes(c.prog.BinWriter, opcode.DROP) emit.Opcodes(c.prog.BinWriter, opcode.DROP)
return return
} }
t, i := c.getVarIndex(pkg, name) vi := c.getVarIndex(pkg, name)
c.emitStoreByIndex(t, i) c.emitStoreByIndex(vi.refType, vi.index)
} }
// emitLoadByIndex stores top value in the specified variable type with index i. // emitLoadByIndex stores top value in the specified variable type with index i.
@ -320,7 +330,7 @@ func (c *codegen) convertInitFuncs(f *ast.File, pkg *types.Package, seenBefore b
case *ast.FuncDecl: case *ast.FuncDecl:
if isInitFunc(n) { if isInitFunc(n) {
if seenBefore { if seenBefore {
cnt, _ := countLocals(n) cnt, _ := c.countLocals(n)
c.clearSlots(cnt) c.clearSlots(cnt)
seenBefore = true seenBefore = true
} }
@ -352,7 +362,7 @@ func (c *codegen) convertDeployFuncs() {
case *ast.FuncDecl: case *ast.FuncDecl:
if isDeployFunc(n) { if isDeployFunc(n) {
if seenBefore { if seenBefore {
cnt, _ := countLocals(n) cnt, _ := c.countLocals(n)
c.clearSlots(cnt) c.clearSlots(cnt)
} }
c.convertFuncDecl(f, n, pkg) c.convertFuncDecl(f, n, pkg)
@ -398,7 +408,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
// All globals copied into the scope of the function need to be added // All globals copied into the scope of the function need to be added
// to the stack size of the function. // to the stack size of the function.
if !isInit && !isDeploy { if !isInit && !isDeploy {
sizeLoc := f.countLocals() sizeLoc := c.countLocalsWithDefer(f)
if sizeLoc > 255 { if sizeLoc > 255 {
c.prog.Err = errors.New("maximum of 255 local variables is allowed") c.prog.Err = errors.New("maximum of 255 local variables is allowed")
} }
@ -440,7 +450,7 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl, pkg *types.
// If we have reached the end of the function without encountering `return` statement, // If we have reached the end of the function without encountering `return` statement,
// we should clean alt.stack manually. // we should clean alt.stack manually.
// This can be the case with void and named-return functions. // This can be the case with void and named-return functions.
if !isInit && !isDeploy && !lastStmtIsReturn(decl) { if !isInit && !isDeploy && !lastStmtIsReturn(decl.Body) {
c.saveSequencePoint(decl.Body) c.saveSequencePoint(decl.Body)
emit.Opcodes(c.prog.BinWriter, opcode.RET) emit.Opcodes(c.prog.BinWriter, opcode.RET)
} }
@ -623,7 +633,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.processDefers() c.processDefers()
c.saveSequencePoint(n) c.saveSequencePoint(n)
if len(c.pkgInfoInline) == 0 {
emit.Opcodes(c.prog.BinWriter, opcode.RET) emit.Opcodes(c.prog.BinWriter, opcode.RET)
}
return nil return nil
case *ast.IfStmt: case *ast.IfStmt:
@ -800,7 +812,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
switch fun := n.Fun.(type) { switch fun := n.Fun.(type) {
case *ast.Ident: case *ast.Ident:
f, ok = c.funcs[c.getIdentName("", fun.Name)] var pkgName string
if len(c.pkgInfoInline) != 0 {
pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path()
}
f, ok = c.funcs[c.getIdentName(pkgName, fun.Name)]
isBuiltin = isGoBuiltin(fun.Name) isBuiltin = isGoBuiltin(fun.Name)
if !ok && !isBuiltin { if !ok && !isBuiltin {
name = fun.Name name = fun.Name
@ -809,6 +826,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
if fun.Obj != nil && fun.Obj.Kind == ast.Var { if fun.Obj != nil && fun.Obj.Kind == ast.Var {
isFunc = true isFunc = true
} }
if ok && canInline(f.pkg.Path()) {
c.inlineCall(f, n)
return nil
}
case *ast.SelectorExpr: case *ast.SelectorExpr:
// If this is a method call we need to walk the AST to load the struct locally. // If this is a method call we need to walk the AST to load the struct locally.
// Otherwise this is a function call from a imported package and we can call it // Otherwise this is a function call from a imported package and we can call it
@ -824,6 +845,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
if ok { if ok {
f.selector = fun.X.(*ast.Ident) f.selector = fun.X.(*ast.Ident)
isBuiltin = isCustomBuiltin(f) isBuiltin = isCustomBuiltin(f)
if canInline(f.pkg.Path()) {
c.inlineCall(f, n)
return nil
}
} else { } else {
typ := c.typeOf(fun) typ := c.typeOf(fun)
if _, ok := typ.(*types.Signature); ok { if _, ok := typ.(*types.Signature); ok {
@ -867,10 +892,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
typ, ok := c.typeOf(n.Fun).(*types.Signature) typ, ok := c.typeOf(n.Fun).(*types.Signature)
if ok && typ.Variadic() && !n.Ellipsis.IsValid() { if ok && typ.Variadic() && !n.Ellipsis.IsValid() {
// pack variadic args into an array only if last argument is not of form `...` // pack variadic args into an array only if last argument is not of form `...`
varSize := len(n.Args) - typ.Params().Len() + 1 varSize := c.packVarArgs(n, typ)
c.emitReverse(varSize)
emit.Int(c.prog.BinWriter, int64(varSize))
emit.Opcodes(c.prog.BinWriter, opcode.PACK)
numArgs -= varSize - 1 numArgs -= varSize - 1
} }
c.emitReverse(numArgs) c.emitReverse(numArgs)
@ -1207,6 +1229,16 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return c return c
} }
// packVarArgs packs variadic arguments into an array
// and returns amount of arguments packed.
func (c *codegen) packVarArgs(n *ast.CallExpr, typ *types.Signature) int {
varSize := len(n.Args) - typ.Params().Len() + 1
c.emitReverse(varSize)
emit.Int(c.prog.BinWriter, int64(varSize))
emit.Opcodes(c.prog.BinWriter, opcode.PACK)
return varSize
}
// processDefers emits code for `defer` statements. // processDefers emits code for `defer` statements.
// TRY-related opcodes handle exception as follows: // TRY-related opcodes handle exception as follows:
// 1. CATCH block is executed only if exception has occurred. // 1. CATCH block is executed only if exception has occurred.
@ -1919,7 +1951,7 @@ func (c *codegen) compile(info *buildInfo, pkg *loader.PackageInfo) error {
// of bytecode space. // of bytecode space.
name := c.getFuncNameFromDecl(pkg.Path(), n) name := c.getFuncNameFromDecl(pkg.Path(), n)
if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) && if !isInitFunc(n) && !isDeployFunc(n) && funUsage.funcUsed(name) &&
(!isInteropPath(pkg.Path()) || isNativeHelpersPath(pkg.Path())) { (!isInteropPath(pkg.Path()) && !canInline(pkg.Path())) {
c.convertFuncDecl(f, n, pkg) c.convertFuncDecl(f, n, pkg)
} }
} }
@ -1970,7 +2002,8 @@ func (c *codegen) resolveFuncDecls(f *ast.File, pkg *types.Package) {
for _, decl := range f.Decls { for _, decl := range f.Decls {
switch n := decl.(type) { switch n := decl.(type) {
case *ast.FuncDecl: case *ast.FuncDecl:
c.newFunc(n) fs := c.newFunc(n)
fs.file = f
} }
} }
} }

View file

@ -22,6 +22,8 @@ type funcScope struct {
// Package where the function is defined. // Package where the function is defined.
pkg *types.Package pkg *types.Package
file *ast.File
// Program label of the scope // Program label of the scope
label uint16 label uint16
@ -100,11 +102,47 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
return true return true
} }
func countLocals(decl *ast.FuncDecl) (int, bool) { func (c *codegen) countLocals(decl *ast.FuncDecl) (int, bool) {
return c.countLocalsInline(decl, nil, nil)
}
func (c *codegen) countLocalsInline(decl *ast.FuncDecl, pkg *types.Package, f *funcScope) (int, bool) {
oldMap := c.importMap
if pkg != nil {
c.fillImportMap(f.file, pkg)
}
size := 0 size := 0
hasDefer := false hasDefer := false
ast.Inspect(decl, func(n ast.Node) bool { ast.Inspect(decl, func(n ast.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *ast.CallExpr:
var name string
switch fun := n.Fun.(type) {
case *ast.Ident:
var pkgName string
if pkg != nil {
pkgName = pkg.Path()
}
name = c.getIdentName(pkgName, fun.Name)
case *ast.SelectorExpr:
name, _ = c.getFuncNameFromSelector(fun)
default:
return false
}
if inner, ok := c.funcs[name]; ok && canInline(name) {
for i := range n.Args {
switch n.Args[i].(type) {
case *ast.Ident:
case *ast.BasicLit:
default:
size++
}
}
innerSz, _ := c.countLocalsInline(inner.decl, inner.pkg, inner)
size += innerSz
}
return false
case *ast.FuncType: case *ast.FuncType:
num := n.Results.NumFields() num := n.Results.NumFields()
if num != 0 && len(n.Results.List[0].Names) != 0 { if num != 0 && len(n.Results.List[0].Names) != 0 {
@ -117,7 +155,11 @@ func countLocals(decl *ast.FuncDecl) (int, bool) {
case *ast.DeferStmt: case *ast.DeferStmt:
hasDefer = true hasDefer = true
return false return false
case *ast.ReturnStmt, *ast.IfStmt: case *ast.ReturnStmt:
if pkg == nil {
size++
}
case *ast.IfStmt:
size++ size++
// This handles the inline GenDecl like "var x = 2" // This handles the inline GenDecl like "var x = 2"
case *ast.ValueSpec: case *ast.ValueSpec:
@ -134,13 +176,16 @@ func countLocals(decl *ast.FuncDecl) (int, bool) {
} }
return true return true
}) })
if pkg != nil {
c.importMap = oldMap
}
return size, hasDefer return size, hasDefer
} }
func (c *funcScope) countLocals() int { func (c *codegen) countLocalsWithDefer(f *funcScope) int {
size, hasDefer := countLocals(c.decl) size, hasDefer := c.countLocals(f.decl)
if hasDefer { if hasDefer {
c.finallyProcessedIndex = size f.finallyProcessedIndex = size
size++ size++
} }
return size return size
@ -154,12 +199,6 @@ func (c *funcScope) countArgs() int {
return n return n
} }
func (c *funcScope) stackSize() int64 {
size := c.countLocals()
numArgs := c.countArgs()
return int64(size + numArgs)
}
// newVariable creates a new local variable or argument in the scope of the function. // newVariable creates a new local variable or argument in the scope of the function.
func (c *funcScope) newVariable(t varType, name string) int { func (c *funcScope) newVariable(t varType, name string) int {
return c.vars.newVariable(t, name) return c.vars.newVariable(t, name)

94
pkg/compiler/inline.go Normal file
View file

@ -0,0 +1,94 @@
package compiler
import (
"go/ast"
"go/types"
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
)
// inlineCall inlines call of n for function represented by f.
// Call `f(a,b)` for definition `func f(x,y int)` is translated to block:
// {
// x := a
// y := b
// <inline body of f directly>
// }
func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
pkg := c.buildInfo.program.Package(f.pkg.Path())
sig := c.typeOf(n.Fun).(*types.Signature)
// Arguments need to be walked with the current scope,
// while stored in the new.
oldScope := c.scope.vars.locals
c.scope.vars.newScope()
newScope := c.scope.vars.locals
defer c.scope.vars.dropScope()
hasVarArgs := !n.Ellipsis.IsValid()
needPack := sig.Variadic() && hasVarArgs
for i := range n.Args {
c.scope.vars.locals = oldScope
// true if normal arg or var arg is `slice...`
needStore := i < sig.Params().Len()-1 || !sig.Variadic() || !hasVarArgs
if !needStore {
break
}
name := sig.Params().At(i).Name()
if tv := c.typeAndValueOf(n.Args[i]); tv.Value != nil {
c.scope.vars.locals = newScope
c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, tv)
continue
}
if arg, ok := n.Args[i].(*ast.Ident); ok {
// When function argument is variable or const, we may avoid
// introducing additional variables for parameters.
// This is done by providing additional alias to variable.
if vi := c.scope.vars.getVarInfo(arg.Name); vi != nil {
c.scope.vars.locals = newScope
c.scope.vars.addAlias(name, vi.refType, vi.index, vi.tv)
continue
} else if arg.Name == "nil" {
c.scope.vars.locals = newScope
c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, types.TypeAndValue{})
continue
} else if index, ok := c.globals[c.getIdentName("", arg.Name)]; ok {
c.scope.vars.locals = newScope
c.scope.vars.addAlias(name, varGlobal, index, types.TypeAndValue{})
continue
}
}
ast.Walk(c, n.Args[i])
c.scope.vars.locals = newScope
c.scope.newLocal(name)
c.emitStoreVar("", name)
}
if needPack {
// traverse variadic args and pack them
// if they are provided directly i.e. without `...`
c.scope.vars.locals = oldScope
for i := sig.Params().Len() - 1; i < len(n.Args); i++ {
ast.Walk(c, n.Args[i])
}
c.scope.vars.locals = newScope
c.packVarArgs(n, sig)
name := sig.Params().At(sig.Params().Len() - 1).Name()
c.scope.newLocal(name)
c.emitStoreVar("", name)
}
c.pkgInfoInline = append(c.pkgInfoInline, pkg)
oldMap := c.importMap
c.fillImportMap(f.file, pkg.Pkg)
ast.Inspect(f.decl, c.scope.analyzeVoidCalls)
ast.Walk(c, f.decl.Body)
if c.scope.voidCalls[n] {
for i := 0; i < f.decl.Type.Results.NumFields(); i++ {
emit.Opcodes(c.prog.BinWriter, opcode.DROP)
}
}
c.importMap = oldMap
c.pkgInfoInline = c.pkgInfoInline[:len(c.pkgInfoInline)-1]
}

163
pkg/compiler/inline_test.go Normal file
View file

@ -0,0 +1,163 @@
package compiler_test
import (
"fmt"
"math/big"
"strings"
"testing"
"github.com/nspcc-dev/neo-go/pkg/compiler"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/stretchr/testify/require"
)
func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int) {
v := vmAndCompile(t, src)
ctx := v.Context()
actualCall := 0
actualInitSlot := 0
for op, _, err := ctx.Next(); ; op, _, err = ctx.Next() {
require.NoError(t, err)
switch op {
case opcode.CALL, opcode.CALLL:
actualCall++
case opcode.INITSLOT:
actualInitSlot++
}
if ctx.IP() == ctx.LenInstr() {
break
}
}
require.Equal(t, expectedCall, actualCall)
require.Equal(t, expectedInitSlot, actualInitSlot)
}
func TestInline(t *testing.T) {
srcTmpl := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
// local alias
func sum(a, b int) int {
return 42
}
var Num = 1
func Main() int {
%s
}`
t.Run("no return", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `inline.NoArgsNoReturn()
return 1`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(1))
})
t.Run("has return, dropped", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `inline.NoArgsReturn1()
return 2`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(2))
})
t.Run("drop twice", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `inline.DropInsideInline()
return 42`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(42))
})
t.Run("no args return 1", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.NoArgsReturn1()`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(1))
})
t.Run("sum", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(1, 2)`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(3))
})
t.Run("sum squared (nested inline)", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.SumSquared(1, 2)`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(9))
})
t.Run("inline function in inline function parameter", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), inline.Sum(3, 4))`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(9+3+4))
})
t.Run("global name clash", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.GetSumSameName()`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(42))
})
t.Run("local name clash", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), sum(3, 4))`)
checkCallCount(t, src, 1, 2)
eval(t, src, big.NewInt(51))
})
t.Run("var args, empty", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11)`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(11))
})
t.Run("var args, direct", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11, 14, 17)`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(42))
})
t.Run("var args, array", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `arr := []int{14, 17}
return inline.VarSum(11, arr...)`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(42))
})
t.Run("globals", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Concat(Num)`)
checkCallCount(t, src, 0, 1)
eval(t, src, big.NewInt(221))
})
}
func TestInlineConversion(t *testing.T) {
src1 := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
var _ = inline.A
func Main() int {
a := 2
return inline.SumSquared(1, a)
}`
b1, err := compiler.Compile("foo.go", strings.NewReader(src1))
require.NoError(t, err)
src2 := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
var _ = inline.A
func Main() int {
a := 2
{
return (1 + a) * (1 + a)
}
}`
b2, err := compiler.Compile("foo.go", strings.NewReader(src2))
require.NoError(t, err)
require.Equal(t, b2, b1)
}
func TestInlineConversionQualified(t *testing.T) {
src1 := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
var A = 1
func Main() int {
return inline.Concat(A)
}`
b1, err := compiler.Compile("foo.go", strings.NewReader(src1))
require.NoError(t, err)
src2 := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/b"
var A = 1
func Main() int {
return A * 100 + b.A * 10 + inline.A
}`
b2, err := compiler.Compile("foo.go", strings.NewReader(src2))
require.NoError(t, err)
require.Equal(t, b2, b1)
}

7
pkg/compiler/testdata/inline/a/a.go vendored Normal file
View file

@ -0,0 +1,7 @@
package a
var A = 29
func GetA() int {
return A
}

7
pkg/compiler/testdata/inline/b/b.go vendored Normal file
View file

@ -0,0 +1,7 @@
package b
var A = 12
func GetA() int {
return A
}

44
pkg/compiler/testdata/inline/inline.go vendored Normal file
View file

@ -0,0 +1,44 @@
package inline
import (
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/a"
"github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline/b"
)
func NoArgsNoReturn() {}
func NoArgsReturn1() int {
return 1
}
func Sum(a, b int) int {
return a + b
}
func sum(x, y int) int {
return x + y
}
func SumSquared(a, b int) int {
return sum(a, b) * (a + b)
}
var A = 1
func GetSumSameName() int {
return a.GetA() + b.GetA() + A
}
func DropInsideInline() int {
sum(1, 2)
sum(3, 4)
return 7
}
func VarSum(a int, b ...int) int {
sum := a
for i := range b {
sum += b[i]
}
return sum
}
func Concat(n int) int {
return n*100 + b.A*10 + A
}

View file

@ -8,6 +8,11 @@ import (
) )
func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue { func (c *codegen) typeAndValueOf(e ast.Expr) types.TypeAndValue {
for i := len(c.pkgInfoInline) - 1; i >= 0; i-- {
if tv, ok := c.pkgInfoInline[i].Types[e]; ok {
return tv
}
}
return c.typeInfo.Types[e] return c.typeInfo.Types[e]
} }

View file

@ -1,12 +1,24 @@
package compiler package compiler
import (
"go/types"
)
type varScope struct { type varScope struct {
localsCnt int localsCnt int
argCnt int argCnt int
arguments map[string]int arguments map[string]int
locals []map[string]int locals []map[string]varInfo
} }
type varInfo struct {
refType varType
index int
tv types.TypeAndValue
}
const unspecifiedVarIndex = -1
func newVarScope() varScope { func newVarScope() varScope {
return varScope{ return varScope{
arguments: make(map[string]int), arguments: make(map[string]int),
@ -14,23 +26,34 @@ func newVarScope() varScope {
} }
func (c *varScope) newScope() { func (c *varScope) newScope() {
c.locals = append(c.locals, map[string]int{}) c.locals = append(c.locals, map[string]varInfo{})
} }
func (c *varScope) dropScope() { func (c *varScope) dropScope() {
c.locals = c.locals[:len(c.locals)-1] c.locals = c.locals[:len(c.locals)-1]
} }
func (c *varScope) getVarIndex(name string) (varType, int) { func (c *varScope) addAlias(name string, vt varType, index int, tv types.TypeAndValue) {
c.locals[len(c.locals)-1][name] = varInfo{
refType: vt,
index: index,
tv: tv,
}
}
func (c *varScope) getVarInfo(name string) *varInfo {
for i := len(c.locals) - 1; i >= 0; i-- { for i := len(c.locals) - 1; i >= 0; i-- {
if i, ok := c.locals[i][name]; ok { if vi, ok := c.locals[i][name]; ok {
return varLocal, i return &vi
} }
} }
if i, ok := c.arguments[name]; ok { if i, ok := c.arguments[name]; ok {
return varArgument, i return &varInfo{
refType: varArgument,
index: i,
} }
return 0, -1 }
return nil
} }
// newVariable creates a new local variable or argument in the scope of the function. // newVariable creates a new local variable or argument in the scope of the function.
@ -56,7 +79,10 @@ func (c *varScope) newVariable(t varType, name string) int {
func (c *varScope) newLocal(name string) int { func (c *varScope) newLocal(name string) int {
idx := len(c.locals) - 1 idx := len(c.locals) - 1
m := c.locals[idx] m := c.locals[idx]
m[name] = c.localsCnt m[name] = varInfo{
refType: varLocal,
index: c.localsCnt,
}
c.localsCnt++ c.localsCnt++
c.locals[idx] = m c.locals[idx] = m
return c.localsCnt - 1 return c.localsCnt - 1