Merge pull request #936 from nspcc-dev/fix/voidcall
compiler: allow empty returns
This commit is contained in:
commit
d3f1ccd518
5 changed files with 103 additions and 20 deletions
|
@ -133,16 +133,13 @@ func (f funcUsage) funcUsed(name string) bool {
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasReturnStmt looks if the given FuncDecl has a return statement.
|
// lastStmtIsReturn checks if last statement of the declaration was return statement..
|
||||||
func hasReturnStmt(decl ast.Node) (b bool) {
|
func lastStmtIsReturn(decl *ast.FuncDecl) (b bool) {
|
||||||
ast.Inspect(decl, func(node ast.Node) bool {
|
if l := len(decl.Body.List); l != 0 {
|
||||||
if _, ok := node.(*ast.ReturnStmt); ok {
|
_, ok := decl.Body.List[l-1].(*ast.ReturnStmt)
|
||||||
b = true
|
return ok
|
||||||
return false
|
}
|
||||||
}
|
return false
|
||||||
return true
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func analyzeFuncUsage(pkgs map[*types.Package]*loader.PackageInfo) funcUsage {
|
func analyzeFuncUsage(pkgs map[*types.Package]*loader.PackageInfo) funcUsage {
|
||||||
|
|
|
@ -263,8 +263,10 @@ func (c *codegen) convertFuncDecl(file ast.Node, decl *ast.FuncDecl) {
|
||||||
|
|
||||||
ast.Walk(c, decl.Body)
|
ast.Walk(c, decl.Body)
|
||||||
|
|
||||||
// If this function returns the void (no return stmt) we will cleanup its junk on the stack.
|
// If we have reached the end of the function without encountering `return` statement,
|
||||||
if !hasReturnStmt(decl) {
|
// we should clean alt.stack manually.
|
||||||
|
// This can be the case with void and named-return functions.
|
||||||
|
if !lastStmtIsReturn(decl) {
|
||||||
c.saveSequencePoint(decl.Body)
|
c.saveSequencePoint(decl.Body)
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.FROMALTSTACK)
|
emit.Opcode(c.prog.BinWriter, opcode.FROMALTSTACK)
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.DROP)
|
emit.Opcode(c.prog.BinWriter, opcode.DROP)
|
||||||
|
@ -419,9 +421,22 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
}
|
}
|
||||||
c.dropItems(cnt)
|
c.dropItems(cnt)
|
||||||
|
|
||||||
// first result should be on top of the stack
|
if len(n.Results) == 0 {
|
||||||
for i := len(n.Results) - 1; i >= 0; i-- {
|
results := c.scope.decl.Type.Results
|
||||||
ast.Walk(c, n.Results[i])
|
if results.NumFields() != 0 {
|
||||||
|
// function with named returns
|
||||||
|
for i := len(results.List) - 1; i >= 0; i-- {
|
||||||
|
names := results.List[i].Names
|
||||||
|
for j := len(names) - 1; j >= 0; j-- {
|
||||||
|
c.emitLoadLocal(names[j].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// first result should be on top of the stack
|
||||||
|
for i := len(n.Results) - 1; i >= 0; i-- {
|
||||||
|
ast.Walk(c, n.Results[i])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.saveSequencePoint(n)
|
c.saveSequencePoint(n)
|
||||||
|
|
|
@ -65,9 +65,11 @@ func (c *funcScope) analyzeVoidCalls(node ast.Node) bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *ast.ReturnStmt:
|
case *ast.ReturnStmt:
|
||||||
switch n.Results[0].(type) {
|
if len(n.Results) > 0 {
|
||||||
case *ast.CallExpr:
|
switch n.Results[0].(type) {
|
||||||
return false
|
case *ast.CallExpr:
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case *ast.BinaryExpr:
|
case *ast.BinaryExpr:
|
||||||
return false
|
return false
|
||||||
|
@ -82,6 +84,11 @@ func (c *funcScope) stackSize() int64 {
|
||||||
size := 0
|
size := 0
|
||||||
ast.Inspect(c.decl, func(n ast.Node) bool {
|
ast.Inspect(c.decl, func(n ast.Node) bool {
|
||||||
switch n := n.(type) {
|
switch n := n.(type) {
|
||||||
|
case *ast.FuncType:
|
||||||
|
num := n.Results.NumFields()
|
||||||
|
if num != 0 && len(n.Results.List[0].Names) != 0 {
|
||||||
|
size += num
|
||||||
|
}
|
||||||
case *ast.AssignStmt:
|
case *ast.AssignStmt:
|
||||||
if n.Tok == token.DEFINE {
|
if n.Tok == token.DEFINE {
|
||||||
size += len(n.Rhs)
|
size += len(n.Rhs)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package compiler_test
|
package compiler_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -121,7 +122,39 @@ func TestFunctionWithVoidReturn(t *testing.T) {
|
||||||
return x + y
|
return x + y
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSomeInteger() { }
|
func getSomeInteger() { %s }
|
||||||
`
|
`
|
||||||
eval(t, src, big.NewInt(6))
|
t.Run("EmptyBody", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(src, "")
|
||||||
|
eval(t, src, big.NewInt(6))
|
||||||
|
})
|
||||||
|
t.Run("SingleReturn", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(src, "return")
|
||||||
|
eval(t, src, big.NewInt(6))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionWithVoidReturnBranch(t *testing.T) {
|
||||||
|
src := `
|
||||||
|
package testcase
|
||||||
|
func Main() int {
|
||||||
|
x := %t
|
||||||
|
f(x)
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
|
||||||
|
func f(x bool) {
|
||||||
|
if x {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
t.Run("ReturnBranch", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(src, true)
|
||||||
|
eval(t, src, big.NewInt(2))
|
||||||
|
})
|
||||||
|
t.Run("NoReturn", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(src, false)
|
||||||
|
eval(t, src, big.NewInt(2))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
package compiler_test
|
package compiler_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReturnInt64(t *testing.T) {
|
func TestReturnInt64(t *testing.T) {
|
||||||
|
@ -92,3 +96,30 @@ func TestSingleReturn(t *testing.T) {
|
||||||
`
|
`
|
||||||
eval(t, src, big.NewInt(9))
|
eval(t, src, big.NewInt(9))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNamedReturn(t *testing.T) {
|
||||||
|
src := `package foo
|
||||||
|
func Main() (a int, b int) {
|
||||||
|
a = 1
|
||||||
|
b = 2
|
||||||
|
c := 3
|
||||||
|
_ = c
|
||||||
|
return %s
|
||||||
|
}`
|
||||||
|
|
||||||
|
runCase := func(ret string, result ...interface{}) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(src, ret)
|
||||||
|
v := vmAndCompile(t, src)
|
||||||
|
require.NoError(t, v.Run())
|
||||||
|
require.Equal(t, len(result), v.Estack().Len())
|
||||||
|
for i := range result {
|
||||||
|
assert.EqualValues(t, result[i], v.Estack().Pop().Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("NormalReturn", runCase("a, b", big.NewInt(1), big.NewInt(2)))
|
||||||
|
t.Run("EmptyReturn", runCase("", big.NewInt(1), big.NewInt(2)))
|
||||||
|
t.Run("AnotherVariable", runCase("b, c", big.NewInt(2), big.NewInt(3)))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue