vm: count map key in the refcounter as well

Thanks @ixje for spotting this.
This commit is contained in:
Roman Khimov 2022-05-16 16:07:25 +03:00
parent 81fa751000
commit 18d627e7f7
4 changed files with 41 additions and 17 deletions

View file

@ -37,8 +37,10 @@ func (r *refCounter) Add(item stackitem.Item) {
r.Add(it) r.Add(it)
} }
case *stackitem.Map: case *stackitem.Map:
for i := range t.Value().([]stackitem.MapElement) { elems := t.Value().([]stackitem.MapElement)
r.Add(t.Value().([]stackitem.MapElement)[i].Value) for i := range elems {
r.Add(elems[i].Key)
r.Add(elems[i].Value)
} }
} }
} }
@ -60,8 +62,10 @@ func (r *refCounter) Remove(item stackitem.Item) {
r.Remove(it) r.Remove(it)
} }
case *stackitem.Map: case *stackitem.Map:
for i := range t.Value().([]stackitem.MapElement) { elems := t.Value().([]stackitem.MapElement)
r.Remove(t.Value().([]stackitem.MapElement)[i].Value) for i := range elems {
r.Remove(elems[i].Key)
r.Remove(elems[i].Value)
} }
} }
} }

View file

@ -30,6 +30,20 @@ func TestRefCounter_Add(t *testing.T) {
r.Remove(arr) r.Remove(arr)
require.Equal(t, 2, int(*r)) require.Equal(t, 2, int(*r))
m := stackitem.NewMap()
m.Add(stackitem.NewByteArray([]byte("some")), stackitem.NewBool(false))
r.Add(m)
require.Equal(t, 5, int(*r)) // map + key + value
r.Add(m)
require.Equal(t, 6, int(*r)) // map only
r.Remove(m)
require.Equal(t, 5, int(*r))
r.Remove(m)
require.Equal(t, 2, int(*r))
} }
func BenchmarkRefCounter_Add(b *testing.B) { func BenchmarkRefCounter_Add(b *testing.B) {

View file

@ -1250,6 +1250,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case *stackitem.Map: case *stackitem.Map:
if i := t.Index(key.value); i >= 0 { if i := t.Index(key.value); i >= 0 {
v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value) v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value)
} else {
v.refs.Add(key.value)
} }
t.Add(key.value, item) t.Add(key.value, item)
v.refs.Add(item) v.refs.Add(item)
@ -1312,7 +1314,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
index := t.Index(key.Item()) index := t.Index(key.Item())
// NEO 2.0 doesn't error on missing key. // NEO 2.0 doesn't error on missing key.
if index >= 0 { if index >= 0 {
v.refs.Remove(t.Value().([]stackitem.MapElement)[index].Value) elems := t.Value().([]stackitem.MapElement)
v.refs.Remove(elems[index].Key)
v.refs.Remove(elems[index].Value)
t.Drop(index) t.Drop(index)
} }
default: default:
@ -1333,8 +1337,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
} }
t.Clear() t.Clear()
case *stackitem.Map: case *stackitem.Map:
for i := range t.Value().([]stackitem.MapElement) { elems := t.Value().([]stackitem.MapElement)
v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value) for i := range elems {
v.refs.Remove(elems[i].Key)
v.refs.Remove(elems[i].Value)
} }
t.Clear() t.Clear()
default: default:

View file

@ -395,13 +395,13 @@ func TestStackLimit(t *testing.T) {
{opcode.DUP, 6}, {opcode.DUP, 6},
{opcode.PUSH2, 7}, {opcode.PUSH2, 7},
{opcode.LDSFLD0, 8}, {opcode.LDSFLD0, 8},
{opcode.SETITEM, 6}, // -3 items and 1 new element in map {opcode.SETITEM, 7}, // -3 items and 1 new kv pair in map
{opcode.DUP, 7}, {opcode.DUP, 8},
{opcode.PUSH2, 8}, {opcode.PUSH2, 9},
{opcode.LDSFLD0, 9}, {opcode.LDSFLD0, 10},
{opcode.SETITEM, 6}, // -3 items and no new elements in map {opcode.SETITEM, 7}, // -3 items and no new elements in map
{opcode.DUP, 7}, {opcode.DUP, 8},
{opcode.PUSH2, 8}, {opcode.PUSH2, 9},
{opcode.REMOVE, 5}, // as we have right after NEWMAP {opcode.REMOVE, 5}, // as we have right after NEWMAP
{opcode.DROP, 4}, // DROP map with no elements {opcode.DROP, 4}, // DROP map with no elements
} }
@ -1402,7 +1402,7 @@ func TestSETITEMBigMapBad(t *testing.T) {
// 2. SETITEM each of them to a map. // 2. SETITEM each of them to a map.
// 3. Replace each of them with a scalar value. // 3. Replace each of them with a scalar value.
func TestSETITEMMapStackLimit(t *testing.T) { func TestSETITEMMapStackLimit(t *testing.T) {
size := MaxStackSize/2 - 3 size := MaxStackSize/2 - 4
m := stackitem.NewMap() m := stackitem.NewMap()
m.Add(stackitem.NewBigInteger(big.NewInt(1)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT))) m.Add(stackitem.NewBigInteger(big.NewInt(1)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT)))
m.Add(stackitem.NewBigInteger(big.NewInt(2)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT))) m.Add(stackitem.NewBigInteger(big.NewInt(2)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT)))
@ -2036,8 +2036,8 @@ func TestPACKMAP_UNPACK_PACKMAP_MaxSize(t *testing.T) {
} }
vm.estack.PushVal(len(elements)) vm.estack.PushVal(len(elements))
runVM(t, vm) runVM(t, vm)
// check reference counter = 1+1+1024 // check reference counter = 1+1+1024*2
assert.Equal(t, 1+1+len(elements), int(vm.refs)) assert.Equal(t, 1+1+len(elements)*2, int(vm.refs))
assert.Equal(t, 2, vm.estack.Len()) assert.Equal(t, 2, vm.estack.Len())
m := vm.estack.Peek(0).value.(*stackitem.Map).Value().([]stackitem.MapElement) m := vm.estack.Peek(0).value.(*stackitem.Map).Value().([]stackitem.MapElement)
assert.Equal(t, len(elements), len(m)) assert.Equal(t, len(elements), len(m))