vm: make EQUAL type strict

Do not perform type conversions when comparing elements.
This commit is contained in:
Evgenii Stratonikov 2020-07-23 12:53:12 +03:00
parent 18dcc16553
commit 9252ef65bb
3 changed files with 10 additions and 17 deletions

View file

@ -356,11 +356,7 @@ func (i *BigInteger) Equals(s Item) bool {
return false return false
} }
val, ok := s.(*BigInteger) val, ok := s.(*BigInteger)
if ok { return ok && i.value.Cmp(val.value) == 0
return i.value.Cmp(val.value) == 0
}
bs, err := s.TryBytes()
return err == nil && bytes.Equal(i.Bytes(), bs)
} }
// Value implements Item interface. // Value implements Item interface.
@ -454,11 +450,7 @@ func (i *Bool) Equals(s Item) bool {
return false return false
} }
val, ok := s.(*Bool) val, ok := s.(*Bool)
if ok { return ok && i.value == val.value
return i.value == val.value
}
bs, err := s.TryBytes()
return err == nil && bytes.Equal(i.Bytes(), bs)
} }
// Type implements Item interface. // Type implements Item interface.
@ -530,8 +522,8 @@ func (i *ByteArray) Equals(s Item) bool {
} else if s == nil { } else if s == nil {
return false return false
} }
bs, err := s.TryBytes() val, ok := s.(*ByteArray)
return err == nil && bytes.Equal(i.value, bs) return ok && bytes.Equal(i.value, val.value)
} }
// Dup implements Item interface. // Dup implements Item interface.

View file

@ -211,7 +211,7 @@ var equalsTestCases = map[string][]struct {
{ {
item1: NewBool(true), item1: NewBool(true),
item2: NewBigInteger(big.NewInt(1)), item2: NewBigInteger(big.NewInt(1)),
result: true, result: false,
}, },
{ {
item1: NewBool(true), item1: NewBool(true),
@ -238,7 +238,7 @@ var equalsTestCases = map[string][]struct {
{ {
item1: NewByteArray([]byte{1}), item1: NewByteArray([]byte{1}),
item2: NewBigInteger(big.NewInt(1)), item2: NewBigInteger(big.NewInt(1)),
result: true, result: false,
}, },
{ {
item1: NewByteArray([]byte{1, 2, 3}), item1: NewByteArray([]byte{1, 2, 3}),

View file

@ -1129,7 +1129,8 @@ func TestEQUAL(t *testing.T) {
t.Run("NoArgs", getTestFuncForVM(prog, nil)) t.Run("NoArgs", getTestFuncForVM(prog, nil))
t.Run("OneArgument", getTestFuncForVM(prog, nil, 1)) t.Run("OneArgument", getTestFuncForVM(prog, nil, 1))
t.Run("Integer", getTestFuncForVM(prog, true, 5, 5)) t.Run("Integer", getTestFuncForVM(prog, true, 5, 5))
t.Run("IntegerByteArray", getTestFuncForVM(prog, true, []byte{16}, 16)) t.Run("IntegerByteArray", getTestFuncForVM(prog, false, []byte{16}, 16))
t.Run("BooleanInteger", getTestFuncForVM(prog, false, true, 1))
t.Run("Map", getTestFuncForVM(prog, false, stackitem.NewMap(), stackitem.NewMap())) t.Run("Map", getTestFuncForVM(prog, false, stackitem.NewMap(), stackitem.NewMap()))
t.Run("Array", getTestFuncForVM(prog, false, []stackitem.Item{}, []stackitem.Item{})) t.Run("Array", getTestFuncForVM(prog, false, []stackitem.Item{}, []stackitem.Item{}))
t.Run("Buffer", getTestFuncForVM(prog, false, stackitem.NewBuffer([]byte{42}), stackitem.NewBuffer([]byte{42}))) t.Run("Buffer", getTestFuncForVM(prog, false, stackitem.NewBuffer([]byte{42}), stackitem.NewBuffer([]byte{42})))
@ -1363,14 +1364,14 @@ func TestPICKITEMDupMap(t *testing.T) {
prog := makeProgram(opcode.DUP, opcode.PUSHINT8, 42, opcode.PICKITEM, opcode.ABS) prog := makeProgram(opcode.DUP, opcode.PUSHINT8, 42, opcode.PICKITEM, opcode.ABS)
vm := load(prog) vm := load(prog)
m := stackitem.NewMap() m := stackitem.NewMap()
m.Add(stackitem.Make([]byte{42}), stackitem.Make(-1)) m.Add(stackitem.Make(42), stackitem.Make(-1))
vm.estack.Push(&Element{value: m}) vm.estack.Push(&Element{value: m})
runVM(t, vm) runVM(t, vm)
assert.Equal(t, 2, vm.estack.Len()) assert.Equal(t, 2, vm.estack.Len())
assert.Equal(t, int64(1), vm.estack.Pop().BigInt().Int64()) assert.Equal(t, int64(1), vm.estack.Pop().BigInt().Int64())
items := vm.estack.Pop().Value().([]stackitem.MapElement) items := vm.estack.Pop().Value().([]stackitem.MapElement)
assert.Equal(t, 1, len(items)) assert.Equal(t, 1, len(items))
assert.Equal(t, []byte{42}, items[0].Key.Value()) assert.Equal(t, big.NewInt(42), items[0].Key.Value())
assert.Equal(t, big.NewInt(-1), items[0].Value.Value()) assert.Equal(t, big.NewInt(-1), items[0].Value.Value())
} }