Merge pull request #931 from nspcc-dev/fix/voidcall

compiler: allow to use `return` with no arguments
This commit is contained in:
Roman Khimov 2020-05-06 18:19:16 +03:00 committed by GitHub
commit 91a3b655b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 20 deletions

View file

@ -133,16 +133,13 @@ func (f funcUsage) funcUsed(name string) bool {
return ok return ok
} }
// hasReturnStmt looks if the given FuncDecl has a return statement. // lastStmtIsReturn checks if last statement of the declaration was return statement..
func hasReturnStmt(decl ast.Node) (b bool) { func lastStmtIsReturn(decl *ast.FuncDecl) (b bool) {
ast.Inspect(decl, func(node ast.Node) bool { if l := len(decl.Body.List); l != 0 {
if _, ok := node.(*ast.ReturnStmt); ok { _, ok := decl.Body.List[l-1].(*ast.ReturnStmt)
b = true return ok
return false }
} return false
return true
})
return
} }
func analyzeFuncUsage(pkgs map[*types.Package]*loader.PackageInfo) funcUsage { func analyzeFuncUsage(pkgs map[*types.Package]*loader.PackageInfo) funcUsage {

View file

@ -261,8 +261,10 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl) {
ast.Walk(c, decl.Body) ast.Walk(c, decl.Body)
// If this function returns the void (no return stmt) we will cleanup its junk on the stack. // If we have reached the end of the function without encountering `return` statement,
if !hasReturnStmt(decl) { // we should clean alt.stack manually.
// This can be the case with void and named-return functions.
if !lastStmtIsReturn(decl) {
c.saveSequencePoint(decl.Body) c.saveSequencePoint(decl.Body)
emit.Opcode(c.prog.BinWriter, opcode.FROMALTSTACK) emit.Opcode(c.prog.BinWriter, opcode.FROMALTSTACK)
emit.Opcode(c.prog.BinWriter, opcode.DROP) emit.Opcode(c.prog.BinWriter, opcode.DROP)
@ -417,9 +419,22 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
c.dropItems(cnt) c.dropItems(cnt)
// first result should be on top of the stack if len(n.Results) == 0 {
for i := len(n.Results) - 1; i >= 0; i-- { results := c.scope.decl.Type.Results
ast.Walk(c, n.Results[i]) if results.NumFields() != 0 {
// function with named returns
for i := len(results.List) - 1; i >= 0; i-- {
names := results.List[i].Names
for j := len(names) - 1; j >= 0; j-- {
c.emitLoadLocal(names[j].Name)
}
}
}
} else {
// first result should be on top of the stack
for i := len(n.Results) - 1; i >= 0; i-- {
ast.Walk(c, n.Results[i])
}
} }
c.saveSequencePoint(n) c.saveSequencePoint(n)

View file

@ -65,9 +65,11 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
} }
} }
case *ast.ReturnStmt: case *ast.ReturnStmt:
switch n.Results[0].(type) { if len(n.Results) > 0 {
case *ast.CallExpr: switch n.Results[0].(type) {
return false case *ast.CallExpr:
return false
}
} }
case *ast.BinaryExpr: case *ast.BinaryExpr:
return false return false
@ -82,6 +84,11 @@ func (c *funcScope) stackSize() int64 {
size := 0 size := 0
ast.Inspect(c.decl, func(n ast.Node) bool { ast.Inspect(c.decl, func(n ast.Node) bool {
switch n := n.(type) { switch n := n.(type) {
case *ast.FuncType:
num := n.Results.NumFields()
if num != 0 && len(n.Results.List[0].Names) != 0 {
size += num
}
case *ast.AssignStmt: case *ast.AssignStmt:
if n.Tok == token.DEFINE { if n.Tok == token.DEFINE {
size += len(n.Rhs) size += len(n.Rhs)

View file

@ -1,6 +1,7 @@
package compiler_test package compiler_test
import ( import (
"fmt"
"math/big" "math/big"
"testing" "testing"
) )
@ -121,7 +122,39 @@ func TestFunctionWithVoidReturn(t *testing.T) {
return x + y return x + y
} }
func getSomeInteger() { } func getSomeInteger() { %s }
` `
eval(t, src, big.NewInt(6)) t.Run("EmptyBody", func(t *testing.T) {
src := fmt.Sprintf(src, "")
eval(t, src, big.NewInt(6))
})
t.Run("SingleReturn", func(t *testing.T) {
src := fmt.Sprintf(src, "return")
eval(t, src, big.NewInt(6))
})
}
func TestFunctionWithVoidReturnBranch(t *testing.T) {
src := `
package testcase
func Main() int {
x := %t
f(x)
return 2
}
func f(x bool) {
if x {
return
}
}
`
t.Run("ReturnBranch", func(t *testing.T) {
src := fmt.Sprintf(src, true)
eval(t, src, big.NewInt(2))
})
t.Run("NoReturn", func(t *testing.T) {
src := fmt.Sprintf(src, false)
eval(t, src, big.NewInt(2))
})
} }

View file

@ -1,8 +1,12 @@
package compiler_test package compiler_test
import ( import (
"fmt"
"math/big" "math/big"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestMultipleReturn1(t *testing.T) { func TestMultipleReturn1(t *testing.T) {
@ -84,3 +88,30 @@ func TestSingleReturn(t *testing.T) {
` `
eval(t, src, big.NewInt(9)) eval(t, src, big.NewInt(9))
} }
func TestNamedReturn(t *testing.T) {
src := `package foo
func Main() (a int, b int) {
a = 1
b = 2
c := 3
_ = c
return %s
}`
runCase := func(ret string, result ...interface{}) func(t *testing.T) {
return func(t *testing.T) {
src := fmt.Sprintf(src, ret)
v := vmAndCompile(t, src)
require.NoError(t, v.Run())
require.Equal(t, len(result), v.Estack().Len())
for i := range result {
assert.EqualValues(t, result[i], v.Estack().Pop().Value())
}
}
}
t.Run("NormalReturn", runCase("a, b", big.NewInt(1), big.NewInt(2)))
t.Run("EmptyReturn", runCase("", big.NewInt(1), big.NewInt(2)))
t.Run("AnotherVariable", runCase("b, c", big.NewInt(2), big.NewInt(3)))
}