diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 4de352221..089660638 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -446,7 +446,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { // RHS can contain exactly one expression, thus there is no need to iterate. ast.Walk(c, n.Lhs[0]) ast.Walk(c, n.Rhs[0]) - c.convertToken(n.Tok) + c.emitToken(n.Tok, c.typeOf(n.Rhs[0])) } for i := 0; i < len(n.Lhs); i++ { switch t := n.Lhs[i].(type) { @@ -581,7 +581,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { case *ast.SwitchStmt: ast.Walk(c, n.Tag) - eqOpcode := c.getEqualityOpcode(n.Tag) + eqOpcode, _ := convertToken(token.EQL, c.typeOf(n.Tag)) switchEnd, label := c.generateLabel(labelEnd) lastSwitch := c.currentSwitch @@ -735,28 +735,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { default: ast.Walk(c, n.X) ast.Walk(c, n.Y) - switch { - case n.Op == token.ADD: - // VM has separate opcodes for number and string concatenation - if isString(tinfo.Type) { - emit.Opcode(c.prog.BinWriter, opcode.CAT) - } else { - emit.Opcode(c.prog.BinWriter, opcode.ADD) - } - case n.Op == token.EQL: - // VM has separate opcodes for number and string equality - op := c.getEqualityOpcode(n.X) - emit.Opcode(c.prog.BinWriter, op) - case n.Op == token.NEQ: - // VM has separate opcodes for number and string equality - if isString(c.typeOf(n.X)) { - emit.Opcode(c.prog.BinWriter, opcode.NOTEQUAL) - } else { - emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL) - } - default: - c.convertToken(n.Op) - } + c.emitToken(n.Op, c.typeOf(n.X)) return nil } @@ -925,7 +904,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { case *ast.IncDecStmt: ast.Walk(c, n.X) - c.convertToken(n.Tok) + c.emitToken(n.Tok, c.typeOf(n.X)) // For now only identifiers are supported for (post) for stmts. // for i := 0; i < 10; i++ {} @@ -1205,15 +1184,6 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 { return c.labels[labelWithType{name: name, typ: typ}] } -func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { - t, ok := c.typeOf(expr).Underlying().(*types.Basic) - if ok && t.Info()&types.IsNumeric != 0 { - return opcode.NUMEQUAL - } - - return opcode.EQUAL -} - // getByteArray returns byte array value from constant expr. // Only literals are supported. func (c *codegen) getByteArray(expr ast.Expr) []byte { @@ -1459,59 +1429,79 @@ func (c *codegen) convertStruct(lit *ast.CompositeLit, ptr bool) { } } -func (c *codegen) convertToken(tok token.Token) { +func (c *codegen) emitToken(tok token.Token, typ types.Type) { + op, err := convertToken(tok, typ) + if err != nil { + c.prog.Err = err + return + } + emit.Opcode(c.prog.BinWriter, op) +} + +func convertToken(tok token.Token, typ types.Type) (opcode.Opcode, error) { switch tok { case token.ADD_ASSIGN: - emit.Opcode(c.prog.BinWriter, opcode.ADD) + return opcode.ADD, nil case token.SUB_ASSIGN: - emit.Opcode(c.prog.BinWriter, opcode.SUB) + return opcode.SUB, nil case token.MUL_ASSIGN: - emit.Opcode(c.prog.BinWriter, opcode.MUL) + return opcode.MUL, nil case token.QUO_ASSIGN: - emit.Opcode(c.prog.BinWriter, opcode.DIV) + return opcode.DIV, nil case token.REM_ASSIGN: - emit.Opcode(c.prog.BinWriter, opcode.MOD) + return opcode.MOD, nil case token.ADD: - emit.Opcode(c.prog.BinWriter, opcode.ADD) + // VM has separate opcodes for number and string concatenation + if isString(typ) { + return opcode.CAT, nil + } + return opcode.ADD, nil case token.SUB: - emit.Opcode(c.prog.BinWriter, opcode.SUB) + return opcode.SUB, nil case token.MUL: - emit.Opcode(c.prog.BinWriter, opcode.MUL) + return opcode.MUL, nil case token.QUO: - emit.Opcode(c.prog.BinWriter, opcode.DIV) + return opcode.DIV, nil case token.REM: - emit.Opcode(c.prog.BinWriter, opcode.MOD) + return opcode.MOD, nil case token.LSS: - emit.Opcode(c.prog.BinWriter, opcode.LT) + return opcode.LT, nil case token.LEQ: - emit.Opcode(c.prog.BinWriter, opcode.LTE) + return opcode.LTE, nil case token.GTR: - emit.Opcode(c.prog.BinWriter, opcode.GT) + return opcode.GT, nil case token.GEQ: - emit.Opcode(c.prog.BinWriter, opcode.GTE) + return opcode.GTE, nil case token.EQL: - emit.Opcode(c.prog.BinWriter, opcode.NUMEQUAL) + // VM has separate opcodes for number and string equality + if isNumber(typ) { + return opcode.NUMEQUAL, nil + } + return opcode.EQUAL, nil case token.NEQ: - emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL) + // VM has separate opcodes for number and string equality + if isNumber(typ) { + return opcode.NUMNOTEQUAL, nil + } + return opcode.NOTEQUAL, nil case token.DEC: - emit.Opcode(c.prog.BinWriter, opcode.DEC) + return opcode.DEC, nil case token.INC: - emit.Opcode(c.prog.BinWriter, opcode.INC) + return opcode.INC, nil case token.NOT: - emit.Opcode(c.prog.BinWriter, opcode.NOT) + return opcode.NOT, nil case token.AND: - emit.Opcode(c.prog.BinWriter, opcode.AND) + return opcode.AND, nil case token.OR: - emit.Opcode(c.prog.BinWriter, opcode.OR) + return opcode.OR, nil case token.SHL: - emit.Opcode(c.prog.BinWriter, opcode.SHL) + return opcode.SHL, nil case token.SHR: - emit.Opcode(c.prog.BinWriter, opcode.SHR) + return opcode.SHR, nil case token.XOR: - emit.Opcode(c.prog.BinWriter, opcode.XOR) + return opcode.XOR, nil default: - c.prog.Err = fmt.Errorf("compiler could not convert token: %s", tok) - return + return 0, fmt.Errorf("compiler could not convert token: %s", tok) } } diff --git a/pkg/compiler/codegen_test.go b/pkg/compiler/codegen_test.go index b499771d8..9245edef9 100644 --- a/pkg/compiler/codegen_test.go +++ b/pkg/compiler/codegen_test.go @@ -2,9 +2,9 @@ package compiler import ( "go/token" + "go/types" "testing" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/stretchr/testify/assert" ) @@ -14,59 +14,74 @@ func TestConvertToken(t *testing.T) { name string token token.Token opcode opcode.Opcode + typ types.Type } testCases := []testCase{ - {"ADD", + {"ADD (string)", token.ADD, opcode.ADD, + types.Typ[types.Int], + }, + {"ADD (number)", + token.ADD, + opcode.CAT, + types.Typ[types.String], }, {"SUB", token.SUB, opcode.SUB, + nil, }, {"MUL", token.MUL, opcode.MUL, + nil, }, {"QUO", token.QUO, opcode.DIV, + nil, }, {"REM", token.REM, opcode.MOD, + nil, }, {"ADD_ASSIGN", token.ADD_ASSIGN, opcode.ADD, + nil, }, {"SUB_ASSIGN", token.SUB_ASSIGN, opcode.SUB, + nil, }, {"MUL_ASSIGN", token.MUL_ASSIGN, opcode.MUL, + nil, }, {"QUO_ASSIGN", token.QUO_ASSIGN, opcode.DIV, + nil, }, {"REM_ASSIGN", token.REM_ASSIGN, opcode.MOD, + nil, }, } for _, tcase := range testCases { - t.Run(tcase.name, func(t *testing.T) { eval(t, tcase.token, tcase.opcode) }) + t.Run(tcase.name, func(t *testing.T) { eval(t, tcase.token, tcase.opcode, tcase.typ) }) } } -func eval(t *testing.T, token token.Token, opcode opcode.Opcode) { - codegen := &codegen{prog: io.NewBufBinWriter()} - codegen.convertToken(token) - readOpcode := codegen.prog.Bytes() - assert.Equal(t, []byte{byte(opcode)}, readOpcode) +func eval(t *testing.T, token token.Token, opcode opcode.Opcode, typ types.Type) { + op, err := convertToken(token, typ) + assert.NoError(t, err) + assert.Equal(t, opcode, op) } diff --git a/pkg/compiler/types.go b/pkg/compiler/types.go index 77bbcfc1c..e03dcf286 100644 --- a/pkg/compiler/types.go +++ b/pkg/compiler/types.go @@ -31,6 +31,11 @@ func isByte(typ types.Type) bool { return isBasicTypeOfKind(typ, types.Uint8, types.Int8) } +func isNumber(typ types.Type) bool { + t, ok := typ.Underlying().(*types.Basic) + return ok && t.Info()&types.IsNumeric != 0 +} + func isString(typ types.Type) bool { return isBasicTypeOfKind(typ, types.String) }