From 28571bd3dcb3bc184724e9774dfa76f597de1975 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Tue, 28 Jan 2020 15:47:56 +0300 Subject: [PATCH] compiler: implement switch statement support --- pkg/compiler/codegen.go | 54 ++++++++++++-- pkg/compiler/switch_test.go | 138 ++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 5 deletions(-) create mode 100644 pkg/compiler/switch_test.go diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index d9c0ebb41..a8973da3b 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -337,6 +337,44 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { c.setLabel(lElseEnd) return nil + case *ast.SwitchStmt: + // fallthrough is not supported + ast.Walk(c, n.Tag) + + eqOpcode := c.getEqualityOpcode(n.Tag) + switchEnd := c.newLabel() + + for i := range n.Body.List { + lEnd := c.newLabel() + lStart := c.newLabel() + cc := n.Body.List[i].(*ast.CaseClause) + + if l := len(cc.List); l != 0 { // if not `default` + for j := range cc.List { + emitOpcode(c.prog.BinWriter, opcode.DUP) + ast.Walk(c, cc.List[j]) + emitOpcode(c.prog.BinWriter, eqOpcode) + if j == l-1 { + emitJmp(c.prog.BinWriter, opcode.JMPIFNOT, int16(lEnd)) + } else { + emitJmp(c.prog.BinWriter, opcode.JMPIF, int16(lStart)) + } + } + } + + c.setLabel(lStart) + for _, stmt := range cc.Body { + ast.Walk(c, stmt) + } + emitJmp(c.prog.BinWriter, opcode.JMP, int16(switchEnd)) + c.setLabel(lEnd) + } + + c.setLabel(switchEnd) + emitOpcode(c.prog.BinWriter, opcode.DROP) + + return nil + case *ast.BasicLit: c.emitLoadConst(c.typeInfo.Types[n]) return nil @@ -431,11 +469,8 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { } case n.Op == token.EQL: // VM has separate opcodes for number and string equality - if isStringType(c.typeInfo.Types[n.X].Type) { - emitOpcode(c.prog.BinWriter, opcode.EQUAL) - } else { - emitOpcode(c.prog.BinWriter, opcode.NUMEQUAL) - } + op := c.getEqualityOpcode(n.X) + emitOpcode(c.prog.BinWriter, op) case n.Op == token.NEQ: // VM has separate opcodes for number and string equality if isStringType(c.typeInfo.Types[n.X].Type) { @@ -657,6 +692,15 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { return c } +func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { + t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic) + if ok && t.Info()&types.IsNumeric != 0 { + return opcode.NUMEQUAL + } + + return opcode.EQUAL +} + // getByteArray returns byte array value from constant expr. // Only literals are supported. func (c *codegen) getByteArray(expr ast.Expr) []byte { diff --git a/pkg/compiler/switch_test.go b/pkg/compiler/switch_test.go new file mode 100644 index 000000000..8b39c3c8b --- /dev/null +++ b/pkg/compiler/switch_test.go @@ -0,0 +1,138 @@ +package compiler_test + +import ( + "math/big" + "testing" +) + +var switchTestCases = []testCase{ + { + "simple switch success", + `package main + func Main() int { + a := 5 + switch a { + case 5: return 2 + } + return 1 + }`, + big.NewInt(2), + }, + { + "simple switch fail", + `package main + func Main() int { + a := 6 + switch a { + case 5: + return 2 + } + return 1 + }`, + big.NewInt(1), + }, + { + "multiple cases success", + `package main + func Main() int { + a := 6 + switch a { + case 5: return 2 + case 6: return 3 + } + return 1 + }`, + big.NewInt(3), + }, + { + "multiple cases fail", + `package main + func Main() int { + a := 7 + switch a { + case 5: return 2 + case 6: return 3 + } + return 1 + }`, + big.NewInt(1), + }, + { + "default case", + `package main + func Main() int { + a := 7 + switch a { + case 5: return 2 + case 6: return 3 + default: return 4 + } + return 1 + }`, + big.NewInt(4), + }, + { + "empty case before default", + `package main + func Main() int { + a := 6 + switch a { + case 5: return 2 + case 6: + default: return 4 + } + return 1 + }`, + big.NewInt(1), + }, + { + "expression in case clause", + `package main + func Main() int { + a := 6 + b := 3 + switch a { + case 5: return 2 + case b*3-3: return 3 + } + return 1 + }`, + big.NewInt(3), + }, + { + "multiple expressions in case", + `package main + func Main() int { + a := 8 + b := 3 + switch a { + case 5: return 2 + case b*3-3, 7, 8: return 3 + } + return 1 + }`, + big.NewInt(3), + }, + { + "string switch", + `package main + func Main() int { + name := "Valera" + switch name { + case "Misha": return 2 + case "Katya", "Dima": return 3 + case "Lera", "Valer" + "a": return 4 + } + return 1 + }`, + big.NewInt(4), + }, +} + +func TestSwitch(t *testing.T) { + for _, tc := range switchTestCases { + t.Run(tc.name, func(t *testing.T) { + eval(t, tc.src, tc.result) + }) + } +}