compiler: clean up stack on branch statements

When `return` or `break` statement is encountered inside
a for/range/switch statement, top stack items can be auxilliary.
They need to be cleaned up before returning from the function.
This commit is contained in:
Evgenii Stratonikov 2020-03-06 15:11:14 +03:00
parent 3f1e8f66b6
commit 2a1402f25d
3 changed files with 63 additions and 2 deletions

View file

@ -41,6 +41,8 @@ type codegen struct {
// A mapping from label's names to their ids. // A mapping from label's names to their ids.
labels map[labelWithType]uint16 labels map[labelWithType]uint16
// A list of nested label names together with evaluation stack depth.
labelList []labelWithStackSize
// A label for the for-loop being currently visited. // A label for the for-loop being currently visited.
currentFor string currentFor string
@ -66,6 +68,11 @@ type labelWithType struct {
typ labelOffsetType typ labelOffsetType
} }
type labelWithStackSize struct {
name string
sz int
}
// newLabel creates a new label to jump to // newLabel creates a new label to jump to
func (c *codegen) newLabel() (l uint16) { func (c *codegen) newLabel() (l uint16) {
li := len(c.l) li := len(c.l)
@ -373,6 +380,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
l := c.newLabel() l := c.newLabel()
c.setLabel(l) c.setLabel(l)
cnt := 0
for i := range c.labelList {
cnt += c.labelList[i].sz
}
c.dropItems(cnt)
// first result should be on top of the stack // first result should be on top of the stack
for i := len(n.Results) - 1; i >= 0; i-- { for i := len(n.Results) - 1; i >= 0; i-- {
ast.Walk(c, n.Results[i]) ast.Walk(c, n.Results[i])
@ -414,6 +427,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
lastSwitch := c.currentSwitch lastSwitch := c.currentSwitch
c.currentSwitch = label c.currentSwitch = label
c.pushStackLabel(label, 1)
startLabels := make([]uint16, len(n.Body.List)) startLabels := make([]uint16, len(n.Body.List))
for i := range startLabels { for i := range startLabels {
@ -451,7 +465,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
c.setLabel(switchEnd) c.setLabel(switchEnd)
emit.Opcode(c.prog.BinWriter, opcode.DROP) c.dropStackLabel()
c.currentSwitch = lastSwitch c.currentSwitch = lastSwitch
@ -725,6 +739,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
label = c.currentFor label = c.currentFor
} }
cnt := 0
for i := len(c.labelList) - 1; i >= 0 && c.labelList[i].name != label; i-- {
cnt += c.labelList[i].sz
}
c.dropItems(cnt)
switch n.Tok { switch n.Tok {
case token.BREAK: case token.BREAK:
end := c.getLabelOffset(labelEnd, label) end := c.getLabelOffset(labelEnd, label)
@ -759,6 +779,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
// Set label and walk the condition. // Set label and walk the condition.
c.pushStackLabel(label, 0)
c.setLabel(fstart) c.setLabel(fstart)
ast.Walk(c, n.Cond) ast.Walk(c, n.Cond)
@ -775,6 +796,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
// Jump back to condition. // Jump back to condition.
emit.Jmp(c.prog.BinWriter, opcode.JMP, fstart) emit.Jmp(c.prog.BinWriter, opcode.JMP, fstart)
c.setLabel(fend) c.setLabel(fend)
c.dropStackLabel()
c.currentFor = lastLabel c.currentFor = lastLabel
c.currentSwitch = lastSwitch c.currentSwitch = lastSwitch
@ -803,6 +825,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
emit.Opcode(c.prog.BinWriter, opcode.ARRAYSIZE) emit.Opcode(c.prog.BinWriter, opcode.ARRAYSIZE)
emit.Opcode(c.prog.BinWriter, opcode.PUSH0) emit.Opcode(c.prog.BinWriter, opcode.PUSH0)
c.pushStackLabel(label, 2)
c.setLabel(start) c.setLabel(start)
emit.Opcode(c.prog.BinWriter, opcode.OVER) emit.Opcode(c.prog.BinWriter, opcode.OVER)
@ -825,6 +848,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
emit.Jmp(c.prog.BinWriter, opcode.JMP, start) emit.Jmp(c.prog.BinWriter, opcode.JMP, start)
c.setLabel(end) c.setLabel(end)
c.dropStackLabel()
c.currentFor = lastFor c.currentFor = lastFor
c.currentSwitch = lastSwitch c.currentSwitch = lastSwitch
@ -847,6 +871,32 @@ func isFallthroughStmt(c ast.Node) bool {
return ok && s.Tok == token.FALLTHROUGH return ok && s.Tok == token.FALLTHROUGH
} }
func (c *codegen) pushStackLabel(name string, size int) {
c.labelList = append(c.labelList, labelWithStackSize{
name: name,
sz: size,
})
}
func (c *codegen) dropStackLabel() {
last := len(c.labelList) - 1
c.dropItems(c.labelList[last].sz)
c.labelList = c.labelList[:last]
}
func (c *codegen) dropItems(n int) {
if n < 4 {
for i := 0; i < n; i++ {
emit.Opcode(c.prog.BinWriter, opcode.DROP)
}
return
}
emit.Int(c.prog.BinWriter, int64(n))
emit.Opcode(c.prog.BinWriter, opcode.PACK)
emit.Opcode(c.prog.BinWriter, opcode.DROP)
}
// emitReverse reverses top num items of the stack. // emitReverse reverses top num items of the stack.
func (c *codegen) emitReverse(num int) { func (c *codegen) emitReverse(num int) {
switch num { switch num {

View file

@ -35,7 +35,10 @@ func TestNotAssignedFunctionCall(t *testing.T) {
return 0 return 0
} }
` `
eval(t, src, []byte{}) // disable stack checks because it is hard right now
// to distinguish between simple function call traversal
// and the same traversal inside an assignment.
evalWithoutStackChecks(t, src, []byte{})
} }
func TestMultipleFunctionCalls(t *testing.T) { func TestMultipleFunctionCalls(t *testing.T) {

View file

@ -23,10 +23,17 @@ func runTestCases(t *testing.T, tcases []testCase) {
} }
} }
func evalWithoutStackChecks(t *testing.T, src string, result interface{}) {
v := vmAndCompile(t, src)
require.NoError(t, v.Run())
assertResult(t, v, result)
}
func eval(t *testing.T, src string, result interface{}) { func eval(t *testing.T, src string, result interface{}) {
vm := vmAndCompile(t, src) vm := vmAndCompile(t, src)
err := vm.Run() err := vm.Run()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, vm.Estack().Len(), "stack contains unexpected items")
assertResult(t, vm, result) assertResult(t, vm, result)
} }
@ -35,6 +42,7 @@ func evalWithArgs(t *testing.T, src string, op []byte, args []vm.StackItem, resu
vm.LoadArgs(op, args) vm.LoadArgs(op, args)
err := vm.Run() err := vm.Run()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, vm.Estack().Len(), "stack contains unexpected items")
assertResult(t, vm, result) assertResult(t, vm, result)
} }