compiler: support named returns

This commit is contained in:
Evgenii Stratonikov 2020-05-06 17:24:32 +03:00
parent 156a2eddc5
commit b0a89e8a1a
3 changed files with 52 additions and 3 deletions

View file

@ -421,10 +421,23 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
c.dropItems(cnt) c.dropItems(cnt)
if len(n.Results) == 0 {
results := c.scope.decl.Type.Results
if results.NumFields() != 0 {
// function with named returns
for i := len(results.List) - 1; i >= 0; i-- {
names := results.List[i].Names
for j := len(names) - 1; j >= 0; j-- {
c.emitLoadLocal(names[j].Name)
}
}
}
} else {
// first result should be on top of the stack // first result should be on top of the stack
for i := len(n.Results) - 1; i >= 0; i-- { for i := len(n.Results) - 1; i >= 0; i-- {
ast.Walk(c, n.Results[i]) ast.Walk(c, n.Results[i])
} }
}
c.saveSequencePoint(n) c.saveSequencePoint(n)
emit.Opcode(c.prog.BinWriter, opcode.FROMALTSTACK) emit.Opcode(c.prog.BinWriter, opcode.FROMALTSTACK)

View file

@ -84,6 +84,11 @@ func (c *funcScope) stackSize() int64 {
size := 0 size := 0
ast.Inspect(c.decl, func(n ast.Node) bool { ast.Inspect(c.decl, func(n ast.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *ast.FuncType:
num := n.Results.NumFields()
if num != 0 && len(n.Results.List[0].Names) != 0 {
size += num
}
case *ast.AssignStmt: case *ast.AssignStmt:
if n.Tok == token.DEFINE { if n.Tok == token.DEFINE {
size += len(n.Rhs) size += len(n.Rhs)

View file

@ -1,8 +1,12 @@
package compiler_test package compiler_test
import ( import (
"fmt"
"math/big" "math/big"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestReturnInt64(t *testing.T) { func TestReturnInt64(t *testing.T) {
@ -92,3 +96,30 @@ func TestSingleReturn(t *testing.T) {
` `
eval(t, src, big.NewInt(9)) eval(t, src, big.NewInt(9))
} }
func TestNamedReturn(t *testing.T) {
src := `package foo
func Main() (a int, b int) {
a = 1
b = 2
c := 3
_ = c
return %s
}`
runCase := func(ret string, result ...interface{}) func(t *testing.T) {
return func(t *testing.T) {
src := fmt.Sprintf(src, ret)
v := vmAndCompile(t, src)
require.NoError(t, v.Run())
require.Equal(t, len(result), v.Estack().Len())
for i := range result {
assert.EqualValues(t, result[i], v.Estack().Pop().Value())
}
}
}
t.Run("NormalReturn", runCase("a, b", big.NewInt(1), big.NewInt(2)))
t.Run("EmptyReturn", runCase("", big.NewInt(1), big.NewInt(2)))
t.Run("AnotherVariable", runCase("b, c", big.NewInt(2), big.NewInt(3)))
}