diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 5fd993f01..5abf00ec0 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -247,8 +247,24 @@ func getBaseOpcode(t varType) (opcode.Opcode, opcode.Opcode) { // emitLoadVar loads specified variable to the evaluation stack. func (c *codegen) emitLoadVar(pkg string, name string) { vi := c.getVarIndex(pkg, name) - if vi.tv.Value != nil { - c.emitLoadConst(vi.tv) + if vi.ctx != nil && c.typeAndValueOf(vi.ctx.expr).Value != nil { + c.emitLoadConst(c.typeAndValueOf(vi.ctx.expr)) + return + } else if vi.ctx != nil { + var oldScope []map[string]varInfo + oldMap := c.importMap + c.importMap = vi.ctx.importMap + if c.scope != nil { + oldScope = c.scope.vars.locals + c.scope.vars.locals = vi.ctx.scope + } + + ast.Walk(c, vi.ctx.expr) + + if c.scope != nil { + c.scope.vars.locals = oldScope + } + c.importMap = oldMap return } else if vi.index == unspecifiedVarIndex { emit.Opcodes(c.prog.BinWriter, opcode.PUSHNULL) diff --git a/pkg/compiler/inline.go b/pkg/compiler/inline.go index 6c5b676ce..e339dad2d 100644 --- a/pkg/compiler/inline.go +++ b/pkg/compiler/inline.go @@ -63,29 +63,18 @@ func (c *codegen) inlineCall(f *funcScope, n *ast.CallExpr) { break } name := sig.Params().At(i).Name() - if tv := c.typeAndValueOf(n.Args[i]); tv.Value != nil { + if !c.hasCalls(n.Args[i]) { + // If argument contains no calls, we save context and traverse the expression + // when argument is emitted. c.scope.vars.locals = newScope - c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, tv) + c.scope.vars.addAlias(name, -1, unspecifiedVarIndex, &varContext{ + importMap: c.importMap, + expr: n.Args[i], + scope: oldScope, + }) continue } - if arg, ok := n.Args[i].(*ast.Ident); ok { - // When function argument is variable or const, we may avoid - // introducing additional variables for parameters. - // This is done by providing additional alias to variable. - if vi := c.scope.vars.getVarInfo(arg.Name); vi != nil { - c.scope.vars.locals = newScope - c.scope.vars.addAlias(name, vi.refType, vi.index, vi.tv) - continue - } else if arg.Name == "nil" { - c.scope.vars.locals = newScope - c.scope.vars.addAlias(name, varLocal, unspecifiedVarIndex, types.TypeAndValue{}) - continue - } else if index, ok := c.globals[c.getIdentName("", arg.Name)]; ok { - c.scope.vars.locals = newScope - c.scope.vars.addAlias(name, varGlobal, index, types.TypeAndValue{}) - continue - } - } + ast.Walk(c, n.Args[i]) c.scope.vars.locals = newScope c.scope.newLocal(name) @@ -144,3 +133,18 @@ func (c *codegen) processNotify(f *funcScope, args []ast.Expr) { c.emittedEvents[name] = append(c.emittedEvents[name], params) } } + +// hasCalls returns true if expression contains any calls. +// We uses this as a rough heuristic to determine if expression calculation +// has any side-effects. +func (c *codegen) hasCalls(expr ast.Expr) bool { + var has bool + ast.Inspect(expr, func(n ast.Node) bool { + _, ok := n.(*ast.CallExpr) + if ok { + has = true + } + return !has + }) + return has +} diff --git a/pkg/compiler/inline_test.go b/pkg/compiler/inline_test.go index 644476731..87ab89197 100644 --- a/pkg/compiler/inline_test.go +++ b/pkg/compiler/inline_test.go @@ -137,24 +137,24 @@ func TestInline(t *testing.T) { }) t.Run("selector, global", func(t *testing.T) { src := fmt.Sprintf(srcTmpl, `return inline.Sum(inline.A, 2)`) - checkCallCount(t, src, 0, 1, 1) + checkCallCount(t, src, 0, 0, 0) eval(t, src, big.NewInt(3)) }) t.Run("selector, struct, simple", func(t *testing.T) { src := fmt.Sprintf(srcTmpl, `x := pair{a: 1, b: 2}; return inline.Sum(x.b, 1)`) - checkCallCount(t, src, 0, 1, 2) + checkCallCount(t, src, 0, 1, 1) eval(t, src, big.NewInt(3)) }) t.Run("selector, struct, complex", func(t *testing.T) { src := fmt.Sprintf(srcTmpl, `x := triple{a: 1, b: pair{a: 2, b: 3}} return inline.Sum(x.b.a, 1)`) - checkCallCount(t, src, 0, 1, 2) + checkCallCount(t, src, 0, 1, 1) eval(t, src, big.NewInt(3)) }) t.Run("expression", func(t *testing.T) { src := fmt.Sprintf(srcTmpl, `x, y := 1, 2 return inline.Sum(x+y, y*2)`) - checkCallCount(t, src, 0, 1, 4) + checkCallCount(t, src, 0, 1, 2) eval(t, src, big.NewInt(7)) }) } diff --git a/pkg/compiler/vars.go b/pkg/compiler/vars.go index 7b21151a4..a12aa1459 100644 --- a/pkg/compiler/vars.go +++ b/pkg/compiler/vars.go @@ -1,7 +1,7 @@ package compiler import ( - "go/types" + "go/ast" ) type varScope struct { @@ -10,10 +10,18 @@ type varScope struct { locals []map[string]varInfo } +type varContext struct { + importMap map[string]string + expr ast.Expr + scope []map[string]varInfo +} + type varInfo struct { refType varType index int - tv types.TypeAndValue + // ctx is set for inline arguments and contains + // context for expression traversal. + ctx *varContext } const unspecifiedVarIndex = -1 @@ -32,11 +40,11 @@ func (c *varScope) dropScope() { c.locals = c.locals[:len(c.locals)-1] } -func (c *varScope) addAlias(name string, vt varType, index int, tv types.TypeAndValue) { +func (c *varScope) addAlias(name string, vt varType, index int, ctx *varContext) { c.locals[len(c.locals)-1][name] = varInfo{ refType: vt, index: index, - tv: tv, + ctx: ctx, } }