Merge pull request #936 from nspcc-dev/fix/voidcall

compiler: allow empty returns
This commit is contained in:
Roman Khimov 2020-05-06 18:27:59 +03:00 committed by GitHub
commit d3f1ccd518
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 20 deletions

View file

@ -133,16 +133,13 @@ func (f funcUsage) funcUsed(name string) bool {
return ok
}
// hasReturnStmt looks if the given FuncDecl has a return statement.
func hasReturnStmt(decl ast.Node) (b bool) {
ast.Inspect(decl, func(node ast.Node) bool {
if _, ok := node.(*ast.ReturnStmt); ok {
b = true
return false
}
return true
})
return
// lastStmtIsReturn checks if last statement of the declaration was return statement..
func lastStmtIsReturn(decl *ast.FuncDecl) (b bool) {
if l := len(decl.Body.List); l != 0 {
_, ok := decl.Body.List[l-1].(*ast.ReturnStmt)
return ok
}
return false
}
func analyzeFuncUsage(pkgs map[*types.Package]*loader.PackageInfo) funcUsage {

View file

@ -263,8 +263,10 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl) {
ast.Walk(c, decl.Body)
// If this function returns the void (no return stmt) we will cleanup its junk on the stack.
if !hasReturnStmt(decl) {
// If we have reached the end of the function without encountering `return` statement,
// we should clean alt.stack manually.
// This can be the case with void and named-return functions.
if !lastStmtIsReturn(decl) {
c.saveSequencePoint(decl.Body)
emit.Opcode(c.prog.BinWriter, opcode.FROMALTSTACK)
emit.Opcode(c.prog.BinWriter, opcode.DROP)
@ -419,9 +421,22 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
}
c.dropItems(cnt)
// first result should be on top of the stack
for i := len(n.Results) - 1; i >= 0; i-- {
ast.Walk(c, n.Results[i])
if len(n.Results) == 0 {
results := c.scope.decl.Type.Results
if results.NumFields() != 0 {
// function with named returns
for i := len(results.List) - 1; i >= 0; i-- {
names := results.List[i].Names
for j := len(names) - 1; j >= 0; j-- {
c.emitLoadLocal(names[j].Name)
}
}
}
} else {
// first result should be on top of the stack
for i := len(n.Results) - 1; i >= 0; i-- {
ast.Walk(c, n.Results[i])
}
}
c.saveSequencePoint(n)

View file

@ -65,9 +65,11 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
}
}
case *ast.ReturnStmt:
switch n.Results[0].(type) {
case *ast.CallExpr:
return false
if len(n.Results) > 0 {
switch n.Results[0].(type) {
case *ast.CallExpr:
return false
}
}
case *ast.BinaryExpr:
return false
@ -82,6 +84,11 @@ func (c *funcScope) stackSize() int64 {
size := 0
ast.Inspect(c.decl, func(n ast.Node) bool {
switch n := n.(type) {
case *ast.FuncType:
num := n.Results.NumFields()
if num != 0 && len(n.Results.List[0].Names) != 0 {
size += num
}
case *ast.AssignStmt:
if n.Tok == token.DEFINE {
size += len(n.Rhs)

View file

@ -1,6 +1,7 @@
package compiler_test
import (
"fmt"
"math/big"
"testing"
)
@ -121,7 +122,39 @@ func TestFunctionWithVoidReturn(t *testing.T) {
return x + y
}
func getSomeInteger() { }
func getSomeInteger() { %s }
`
eval(t, src, big.NewInt(6))
t.Run("EmptyBody", func(t *testing.T) {
src := fmt.Sprintf(src, "")
eval(t, src, big.NewInt(6))
})
t.Run("SingleReturn", func(t *testing.T) {
src := fmt.Sprintf(src, "return")
eval(t, src, big.NewInt(6))
})
}
func TestFunctionWithVoidReturnBranch(t *testing.T) {
src := `
package testcase
func Main() int {
x := %t
f(x)
return 2
}
func f(x bool) {
if x {
return
}
}
`
t.Run("ReturnBranch", func(t *testing.T) {
src := fmt.Sprintf(src, true)
eval(t, src, big.NewInt(2))
})
t.Run("NoReturn", func(t *testing.T) {
src := fmt.Sprintf(src, false)
eval(t, src, big.NewInt(2))
})
}

View file

@ -1,8 +1,12 @@
package compiler_test
import (
"fmt"
"math/big"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReturnInt64(t *testing.T) {
@ -92,3 +96,30 @@ func TestSingleReturn(t *testing.T) {
`
eval(t, src, big.NewInt(9))
}
func TestNamedReturn(t *testing.T) {
src := `package foo
func Main() (a int, b int) {
a = 1
b = 2
c := 3
_ = c
return %s
}`
runCase := func(ret string, result ...interface{}) func(t *testing.T) {
return func(t *testing.T) {
src := fmt.Sprintf(src, ret)
v := vmAndCompile(t, src)
require.NoError(t, v.Run())
require.Equal(t, len(result), v.Estack().Len())
for i := range result {
assert.EqualValues(t, result[i], v.Estack().Pop().Value())
}
}
}
t.Run("NormalReturn", runCase("a, b", big.NewInt(1), big.NewInt(2)))
t.Run("EmptyReturn", runCase("", big.NewInt(1), big.NewInt(2)))
t.Run("AnotherVariable", runCase("b, c", big.NewInt(2), big.NewInt(3)))
}