Merge pull request #2498 from nspcc-dev/fix-map-key-refcounting

Fix some refcounter issues
This commit is contained in:
Roman Khimov 2022-05-17 09:51:41 +03:00 committed by GitHub
commit 1c0fae2658
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 32 deletions

View file

@ -56,11 +56,11 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int,
return nil
}
if sslot != 0 {
v.Context().static.init(sslot)
v.Context().static.init(sslot, &v.refs)
}
if slotloc != 0 && slotarg != 0 {
v.Context().local.init(slotloc)
v.Context().arguments.init(slotarg)
v.Context().local.init(slotloc, &v.refs)
v.Context().arguments.init(slotarg, &v.refs)
}
for i := range items {
item, ok := items[i].(stackitem.Item)

View file

@ -37,8 +37,10 @@ func (r *refCounter) Add(item stackitem.Item) {
r.Add(it)
}
case *stackitem.Map:
for i := range t.Value().([]stackitem.MapElement) {
r.Add(t.Value().([]stackitem.MapElement)[i].Value)
elems := t.Value().([]stackitem.MapElement)
for i := range elems {
r.Add(elems[i].Key)
r.Add(elems[i].Value)
}
}
}
@ -60,8 +62,10 @@ func (r *refCounter) Remove(item stackitem.Item) {
r.Remove(it)
}
case *stackitem.Map:
for i := range t.Value().([]stackitem.MapElement) {
r.Remove(t.Value().([]stackitem.MapElement)[i].Value)
elems := t.Value().([]stackitem.MapElement)
for i := range elems {
r.Remove(elems[i].Key)
r.Remove(elems[i].Value)
}
}
}

View file

@ -30,6 +30,20 @@ func TestRefCounter_Add(t *testing.T) {
r.Remove(arr)
require.Equal(t, 2, int(*r))
m := stackitem.NewMap()
m.Add(stackitem.NewByteArray([]byte("some")), stackitem.NewBool(false))
r.Add(m)
require.Equal(t, 5, int(*r)) // map + key + value
r.Add(m)
require.Equal(t, 6, int(*r)) // map only
r.Remove(m)
require.Equal(t, 5, int(*r))
r.Remove(m)
require.Equal(t, 2, int(*r))
}
func BenchmarkRefCounter_Add(b *testing.B) {

View file

@ -10,11 +10,12 @@ import (
type slot []stackitem.Item
// init sets static slot size to n. It is intended to be used only by INITSSLOT.
func (s *slot) init(n int) {
func (s *slot) init(n int, rc *refCounter) {
if *s != nil {
panic("already initialized")
}
*s = make([]stackitem.Item, n)
*rc += refCounter(n) // Virtual "Null" elements.
}
// Set sets i-th storage slot.
@ -26,6 +27,8 @@ func (s slot) Set(i int, item stackitem.Item, refs *refCounter) {
s[i] = item
if old != nil {
refs.Remove(old)
} else {
*refs-- // Not really existing, but counted Null element.
}
refs.Add(item)
}
@ -38,8 +41,8 @@ func (s slot) Get(i int) stackitem.Item {
return stackitem.Null{}
}
// Clear removes all slot variables from the reference counter.
func (s slot) Clear(refs *refCounter) {
// ClearRefs removes all slot variables from the reference counter.
func (s slot) ClearRefs(refs *refCounter) {
for _, item := range s {
refs.Remove(item)
}

View file

@ -13,8 +13,9 @@ func TestSlot_Get(t *testing.T) {
var s slot
require.Panics(t, func() { s.Size() })
s.init(3)
s.init(3, rc)
require.Equal(t, 3, s.Size())
require.Equal(t, 3, int(*rc))
// Null is the default
item := s.Get(2)
@ -22,4 +23,5 @@ func TestSlot_Get(t *testing.T) {
s.Set(1, stackitem.NewBigInteger(big.NewInt(42)), rc)
require.Equal(t, stackitem.NewBigInteger(big.NewInt(42)), s.Get(1))
require.Equal(t, 3, int(*rc))
}

View file

@ -616,7 +616,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
if parameter[0] == 0 {
panic("zero argument")
}
ctx.static.init(int(parameter[0]))
ctx.static.init(int(parameter[0]), &v.refs)
case opcode.INITSLOT:
if ctx.local != nil || ctx.arguments != nil {
@ -626,11 +626,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
panic("zero argument")
}
if parameter[0] > 0 {
ctx.local.init(int(parameter[0]))
ctx.local.init(int(parameter[0]), &v.refs)
}
if parameter[1] > 0 {
sz := int(parameter[1])
ctx.arguments.init(sz)
ctx.arguments.init(sz, &v.refs)
for i := 0; i < sz; i++ {
ctx.arguments.Set(i, v.estack.Pop().Item(), &v.refs)
}
@ -1250,6 +1250,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case *stackitem.Map:
if i := t.Index(key.value); i >= 0 {
v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value)
} else {
v.refs.Add(key.value)
}
t.Add(key.value, item)
v.refs.Add(item)
@ -1312,7 +1314,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
index := t.Index(key.Item())
// NEO 2.0 doesn't error on missing key.
if index >= 0 {
v.refs.Remove(t.Value().([]stackitem.MapElement)[index].Value)
elems := t.Value().([]stackitem.MapElement)
v.refs.Remove(elems[index].Key)
v.refs.Remove(elems[index].Value)
t.Drop(index)
}
default:
@ -1333,8 +1337,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
}
t.Clear()
case *stackitem.Map:
for i := range t.Value().([]stackitem.MapElement) {
v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value)
elems := t.Value().([]stackitem.MapElement)
for i := range elems {
v.refs.Remove(elems[i].Key)
v.refs.Remove(elems[i].Value)
}
t.Clear()
default:
@ -1576,14 +1582,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
func (v *VM) unloadContext(ctx *Context) {
if ctx.local != nil {
ctx.local.Clear(&v.refs)
ctx.local.ClearRefs(&v.refs)
}
if ctx.arguments != nil {
ctx.arguments.Clear(&v.refs)
ctx.arguments.ClearRefs(&v.refs)
}
currCtx := v.Context()
if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static {
ctx.static.Clear(&v.refs)
ctx.static.ClearRefs(&v.refs)
}
}

View file

@ -387,21 +387,21 @@ func TestStackLimit(t *testing.T) {
inst opcode.Opcode
size int
}{
{opcode.PUSH2, 1},
{opcode.NEWARRAY, 3}, // array + 2 items
{opcode.PUSH2, 2}, // 1 from INITSSLOT and 1 for integer 2
{opcode.NEWARRAY, 4}, // array + 2 items
{opcode.STSFLD0, 3},
{opcode.LDSFLD0, 4},
{opcode.NEWMAP, 5},
{opcode.DUP, 6},
{opcode.PUSH2, 7},
{opcode.LDSFLD0, 8},
{opcode.SETITEM, 6}, // -3 items and 1 new element in map
{opcode.DUP, 7},
{opcode.PUSH2, 8},
{opcode.LDSFLD0, 9},
{opcode.SETITEM, 6}, // -3 items and no new elements in map
{opcode.DUP, 7},
{opcode.PUSH2, 8},
{opcode.SETITEM, 7}, // -3 items and 1 new kv pair in map
{opcode.DUP, 8},
{opcode.PUSH2, 9},
{opcode.LDSFLD0, 10},
{opcode.SETITEM, 7}, // -3 items and no new elements in map
{opcode.DUP, 8},
{opcode.PUSH2, 9},
{opcode.REMOVE, 5}, // as we have right after NEWMAP
{opcode.DROP, 4}, // DROP map with no elements
}
@ -1402,7 +1402,7 @@ func TestSETITEMBigMapBad(t *testing.T) {
// 2. SETITEM each of them to a map.
// 3. Replace each of them with a scalar value.
func TestSETITEMMapStackLimit(t *testing.T) {
size := MaxStackSize/2 - 3
size := MaxStackSize/2 - 4
m := stackitem.NewMap()
m.Add(stackitem.NewBigInteger(big.NewInt(1)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT)))
m.Add(stackitem.NewBigInteger(big.NewInt(2)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT)))
@ -2036,8 +2036,8 @@ func TestPACKMAP_UNPACK_PACKMAP_MaxSize(t *testing.T) {
}
vm.estack.PushVal(len(elements))
runVM(t, vm)
// check reference counter = 1+1+1024
assert.Equal(t, 1+1+len(elements), int(vm.refs))
// check reference counter = 1+1+1024*2
assert.Equal(t, 1+1+len(elements)*2, int(vm.refs))
assert.Equal(t, 2, vm.estack.Len())
m := vm.estack.Peek(0).value.(*stackitem.Map).Value().([]stackitem.MapElement)
assert.Equal(t, len(elements), len(m))