diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index d439b06db..d73cd0bcf 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -38,10 +38,30 @@ type codegen struct { // Current funcScope being converted. scope *funcScope + // A mapping from label's names to their ids. + labels map[labelWithType]int + + // A label for the for-loop being currently visited. + currentFor string + // A label to be used in the next statement. + nextLabel string + // Label table for recording jump destinations. l []int } +type labelOffsetType byte + +const ( + labelStart labelOffsetType = iota // labelStart is a default label type + labelEnd // labelEnd is a type for labels that are targets for break +) + +type labelWithType struct { + name string + typ labelOffsetType +} + // newLabel creates a new label to jump to func (c *codegen) newLabel() (l int) { l = len(c.l) @@ -49,6 +69,14 @@ func (c *codegen) newLabel() (l int) { return } +// newNamedLabel creates a new label with a specified name. +func (c *codegen) newNamedLabel(typ labelOffsetType, name string) (l int) { + l = c.newLabel() + lt := labelWithType{name: name, typ: typ} + c.labels[lt] = l + return +} + func (c *codegen) setLabel(l int) { c.l[l] = c.pc() + 1 } @@ -653,11 +681,35 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return nil + case *ast.BranchStmt: + label := c.currentFor + if n.Label != nil { + label = n.Label.Name + } + + switch n.Tok { + case token.BREAK: + end := c.getLabelOffset(labelEnd, label) + emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(end)) + case token.CONTINUE: + c.prog.Err = fmt.Errorf("continue statement is not supported yet") + } + + return nil + + case *ast.LabeledStmt: + c.nextLabel = n.Label.Name + + ast.Walk(c, n.Stmt) + + return nil + case *ast.ForStmt: - var ( - fstart = c.newLabel() - fend = c.newLabel() - ) + fstart, label := c.generateLabel(labelStart) + fend := c.newNamedLabel(labelEnd, label) + + lastLabel := c.currentFor + c.currentFor = label // Walk the initializer and condition. if n.Init != nil { @@ -681,6 +733,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(fstart)) c.setLabel(fend) + c.currentFor = lastLabel + return nil case *ast.RangeStmt: @@ -749,6 +803,21 @@ func (c *codegen) emitReverse(num int) { } } +// generateLabel returns a new label. +func (c *codegen) generateLabel(typ labelOffsetType) (int, string) { + name := c.nextLabel + if name == "" { + name = fmt.Sprintf("@%d", len(c.l)) + } + + c.nextLabel = "" + return c.newNamedLabel(typ, name), name +} + +func (c *codegen) getLabelOffset(typ labelOffsetType, name string) int { + return c.labels[labelWithType{name: name, typ: typ}] +} + func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic) if ok && t.Info()&types.IsNumeric != 0 { @@ -1046,6 +1115,7 @@ func CodeGen(info *buildInfo) ([]byte, error) { prog: io.NewBufBinWriter(), l: []int{}, funcs: map[string]*funcScope{}, + labels: map[labelWithType]int{}, typeInfo: &pkg.Info, } diff --git a/pkg/compiler/for_test.go b/pkg/compiler/for_test.go index 6ee6a46eb..b0da885d9 100644 --- a/pkg/compiler/for_test.go +++ b/pkg/compiler/for_test.go @@ -462,6 +462,82 @@ func TestForLoopRangeChangeVariable(t *testing.T) { eval(t, src, big.NewInt(12)) } +func TestForLoopBreak(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + for i < 10 { + i++ + if i == 5 { + break + } + } + return i + }` + + eval(t, src, big.NewInt(5)) +} + +func TestForLoopBreakLabel(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + loop: + for i < 10 { + i++ + if i == 5 { + break loop + } + } + return i + }` + + eval(t, src, big.NewInt(5)) +} + +func TestForLoopNestedBreak(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + for i < 10 { + i++ + for j := 0; j < 2; j++ { + i++ + if i == 5 { + break + } + } + } + return i + }` + + eval(t, src, big.NewInt(11)) +} + +func TestForLoopNestedBreakLabel(t *testing.T) { + src := ` + package foo + func Main() int { + var i int + loop: + for i < 10 { + i++ + for j := 0; j < 2; j++ { + if i == 5 { + break loop + } + i++ + } + } + return i + }` + + eval(t, src, big.NewInt(5)) +} + func TestForLoopRangeNoVariable(t *testing.T) { src := ` package foo