diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index 54fe8ff00..5d7f6acf5 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -208,6 +208,28 @@ func compareItems(t *testing.T, a, b stackitem.Item) { p, ok := b.(*stackitem.Pointer) require.True(t, ok) require.Equal(t, si.Position(), p.Position()) // there no script in test files + case *stackitem.Array, *stackitem.Struct: + require.Equal(t, a.Type(), b.Type()) + + as := a.Value().([]stackitem.Item) + bs := a.Value().([]stackitem.Item) + require.Equal(t, len(as), len(bs)) + + for i := range as { + compareItems(t, as[i], bs[i]) + } + + case *stackitem.Map: + require.Equal(t, a.Type(), b.Type()) + + as := a.Value().([]stackitem.MapElement) + bs := a.Value().([]stackitem.MapElement) + require.Equal(t, len(as), len(bs)) + + for i := range as { + compareItems(t, as[i].Key, bs[i].Key) + compareItems(t, as[i].Value, bs[i].Value) + } default: require.Equal(t, a, b) } diff --git a/pkg/vm/ref_counter.go b/pkg/vm/ref_counter.go index 182177f6c..e3d787a0a 100644 --- a/pkg/vm/ref_counter.go +++ b/pkg/vm/ref_counter.go @@ -5,15 +5,19 @@ import ( ) // refCounter represents reference counter for the VM. -type refCounter struct { - items map[stackitem.Item]int - size int -} +type refCounter int + +type ( + rcInc interface { + IncRC() int + } + rcDec interface { + DecRC() int + } +) func newRefCounter() *refCounter { - return &refCounter{ - items: make(map[stackitem.Item]int), - } + return new(refCounter) } // Add adds an item to the reference counter. @@ -21,23 +25,20 @@ func (r *refCounter) Add(item stackitem.Item) { if r == nil { return } - r.size++ + *r++ - switch item.(type) { - case *stackitem.Array, *stackitem.Struct, *stackitem.Map: - if r.items[item]++; r.items[item] > 1 { - return + irc, ok := item.(rcInc) + if !ok || irc.IncRC() > 1 { + return + } + switch t := item.(type) { + case *stackitem.Array, *stackitem.Struct: + for _, it := range item.Value().([]stackitem.Item) { + r.Add(it) } - - switch t := item.(type) { - case *stackitem.Array, *stackitem.Struct: - for _, it := range item.Value().([]stackitem.Item) { - r.Add(it) - } - case *stackitem.Map: - for i := range t.Value().([]stackitem.MapElement) { - r.Add(t.Value().([]stackitem.MapElement)[i].Value) - } + case *stackitem.Map: + for i := range t.Value().([]stackitem.MapElement) { + r.Add(t.Value().([]stackitem.MapElement)[i].Value) } } } @@ -47,26 +48,20 @@ func (r *refCounter) Remove(item stackitem.Item) { if r == nil { return } - r.size-- + *r-- - switch item.(type) { - case *stackitem.Array, *stackitem.Struct, *stackitem.Map: - if r.items[item] > 1 { - r.items[item]-- - return + irc, ok := item.(rcDec) + if !ok || irc.DecRC() > 0 { + return + } + switch t := item.(type) { + case *stackitem.Array, *stackitem.Struct: + for _, it := range item.Value().([]stackitem.Item) { + r.Remove(it) } - - delete(r.items, item) - - switch t := item.(type) { - case *stackitem.Array, *stackitem.Struct: - for _, it := range item.Value().([]stackitem.Item) { - r.Remove(it) - } - case *stackitem.Map: - for i := range t.Value().([]stackitem.MapElement) { - r.Remove(t.Value().([]stackitem.MapElement)[i].Value) - } + case *stackitem.Map: + for i := range t.Value().([]stackitem.MapElement) { + r.Remove(t.Value().([]stackitem.MapElement)[i].Value) } } } diff --git a/pkg/vm/ref_counter_test.go b/pkg/vm/ref_counter_test.go index b50390609..9e0a82d99 100644 --- a/pkg/vm/ref_counter_test.go +++ b/pkg/vm/ref_counter_test.go @@ -10,24 +10,34 @@ import ( func TestRefCounter_Add(t *testing.T) { r := newRefCounter() - require.Equal(t, 0, r.size) + require.Equal(t, 0, int(*r)) r.Add(stackitem.Null{}) - require.Equal(t, 1, r.size) + require.Equal(t, 1, int(*r)) r.Add(stackitem.Null{}) - require.Equal(t, 2, r.size) // count scalar items twice + require.Equal(t, 2, int(*r)) // count scalar items twice arr := stackitem.NewArray([]stackitem.Item{stackitem.NewByteArray([]byte{1}), stackitem.NewBool(false)}) r.Add(arr) - require.Equal(t, 5, r.size) // array + 2 elements + require.Equal(t, 5, int(*r)) // array + 2 elements r.Add(arr) - require.Equal(t, 6, r.size) // count only array + require.Equal(t, 6, int(*r)) // count only array r.Remove(arr) - require.Equal(t, 5, r.size) + require.Equal(t, 5, int(*r)) r.Remove(arr) - require.Equal(t, 2, r.size) + require.Equal(t, 2, int(*r)) +} + +func BenchmarkRefCounter_Add(b *testing.B) { + a := stackitem.NewArray(nil) + rc := newRefCounter() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rc.Add(a) + } } diff --git a/pkg/vm/stackitem/item.go b/pkg/vm/stackitem/item.go index a57dfb30d..8d9b7d63c 100644 --- a/pkg/vm/stackitem/item.go +++ b/pkg/vm/stackitem/item.go @@ -188,6 +188,7 @@ func convertPrimitive(item Item, typ Type) (Item, error) { // Struct represents a struct on the stack. type Struct struct { value []Item + rc } // NewStruct returns an new Struct object. @@ -311,7 +312,7 @@ func (i *Struct) Clone() (*Struct, error) { } func (i *Struct) clone(limit *int) (*Struct, error) { - ret := &Struct{make([]Item, len(i.value))} + ret := &Struct{value: make([]Item, len(i.value))} for j := range i.value { *limit-- if *limit < 0 { @@ -624,6 +625,7 @@ func (i *ByteArray) Convert(typ Type) (Item, error) { // Array represents a new Array object. type Array struct { value []Item + rc } // NewArray returns a new Array object. @@ -724,6 +726,7 @@ type MapElement struct { // if need be. type Map struct { value []MapElement + rc } // NewMap returns new Map object. diff --git a/pkg/vm/stackitem/reference.go b/pkg/vm/stackitem/reference.go new file mode 100644 index 000000000..0102d6b0e --- /dev/null +++ b/pkg/vm/stackitem/reference.go @@ -0,0 +1,15 @@ +package stackitem + +type rc struct { + count int +} + +func (r *rc) IncRC() int { + r.count++ + return r.count +} + +func (r *rc) DecRC() int { + r.count-- + return r.count +} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 4856270e4..2f973ed25 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -105,7 +105,6 @@ func NewWithTrigger(t trigger.Type) *VM { Invocations: make(map[util.Uint160]int), } - vm.refs.items = make(map[stackitem.Item]int) initStack(&vm.istack, "invocation", nil) vm.estack = newStack("evaluation", &vm.refs) return vm @@ -520,7 +519,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if errRecover := recover(); errRecover != nil { v.state = FaultState err = newError(ctx.ip, op, errRecover) - } else if v.refs.size > MaxStackSize { + } else if v.refs > MaxStackSize { v.state = FaultState err = newError(ctx.ip, op, "stack is too big") } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 8fb07646f..3391ddeca 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -415,7 +415,7 @@ func TestStackLimit(t *testing.T) { require.NoError(t, vm.Step(), "failed to initialize static slot") for i := range expected { require.NoError(t, vm.Step()) - require.Equal(t, expected[i].size, vm.refs.size, "i: %d", i) + require.Equal(t, expected[i].size, int(vm.refs), "i: %d", i) } } @@ -829,7 +829,7 @@ func getTestFuncForVM(prog []byte, result interface{}, args ...interface{}) func if result != nil { f = func(t *testing.T, v *VM) { require.Equal(t, 1, v.estack.Len()) - require.Equal(t, stackitem.Make(result), v.estack.Pop().value) + require.Equal(t, stackitem.Make(result).Value(), v.estack.Pop().Value()) } } return getCustomTestFuncForVM(prog, f, args...) @@ -1761,7 +1761,7 @@ func TestPACK_UNPACK_MaxSize(t *testing.T) { vm.estack.PushVal(len(elements)) runVM(t, vm) // check reference counter = 1+1+1024 - assert.Equal(t, 1+1+len(elements), vm.refs.size) + assert.Equal(t, 1+1+len(elements), int(vm.refs)) assert.Equal(t, 1+1+len(elements), vm.estack.Len()) // canary + length + elements assert.Equal(t, int64(len(elements)), vm.estack.Peek(0).Value().(*big.Int).Int64()) for i := 0; i < len(elements); i++ { @@ -1784,7 +1784,7 @@ func TestPACK_UNPACK_PACK_MaxSize(t *testing.T) { vm.estack.PushVal(len(elements)) runVM(t, vm) // check reference counter = 1+1+1024 - assert.Equal(t, 1+1+len(elements), vm.refs.size) + assert.Equal(t, 1+1+len(elements), int(vm.refs)) assert.Equal(t, 2, vm.estack.Len()) a := vm.estack.Peek(0).Array() assert.Equal(t, len(elements), len(a)) @@ -1959,7 +1959,7 @@ func testCLEARITEMS(t *testing.T, item stackitem.Item) { v.estack.PushVal(item) runVM(t, v) require.Equal(t, 2, v.estack.Len()) - require.EqualValues(t, 2, v.refs.size) // empty collection + it's size + require.EqualValues(t, 2, int(v.refs)) // empty collection + it's size require.EqualValues(t, 0, v.estack.Pop().BigInt().Int64()) }