diff --git a/pkg/vm/ref_counter.go b/pkg/vm/ref_counter.go new file mode 100644 index 000000000..c89f07983 --- /dev/null +++ b/pkg/vm/ref_counter.go @@ -0,0 +1,62 @@ +package vm + +// refCounter represents reference counter for the VM. +type refCounter struct { + items map[StackItem]int + size int +} + +func newRefCounter() *refCounter { + return &refCounter{ + items: make(map[StackItem]int), + } +} + +// Add adds an item to the reference counter. +func (r *refCounter) Add(item StackItem) { + r.size++ + + switch item.(type) { + case *ArrayItem, *StructItem, *MapItem: + if r.items[item]++; r.items[item] > 1 { + return + } + + switch t := item.(type) { + case *ArrayItem, *StructItem: + for _, it := range item.Value().([]StackItem) { + r.Add(it) + } + case *MapItem: + for i := range t.value { + r.Add(t.value[i].Value) + } + } + } +} + +// Remove removes item from the reference counter. +func (r *refCounter) Remove(item StackItem) { + r.size-- + + switch item.(type) { + case *ArrayItem, *StructItem, *MapItem: + if r.items[item] > 1 { + r.items[item]-- + return + } + + delete(r.items, item) + + switch t := item.(type) { + case *ArrayItem, *StructItem: + for _, it := range item.Value().([]StackItem) { + r.Remove(it) + } + case *MapItem: + for i := range t.value { + r.Remove(t.value[i].Value) + } + } + } +} diff --git a/pkg/vm/ref_counter_test.go b/pkg/vm/ref_counter_test.go new file mode 100644 index 000000000..731a5c46a --- /dev/null +++ b/pkg/vm/ref_counter_test.go @@ -0,0 +1,32 @@ +package vm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRefCounter_Add(t *testing.T) { + r := newRefCounter() + + require.Equal(t, 0, r.size) + + r.Add(NullItem{}) + require.Equal(t, 1, r.size) + + r.Add(NullItem{}) + require.Equal(t, 2, r.size) // count scalar items twice + + arr := NewArrayItem([]StackItem{NewByteArrayItem([]byte{1}), NewBoolItem(false)}) + r.Add(arr) + require.Equal(t, 5, r.size) // array + 2 elements + + r.Add(arr) + require.Equal(t, 6, r.size) // count only array + + r.Remove(arr) + require.Equal(t, 5, r.size) + + r.Remove(arr) + require.Equal(t, 2, r.size) +} diff --git a/pkg/vm/stack.go b/pkg/vm/stack.go index e65bcada2..19e9e9e8f 100644 --- a/pkg/vm/stack.go +++ b/pkg/vm/stack.go @@ -125,9 +125,7 @@ type Stack struct { top Element name string len int - - itemCount map[StackItem]int - size *int + refs *refCounter } // NewStack returns a new stack name by the given name. @@ -138,8 +136,7 @@ func NewStack(n string) *Stack { s.top.next = &s.top s.top.prev = &s.top s.len = 0 - s.itemCount = make(map[StackItem]int) - s.size = new(int) + s.refs = newRefCounter() return s } @@ -171,58 +168,11 @@ func (s *Stack) insert(e, at *Element) *Element { e.stack = s s.len++ - s.updateSizeAdd(e.value) + s.refs.Add(e.value) return e } -func (s *Stack) updateSizeAdd(item StackItem) { - *s.size++ - - switch item.(type) { - case *ArrayItem, *StructItem, *MapItem: - if s.itemCount[item]++; s.itemCount[item] > 1 { - return - } - - switch t := item.(type) { - case *ArrayItem, *StructItem: - for _, it := range item.Value().([]StackItem) { - s.updateSizeAdd(it) - } - case *MapItem: - for i := range t.value { - s.updateSizeAdd(t.value[i].Value) - } - } - } -} - -func (s *Stack) updateSizeRemove(item StackItem) { - *s.size-- - - switch item.(type) { - case *ArrayItem, *StructItem, *MapItem: - if s.itemCount[item] > 1 { - s.itemCount[item]-- - return - } - - delete(s.itemCount, item) - - switch t := item.(type) { - case *ArrayItem, *StructItem: - for _, it := range item.Value().([]StackItem) { - s.updateSizeRemove(it) - } - case *MapItem: - for i := range t.value { - s.updateSizeRemove(t.value[i].Value) - } - } - } -} - // InsertAt inserts the given item (n) deep on the stack. // Be very careful using it and _always_ check both e and n before invocation // as it will silently do wrong things otherwise. @@ -300,7 +250,7 @@ func (s *Stack) Remove(e *Element) *Element { e.stack = nil s.len-- - s.updateSizeRemove(e.value) + s.refs.Remove(e.value) return e } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 694ae147c..c96de3977 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -78,8 +78,7 @@ type VM struct { // Hash to verify in CHECKSIG/CHECKMULTISIG. checkhash []byte - itemCount map[StackItem]int - size int + refs *refCounter gasConsumed util.Fixed8 gasLimit util.Fixed8 @@ -94,9 +93,8 @@ func New() *VM { getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage. state: haltState, istack: NewStack("invocation"), - - itemCount: make(map[StackItem]int), - keys: make(map[string]*keys.PublicKey), + refs: newRefCounter(), + keys: make(map[string]*keys.PublicKey), } vm.estack = vm.newItemStack("evaluation") @@ -108,8 +106,7 @@ func New() *VM { func (v *VM) newItemStack(n string) *Stack { s := NewStack(n) - s.size = &v.size - s.itemCount = v.itemCount + s.refs = v.refs return s } @@ -499,7 +496,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if errRecover := recover(); errRecover != nil { v.state = faultState err = newError(ctx.ip, op, errRecover) - } else if v.size > MaxStackSize { + } else if v.refs.size > MaxStackSize { v.state = faultState err = newError(ctx.ip, op, "stack is too big") } @@ -955,7 +952,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro panic("APPEND: not of underlying type Array") } - v.estack.updateSizeAdd(val) + v.refs.Add(val) case opcode.PACK: n := int(v.estack.Pop().BigInt().Int64()) @@ -1024,17 +1021,17 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if index < 0 || index >= len(arr) { panic("SETITEM: invalid index") } - v.estack.updateSizeRemove(arr[index]) + v.refs.Remove(arr[index]) arr[index] = item - v.estack.updateSizeAdd(arr[index]) + v.refs.Add(arr[index]) case *MapItem: if t.Has(key.value) { - v.estack.updateSizeRemove(item) + v.refs.Remove(item) } else if len(t.value) >= MaxArraySize { panic("too big map") } t.Add(key.value, item) - v.estack.updateSizeAdd(item) + v.refs.Add(item) default: panic(fmt.Sprintf("SETITEM: invalid item type %s", t)) @@ -1059,7 +1056,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if k < 0 || k >= len(a) { panic("REMOVE: invalid index") } - v.estack.updateSizeRemove(a[k]) + v.refs.Remove(a[k]) a = append(a[:k], a[k+1:]...) t.value = a case *StructItem: @@ -1068,14 +1065,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if k < 0 || k >= len(a) { panic("REMOVE: invalid index") } - v.estack.updateSizeRemove(a[k]) + v.refs.Remove(a[k]) a = append(a[:k], a[k+1:]...) t.value = a case *MapItem: index := t.Index(key.Item()) // NEO 2.0 doesn't error on missing key. if index >= 0 { - v.estack.updateSizeRemove(t.value[index].Value) + v.refs.Remove(t.value[index].Value) t.Drop(index) } default: @@ -1087,17 +1084,17 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro switch t := elem.value.(type) { case *ArrayItem: for _, item := range t.value { - v.estack.updateSizeRemove(item) + v.refs.Remove(item) } t.value = t.value[:0] case *StructItem: for _, item := range t.value { - v.estack.updateSizeRemove(item) + v.refs.Remove(item) } t.value = t.value[:0] case *MapItem: for i := range t.value { - v.estack.updateSizeRemove(t.value[i].Value) + v.refs.Remove(t.value[i].Value) } t.value = t.value[:0] default: diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 5a313bfa5..262e77cc5 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -423,7 +423,7 @@ func TestStackLimit(t *testing.T) { vm := load(makeProgram(prog...)) for i := range expected { require.NoError(t, vm.Step()) - require.Equal(t, expected[i].size, vm.size) + require.Equal(t, expected[i].size, vm.refs.size) } } @@ -1980,7 +1980,7 @@ func testCLEARITEMS(t *testing.T, item StackItem) { v.estack.PushVal(item) runVM(t, v) require.Equal(t, 2, v.estack.Len()) - require.EqualValues(t, 2, v.size) // empty collection + it's size + require.EqualValues(t, 2, v.refs.size) // empty collection + it's size require.EqualValues(t, 0, v.estack.Pop().BigInt().Int64()) }