diff --git a/pkg/vm/opcodebench_test.go b/pkg/vm/opcodebench_test.go index 465a9943e..0c68ae647 100644 --- a/pkg/vm/opcodebench_test.go +++ b/pkg/vm/opcodebench_test.go @@ -56,11 +56,11 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int, return nil } if sslot != 0 { - v.Context().static.init(sslot) + v.Context().static.init(sslot, &v.refs) } if slotloc != 0 && slotarg != 0 { - v.Context().local.init(slotloc) - v.Context().arguments.init(slotarg) + v.Context().local.init(slotloc, &v.refs) + v.Context().arguments.init(slotarg, &v.refs) } for i := range items { item, ok := items[i].(stackitem.Item) diff --git a/pkg/vm/ref_counter.go b/pkg/vm/ref_counter.go index f119e3959..a493a43ce 100644 --- a/pkg/vm/ref_counter.go +++ b/pkg/vm/ref_counter.go @@ -37,8 +37,10 @@ func (r *refCounter) Add(item stackitem.Item) { r.Add(it) } case *stackitem.Map: - for i := range t.Value().([]stackitem.MapElement) { - r.Add(t.Value().([]stackitem.MapElement)[i].Value) + elems := t.Value().([]stackitem.MapElement) + for i := range elems { + r.Add(elems[i].Key) + r.Add(elems[i].Value) } } } @@ -60,8 +62,10 @@ func (r *refCounter) Remove(item stackitem.Item) { r.Remove(it) } case *stackitem.Map: - for i := range t.Value().([]stackitem.MapElement) { - r.Remove(t.Value().([]stackitem.MapElement)[i].Value) + elems := t.Value().([]stackitem.MapElement) + for i := range elems { + r.Remove(elems[i].Key) + r.Remove(elems[i].Value) } } } diff --git a/pkg/vm/ref_counter_test.go b/pkg/vm/ref_counter_test.go index 9e0a82d99..f0e63b318 100644 --- a/pkg/vm/ref_counter_test.go +++ b/pkg/vm/ref_counter_test.go @@ -30,6 +30,20 @@ func TestRefCounter_Add(t *testing.T) { r.Remove(arr) require.Equal(t, 2, int(*r)) + + m := stackitem.NewMap() + m.Add(stackitem.NewByteArray([]byte("some")), stackitem.NewBool(false)) + r.Add(m) + require.Equal(t, 5, int(*r)) // map + key + value + + r.Add(m) + require.Equal(t, 6, int(*r)) // map only + + r.Remove(m) + require.Equal(t, 5, int(*r)) + + r.Remove(m) + require.Equal(t, 2, int(*r)) } func BenchmarkRefCounter_Add(b *testing.B) { diff --git a/pkg/vm/slot.go b/pkg/vm/slot.go index 69645f689..4544ab8dd 100644 --- a/pkg/vm/slot.go +++ b/pkg/vm/slot.go @@ -10,11 +10,12 @@ import ( type slot []stackitem.Item // init sets static slot size to n. It is intended to be used only by INITSSLOT. -func (s *slot) init(n int) { +func (s *slot) init(n int, rc *refCounter) { if *s != nil { panic("already initialized") } *s = make([]stackitem.Item, n) + *rc += refCounter(n) // Virtual "Null" elements. } // Set sets i-th storage slot. @@ -26,6 +27,8 @@ func (s slot) Set(i int, item stackitem.Item, refs *refCounter) { s[i] = item if old != nil { refs.Remove(old) + } else { + *refs-- // Not really existing, but counted Null element. } refs.Add(item) } @@ -38,8 +41,8 @@ func (s slot) Get(i int) stackitem.Item { return stackitem.Null{} } -// Clear removes all slot variables from the reference counter. -func (s slot) Clear(refs *refCounter) { +// ClearRefs removes all slot variables from the reference counter. +func (s slot) ClearRefs(refs *refCounter) { for _, item := range s { refs.Remove(item) } diff --git a/pkg/vm/slot_test.go b/pkg/vm/slot_test.go index 212470a9f..ec10a7ffc 100644 --- a/pkg/vm/slot_test.go +++ b/pkg/vm/slot_test.go @@ -13,8 +13,9 @@ func TestSlot_Get(t *testing.T) { var s slot require.Panics(t, func() { s.Size() }) - s.init(3) + s.init(3, rc) require.Equal(t, 3, s.Size()) + require.Equal(t, 3, int(*rc)) // Null is the default item := s.Get(2) @@ -22,4 +23,5 @@ func TestSlot_Get(t *testing.T) { s.Set(1, stackitem.NewBigInteger(big.NewInt(42)), rc) require.Equal(t, stackitem.NewBigInteger(big.NewInt(42)), s.Get(1)) + require.Equal(t, 3, int(*rc)) } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 717d5468a..801d68846 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -616,7 +616,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if parameter[0] == 0 { panic("zero argument") } - ctx.static.init(int(parameter[0])) + ctx.static.init(int(parameter[0]), &v.refs) case opcode.INITSLOT: if ctx.local != nil || ctx.arguments != nil { @@ -626,11 +626,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro panic("zero argument") } if parameter[0] > 0 { - ctx.local.init(int(parameter[0])) + ctx.local.init(int(parameter[0]), &v.refs) } if parameter[1] > 0 { sz := int(parameter[1]) - ctx.arguments.init(sz) + ctx.arguments.init(sz, &v.refs) for i := 0; i < sz; i++ { ctx.arguments.Set(i, v.estack.Pop().Item(), &v.refs) } @@ -1250,6 +1250,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case *stackitem.Map: if i := t.Index(key.value); i >= 0 { v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value) + } else { + v.refs.Add(key.value) } t.Add(key.value, item) v.refs.Add(item) @@ -1312,7 +1314,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro index := t.Index(key.Item()) // NEO 2.0 doesn't error on missing key. if index >= 0 { - v.refs.Remove(t.Value().([]stackitem.MapElement)[index].Value) + elems := t.Value().([]stackitem.MapElement) + v.refs.Remove(elems[index].Key) + v.refs.Remove(elems[index].Value) t.Drop(index) } default: @@ -1333,8 +1337,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } t.Clear() case *stackitem.Map: - for i := range t.Value().([]stackitem.MapElement) { - v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value) + elems := t.Value().([]stackitem.MapElement) + for i := range elems { + v.refs.Remove(elems[i].Key) + v.refs.Remove(elems[i].Value) } t.Clear() default: @@ -1576,14 +1582,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro func (v *VM) unloadContext(ctx *Context) { if ctx.local != nil { - ctx.local.Clear(&v.refs) + ctx.local.ClearRefs(&v.refs) } if ctx.arguments != nil { - ctx.arguments.Clear(&v.refs) + ctx.arguments.ClearRefs(&v.refs) } currCtx := v.Context() if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { - ctx.static.Clear(&v.refs) + ctx.static.ClearRefs(&v.refs) } } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index fcd25ef4a..0e88a14e4 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -387,21 +387,21 @@ func TestStackLimit(t *testing.T) { inst opcode.Opcode size int }{ - {opcode.PUSH2, 1}, - {opcode.NEWARRAY, 3}, // array + 2 items + {opcode.PUSH2, 2}, // 1 from INITSSLOT and 1 for integer 2 + {opcode.NEWARRAY, 4}, // array + 2 items {opcode.STSFLD0, 3}, {opcode.LDSFLD0, 4}, {opcode.NEWMAP, 5}, {opcode.DUP, 6}, {opcode.PUSH2, 7}, {opcode.LDSFLD0, 8}, - {opcode.SETITEM, 6}, // -3 items and 1 new element in map - {opcode.DUP, 7}, - {opcode.PUSH2, 8}, - {opcode.LDSFLD0, 9}, - {opcode.SETITEM, 6}, // -3 items and no new elements in map - {opcode.DUP, 7}, - {opcode.PUSH2, 8}, + {opcode.SETITEM, 7}, // -3 items and 1 new kv pair in map + {opcode.DUP, 8}, + {opcode.PUSH2, 9}, + {opcode.LDSFLD0, 10}, + {opcode.SETITEM, 7}, // -3 items and no new elements in map + {opcode.DUP, 8}, + {opcode.PUSH2, 9}, {opcode.REMOVE, 5}, // as we have right after NEWMAP {opcode.DROP, 4}, // DROP map with no elements } @@ -1402,7 +1402,7 @@ func TestSETITEMBigMapBad(t *testing.T) { // 2. SETITEM each of them to a map. // 3. Replace each of them with a scalar value. func TestSETITEMMapStackLimit(t *testing.T) { - size := MaxStackSize/2 - 3 + size := MaxStackSize/2 - 4 m := stackitem.NewMap() m.Add(stackitem.NewBigInteger(big.NewInt(1)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT))) m.Add(stackitem.NewBigInteger(big.NewInt(2)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT))) @@ -2036,8 +2036,8 @@ func TestPACKMAP_UNPACK_PACKMAP_MaxSize(t *testing.T) { } vm.estack.PushVal(len(elements)) runVM(t, vm) - // check reference counter = 1+1+1024 - assert.Equal(t, 1+1+len(elements), int(vm.refs)) + // check reference counter = 1+1+1024*2 + assert.Equal(t, 1+1+len(elements)*2, int(vm.refs)) assert.Equal(t, 2, vm.estack.Len()) m := vm.estack.Peek(0).value.(*stackitem.Map).Value().([]stackitem.MapElement) assert.Equal(t, len(elements), len(m))