compiler: emit NOTEQUAL only for numbers
Dispatch based on types similarly to EQUAL.
This commit is contained in:
parent
9d6b0ee4a8
commit
cc5b5bff2e
2 changed files with 43 additions and 20 deletions
|
@ -470,7 +470,6 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
case *ast.SwitchStmt:
|
case *ast.SwitchStmt:
|
||||||
ast.Walk(c, n.Tag)
|
ast.Walk(c, n.Tag)
|
||||||
|
|
||||||
eqOpcode := c.getEqualityOpcode(n.Tag)
|
|
||||||
switchEnd, label := c.generateLabel(labelEnd)
|
switchEnd, label := c.generateLabel(labelEnd)
|
||||||
|
|
||||||
lastSwitch := c.currentSwitch
|
lastSwitch := c.currentSwitch
|
||||||
|
@ -490,7 +489,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
for j := range cc.List {
|
for j := range cc.List {
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.DUP)
|
emit.Opcode(c.prog.BinWriter, opcode.DUP)
|
||||||
ast.Walk(c, cc.List[j])
|
ast.Walk(c, cc.List[j])
|
||||||
emit.Opcode(c.prog.BinWriter, eqOpcode)
|
c.emitEquality(n.Tag, token.EQL)
|
||||||
if j == l-1 {
|
if j == l-1 {
|
||||||
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOT, lEnd)
|
emit.Jmp(c.prog.BinWriter, opcode.JMPIFNOT, lEnd)
|
||||||
} else {
|
} else {
|
||||||
|
@ -630,19 +629,7 @@ func (c *codegen) Visit(node ast.Node) ast.Visitor {
|
||||||
c.prog.Err = errors.New("comparison with `nil` is not supported, use `len(..) == 0` instead")
|
c.prog.Err = errors.New("comparison with `nil` is not supported, use `len(..) == 0` instead")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if n.Op == token.EQL {
|
c.emitEquality(n.X, n.Op)
|
||||||
// VM has separate opcodes for number and string equality
|
|
||||||
op := c.getEqualityOpcode(n.X)
|
|
||||||
emit.Opcode(c.prog.BinWriter, op)
|
|
||||||
} else {
|
|
||||||
// VM has separate opcodes for number and string equality
|
|
||||||
if isStringType(c.typeInfo.Types[n.X].Type) {
|
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
|
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.NOT)
|
|
||||||
} else {
|
|
||||||
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
c.convertToken(n.Op)
|
c.convertToken(n.Op)
|
||||||
}
|
}
|
||||||
|
@ -988,13 +975,26 @@ func (c *codegen) getLabelOffset(typ labelOffsetType, name string) uint16 {
|
||||||
return c.labels[labelWithType{name: name, typ: typ}]
|
return c.labels[labelWithType{name: name, typ: typ}]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *codegen) getEqualityOpcode(expr ast.Expr) opcode.Opcode {
|
func (c *codegen) emitEquality(expr ast.Expr, op token.Token) {
|
||||||
t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic)
|
t, ok := c.typeInfo.Types[expr].Type.Underlying().(*types.Basic)
|
||||||
if ok && t.Info()&types.IsNumeric != 0 {
|
isNum := ok && t.Info()&types.IsNumeric != 0
|
||||||
return opcode.NUMEQUAL
|
switch op {
|
||||||
|
case token.EQL:
|
||||||
|
if isNum {
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.NUMEQUAL)
|
||||||
|
} else {
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
|
||||||
|
}
|
||||||
|
case token.NEQ:
|
||||||
|
if isNum {
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.NUMNOTEQUAL)
|
||||||
|
} else {
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.EQUAL)
|
||||||
|
emit.Opcode(c.prog.BinWriter, opcode.NOT)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic("invalid token in emitEqual()")
|
||||||
}
|
}
|
||||||
|
|
||||||
return opcode.EQUAL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getByteArray returns byte array value from constant expr.
|
// getByteArray returns byte array value from constant expr.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package compiler_test
|
package compiler_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -366,3 +367,25 @@ var structTestCases = []testCase{
|
||||||
func TestStructs(t *testing.T) {
|
func TestStructs(t *testing.T) {
|
||||||
runTestCases(t, structTestCases)
|
runTestCases(t, structTestCases)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStructCompare(t *testing.T) {
|
||||||
|
srcTmpl := `package testcase
|
||||||
|
type T struct { f int }
|
||||||
|
func Main() int {
|
||||||
|
a := T{f: %d}
|
||||||
|
b := T{f: %d}
|
||||||
|
if a != b {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}`
|
||||||
|
t.Run("Equal", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, 4, 4)
|
||||||
|
eval(t, src, big.NewInt(1))
|
||||||
|
})
|
||||||
|
t.Run("NotEqual", func(t *testing.T) {
|
||||||
|
src := fmt.Sprintf(srcTmpl, 4, 5)
|
||||||
|
eval(t, src, big.NewInt(2))
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue