Merge pull request #1972 from nspcc-dev/compiler-inline-selector

Allow to inline selector statements
This commit is contained in:
Roman Khimov 2021-06-04 11:25:32 +03:00 committed by GitHub
commit 53433c2752
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 157 additions and 49 deletions

View file

@ -247,8 +247,24 @@ func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) {
// emitLoadVar loads specified variable to the evaluation stack. // emitLoadVar loads specified variable to the evaluation stack.
func (c *codegen) emitLoadVar(pkg string, name string) { func (c *codegen) emitLoadVar(pkg string, name string) {
vi := c.getVarIndex(pkg, name) vi := c.getVarIndex(pkg, name)
if vi.tv.Value != nil { if vi.ctx != nil && c.typeAndValueOf(vi.ctx.expr).Value != nil {
c.emitLoadConst(vi.tv) c.emitLoadConst(c.typeAndValueOf(vi.ctx.expr))
return
} else if vi.ctx != nil {
var oldScope []map[string]varInfo
oldMap := c.importMap
c.importMap = vi.ctx.importMap
if c.scope != nil {
oldScope = c.scope.vars.locals
c.scope.vars.locals = vi.ctx.scope
}
ast.Walk(c, vi.ctx.expr)
if c.scope != nil {
c.scope.vars.locals = oldScope
}
c.importMap = oldMap
return return
} else if vi.index == unspecifiedVarIndex { } else if vi.index == unspecifiedVarIndex {
emit.Opcodes(c.prog.BinWriter, opcode.PUSHNULL) emit.Opcodes(c.prog.BinWriter, opcode.PUSHNULL)
@ -853,12 +869,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
switch fun := n.Fun.(type) { switch fun := n.Fun.(type) {
case *ast.Ident: case *ast.Ident:
var pkgName string f, ok = c.getFuncFromIdent(fun)
if len(c.pkgInfoInline) != 0 {
pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path()
}
f, ok = c.funcs[c.getIdentName(pkgName, fun.Name)]
isBuiltin = isGoBuiltin(fun.Name) isBuiltin = isGoBuiltin(fun.Name)
if !ok && !isBuiltin { if !ok && !isBuiltin {
name = fun.Name name = fun.Name
@ -1940,6 +1951,16 @@ func (c *codegen) newFunc(decl *ast.FuncDecl) *funcScope {
return f return f
} }
func (c *codegen) getFuncFromIdent(fun *ast.Ident) (*funcScope, bool) {
var pkgName string
if len(c.pkgInfoInline) != 0 {
pkgName = c.pkgInfoInline[len(c.pkgInfoInline)-1].Pkg.Path()
}
f, ok := c.funcs[c.getIdentName(pkgName, fun.Name)]
return f, ok
}
// getFuncNameFromSelector returns fully-qualified function name from the selector expression. // getFuncNameFromSelector returns fully-qualified function name from the selector expression.
// Second return value is true iff this was a method call, not foreign package call. // Second return value is true iff this was a method call, not foreign package call.
func (c *codegen) getFuncNameFromSelector(e *ast.SelectorExpr) (string, bool) { func (c *codegen) getFuncNameFromSelector(e *ast.SelectorExpr) (string, bool) {

View file

@ -63,29 +63,18 @@ func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) {
break break
} }
name := sig.Params().At(i).Name() name := sig.Params().At(i).Name()
if tv := c.typeAndValueOf(n.Args[i]); tv.Value != nil { if !c.hasCalls(n.Args[i]) {
// If argument contains no calls, we save context and traverse the expression
// when argument is emitted.
c.scope.vars.locals = newScope c.scope.vars.locals = newScope
c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, tv) c.scope.vars.addAlias(name, -1, unspecifiedVarIndex, &varContext{
importMap: c.importMap,
expr: n.Args[i],
scope: oldScope,
})
continue continue
} }
if arg, ok := n.Args[i].(*ast.Ident); ok {
// When function argument is variable or const, we may avoid
// introducing additional variables for parameters.
// This is done by providing additional alias to variable.
if vi := c.scope.vars.getVarInfo(arg.Name); vi != nil {
c.scope.vars.locals = newScope
c.scope.vars.addAlias(name, vi.refType, vi.index, vi.tv)
continue
} else if arg.Name == "nil" {
c.scope.vars.locals = newScope
c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, types.TypeAndValue{})
continue
} else if index, ok := c.globals[c.getIdentName("", arg.Name)]; ok {
c.scope.vars.locals = newScope
c.scope.vars.addAlias(name, varGlobal, index, types.TypeAndValue{})
continue
}
}
ast.Walk(c, n.Args[i]) ast.Walk(c, n.Args[i])
c.scope.vars.locals = newScope c.scope.vars.locals = newScope
c.scope.newLocal(name) c.scope.newLocal(name)
@ -144,3 +133,31 @@ func (c *codegen) processNotify(f *funcScope, args []ast.Expr) {
c.emittedEvents[name] = append(c.emittedEvents[name], params) c.emittedEvents[name] = append(c.emittedEvents[name], params)
} }
} }
// hasCalls returns true if expression contains any calls.
// We uses this as a rough heuristic to determine if expression calculation
// has any side-effects.
func (c *codegen) hasCalls(expr ast.Expr) bool {
var has bool
ast.Inspect(expr, func(n ast.Node) bool {
ce, ok := n.(*ast.CallExpr)
if !has && ok {
isFunc := true
fun, ok := ce.Fun.(*ast.Ident)
if ok {
_, isFunc = c.getFuncFromIdent(fun)
} else {
var sel *ast.SelectorExpr
sel, ok = ce.Fun.(*ast.SelectorExpr)
if ok {
name, _ := c.getFuncNameFromSelector(sel)
_, isFunc = c.funcs[name]
fun = sel.Sel
}
}
has = isFunc || fun.Obj != nil && (fun.Obj.Kind == ast.Var || fun.Obj.Kind == ast.Fun)
}
return !has
})
return has
}

View file

@ -11,19 +11,31 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int) { func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot, expectedLocalsMain int) {
v := vmAndCompile(t, src) v, sp := vmAndCompileInterop(t, src)
mainStart := -1
for _, m := range sp.info.Methods {
if m.Name.Name == "main" {
mainStart = int(m.Range.Start)
}
}
require.True(t, mainStart >= 0)
ctx := v.Context() ctx := v.Context()
actualCall := 0 actualCall := 0
actualInitSlot := 0 actualInitSlot := 0
for op, _, err := ctx.Next(); ; op, _, err = ctx.Next() { for op, param, err := ctx.Next(); ; op, param, err = ctx.Next() {
require.NoError(t, err) require.NoError(t, err)
switch op { switch op {
case opcode.CALL, opcode.CALLL: case opcode.CALL, opcode.CALLL:
actualCall++ actualCall++
case opcode.INITSLOT: case opcode.INITSLOT:
actualInitSlot++ actualInitSlot++
if ctx.IP() == mainStart {
require.Equal(t, expectedLocalsMain, int(param[0]))
}
} }
if ctx.IP() == ctx.LenInstr() { if ctx.IP() == ctx.LenInstr() {
break break
@ -36,6 +48,13 @@ func checkCallCount(t *testing.T, src string, expectedCall, expectedInitSlot int
func TestInline(t *testing.T) { func TestInline(t *testing.T) {
srcTmpl := `package foo srcTmpl := `package foo
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline" import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/inline"
import "github.com/nspcc-dev/neo-go/pkg/compiler/testdata/foo"
var _ = foo.Dummy
type pair struct { a, b int }
type triple struct {
a int
b pair
}
// local alias // local alias
func sum(a, b int) int { func sum(a, b int) int {
return 42 return 42
@ -47,77 +66,115 @@ func TestInline(t *testing.T) {
t.Run("no return", func(t *testing.T) { t.Run("no return", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `inline.NoArgsNoReturn() src := fmt.Sprintf(srcTmpl, `inline.NoArgsNoReturn()
return 1`) return 1`)
checkCallCount(t, src, 0, 0) checkCallCount(t, src, 0, 0, 0)
eval(t, src, big.NewInt(1)) eval(t, src, big.NewInt(1))
}) })
t.Run("has return, dropped", func(t *testing.T) { t.Run("has return, dropped", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `inline.NoArgsReturn1() src := fmt.Sprintf(srcTmpl, `inline.NoArgsReturn1()
return 2`) return 2`)
checkCallCount(t, src, 0, 0) checkCallCount(t, src, 0, 0, 0)
eval(t, src, big.NewInt(2)) eval(t, src, big.NewInt(2))
}) })
t.Run("drop twice", func(t *testing.T) { t.Run("drop twice", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `inline.DropInsideInline() src := fmt.Sprintf(srcTmpl, `inline.DropInsideInline()
return 42`) return 42`)
checkCallCount(t, src, 0, 0) checkCallCount(t, src, 0, 0, 0)
eval(t, src, big.NewInt(42)) eval(t, src, big.NewInt(42))
}) })
t.Run("no args return 1", func(t *testing.T) { t.Run("no args return 1", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.NoArgsReturn1()`) src := fmt.Sprintf(srcTmpl, `return inline.NoArgsReturn1()`)
checkCallCount(t, src, 0, 0) checkCallCount(t, src, 0, 0, 0)
eval(t, src, big.NewInt(1)) eval(t, src, big.NewInt(1))
}) })
t.Run("sum", func(t *testing.T) { t.Run("sum", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(1, 2)`) src := fmt.Sprintf(srcTmpl, `return inline.Sum(1, 2)`)
checkCallCount(t, src, 0, 0) checkCallCount(t, src, 0, 0, 0)
eval(t, src, big.NewInt(3)) eval(t, src, big.NewInt(3))
}) })
t.Run("sum squared (nested inline)", func(t *testing.T) { t.Run("sum squared (nested inline)", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.SumSquared(1, 2)`) src := fmt.Sprintf(srcTmpl, `return inline.SumSquared(1, 2)`)
checkCallCount(t, src, 0, 0) checkCallCount(t, src, 0, 0, 0)
eval(t, src, big.NewInt(9)) eval(t, src, big.NewInt(9))
}) })
t.Run("inline function in inline function parameter", func(t *testing.T) { t.Run("inline function in inline function parameter", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), inline.Sum(3, 4))`) src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), inline.Sum(3, 4))`)
checkCallCount(t, src, 0, 1) checkCallCount(t, src, 0, 1, 2)
eval(t, src, big.NewInt(9+3+4)) eval(t, src, big.NewInt(9+3+4))
}) })
t.Run("global name clash", func(t *testing.T) { t.Run("global name clash", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.GetSumSameName()`) src := fmt.Sprintf(srcTmpl, `return inline.GetSumSameName()`)
checkCallCount(t, src, 0, 0) checkCallCount(t, src, 0, 0, 0)
eval(t, src, big.NewInt(42)) eval(t, src, big.NewInt(42))
}) })
t.Run("local name clash", func(t *testing.T) { t.Run("local name clash", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), sum(3, 4))`) src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.SumSquared(1, 2), sum(3, 4))`)
checkCallCount(t, src, 1, 2) checkCallCount(t, src, 1, 2, 2)
eval(t, src, big.NewInt(51)) eval(t, src, big.NewInt(51))
}) })
t.Run("var args, empty", func(t *testing.T) { t.Run("var args, empty", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11)`) src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11)`)
checkCallCount(t, src, 0, 1) checkCallCount(t, src, 0, 1, 3)
eval(t, src, big.NewInt(11)) eval(t, src, big.NewInt(11))
}) })
t.Run("var args, direct", func(t *testing.T) { t.Run("var args, direct", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11, 14, 17)`) src := fmt.Sprintf(srcTmpl, `return inline.VarSum(11, 14, 17)`)
checkCallCount(t, src, 0, 1) checkCallCount(t, src, 0, 1, 3)
eval(t, src, big.NewInt(42)) eval(t, src, big.NewInt(42))
}) })
t.Run("var args, array", func(t *testing.T) { t.Run("var args, array", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `arr := []int{14, 17} src := fmt.Sprintf(srcTmpl, `arr := []int{14, 17}
return inline.VarSum(11, arr...)`) return inline.VarSum(11, arr...)`)
checkCallCount(t, src, 0, 1) checkCallCount(t, src, 0, 1, 3)
eval(t, src, big.NewInt(42)) eval(t, src, big.NewInt(42))
}) })
t.Run("globals", func(t *testing.T) { t.Run("globals", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Concat(Num)`) src := fmt.Sprintf(srcTmpl, `return inline.Concat(Num)`)
checkCallCount(t, src, 0, 0) checkCallCount(t, src, 0, 0, 0)
eval(t, src, big.NewInt(221)) eval(t, src, big.NewInt(221))
}) })
t.Run("locals, alias", func(t *testing.T) { t.Run("locals, alias", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `num := 1; return inline.Concat(num)`) src := fmt.Sprintf(srcTmpl, `num := 1; return inline.Concat(num)`)
checkCallCount(t, src, 0, 1) checkCallCount(t, src, 0, 1, 1)
eval(t, src, big.NewInt(221)) eval(t, src, big.NewInt(221))
}) })
t.Run("selector, global", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.A, 2)`)
checkCallCount(t, src, 0, 0, 0)
eval(t, src, big.NewInt(3))
})
t.Run("selector, struct, simple", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `x := pair{a: 1, b: 2}; return inline.Sum(x.b, 1)`)
checkCallCount(t, src, 0, 1, 1)
eval(t, src, big.NewInt(3))
})
t.Run("selector, struct, complex", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `x := triple{a: 1, b: pair{a: 2, b: 3}}
return inline.Sum(x.b.a, 1)`)
checkCallCount(t, src, 0, 1, 1)
eval(t, src, big.NewInt(3))
})
t.Run("expression", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `x, y := 1, 2
return inline.Sum(x+y, y*2)`)
checkCallCount(t, src, 0, 1, 2)
eval(t, src, big.NewInt(7))
})
t.Run("foreign package call", func(t *testing.T) {
src := fmt.Sprintf(srcTmpl, `return inline.Sum(foo.Bar(), foo.Dummy+1)`)
checkCallCount(t, src, 1, 1, 1)
eval(t, src, big.NewInt(3))
})
}
func TestIssue1879(t *testing.T) {
src := `package foo
import "github.com/nspcc-dev/neo-go/pkg/interop/runtime"
func Main() int {
data := "main is called"
runtime.Log("log " + string(data))
return 42
}`
checkCallCount(t, src, 0, 1, 1)
} }
func TestInlineInLoop(t *testing.T) { func TestInlineInLoop(t *testing.T) {

View file

@ -5,6 +5,9 @@ func NewBar() int {
return 10 return 10
} }
// Dummy is dummy constant.
var Dummy = 1
// Foo is a type. // Foo is a type.
type Foo struct{} type Foo struct{}

View file

@ -1,7 +1,7 @@
package compiler package compiler
import ( import (
"go/types" "go/ast"
) )
type varScope struct { type varScope struct {
@ -10,10 +10,18 @@ type varScope struct {
locals []map[string]varInfo locals []map[string]varInfo
} }
type varContext struct {
importMap map[string]string
expr ast.Expr
scope []map[string]varInfo
}
type varInfo struct { type varInfo struct {
refType varType refType varType
index int index int
tv types.TypeAndValue // ctx is set for inline arguments and contains
// context for expression traversal.
ctx *varContext
} }
const unspecifiedVarIndex = -1 const unspecifiedVarIndex = -1
@ -32,11 +40,11 @@ func (c *varScope) dropScope() {
c.locals = c.locals[:len(c.locals)-1] c.locals = c.locals[:len(c.locals)-1]
} }
func (c *varScope) addAlias(name string, vt varType, index int, tv types.TypeAndValue) { func (c *varScope) addAlias(name string, vt varType, index int, ctx *varContext) {
c.locals[len(c.locals)-1][name] = varInfo{ c.locals[len(c.locals)-1][name] = varInfo{
refType: vt, refType: vt,
index: index, index: index,
tv: tv, ctx: ctx,
} }
} }

View file

@ -69,6 +69,7 @@ func vmAndCompileInterop(t *testing.T, src string) (*vm.VM, *storagePlugin) {
b, di, err := compiler.CompileWithDebugInfo("foo.go", strings.NewReader(src)) b, di, err := compiler.CompileWithDebugInfo("foo.go", strings.NewReader(src))
require.NoError(t, err) require.NoError(t, err)
storePlugin.info = di
invokeMethod(t, testMainIdent, b, vm, di) invokeMethod(t, testMainIdent, b, vm, di)
return vm, storePlugin return vm, storePlugin
} }
@ -93,6 +94,7 @@ func invokeMethod(t *testing.T, method string, script []byte, v *vm.VM, di *comp
} }
type storagePlugin struct { type storagePlugin struct {
info *compiler.DebugInfo
mem map[string][]byte mem map[string][]byte
interops map[uint32]func(v *vm.VM) error interops map[uint32]func(v *vm.VM) error
events []state.NotificationEvent events []state.NotificationEvent