Merge pull request #1972 from nspcc-dev/compiler-inline-selector
Allow to inline selector statements
This commit is contained in:
commit
53433c2752
6 changed files with 157 additions and 49 deletions
|
@ -247,8 +247,24 @@ func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) {
|
|||
// emitLoadVar loads specified variable to the evaluation stack.
|
||||
func (c *codegen) emitLoadVar(pkg string, name string) {
|
||||
vi := c.getVarIndex(pkg, name)
|
||||
if vi.tv.Value != nil {
|
||||
c.emitLoadConst(vi.tv)
|
||||
if vi.ctx != nil && c.typeAndValueOf(vi.ctx.expr).Value != nil {
|
||||
c.emitLoadConst(c.typeAndValueOf(vi.ctx.expr))
|
||||
return
|
||||
} else if vi.ctx != nil {
|
||||
var oldScope []map[string]varInfo
|
||||
oldMap := c.importMap
|
||||
c.importMap = vi.ctx.importMap
|
||||
if c.scope != nil {
|
||||
oldScope = c.scope.vars.locals
|
||||
c.scope.vars.locals = vi.ctx.scope
|
||||
}
|
||||
|
||||
ast.Walk(c, vi.ctx.expr)
|
||||
|
||||
if c.scope != nil {
|
||||
c.scope.vars.locals = oldScope
|
||||
}
|
||||
c.importMap = oldMap
|
||||
return
|
||||
} else if vi.index == unspecifiedVarIndex {
|
||||
emit.Opcodes(c.prog.BinWriter, opcode.PUSHNULL)
|
||||
|
@ -853,12 +869,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
|||
|
||||
switch fun := n.Fun.(type) {
|
||||
case *ast.Ident:
|
||||
var pkgName string
|
||||
if len(c.pkgInfoInline) != 0 {
|
||||
pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path()
|
||||
}
|
||||
f, ok = c.funcs[c.getIdentName(pkgName, fun.Name)]
|
||||
|
||||
f, ok = c.getFuncFromIdent(fun)
|
||||
isBuiltin = isGoBuiltin(fun.Name)
|
||||
if !ok && !isBuiltin {
|
||||
name = fun.Name
|
||||
|
@ -1940,6 +1951,16 @@ func (c *codegen) newFunc(decl *ast.FuncDecl) *funcScope {
|
|||
return f
|
||||
}
|
||||
|
||||
func (c *codegen) getFuncFromIdent(fun *ast.Ident) (*funcScope, bool) {
|
||||
var pkgName string
|
||||
if len(c.pkgInfoInline) != 0 {
|
||||
pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path()
|
||||
}
|
||||
|
||||
f, ok := c.funcs[c.getIdentName(pkgName, fun.Name)]
|
||||
return f, ok
|
||||
}
|
||||
|
||||
// getFuncNameFromSelector returns fully-qualified function name from the selector expression.
|
||||
// Second return value is true iff this was a method call, not foreign package call.
|
||||
func (c *codegen) getFuncNameFromSelector(e *ast.SelectorExpr) (string, bool) {
|
||||
|
|
|
@ -63,29 +63,18 @@ func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
|
|||
break
|
||||
}
|
||||
name := sig.Params().At(i).Name()
|
||||
if tv := c.typeAndValueOf(n.Args[i]); tv.Value != nil {
|
||||
if !c.hasCalls(n.Args[i]) {
|
||||
// If argument contains no calls, we save context and traverse the expression
|
||||
// when argument is emitted.
|
||||
c.scope.vars.locals = newScope
|
||||
c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, tv)
|
||||
c.scope.vars.addAlias(name, -1, unspecifiedVarIndex, &varContext{
|
||||
importMap: c.importMap,
|
||||
expr: n.Args[i],
|
||||
scope: oldScope,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if arg, ok := n.Args[i].(*ast.Ident); ok {
|
||||
// When function argument is variable or const, we may avoid
|
||||
// introducing additional variables for parameters.
|
||||
// This is done by providing additional alias to variable.
|
||||
if vi := c.scope.vars.getVarInfo(arg.Name); vi != nil {
|
||||
c.scope.vars.locals = newScope
|
||||
c.scope.vars.addAlias(name, vi.refType, vi.index, vi.tv)
|
||||
continue
|
||||
} else if arg.Name == "nil" {
|
||||
c.scope.vars.locals = newScope
|
||||
c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, types.TypeAndValue{})
|
||||
continue
|
||||
} else if index, ok := c.globals[c.getIdentName("", arg.Name)]; ok {
|
||||
c.scope.vars.locals = newScope
|
||||
c.scope.vars.addAlias(name, varGlobal, index, types.TypeAndValue{})
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
ast.Walk(c, n.Args[i])
|
||||
c.scope.vars.locals = newScope
|
||||
c.scope.newLocal(name)
|
||||
|
@ -144,3 +133,31 @@ func (c *codegen) processNotify(f *funcScope, args []ast.Expr) {
|
|||
c.emittedEvents[name] = append(c.emittedEvents[name], params)
|
||||
}
|
||||
}
|
||||
|
||||
// hasCalls returns true if expression contains any calls.
|
||||
// We uses this as a rough heuristic to determine if expression calculation
|
||||
// has any side-effects.
|
||||
func (c *codegen) hasCalls(expr ast.Expr) bool {
|
||||
var has bool
|
||||
ast.Inspect(expr, func(n ast.Node) bool {
|
||||
ce, ok := n.(*ast.CallExpr)
|
||||
if !has && ok {
|
||||
isFunc := true
|
||||
fun, ok := ce.Fun.(*ast.Ident)
|
||||
if ok {
|
||||
_, isFunc = c.getFuncFromIdent(fun)
|
||||
} else {
|
||||
var sel *ast.SelectorExpr
|
||||
sel, ok = ce.Fun.(*ast.SelectorExpr)
|
||||
if ok {
|
||||
name, _ := c.getFuncNameFromSelector(sel)
|
||||
_, isFunc = c.funcs[name]
|
||||
fun = sel.Sel
|
||||
}
|
||||
}
|
||||
has = isFunc || fun.Obj != nil && (fun.Obj.Kind == ast.Var || fun.Obj.Kind == ast.Fun)
|
||||
}
|
||||
return !has
|
||||
})
|
||||
return has
|
||||
}
|
||||
|
|
|
@ -11,19 +11,31 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int) {
|
||||
v := vmAndCompile(t, src)
|
||||
func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot, expectedLocalsMain int) {
|
||||
v, sp := vmAndCompileInterop(t, src)
|
||||
|
||||
mainStart := -1
|
||||
for _, m := range sp.info.Methods {
|
||||
if m.Name.Name == "main" {
|
||||
mainStart = int(m.Range.Start)
|
||||
}
|
||||
}
|
||||
require.True(t, mainStart >= 0)
|
||||
|
||||
ctx := v.Context()
|
||||
actualCall := 0
|
||||
actualInitSlot := 0
|
||||
|
||||
for op, _, err := ctx.Next(); ; op, _, err = ctx.Next() {
|
||||
for op, param, err := ctx.Next(); ; op, param, err = ctx.Next() {
|
||||
require.NoError(t, err)
|
||||
switch op {
|
||||
case opcode.CALL, opcode.CALLL:
|
||||
actualCall++
|
||||
case opcode.INITSLOT:
|
||||
actualInitSlot++
|
||||
if ctx.IP() == mainStart {
|
||||
require.Equal(t, expectedLocalsMain, int(param[0]))
|
||||
}
|
||||
}
|
||||
if ctx.IP() == ctx.LenInstr() {
|
||||
break
|
||||
|
@ -36,6 +48,13 @@ func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int
|
|||
func TestInline(t *testing.T) {
|
||||
srcTmpl := `package foo
|
||||
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
|
||||
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/foo"
|
||||
var _ = foo.Dummy
|
||||
type pair struct { a, b int }
|
||||
type triple struct {
|
||||
a int
|
||||
b pair
|
||||
}
|
||||
// local alias
|
||||
func sum(a, b int) int {
|
||||
return 42
|
||||
|
@ -47,77 +66,115 @@ func TestInline(t *testing.T) {
|
|||
t.Run("no return", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `inline.NoArgsNoReturn()
|
||||
return 1`)
|
||||
checkCallCount(t, src, 0, 0)
|
||||
checkCallCount(t, src, 0, 0, 0)
|
||||
eval(t, src, big.NewInt(1))
|
||||
})
|
||||
t.Run("has return, dropped", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `inline.NoArgsReturn1()
|
||||
return 2`)
|
||||
checkCallCount(t, src, 0, 0)
|
||||
checkCallCount(t, src, 0, 0, 0)
|
||||
eval(t, src, big.NewInt(2))
|
||||
})
|
||||
t.Run("drop twice", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `inline.DropInsideInline()
|
||||
return 42`)
|
||||
checkCallCount(t, src, 0, 0)
|
||||
checkCallCount(t, src, 0, 0, 0)
|
||||
eval(t, src, big.NewInt(42))
|
||||
})
|
||||
t.Run("no args return 1", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.NoArgsReturn1()`)
|
||||
checkCallCount(t, src, 0, 0)
|
||||
checkCallCount(t, src, 0, 0, 0)
|
||||
eval(t, src, big.NewInt(1))
|
||||
})
|
||||
t.Run("sum", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.Sum(1, 2)`)
|
||||
checkCallCount(t, src, 0, 0)
|
||||
checkCallCount(t, src, 0, 0, 0)
|
||||
eval(t, src, big.NewInt(3))
|
||||
})
|
||||
t.Run("sum squared (nested inline)", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.SumSquared(1, 2)`)
|
||||
checkCallCount(t, src, 0, 0)
|
||||
checkCallCount(t, src, 0, 0, 0)
|
||||
eval(t, src, big.NewInt(9))
|
||||
})
|
||||
t.Run("inline function in inline function parameter", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), inline.Sum(3, 4))`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
checkCallCount(t, src, 0, 1, 2)
|
||||
eval(t, src, big.NewInt(9+3+4))
|
||||
})
|
||||
t.Run("global name clash", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.GetSumSameName()`)
|
||||
checkCallCount(t, src, 0, 0)
|
||||
checkCallCount(t, src, 0, 0, 0)
|
||||
eval(t, src, big.NewInt(42))
|
||||
})
|
||||
t.Run("local name clash", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), sum(3, 4))`)
|
||||
checkCallCount(t, src, 1, 2)
|
||||
checkCallCount(t, src, 1, 2, 2)
|
||||
eval(t, src, big.NewInt(51))
|
||||
})
|
||||
t.Run("var args, empty", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11)`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
checkCallCount(t, src, 0, 1, 3)
|
||||
eval(t, src, big.NewInt(11))
|
||||
})
|
||||
t.Run("var args, direct", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11, 14, 17)`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
checkCallCount(t, src, 0, 1, 3)
|
||||
eval(t, src, big.NewInt(42))
|
||||
})
|
||||
t.Run("var args, array", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `arr := []int{14, 17}
|
||||
return inline.VarSum(11, arr...)`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
checkCallCount(t, src, 0, 1, 3)
|
||||
eval(t, src, big.NewInt(42))
|
||||
})
|
||||
t.Run("globals", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.Concat(Num)`)
|
||||
checkCallCount(t, src, 0, 0)
|
||||
checkCallCount(t, src, 0, 0, 0)
|
||||
eval(t, src, big.NewInt(221))
|
||||
})
|
||||
t.Run("locals, alias", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `num := 1; return inline.Concat(num)`)
|
||||
checkCallCount(t, src, 0, 1)
|
||||
checkCallCount(t, src, 0, 1, 1)
|
||||
eval(t, src, big.NewInt(221))
|
||||
})
|
||||
t.Run("selector, global", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.A, 2)`)
|
||||
checkCallCount(t, src, 0, 0, 0)
|
||||
eval(t, src, big.NewInt(3))
|
||||
})
|
||||
t.Run("selector, struct, simple", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `x := pair{a: 1, b: 2}; return inline.Sum(x.b, 1)`)
|
||||
checkCallCount(t, src, 0, 1, 1)
|
||||
eval(t, src, big.NewInt(3))
|
||||
})
|
||||
t.Run("selector, struct, complex", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `x := triple{a: 1, b: pair{a: 2, b: 3}}
|
||||
return inline.Sum(x.b.a, 1)`)
|
||||
checkCallCount(t, src, 0, 1, 1)
|
||||
eval(t, src, big.NewInt(3))
|
||||
})
|
||||
t.Run("expression", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `x, y := 1, 2
|
||||
return inline.Sum(x+y, y*2)`)
|
||||
checkCallCount(t, src, 0, 1, 2)
|
||||
eval(t, src, big.NewInt(7))
|
||||
})
|
||||
t.Run("foreign package call", func(t *testing.T) {
|
||||
src := fmt.Sprintf(srcTmpl, `return inline.Sum(foo.Bar(), foo.Dummy+1)`)
|
||||
checkCallCount(t, src, 1, 1, 1)
|
||||
eval(t, src, big.NewInt(3))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIssue1879(t *testing.T) {
|
||||
src := `package foo
|
||||
import "github.com/nspcc-dev/neo-go/pkg/interop/runtime"
|
||||
func Main() int {
|
||||
data := "main is called"
|
||||
runtime.Log("log " + string(data))
|
||||
return 42
|
||||
}`
|
||||
checkCallCount(t, src, 0, 1, 1)
|
||||
}
|
||||
|
||||
func TestInlineInLoop(t *testing.T) {
|
||||
|
|
3
pkg/compiler/testdata/foo/foo.go
vendored
3
pkg/compiler/testdata/foo/foo.go
vendored
|
@ -5,6 +5,9 @@ func NewBar() int {
|
|||
return 10
|
||||
}
|
||||
|
||||
// Dummy is dummy constant.
|
||||
var Dummy = 1
|
||||
|
||||
// Foo is a type.
|
||||
type Foo struct{}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package compiler
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
"go/ast"
|
||||
)
|
||||
|
||||
type varScope struct {
|
||||
|
@ -10,10 +10,18 @@ type varScope struct {
|
|||
locals []map[string]varInfo
|
||||
}
|
||||
|
||||
type varContext struct {
|
||||
importMap map[string]string
|
||||
expr ast.Expr
|
||||
scope []map[string]varInfo
|
||||
}
|
||||
|
||||
type varInfo struct {
|
||||
refType varType
|
||||
index int
|
||||
tv types.TypeAndValue
|
||||
// ctx is set for inline arguments and contains
|
||||
// context for expression traversal.
|
||||
ctx *varContext
|
||||
}
|
||||
|
||||
const unspecifiedVarIndex = -1
|
||||
|
@ -32,11 +40,11 @@ func (c *varScope) dropScope() {
|
|||
c.locals = c.locals[:len(c.locals)-1]
|
||||
}
|
||||
|
||||
func (c *varScope) addAlias(name string, vt varType, index int, tv types.TypeAndValue) {
|
||||
func (c *varScope) addAlias(name string, vt varType, index int, ctx *varContext) {
|
||||
c.locals[len(c.locals)-1][name] = varInfo{
|
||||
refType: vt,
|
||||
index: index,
|
||||
tv: tv,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -69,6 +69,7 @@ func vmAndCompileInterop(t *testing.T, src string) (*vm.VM, *storagePlugin) {
|
|||
b, di, err := compiler.CompileWithDebugInfo("foo.go", strings.NewReader(src))
|
||||
require.NoError(t, err)
|
||||
|
||||
storePlugin.info = di
|
||||
invokeMethod(t, testMainIdent, b, vm, di)
|
||||
return vm, storePlugin
|
||||
}
|
||||
|
@ -93,6 +94,7 @@ func invokeMethod(t *testing.T, method string, script []byte, v *vm.VM, di *comp
|
|||
}
|
||||
|
||||
type storagePlugin struct {
|
||||
info *compiler.DebugInfo
|
||||
mem map[string][]byte
|
||||
interops map[uint32]func(v *vm.VM) error
|
||||
events []state.NotificationEvent
|
||||
|
|
Loading…
Reference in a new issue