Merge pull request #626 from nspcc-dev/feature/switch

compiler: implement switch statement support

Implements 1 & 2 from #628.
This commit is contained in:
Roman Khimov 2020-01-29 10:03:47 +03:00 committed by GitHub
commit a839efb35e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 188 additions and 7 deletions

View file

@ -114,8 +114,7 @@ func (c *codegen) emitStoreLocal(pos int) {
} }
emitInt(c.prog.BinWriter, int64(pos)) emitInt(c.prog.BinWriter, int64(pos))
emitInt(c.prog.BinWriter, 2) emitOpcode(c.prog.BinWriter, opcode.ROT)
emitOpcode(c.prog.BinWriter, opcode.ROLL)
emitOpcode(c.prog.BinWriter, opcode.SETITEM) emitOpcode(c.prog.BinWriter, opcode.SETITEM)
} }
@ -337,6 +336,44 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
c.setLabel(lElseEnd) c.setLabel(lElseEnd)
return nil return nil
case *ast.SwitchStmt:
// fallthrough is not supported
ast.Walk(c, n.Tag)
eqOpcode := c.getEqualityOpcode(n.Tag)
switchEnd := c.newLabel()
for i := range n.Body.List {
lEnd := c.newLabel()
lStart := c.newLabel()
cc := n.Body.List[i].(*ast.CaseClause)
if l := len(cc.List); l != 0 { // if not `default`
for j := range cc.List {
emitOpcode(c.prog.BinWriter, opcode.DUP)
ast.Walk(c, cc.List[j])
emitOpcode(c.prog.BinWriter, eqOpcode)
if j == l-1 {
emitJmp(c.prog.BinWriter, opcode.JMPIFNOT, int16(lEnd))
} else {
emitJmp(c.prog.BinWriter, opcode.JMPIF, int16(lStart))
}
}
}
c.setLabel(lStart)
for _, stmt := range cc.Body {
ast.Walk(c, stmt)
}
emitJmp(c.prog.BinWriter, opcode.JMP, int16(switchEnd))
c.setLabel(lEnd)
}
c.setLabel(switchEnd)
emitOpcode(c.prog.BinWriter, opcode.DROP)
return nil
case *ast.BasicLit: case *ast.BasicLit:
c.emitLoadConst(c.typeInfo.Types[n]) c.emitLoadConst(c.typeInfo.Types[n])
return nil return nil
@ -431,11 +468,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
} }
case n.Op == token.EQL: case n.Op == token.EQL:
// VM has separate opcodes for number and string equality // VM has separate opcodes for number and string equality
if isStringType(c.typeInfo.Types[n.X].Type) { op := c.getEqualityOpcode(n.X)
emitOpcode(c.prog.BinWriter, opcode.EQUAL) emitOpcode(c.prog.BinWriter, op)
} else {
emitOpcode(c.prog.BinWriter, opcode.NUMEQUAL)
}
case n.Op == token.NEQ: case n.Op == token.NEQ:
// VM has separate opcodes for number and string equality // VM has separate opcodes for number and string equality
if isStringType(c.typeInfo.Types[n.X].Type) { if isStringType(c.typeInfo.Types[n.X].Type) {
@ -657,6 +691,15 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
return c return c
} }
func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode {
t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic)
if ok && t.Info()&types.IsNumeric != 0 {
return opcode.NUMEQUAL
}
return opcode.EQUAL
}
// getByteArray returns byte array value from constant expr. // getByteArray returns byte array value from constant expr.
// Only literals are supported. // Only literals are supported.
func (c *codegen) getByteArray(expr ast.Expr) []byte { func (c *codegen) getByteArray(expr ast.Expr) []byte {

138
pkg/compiler/switch_test.go Normal file
View file

@ -0,0 +1,138 @@
package compiler_test
import (
"math/big"
"testing"
)
var switchTestCases = []testCase{
{
"simple switch success",
`package main
func Main() int {
a := 5
switch a {
case 5: return 2
}
return 1
}`,
big.NewInt(2),
},
{
"simple switch fail",
`package main
func Main() int {
a := 6
switch a {
case 5:
return 2
}
return 1
}`,
big.NewInt(1),
},
{
"multiple cases success",
`package main
func Main() int {
a := 6
switch a {
case 5: return 2
case 6: return 3
}
return 1
}`,
big.NewInt(3),
},
{
"multiple cases fail",
`package main
func Main() int {
a := 7
switch a {
case 5: return 2
case 6: return 3
}
return 1
}`,
big.NewInt(1),
},
{
"default case",
`package main
func Main() int {
a := 7
switch a {
case 5: return 2
case 6: return 3
default: return 4
}
return 1
}`,
big.NewInt(4),
},
{
"empty case before default",
`package main
func Main() int {
a := 6
switch a {
case 5: return 2
case 6:
default: return 4
}
return 1
}`,
big.NewInt(1),
},
{
"expression in case clause",
`package main
func Main() int {
a := 6
b := 3
switch a {
case 5: return 2
case b*3-3: return 3
}
return 1
}`,
big.NewInt(3),
},
{
"multiple expressions in case",
`package main
func Main() int {
a := 8
b := 3
switch a {
case 5: return 2
case b*3-3, 7, 8: return 3
}
return 1
}`,
big.NewInt(3),
},
{
"string switch",
`package main
func Main() int {
name := "Valera"
switch name {
case "Misha": return 2
case "Katya", "Dima": return 3
case "Lera", "Valer" + "a": return 4
}
return 1
}`,
big.NewInt(4),
},
}
func TestSwitch(t *testing.T) {
for _, tc := range switchTestCases {
t.Run(tc.name, func(t *testing.T) {
eval(t, tc.src, tc.result)
})
}
}