diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index da5b085c4..067db290e 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -421,9 +421,22 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } c.dropItems(cnt) - // first result should be on top of the stack - for i := len(n.Results) - 1; i >= 0; i-- { - ast.Walk(c, n.Results[i]) + if len(n.Results) == 0 { + results := c.scope.decl.Type.Results + if results.NumFields() != 0 { + // function with named returns + for i := len(results.List) - 1; i >= 0; i-- { + names := results.List[i].Names + for j := len(names) - 1; j >= 0; j-- { + c.emitLoadLocal(names[j].Name) + } + } + } + } else { + // first result should be on top of the stack + for i := len(n.Results) - 1; i >= 0; i-- { + ast.Walk(c, n.Results[i]) + } } c.saveSequencePoint(n) diff --git a/pkg/compiler/func_scope.go b/pkg/compiler/func_scope.go index de9d039e7..47c9aa238 100644 --- a/pkg/compiler/func_scope.go +++ b/pkg/compiler/func_scope.go @@ -84,6 +84,11 @@ func (c *funcScope) stackSize() int64 { size := 0 ast.Inspect(c.decl, func(n ast.Node) bool { switch n := n.(type) { + case *ast.FuncType: + num := n.Results.NumFields() + if num != 0 && len(n.Results.List[0].Names) != 0 { + size += num + } case *ast.AssignStmt: if n.Tok == token.DEFINE { size += len(n.Rhs) diff --git a/pkg/compiler/return_test.go b/pkg/compiler/return_test.go index ae656ae7a..53f33dca8 100644 --- a/pkg/compiler/return_test.go +++ b/pkg/compiler/return_test.go @@ -1,8 +1,12 @@ package compiler_test import ( + "fmt" "math/big" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestReturnInt64(t *testing.T) { @@ -92,3 +96,30 @@ func TestSingleReturn(t *testing.T) { ` eval(t, src, big.NewInt(9)) } + +func TestNamedReturn(t *testing.T) { + src := `package foo + func Main() (a int, b int) { + a = 1 + b = 2 + c := 3 + _ = c + return %s + }` + + runCase := func(ret string, result ...interface{}) func(t *testing.T) { + return func(t *testing.T) { + src := fmt.Sprintf(src, ret) + v := vmAndCompile(t, src) + require.NoError(t, v.Run()) + require.Equal(t, len(result), v.Estack().Len()) + for i := range result { + assert.EqualValues(t, result[i], v.Estack().Pop().Value()) + } + } + } + + t.Run("NormalReturn", runCase("a, b", big.NewInt(1), big.NewInt(2))) + t.Run("EmptyReturn", runCase("", big.NewInt(1), big.NewInt(2))) + t.Run("AnotherVariable", runCase("b, c", big.NewInt(2), big.NewInt(3))) +}