diff --git a/pkg/compiler/analysis.go b/pkg/compiler/analysis.go index b075d7032..fc2c36067 100644 --- a/pkg/compiler/analysis.go +++ b/pkg/compiler/analysis.go @@ -133,16 +133,13 @@ func (f funcUsage) funcUsed(name string) bool { return ok } -// hasReturnStmt looks if the given FuncDecl has a return statement. -func hasReturnStmt(decl ast.Node) (b bool) { - ast.Inspect(decl, func(node ast.Node) bool { - if _, ok := node.(*ast.ReturnStmt); ok { - b = true - return false - } - return true - }) - return +// lastStmtIsReturn checks if last statement of the declaration was return statement.. +func lastStmtIsReturn(decl *ast.FuncDecl) (b bool) { + if l := len(decl.Body.List); l != 0 { + _, ok := decl.Body.List[l-1].(*ast.ReturnStmt) + return ok + } + return false } func analyzeFuncUsage(pkgs map[*types.Package]*loader.PackageInfo) funcUsage { diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 696e96914..067db290e 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -263,8 +263,10 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl) { ast.Walk(c, decl.Body) - // If this function returns the void (no return stmt) we will cleanup its junk on the stack. - if !hasReturnStmt(decl) { + // If we have reached the end of the function without encountering `return` statement, + // we should clean alt.stack manually. + // This can be the case with void and named-return functions. + if !lastStmtIsReturn(decl) { c.saveSequencePoint(decl.Body) emit.Opcode(c.prog.BinWriter, opcode.FROMALTSTACK) emit.Opcode(c.prog.BinWriter, opcode.DROP) @@ -419,9 +421,22 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } c.dropItems(cnt) - // first result should be on top of the stack - for i := len(n.Results) - 1; i >= 0; i-- { - ast.Walk(c, n.Results[i]) + if len(n.Results) == 0 { + results := c.scope.decl.Type.Results + if results.NumFields() != 0 { + // function with named returns + for i := len(results.List) - 1; i >= 0; i-- { + names := results.List[i].Names + for j := len(names) - 1; j >= 0; j-- { + c.emitLoadLocal(names[j].Name) + } + } + } + } else { + // first result should be on top of the stack + for i := len(n.Results) - 1; i >= 0; i-- { + ast.Walk(c, n.Results[i]) + } } c.saveSequencePoint(n) diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index d4a0ca862..47c9aa238 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -65,9 +65,11 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool { } } case *ast.ReturnStmt: - switch n.Results[0].(type) { - case *ast.CallExpr: - return false + if len(n.Results) > 0 { + switch n.Results[0].(type) { + case *ast.CallExpr: + return false + } } case *ast.BinaryExpr: return false @@ -82,6 +84,11 @@ func (c *funcScope) stackSize() int64 { size := 0 ast.Inspect(c.decl, func(n ast.Node) bool { switch n := n.(type) { + case *ast.FuncType: + num := n.Results.NumFields() + if num != 0 && len(n.Results.List[0].Names) != 0 { + size += num + } case *ast.AssignStmt: if n.Tok == token.DEFINE { size += len(n.Rhs) diff --git a/pkg/compiler/function_call_test.go b/pkg/compiler/function_call_test.go index 7c1f22367..005707d76 100644 --- a/pkg/compiler/function_call_test.go +++ b/pkg/compiler/function_call_test.go @@ -1,6 +1,7 @@ package compiler_test import ( + "fmt" "math/big" "testing" ) @@ -121,7 +122,39 @@ func TestFunctionWithVoidReturn(t *testing.T) { return x + y } - func getSomeInteger() { } + func getSomeInteger() { %s } ` - eval(t, src, big.NewInt(6)) + t.Run("EmptyBody", func(t *testing.T) { + src := fmt.Sprintf(src, "") + eval(t, src, big.NewInt(6)) + }) + t.Run("SingleReturn", func(t *testing.T) { + src := fmt.Sprintf(src, "return") + eval(t, src, big.NewInt(6)) + }) +} + +func TestFunctionWithVoidReturnBranch(t *testing.T) { + src := ` + package testcase + func Main() int { + x := %t + f(x) + return 2 + } + + func f(x bool) { + if x { + return + } + } + ` + t.Run("ReturnBranch", func(t *testing.T) { + src := fmt.Sprintf(src, true) + eval(t, src, big.NewInt(2)) + }) + t.Run("NoReturn", func(t *testing.T) { + src := fmt.Sprintf(src, false) + eval(t, src, big.NewInt(2)) + }) } diff --git a/pkg/compiler/return_test.go b/pkg/compiler/return_test.go index ae656ae7a..53f33dca8 100644 --- a/pkg/compiler/return_test.go +++ b/pkg/compiler/return_test.go @@ -1,8 +1,12 @@ package compiler_test import ( + "fmt" "math/big" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestReturnInt64(t *testing.T) { @@ -92,3 +96,30 @@ func TestSingleReturn(t *testing.T) { ` eval(t, src, big.NewInt(9)) } + +func TestNamedReturn(t *testing.T) { + src := `package foo + func Main() (a int, b int) { + a = 1 + b = 2 + c := 3 + _ = c + return %s + }` + + runCase := func(ret string, result ...interface{}) func(t *testing.T) { + return func(t *testing.T) { + src := fmt.Sprintf(src, ret) + v := vmAndCompile(t, src) + require.NoError(t, v.Run()) + require.Equal(t, len(result), v.Estack().Len()) + for i := range result { + assert.EqualValues(t, result[i], v.Estack().Pop().Value()) + } + } + } + + t.Run("NormalReturn", runCase("a, b", big.NewInt(1), big.NewInt(2))) + t.Run("EmptyReturn", runCase("", big.NewInt(1), big.NewInt(2))) + t.Run("AnotherVariable", runCase("b, c", big.NewInt(2), big.NewInt(3))) +}