diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index a9cac98d8..0d4445b3f 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -578,9 +578,13 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.SwitchStmt: - ast.Walk(c, n.Tag) - - eqOpcode, _ := convertToken(token.EQL, c.typeOf(n.Tag)) + eqOpcode := opcode.EQUAL + if n.Tag != nil { + ast.Walk(c, n.Tag) + eqOpcode, _ = convertToken(token.EQL, c.typeOf(n.Tag)) + } else { + emit.Bool(c.prog.BinWriter, true) + } switchEnd, label := c.generateLabel(labelEnd) lastSwitch := c.currentSwitch @@ -795,6 +799,20 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { emit.Call(c.prog.BinWriter, opcode.CALLL, f.label) } + if c.scope != nil && c.scope.voidCalls[n] { + var sz int + if f != nil { + sz = f.decl.Type.Results.NumFields() + } else if !isBuiltin { + // lambda invocation + f := c.typeOf(n.Fun).Underlying().(*types.Signature) + sz = f.Results().Len() + } + for i := 0; i < sz; i++ { + emit.Opcode(c.prog.BinWriter, opcode.DROP) + } + } + return nil case *ast.SelectorExpr: diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index 01ed1d175..aba84a5b9 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -90,8 +90,24 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool { } case *ast.BinaryExpr: return false + case *ast.IfStmt: + // we can't just return `false`, because we still need to process body + ce, ok := n.Cond.(*ast.CallExpr) + if ok { + c.voidCalls[ce] = false + } + case *ast.CaseClause: + for _, e := range n.List { + ce, ok := e.(*ast.CallExpr) + if ok { + c.voidCalls[ce] = false + } + } case *ast.CallExpr: - c.voidCalls[n] = true + _, ok := c.voidCalls[n] + if !ok { + c.voidCalls[n] = true + } return false } return true @@ -141,7 +157,7 @@ func (c *funcScope) countArgs() int { func (c *funcScope) stackSize() int64 { size := c.countLocals() numArgs := c.countArgs() - return int64(size + numArgs + len(c.voidCalls)) + return int64(size + numArgs) } // newVariable creates a new local variable or argument in the scope of the function. diff --git a/pkg/compiler/function_call_test.go b/pkg/compiler/function_call_test.go index d85bc29c8..4c7732f66 100644 --- a/pkg/compiler/function_call_test.go +++ b/pkg/compiler/function_call_test.go @@ -24,8 +24,8 @@ func TestSimpleFunctionCall(t *testing.T) { } func TestNotAssignedFunctionCall(t *testing.T) { - src := ` - package testcase + t.Run("Simple", func(t *testing.T) { + src := `package testcase func Main() int { getSomeInteger() getSomeInteger() @@ -34,12 +34,53 @@ func TestNotAssignedFunctionCall(t *testing.T) { func getSomeInteger() int { return 0 - } - ` - // disable stack checks because it is hard right now - // to distinguish between simple function call traversal - // and the same traversal inside an assignment. - evalWithoutStackChecks(t, src, big.NewInt(0)) + }` + eval(t, src, big.NewInt(0)) + }) + t.Run("If", func(t *testing.T) { + src := `package testcase + func f() bool { return true } + func Main() int { + if f() { + return 42 + } + return 0 + }` + eval(t, src, big.NewInt(42)) + }) + t.Run("Switch", func(t *testing.T) { + src := `package testcase + func f() bool { return true } + func Main() int { + switch true { + case f(): + return 42 + default: + return 0 + } + }` + eval(t, src, big.NewInt(42)) + }) + t.Run("Builtin", func(t *testing.T) { + src := `package foo + import "github.com/nspcc-dev/neo-go/pkg/interop/util" + func Main() int { + util.FromAddress("NPAsqZkx9WhNd4P72uhZxBhLinSuNkxfB8") + util.FromAddress("NPAsqZkx9WhNd4P72uhZxBhLinSuNkxfB8") + return 1 + }` + eval(t, src, big.NewInt(1)) + }) + t.Run("Lambda", func(t *testing.T) { + src := `package foo + func Main() int { + f := func() (int, int) { return 1, 2 } + f() + f() + return 42 + }` + eval(t, src, big.NewInt(42)) + }) } func TestMultipleFunctionCalls(t *testing.T) { diff --git a/pkg/compiler/switch_test.go b/pkg/compiler/switch_test.go index 7a34e3174..ce90ee0c2 100644 --- a/pkg/compiler/switch_test.go +++ b/pkg/compiler/switch_test.go @@ -18,6 +18,21 @@ var switchTestCases = []testCase{ }`, big.NewInt(2), }, + { + "switch with no tag", + `package main + func f() bool { return false } + func Main() int { + switch { + case f(): + return 1 + case true: + return 2 + } + return 3 + }`, + big.NewInt(2), + }, { "simple switch fail", `package main diff --git a/pkg/compiler/vm_test.go b/pkg/compiler/vm_test.go index 19be8744b..d2141a901 100644 --- a/pkg/compiler/vm_test.go +++ b/pkg/compiler/vm_test.go @@ -32,12 +32,6 @@ func runTestCases(t *testing.T, tcases []testCase) { } } -func evalWithoutStackChecks(t *testing.T, src string, result interface{}) { - v := vmAndCompile(t, src) - require.NoError(t, v.Run()) - assertResult(t, v, result) -} - func eval(t *testing.T, src string, result interface{}) { vm := vmAndCompile(t, src) err := vm.Run()