Merge pull request #678 from nspcc-dev/feature/breakfor

compiler: support break statement in for loops and switch statements

Closes #677.
Implements 4-th point from #628.
This commit is contained in:
Roman Khimov 2020-02-21 12:22:16 +03:00 committed by GitHub
commit f345db58ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 378 additions and 7 deletions

View file

@ -38,10 +38,33 @@ type codegen struct {
// Current funcScope being converted. // Current funcScope being converted.
scope *funcScope scope *funcScope
// A mapping from label's names to their ids.
labels map[labelWithType]int
// A label for the for-loop being currently visited.
currentFor string
// A label for the switch statement being visited.
currentSwitch string
// A label to be used in the next statement.
nextLabel string
// Label table for recording jump destinations. // Label table for recording jump destinations.
l []int l []int
} }
type labelOffsetType byte
const (
labelStart labelOffsetType = iota // labelStart is a default label type
labelEnd // labelEnd is a type for labels that are targets for break
labelPost // labelPost is a type for labels that are targets for continue
)
type labelWithType struct {
name string
typ labelOffsetType
}
// newLabel creates a new label to jump to // newLabel creates a new label to jump to
func (c *codegen) newLabel() (l int) { func (c *codegen) newLabel() (l int) {
l = len(c.l) l = len(c.l)
@ -49,6 +72,14 @@ func (c *codegen) newLabel() (l int) {
return return
} }
// newNamedLabel creates a new label with a specified name.
func (c *codegen) newNamedLabel(typ labelOffsetType, name string) (l int) {
l = c.newLabel()
lt := labelWithType{name: name, typ: typ}
c.labels[lt] = l
return
}
func (c *codegen) setLabel(l int) { func (c *codegen) setLabel(l int) {
c.l[l] = c.pc() + 1 c.l[l] = c.pc() + 1
} }
@ -374,7 +405,10 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
ast.Walk(c, n.Tag) ast.Walk(c, n.Tag)
eqOpcode := c.getEqualityOpcode(n.Tag) eqOpcode := c.getEqualityOpcode(n.Tag)
switchEnd := c.newLabel() switchEnd, label := c.generateLabel(labelEnd)
lastSwitch := c.currentSwitch
c.currentSwitch = label
for i := range n.Body.List { for i := range n.Body.List {
lEnd := c.newLabel() lEnd := c.newLabel()
@ -405,6 +439,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.setLabel(switchEnd) c.setLabel(switchEnd)
emit.Opcode(c.prog.BinWriter, opcode.DROP) emit.Opcode(c.prog.BinWriter, opcode.DROP)
c.currentSwitch = lastSwitch
return nil return nil
case *ast.BasicLit: case *ast.BasicLit:
@ -653,11 +689,43 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
case *ast.BranchStmt:
var label string
if n.Label != nil {
label = n.Label.Name
} else if n.Tok == token.BREAK {
label = c.currentSwitch
} else if n.Tok == token.CONTINUE {
label = c.currentFor
}
switch n.Tok {
case token.BREAK:
end := c.getLabelOffset(labelEnd, label)
emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(end))
case token.CONTINUE:
post := c.getLabelOffset(labelPost, label)
emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(post))
}
return nil
case *ast.LabeledStmt:
c.nextLabel = n.Label.Name
ast.Walk(c, n.Stmt)
return nil
case *ast.ForStmt: case *ast.ForStmt:
var ( fstart, label := c.generateLabel(labelStart)
fstart = c.newLabel() fend := c.newNamedLabel(labelEnd, label)
fend = c.newLabel() fpost := c.newNamedLabel(labelPost, label)
)
lastLabel := c.currentFor
lastSwitch := c.currentSwitch
c.currentFor = label
c.currentSwitch = label
// Walk the initializer and condition. // Walk the initializer and condition.
if n.Init != nil { if n.Init != nil {
@ -673,6 +741,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
// Walk body followed by the iterator (post stmt). // Walk body followed by the iterator (post stmt).
ast.Walk(c, n.Body) ast.Walk(c, n.Body)
c.setLabel(fpost)
if n.Post != nil { if n.Post != nil {
ast.Walk(c, n.Post) ast.Walk(c, n.Post)
} }
@ -681,6 +750,9 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(fstart)) emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(fstart))
c.setLabel(fend) c.setLabel(fend)
c.currentFor = lastLabel
c.currentSwitch = lastSwitch
return nil return nil
case *ast.RangeStmt: case *ast.RangeStmt:
@ -691,8 +763,14 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return nil return nil
} }
start := c.newLabel() start, label := c.generateLabel(labelStart)
end := c.newLabel() end := c.newNamedLabel(labelEnd, label)
post := c.newNamedLabel(labelPost, label)
lastFor := c.currentFor
lastSwitch := c.currentSwitch
c.currentFor = label
c.currentSwitch = label
ast.Walk(c, n.X) ast.Walk(c, n.X)
@ -715,11 +793,16 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
ast.Walk(c, n.Body) ast.Walk(c, n.Body)
c.setLabel(post)
emit.Opcode(c.prog.BinWriter, opcode.INC) emit.Opcode(c.prog.BinWriter, opcode.INC)
emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(start)) emit.Jmp(c.prog.BinWriter, opcode.JMP, int16(start))
c.setLabel(end) c.setLabel(end)
c.currentFor = lastFor
c.currentSwitch = lastSwitch
return nil return nil
// We dont really care about assertions for the core logic. // We dont really care about assertions for the core logic.
@ -749,6 +832,21 @@ func (c *codegen) emitReverse(num int) {
} }
} }
// generateLabel returns a new label.
func (c *codegen) generateLabel(typ labelOffsetType) (int, string) {
name := c.nextLabel
if name == "" {
name = fmt.Sprintf("@%d", len(c.l))
}
c.nextLabel = ""
return c.newNamedLabel(typ, name), name
}
func (c *codegen) getLabelOffset(typ labelOffsetType, name string) int {
return c.labels[labelWithType{name: name, typ: typ}]
}
func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode {
t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic) t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic)
if ok && t.Info()&types.IsNumeric != 0 { if ok && t.Info()&types.IsNumeric != 0 {
@ -1046,6 +1144,7 @@ func CodeGen(info *buildInfo) ([]byte, error) {
prog: io.NewBufBinWriter(), prog: io.NewBufBinWriter(),
l: []int{}, l: []int{},
funcs: map[string]*funcScope{}, funcs: map[string]*funcScope{},
labels: map[labelWithType]int{},
typeInfo: &pkg.Info, typeInfo: &pkg.Info,
} }

View file

@ -462,6 +462,218 @@ func TestForLoopRangeChangeVariable(t *testing.T) {
eval(t, src, big.NewInt(12)) eval(t, src, big.NewInt(12))
} }
func TestForLoopBreak(t *testing.T) {
src := `
package foo
func Main() int {
var i int
for i < 10 {
i++
if i == 5 {
break
}
}
return i
}`
eval(t, src, big.NewInt(5))
}
func TestForLoopBreakLabel(t *testing.T) {
src := `
package foo
func Main() int {
var i int
loop:
for i < 10 {
i++
if i == 5 {
break loop
}
}
return i
}`
eval(t, src, big.NewInt(5))
}
func TestForLoopNestedBreak(t *testing.T) {
src := `
package foo
func Main() int {
var i int
for i < 10 {
i++
for j := 0; j < 2; j++ {
i++
if i == 5 {
break
}
}
}
return i
}`
eval(t, src, big.NewInt(11))
}
func TestForLoopNestedBreakLabel(t *testing.T) {
src := `
package foo
func Main() int {
var i int
loop:
for i < 10 {
i++
for j := 0; j < 2; j++ {
if i == 5 {
break loop
}
i++
}
}
return i
}`
eval(t, src, big.NewInt(5))
}
func TestForLoopContinue(t *testing.T) {
src := `
package foo
func Main() int {
var i, j int
for i < 10 {
i++
if i >= 5 {
continue
}
j++
}
return j
}`
eval(t, src, big.NewInt(4))
}
func TestForLoopContinueLabel(t *testing.T) {
src := `
package foo
func Main() int {
var i, j int
loop:
for i < 10 {
i++
if i >= 5 {
continue loop
}
j++
}
return j
}`
eval(t, src, big.NewInt(4))
}
func TestForLoopNestedContinue(t *testing.T) {
src := `
package foo
func Main() int {
var i, k int
for i < 10 {
i++
for j := 0; j < 3; j++ {
if j >= 2 {
continue
}
k++
}
}
return k
}`
eval(t, src, big.NewInt(20))
}
func TestForLoopNestedContinueLabel(t *testing.T) {
src := `
package foo
func Main() int {
var i int
loop:
for ; i < 10; i += 10 {
i++
for j := 0; j < 4; j++ {
if i == 5 {
continue loop
}
i++
}
}
return i
}`
eval(t, src, big.NewInt(15))
}
func TestForLoopRangeBreak(t *testing.T) {
src := `
package foo
func Main() int {
var i int
arr := []int{1, 2, 3}
for i = range arr {
if arr[i] == 2 {
break
}
}
return i
}`
eval(t, src, big.NewInt(1))
}
func TestForLoopRangeNestedBreak(t *testing.T) {
src := `
package foo
func Main() int {
k := 5
arr := []int{1, 2, 3}
urr := []int{4, 5, 6, 7}
loop:
for range arr {
k++
for j := range urr {
k++
if j == 3 {
break loop
}
}
}
return k
}`
eval(t, src, big.NewInt(10))
}
func TestForLoopRangeContinue(t *testing.T) {
src := `
package foo
func Main() int {
i := 6
arr := []int{1, 2, 3}
for j := range arr {
if arr[j] < 2 {
continue
}
i++
}
return i
}`
eval(t, src, big.NewInt(8))
}
func TestForLoopRangeNoVariable(t *testing.T) { func TestForLoopRangeNoVariable(t *testing.T) {
src := ` src := `
package foo package foo

View file

@ -127,6 +127,66 @@ var switchTestCases = []testCase{
}`, }`,
big.NewInt(4), big.NewInt(4),
}, },
{
"break from switch",
`package main
func Main() int {
i := 3
switch i {
case 2: return 2
case 3:
i = 1
break
return 3
case 4: return 4
}
return i
}`,
big.NewInt(1),
},
{
"break from outer for",
`package main
func Main() int {
i := 3
loop:
for i < 10 {
i++
switch i {
case 5:
i = 7
break loop
return 3
case 6: return 4
}
}
return i
}`,
big.NewInt(7),
},
{
"continue outer for",
`package main
func Main() int {
i := 2
for i < 10 {
i++
switch i {
case 3:
i = 7
continue
case 4, 5, 6, 7: return 5
case 8: return 2
}
if i == 7 {
return 6
}
}
return i
}`,
big.NewInt(2),
},
} }
func TestSwitch(t *testing.T) { func TestSwitch(t *testing.T) {