From dfc59129c7909b31e3bfda28ec9d5e5b37f0cc2a Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Wed, 11 Mar 2020 16:44:10 +0300 Subject: [PATCH] vm: implement EQUAL opcode properly When comparing elements of different types, conversions should be performed. This commit implement custom equality predicate for each stack item type. --- pkg/vm/context.go | 5 +++ pkg/vm/stack_item.go | 84 ++++++++++++++++++++++++++++++++++++++++++++ pkg/vm/vm.go | 14 +------- pkg/vm/vm_test.go | 10 ++++++ 4 files changed, 100 insertions(+), 13 deletions(-) diff --git a/pkg/vm/context.go b/pkg/vm/context.go index b934b2ba3..64817e4b4 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -175,6 +175,11 @@ func (c *Context) TryBytes() ([]byte, error) { return nil, errors.New("can't convert Context to ByteArray") } +// Equals implements StackItem interface. +func (c *Context) Equals(s StackItem) bool { + return c == s +} + // ToContractParameter implements StackItem interface. func (c *Context) ToContractParameter(map[StackItem]bool) smartcontract.Parameter { return smartcontract.Parameter{ diff --git a/pkg/vm/stack_item.go b/pkg/vm/stack_item.go index 188b11de6..0a03fe20f 100644 --- a/pkg/vm/stack_item.go +++ b/pkg/vm/stack_item.go @@ -1,6 +1,7 @@ package vm import ( + "bytes" "encoding/binary" "encoding/hex" "encoding/json" @@ -21,6 +22,8 @@ type StackItem interface { Dup() StackItem // TryBytes converts StackItem to a byte slice. TryBytes() ([]byte, error) + // Equals checks if 2 StackItems are equal. + Equals(s StackItem) bool // ToContractParameter converts StackItem to smartcontract.Parameter ToContractParameter(map[StackItem]bool) smartcontract.Parameter } @@ -126,6 +129,25 @@ func (i *StructItem) TryBytes() ([]byte, error) { return nil, errors.New("can't convert Struct to ByteArray") } +// Equals implements StackItem interface. +func (i *StructItem) Equals(s StackItem) bool { + if i == s { + return true + } else if s == nil { + return false + } + val, ok := s.(*StructItem) + if !ok || len(i.value) != len(val.value) { + return false + } + for j := range i.value { + if !i.value[j].Equals(val.value[j]) { + return false + } + } + return true +} + // ToContractParameter implements StackItem interface. func (i *StructItem) ToContractParameter(seen map[StackItem]bool) smartcontract.Parameter { var value []smartcontract.Parameter @@ -180,6 +202,21 @@ func (i *BigIntegerItem) TryBytes() ([]byte, error) { return i.Bytes(), nil } +// Equals implements StackItem interface. +func (i *BigIntegerItem) Equals(s StackItem) bool { + if i == s { + return true + } else if s == nil { + return false + } + val, ok := s.(*BigIntegerItem) + if ok { + return i.value.Cmp(val.value) == 0 + } + bs, err := s.TryBytes() + return err == nil && bytes.Equal(i.Bytes(), bs) +} + // Value implements StackItem interface. func (i *BigIntegerItem) Value() interface{} { return i.value @@ -254,6 +291,21 @@ func (i *BoolItem) TryBytes() ([]byte, error) { return i.Bytes(), nil } +// Equals implements StackItem interface. +func (i *BoolItem) Equals(s StackItem) bool { + if i == s { + return true + } else if s == nil { + return false + } + val, ok := s.(*BoolItem) + if ok { + return i.value == val.value + } + bs, err := s.TryBytes() + return err == nil && bytes.Equal(i.Bytes(), bs) +} + // ToContractParameter implements StackItem interface. func (i *BoolItem) ToContractParameter(map[StackItem]bool) smartcontract.Parameter { return smartcontract.Parameter{ @@ -293,6 +345,17 @@ func (i *ByteArrayItem) TryBytes() ([]byte, error) { return i.value, nil } +// Equals implements StackItem interface. +func (i *ByteArrayItem) Equals(s StackItem) bool { + if i == s { + return true + } else if s == nil { + return false + } + bs, err := s.TryBytes() + return err == nil && bytes.Equal(i.value, bs) +} + // Dup implements StackItem interface. func (i *ByteArrayItem) Dup() StackItem { a := make([]byte, len(i.value)) @@ -339,6 +402,11 @@ func (i *ArrayItem) TryBytes() ([]byte, error) { return nil, errors.New("can't convert Array to ByteArray") } +// Equals implements StackItem interface. +func (i *ArrayItem) Equals(s StackItem) bool { + return i == s +} + // Dup implements StackItem interface. func (i *ArrayItem) Dup() StackItem { // reference type @@ -384,6 +452,11 @@ func (i *MapItem) TryBytes() ([]byte, error) { return nil, errors.New("can't convert Map to ByteArray") } +// Equals implements StackItem interface. +func (i *MapItem) Equals(s StackItem) bool { + return i == s +} + func (i *MapItem) String() string { return "Map" } @@ -486,6 +559,17 @@ func (i *InteropItem) TryBytes() ([]byte, error) { return nil, errors.New("can't convert Interop to ByteArray") } +// Equals implements StackItem interface. +func (i *InteropItem) Equals(s StackItem) bool { + if i == s { + return true + } else if s == nil { + return false + } + val, ok := s.(*InteropItem) + return ok && i.value == val.value +} + // ToContractParameter implements StackItem interface. func (i *InteropItem) ToContractParameter(map[StackItem]bool) smartcontract.Parameter { return smartcontract.Parameter{ diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 4ec5b603f..9bb0ba675 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -8,7 +8,6 @@ import ( "io/ioutil" "math/big" "os" - "reflect" "text/tabwriter" "unicode/utf8" @@ -703,18 +702,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if a == nil { panic("no second-to-the-top element found") } - if ta, ok := a.value.(*ArrayItem); ok { - if tb, ok := b.value.(*ArrayItem); ok { - v.estack.PushVal(ta == tb) - break - } - } else if ma, ok := a.value.(*MapItem); ok { - if mb, ok := b.value.(*MapItem); ok { - v.estack.PushVal(ma == mb) - break - } - } - v.estack.PushVal(reflect.DeepEqual(a, b)) + v.estack.PushVal(a.value.Equals(b.value)) // Bit operations. case opcode.INVERT: diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index c8488d189..bbc654a9a 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -1006,6 +1006,16 @@ func TestEQUALGoodInteger(t *testing.T) { assert.Equal(t, &BoolItem{true}, vm.estack.Pop().value) } +func TestEQUALIntegerByteArray(t *testing.T) { + prog := makeProgram(opcode.EQUAL) + vm := load(prog) + vm.estack.PushVal([]byte{16}) + vm.estack.PushVal(16) + runVM(t, vm) + assert.Equal(t, 1, vm.estack.Len()) + assert.Equal(t, &BoolItem{true}, vm.estack.Pop().value) +} + func TestEQUALArrayTrue(t *testing.T) { prog := makeProgram(opcode.DUP, opcode.EQUAL) vm := load(prog)