vm: implement EQUAL opcode properly

When comparing elements of different types, conversions
should be performed. This commit implement custom equality
predicate for each stack item type.
This commit is contained in:
Evgenii Stratonikov 2020-03-11 16:44:10 +03:00
parent 5da82e8cf0
commit dfc59129c7
4 changed files with 100 additions and 13 deletions

View file

@ -175,6 +175,11 @@ func (c *Context) TryBytes() ([]byte, error) {
return nil, errors.New("can't convert Context to ByteArray") return nil, errors.New("can't convert Context to ByteArray")
} }
// Equals implements StackItem interface.
func (c *Context) Equals(s StackItem) bool {
return c == s
}
// ToContractParameter implements StackItem interface. // ToContractParameter implements StackItem interface.
func (c *Context) ToContractParameter(map[StackItem]bool) smartcontract.Parameter { func (c *Context) ToContractParameter(map[StackItem]bool) smartcontract.Parameter {
return smartcontract.Parameter{ return smartcontract.Parameter{

View file

@ -1,6 +1,7 @@
package vm package vm
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -21,6 +22,8 @@ type StackItem interface {
Dup() StackItem Dup() StackItem
// TryBytes converts StackItem to a byte slice. // TryBytes converts StackItem to a byte slice.
TryBytes() ([]byte, error) TryBytes() ([]byte, error)
// Equals checks if 2 StackItems are equal.
Equals(s StackItem) bool
// ToContractParameter converts StackItem to smartcontract.Parameter // ToContractParameter converts StackItem to smartcontract.Parameter
ToContractParameter(map[StackItem]bool) smartcontract.Parameter ToContractParameter(map[StackItem]bool) smartcontract.Parameter
} }
@ -126,6 +129,25 @@ func (i *StructItem) TryBytes() ([]byte, error) {
return nil, errors.New("can't convert Struct to ByteArray") return nil, errors.New("can't convert Struct to ByteArray")
} }
// Equals implements StackItem interface.
func (i *StructItem) Equals(s StackItem) bool {
if i == s {
return true
} else if s == nil {
return false
}
val, ok := s.(*StructItem)
if !ok || len(i.value) != len(val.value) {
return false
}
for j := range i.value {
if !i.value[j].Equals(val.value[j]) {
return false
}
}
return true
}
// ToContractParameter implements StackItem interface. // ToContractParameter implements StackItem interface.
func (i *StructItem) ToContractParameter(seen map[StackItem]bool) smartcontract.Parameter { func (i *StructItem) ToContractParameter(seen map[StackItem]bool) smartcontract.Parameter {
var value []smartcontract.Parameter var value []smartcontract.Parameter
@ -180,6 +202,21 @@ func (i *BigIntegerItem) TryBytes() ([]byte, error) {
return i.Bytes(), nil return i.Bytes(), nil
} }
// Equals implements StackItem interface.
func (i *BigIntegerItem) Equals(s StackItem) bool {
if i == s {
return true
} else if s == nil {
return false
}
val, ok := s.(*BigIntegerItem)
if ok {
return i.value.Cmp(val.value) == 0
}
bs, err := s.TryBytes()
return err == nil && bytes.Equal(i.Bytes(), bs)
}
// Value implements StackItem interface. // Value implements StackItem interface.
func (i *BigIntegerItem) Value() interface{} { func (i *BigIntegerItem) Value() interface{} {
return i.value return i.value
@ -254,6 +291,21 @@ func (i *BoolItem) TryBytes() ([]byte, error) {
return i.Bytes(), nil return i.Bytes(), nil
} }
// Equals implements StackItem interface.
func (i *BoolItem) Equals(s StackItem) bool {
if i == s {
return true
} else if s == nil {
return false
}
val, ok := s.(*BoolItem)
if ok {
return i.value == val.value
}
bs, err := s.TryBytes()
return err == nil && bytes.Equal(i.Bytes(), bs)
}
// ToContractParameter implements StackItem interface. // ToContractParameter implements StackItem interface.
func (i *BoolItem) ToContractParameter(map[StackItem]bool) smartcontract.Parameter { func (i *BoolItem) ToContractParameter(map[StackItem]bool) smartcontract.Parameter {
return smartcontract.Parameter{ return smartcontract.Parameter{
@ -293,6 +345,17 @@ func (i *ByteArrayItem) TryBytes() ([]byte, error) {
return i.value, nil return i.value, nil
} }
// Equals implements StackItem interface.
func (i *ByteArrayItem) Equals(s StackItem) bool {
if i == s {
return true
} else if s == nil {
return false
}
bs, err := s.TryBytes()
return err == nil && bytes.Equal(i.value, bs)
}
// Dup implements StackItem interface. // Dup implements StackItem interface.
func (i *ByteArrayItem) Dup() StackItem { func (i *ByteArrayItem) Dup() StackItem {
a := make([]byte, len(i.value)) a := make([]byte, len(i.value))
@ -339,6 +402,11 @@ func (i *ArrayItem) TryBytes() ([]byte, error) {
return nil, errors.New("can't convert Array to ByteArray") return nil, errors.New("can't convert Array to ByteArray")
} }
// Equals implements StackItem interface.
func (i *ArrayItem) Equals(s StackItem) bool {
return i == s
}
// Dup implements StackItem interface. // Dup implements StackItem interface.
func (i *ArrayItem) Dup() StackItem { func (i *ArrayItem) Dup() StackItem {
// reference type // reference type
@ -384,6 +452,11 @@ func (i *MapItem) TryBytes() ([]byte, error) {
return nil, errors.New("can't convert Map to ByteArray") return nil, errors.New("can't convert Map to ByteArray")
} }
// Equals implements StackItem interface.
func (i *MapItem) Equals(s StackItem) bool {
return i == s
}
func (i *MapItem) String() string { func (i *MapItem) String() string {
return "Map" return "Map"
} }
@ -486,6 +559,17 @@ func (i *InteropItem) TryBytes() ([]byte, error) {
return nil, errors.New("can't convert Interop to ByteArray") return nil, errors.New("can't convert Interop to ByteArray")
} }
// Equals implements StackItem interface.
func (i *InteropItem) Equals(s StackItem) bool {
if i == s {
return true
} else if s == nil {
return false
}
val, ok := s.(*InteropItem)
return ok && i.value == val.value
}
// ToContractParameter implements StackItem interface. // ToContractParameter implements StackItem interface.
func (i *InteropItem) ToContractParameter(map[StackItem]bool) smartcontract.Parameter { func (i *InteropItem) ToContractParameter(map[StackItem]bool) smartcontract.Parameter {
return smartcontract.Parameter{ return smartcontract.Parameter{

View file

@ -8,7 +8,6 @@ import (
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"os" "os"
"reflect"
"text/tabwriter" "text/tabwriter"
"unicode/utf8" "unicode/utf8"
@ -703,18 +702,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
if a == nil { if a == nil {
panic("no second-to-the-top element found") panic("no second-to-the-top element found")
} }
if ta, ok := a.value.(*ArrayItem); ok { v.estack.PushVal(a.value.Equals(b.value))
if tb, ok := b.value.(*ArrayItem); ok {
v.estack.PushVal(ta == tb)
break
}
} else if ma, ok := a.value.(*MapItem); ok {
if mb, ok := b.value.(*MapItem); ok {
v.estack.PushVal(ma == mb)
break
}
}
v.estack.PushVal(reflect.DeepEqual(a, b))
// Bit operations. // Bit operations.
case opcode.INVERT: case opcode.INVERT:

View file

@ -1006,6 +1006,16 @@ func TestEQUALGoodInteger(t *testing.T) {
assert.Equal(t, &BoolItem{true}, vm.estack.Pop().value) assert.Equal(t, &BoolItem{true}, vm.estack.Pop().value)
} }
func TestEQUALIntegerByteArray(t *testing.T) {
prog := makeProgram(opcode.EQUAL)
vm := load(prog)
vm.estack.PushVal([]byte{16})
vm.estack.PushVal(16)
runVM(t, vm)
assert.Equal(t, 1, vm.estack.Len())
assert.Equal(t, &BoolItem{true}, vm.estack.Pop().value)
}
func TestEQUALArrayTrue(t *testing.T) { func TestEQUALArrayTrue(t *testing.T) {
prog := makeProgram(opcode.DUP, opcode.EQUAL) prog := makeProgram(opcode.DUP, opcode.EQUAL)
vm := load(prog) vm := load(prog)