diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 9956f9a9f..8113eaf8d 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -407,7 +407,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil case *ast.SwitchStmt: - // fallthrough is not supported ast.Walk(c, n.Tag) eqOpcode := c.getEqualityOpcode(n.Tag) @@ -416,9 +415,13 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { lastSwitch := c.currentSwitch c.currentSwitch = label + startLabels := make([]uint16, len(n.Body.List)) + for i := range startLabels { + startLabels[i] = c.newLabel() + } for i := range n.Body.List { lEnd := c.newLabel() - lStart := c.newLabel() + lStart := startLabels[i] cc := n.Body.List[i].(*ast.CaseClause) if l := len(cc.List); l != 0 { // if not `default` @@ -435,7 +438,12 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } c.setLabel(lStart) - for _, stmt := range cc.Body { + last := len(cc.Body) - 1 + for j, stmt := range cc.Body { + if j == last && isFallthroughStmt(stmt) { + emit.Jmp(c.prog.BinWriter, opcode.JMP, startLabels[i+1]) + break + } ast.Walk(c, stmt) } emit.Jmp(c.prog.BinWriter, opcode.JMP, switchEnd) @@ -834,6 +842,11 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return c } +func isFallthroughStmt(c ast.Node) bool { + s, ok := c.(*ast.BranchStmt) + return ok && s.Tok == token.FALLTHROUGH +} + // emitReverse reverses top num items of the stack. func (c *codegen) emitReverse(num int) { switch num { diff --git a/pkg/compiler/switch_test.go b/pkg/compiler/switch_test.go index 8ffe00b8a..7a34e3174 100644 --- a/pkg/compiler/switch_test.go +++ b/pkg/compiler/switch_test.go @@ -187,6 +187,41 @@ var switchTestCases = []testCase{ }`, big.NewInt(2), }, + { + "simple fallthrough", + `package main + func Main() int { + n := 2 + switch n { + case 1: return 5 + case 2: fallthrough + case 3: return 6 + } + return 7 + }`, + big.NewInt(6), + }, + { + "double fallthrough", + `package main + func Main() int { + n := 2 + k := 5 + switch n { + case 0: return k + case 1: fallthrough + case 2: + k++ + fallthrough + case 3: + case 4: + k++ + return k + } + return k + }`, + big.NewInt(6), + }, } func TestSwitch(t *testing.T) {