diff --git a/pkg/compiler/codegen.go b/pkg/compiler/codegen.go index 7c4cfa30b..801937d32 100644 --- a/pkg/compiler/codegen.go +++ b/pkg/compiler/codegen.go @@ -470,7 +470,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { case *ast.SwitchStmt: ast.Walk(c, n.Tag) - eqOpcode := c.getEqualityOpcode(n.Tag) switchEnd, label := c.generateLabel(labelEnd) lastSwitch := c.currentSwitch @@ -490,7 +489,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { for j := range cc.List { emit.Opcode(c.prog.BinWriter, opcode.DUP) ast.Walk(c, cc.List[j]) - emit.Opcode(c.prog.BinWriter, eqOpcode) + c.emitEquality(n.Tag, token.EQL) if j == l-1 { emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOT, lEnd) } else { @@ -630,19 +629,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor { c.prog.Err = errors.New("comparison with `nil` is not supported, use `len(..) == 0` instead") return nil } - if n.Op == token.EQL { - // VM has separate opcodes for number and string equality - op := c.getEqualityOpcode(n.X) - emit.Opcode(c.prog.BinWriter, op) - } else { - // VM has separate opcodes for number and string equality - if isStringType(c.typeInfo.Types[n.X].Type) { - emit.Opcode(c.prog.BinWriter, opcode.EQUAL) - emit.Opcode(c.prog.BinWriter, opcode.NOT) - } else { - emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL) - } - } + c.emitEquality(n.X, n.Op) default: c.convertToken(n.Op) } @@ -988,13 +975,26 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 { return c.labels[labelWithType{name: name, typ: typ}] } -func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode { +func (c *codegen) emitEquality(expr ast.Expr, op token.Token) { t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic) - if ok && t.Info()&types.IsNumeric != 0 { - return opcode.NUMEQUAL + isNum := ok && t.Info()&types.IsNumeric != 0 + switch op { + case token.EQL: + if isNum { + emit.Opcode(c.prog.BinWriter, opcode.NUMEQUAL) + } else { + emit.Opcode(c.prog.BinWriter, opcode.EQUAL) + } + case token.NEQ: + if isNum { + emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL) + } else { + emit.Opcode(c.prog.BinWriter, opcode.EQUAL) + emit.Opcode(c.prog.BinWriter, opcode.NOT) + } + default: + panic("invalid token in emitEqual()") } - - return opcode.EQUAL } // getByteArray returns byte array value from constant expr. diff --git a/pkg/compiler/struct_test.go b/pkg/compiler/struct_test.go index fad15b423..58f685087 100644 --- a/pkg/compiler/struct_test.go +++ b/pkg/compiler/struct_test.go @@ -1,6 +1,7 @@ package compiler_test import ( + "fmt" "math/big" "testing" @@ -366,3 +367,25 @@ var structTestCases = []testCase{ func TestStructs(t *testing.T) { runTestCases(t, structTestCases) } + +func TestStructCompare(t *testing.T) { + srcTmpl := `package testcase + type T struct { f int } + func Main() int { + a := T{f: %d} + b := T{f: %d} + if a != b { + return 2 + } + return 1 + }` + t.Run("Equal", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, 4, 4) + eval(t, src, big.NewInt(1)) + }) + t.Run("NotEqual", func(t *testing.T) { + src := fmt.Sprintf(srcTmpl, 4, 5) + eval(t, src, big.NewInt(2)) + }) + +}