diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index d439b06db..cdb54a938 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -38,10 +38,33 @@ type codegen struct { // Current funcScope being converted. scope *funcScope + // A mapping from label's names to their ids. + labels map[labelWithType]int + + // A label for the for-loop being currently visited. + currentFor string + // A label for the switch statement being visited. + currentSwitch string + // A label to be used in the next statement. + nextLabel string + // Label table for recording jump destinations. l []int } +type labelOffsetType byte + +const ( + labelStart labelOffsetType = iota // labelStart is a default label type + labelEnd // labelEnd is a type for labels that are targets for break + labelPost // labelPost is a type for labels that are targets for continue +) + +type labelWithType struct { + name string + typ labelOffsetType +} + // newLabel creates a new label to jump to func (c *codegen) newLabel() (l int) { l = len(c.l) @@ -49,6 +72,14 @@ func (c *codegen) newLabel() (l int) { return } +// newNamedLabel creates a new label with a specified name. +func (c *codegen) newNamedLabel(typ labelOffsetType, name string) (l int) { + l = c.newLabel() + lt := labelWithType{name: name, typ: typ} + c.labels[lt] = l + return +} + func (c *codegen) setLabel(l int) { c.l[l] = c.pc() + 1 } @@ -374,7 +405,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { ast.Walk(c, n.Tag) eqOpcode := c.getEqualityOpcode(n.Tag) - switchEnd := c.newLabel() + switchEnd, label := c.generateLabel(labelEnd) + + lastSwitch := c.currentSwitch + c.currentSwitch = label for i := range n.Body.List { lEnd := c.newLabel() @@ -405,6 +439,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { c.setLabel(switchEnd) emit.Opcode(c.prog.BinWriter, opcode.DROP) + c.currentSwitch = lastSwitch + return nil case *ast.BasicLit: @@ -653,11 +689,43 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil + case *ast.BranchStmt: + var label string + if n.Label != nil { + label = n.Label.Name + } else if n.Tok == token.BREAK { + label = c.currentSwitch + } else if n.Tok == token.CONTINUE { + label = c.currentFor + } + + switch n.Tok { + case token.BREAK: + end := c.getLabelOffset(labelEnd, label) + emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(end)) + case token.CONTINUE: + post := c.getLabelOffset(labelPost, label) + emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(post)) + } + + return nil + + case *ast.LabeledStmt: + c.nextLabel = n.Label.Name + + ast.Walk(c, n.Stmt) + + return nil + case *ast.ForStmt: - var ( - fstart = c.newLabel() - fend = c.newLabel() - ) + fstart, label := c.generateLabel(labelStart) + fend := c.newNamedLabel(labelEnd, label) + fpost := c.newNamedLabel(labelPost, label) + + lastLabel := c.currentFor + lastSwitch := c.currentSwitch + c.currentFor = label + c.currentSwitch = label // Walk the initializer and condition. if n.Init != nil { @@ -673,6 +741,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { // Walk body followed by the iterator (post stmt). ast.Walk(c, n.Body) + c.setLabel(fpost) if n.Post != nil { ast.Walk(c, n.Post) } @@ -681,6 +750,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(fstart)) c.setLabel(fend) + c.currentFor = lastLabel + c.currentSwitch = lastSwitch + return nil case *ast.RangeStmt: @@ -691,8 +763,14 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil } - start := c.newLabel() - end := c.newLabel() + start, label := c.generateLabel(labelStart) + end := c.newNamedLabel(labelEnd, label) + post := c.newNamedLabel(labelPost, label) + + lastFor := c.currentFor + lastSwitch := c.currentSwitch + c.currentFor = label + c.currentSwitch = label ast.Walk(c, n.X) @@ -715,11 +793,16 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { ast.Walk(c, n.Body) + c.setLabel(post) + emit.Opcode(c.prog.BinWriter, opcode.INC) emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(start)) c.setLabel(end) + c.currentFor = lastFor + c.currentSwitch = lastSwitch + return nil // We dont really care about assertions for the core logic. @@ -749,6 +832,21 @@ func (c *codegen) emitReverse(num int) { } } +// generateLabel returns a new label. +func (c *codegen) generateLabel(typ labelOffsetType) (int, string) { + name := c.nextLabel + if name == "" { + name = fmt.Sprintf("@%d", len(c.l)) + } + + c.nextLabel = "" + return c.newNamedLabel(typ, name), name +} + +func (c *codegen) getLabelOffset(typ labelOffsetType, name string) int { + return c.labels[labelWithType{name: name, typ: typ}] +} + func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic) if ok && t.Info()&types.IsNumeric != 0 { @@ -1046,6 +1144,7 @@ func CodeGen(info *buildInfo) ([]byte, error) { prog: io.NewBufBinWriter(), l: []int{}, funcs: map[string]*funcScope{}, + labels: map[labelWithType]int{}, typeInfo: &pkg.Info, } diff --git a/pkg/compiler/for_test.go b/pkg/compiler/for_test.go index 6ee6a46eb..cb9f1f536 100644 --- a/pkg/compiler/for_test.go +++ b/pkg/compiler/for_test.go @@ -462,6 +462,218 @@ func TestForLoopRangeChangeVariable(t *testing.T) { eval(t, src, big.NewInt(12)) } +func TestForLoopBreak(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + for i < 10 { + i++ + if i == 5 { + break + } + } + return i + }` + + eval(t, src, big.NewInt(5)) +} + +func TestForLoopBreakLabel(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + loop: + for i < 10 { + i++ + if i == 5 { + break loop + } + } + return i + }` + + eval(t, src, big.NewInt(5)) +} + +func TestForLoopNestedBreak(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + for i < 10 { + i++ + for j := 0; j < 2; j++ { + i++ + if i == 5 { + break + } + } + } + return i + }` + + eval(t, src, big.NewInt(11)) +} + +func TestForLoopNestedBreakLabel(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + loop: + for i < 10 { + i++ + for j := 0; j < 2; j++ { + if i == 5 { + break loop + } + i++ + } + } + return i + }` + + eval(t, src, big.NewInt(5)) +} + +func TestForLoopContinue(t *testing.T) { + src := ` + package foo + func Main() int { + var i, j int + for i < 10 { + i++ + if i >= 5 { + continue + } + j++ + } + return j + }` + + eval(t, src, big.NewInt(4)) +} + +func TestForLoopContinueLabel(t *testing.T) { + src := ` + package foo + func Main() int { + var i, j int + loop: + for i < 10 { + i++ + if i >= 5 { + continue loop + } + j++ + } + return j + }` + + eval(t, src, big.NewInt(4)) +} + +func TestForLoopNestedContinue(t *testing.T) { + src := ` + package foo + func Main() int { + var i, k int + for i < 10 { + i++ + for j := 0; j < 3; j++ { + if j >= 2 { + continue + } + k++ + } + } + return k + }` + + eval(t, src, big.NewInt(20)) +} + +func TestForLoopNestedContinueLabel(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + loop: + for ; i < 10; i += 10 { + i++ + for j := 0; j < 4; j++ { + if i == 5 { + continue loop + } + i++ + } + } + return i + }` + + eval(t, src, big.NewInt(15)) +} + +func TestForLoopRangeBreak(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + arr := []int{1, 2, 3} + for i = range arr { + if arr[i] == 2 { + break + } + } + return i + }` + + eval(t, src, big.NewInt(1)) +} + +func TestForLoopRangeNestedBreak(t *testing.T) { + src := ` + package foo + func Main() int { + k := 5 + arr := []int{1, 2, 3} + urr := []int{4, 5, 6, 7} + loop: + for range arr { + k++ + for j := range urr { + k++ + if j == 3 { + break loop + } + } + } + return k + }` + + eval(t, src, big.NewInt(10)) +} + +func TestForLoopRangeContinue(t *testing.T) { + src := ` + package foo + func Main() int { + i := 6 + arr := []int{1, 2, 3} + for j := range arr { + if arr[j] < 2 { + continue + } + i++ + } + return i + }` + + eval(t, src, big.NewInt(8)) +} + func TestForLoopRangeNoVariable(t *testing.T) { src := ` package foo diff --git a/pkg/compiler/switch_test.go b/pkg/compiler/switch_test.go index 8b39c3c8b..8ffe00b8a 100644 --- a/pkg/compiler/switch_test.go +++ b/pkg/compiler/switch_test.go @@ -127,6 +127,66 @@ var switchTestCases = []testCase{ }`, big.NewInt(4), }, + { + "break from switch", + `package main + func Main() int { + i := 3 + switch i { + case 2: return 2 + case 3: + i = 1 + break + return 3 + case 4: return 4 + } + return i + }`, + big.NewInt(1), + }, + { + "break from outer for", + `package main + func Main() int { + i := 3 + loop: + for i < 10 { + i++ + switch i { + case 5: + i = 7 + break loop + return 3 + case 6: return 4 + } + } + return i + }`, + big.NewInt(7), + }, + { + "continue outer for", + `package main + func Main() int { + i := 2 + for i < 10 { + i++ + switch i { + case 3: + i = 7 + continue + case 4, 5, 6, 7: return 5 + case 8: return 2 + } + + if i == 7 { + return 6 + } + } + return i + }`, + big.NewInt(2), + }, } func TestSwitch(t *testing.T) {