compiler: implement fallthrough in switch

Closes #628.
This commit is contained in:
Evgenii Stratonikov 2020-03-10 12:34:05 +03:00
parent 4b83e9a5cd
commit 91301df161
2 changed files with 51 additions and 3 deletions

View file

@ -407,7 +407,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
case *ast.SwitchStmt: case *ast.SwitchStmt:
// fallthrough is not supported
ast.Walk(c, n.Tag) ast.Walk(c, n.Tag)
eqOpcode := c.getEqualityOpcode(n.Tag) eqOpcode := c.getEqualityOpcode(n.Tag)
@ -416,9 +415,13 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
lastSwitch := c.currentSwitch lastSwitch := c.currentSwitch
c.currentSwitch = label c.currentSwitch = label
startLabels := make([]uint16, len(n.Body.List))
for i := range startLabels {
startLabels[i] = c.newLabel()
}
for i := range n.Body.List { for i := range n.Body.List {
lEnd := c.newLabel() lEnd := c.newLabel()
lStart := c.newLabel() lStart := startLabels[i]
cc := n.Body.List[i].(*ast.CaseClause) cc := n.Body.List[i].(*ast.CaseClause)
if l := len(cc.List); l != 0 { // if not `default` if l := len(cc.List); l != 0 { // if not `default`
@ -435,7 +438,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
c.setLabel(lStart) c.setLabel(lStart)
for _, stmt := range cc.Body { last := len(cc.Body) - 1
for j, stmt := range cc.Body {
if j == last && isFallthroughStmt(stmt) {
emit.Jmp(c.prog.BinWriter, opcode.JMP, startLabels[i+1])
break
}
ast.Walk(c, stmt) ast.Walk(c, stmt)
} }
emit.Jmp(c.prog.BinWriter, opcode.JMP, switchEnd) emit.Jmp(c.prog.BinWriter, opcode.JMP, switchEnd)
@ -834,6 +842,11 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return c return c
} }
func isFallthroughStmt(c ast.Node) bool {
s, ok := c.(*ast.BranchStmt)
return ok && s.Tok == token.FALLTHROUGH
}
// emitReverse reverses top num items of the stack. // emitReverse reverses top num items of the stack.
func (c *codegen) emitReverse(num int) { func (c *codegen) emitReverse(num int) {
switch num { switch num {

View file

@ -187,6 +187,41 @@ var switchTestCases = []testCase{
}`, }`,
big.NewInt(2), big.NewInt(2),
}, },
{
"simple fallthrough",
`package main
func Main() int {
n := 2
switch n {
case 1: return 5
case 2: fallthrough
case 3: return 6
}
return 7
}`,
big.NewInt(6),
},
{
"double fallthrough",
`package main
func Main() int {
n := 2
k := 5
switch n {
case 0: return k
case 1: fallthrough
case 2:
k++
fallthrough
case 3:
case 4:
k++
return k
}
return k
}`,
big.NewInt(6),
},
} }
func TestSwitch(t *testing.T) { func TestSwitch(t *testing.T) {