Merge pull request #2498 from nspcc-dev/fix-map-key-refcounting

Fix some refcounter issues
This commit is contained in:
Roman Khimov 2022-05-17 09:51:41 +03:00 committed by GitHub
commit 1c0fae2658
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 61 additions and 32 deletions

View file

@ -56,11 +56,11 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int,
return nil return nil
} }
if sslot != 0 { if sslot != 0 {
v.Context().static.init(sslot) v.Context().static.init(sslot, &v.refs)
} }
if slotloc != 0 && slotarg != 0 { if slotloc != 0 && slotarg != 0 {
v.Context().local.init(slotloc) v.Context().local.init(slotloc, &v.refs)
v.Context().arguments.init(slotarg) v.Context().arguments.init(slotarg, &v.refs)
} }
for i := range items { for i := range items {
item, ok := items[i].(stackitem.Item) item, ok := items[i].(stackitem.Item)

View file

@ -37,8 +37,10 @@ func (r *refCounter) Add(item stackitem.Item) {
r.Add(it) r.Add(it)
} }
case *stackitem.Map: case *stackitem.Map:
for i := range t.Value().([]stackitem.MapElement) { elems := t.Value().([]stackitem.MapElement)
r.Add(t.Value().([]stackitem.MapElement)[i].Value) for i := range elems {
r.Add(elems[i].Key)
r.Add(elems[i].Value)
} }
} }
} }
@ -60,8 +62,10 @@ func (r *refCounter) Remove(item stackitem.Item) {
r.Remove(it) r.Remove(it)
} }
case *stackitem.Map: case *stackitem.Map:
for i := range t.Value().([]stackitem.MapElement) { elems := t.Value().([]stackitem.MapElement)
r.Remove(t.Value().([]stackitem.MapElement)[i].Value) for i := range elems {
r.Remove(elems[i].Key)
r.Remove(elems[i].Value)
} }
} }
} }

View file

@ -30,6 +30,20 @@ func TestRefCounter_Add(t *testing.T) {
r.Remove(arr) r.Remove(arr)
require.Equal(t, 2, int(*r)) require.Equal(t, 2, int(*r))
m := stackitem.NewMap()
m.Add(stackitem.NewByteArray([]byte("some")), stackitem.NewBool(false))
r.Add(m)
require.Equal(t, 5, int(*r)) // map + key + value
r.Add(m)
require.Equal(t, 6, int(*r)) // map only
r.Remove(m)
require.Equal(t, 5, int(*r))
r.Remove(m)
require.Equal(t, 2, int(*r))
} }
func BenchmarkRefCounter_Add(b *testing.B) { func BenchmarkRefCounter_Add(b *testing.B) {

View file

@ -10,11 +10,12 @@ import (
type slot []stackitem.Item type slot []stackitem.Item
// init sets static slot size to n. It is intended to be used only by INITSSLOT. // init sets static slot size to n. It is intended to be used only by INITSSLOT.
func (s *slot) init(n int) { func (s *slot) init(n int, rc *refCounter) {
if *s != nil { if *s != nil {
panic("already initialized") panic("already initialized")
} }
*s = make([]stackitem.Item, n) *s = make([]stackitem.Item, n)
*rc += refCounter(n) // Virtual "Null" elements.
} }
// Set sets i-th storage slot. // Set sets i-th storage slot.
@ -26,6 +27,8 @@ func (s slot) Set(i int, item stackitem.Item, refs *refCounter) {
s[i] = item s[i] = item
if old != nil { if old != nil {
refs.Remove(old) refs.Remove(old)
} else {
*refs-- // Not really existing, but counted Null element.
} }
refs.Add(item) refs.Add(item)
} }
@ -38,8 +41,8 @@ func (s slot) Get(i int) stackitem.Item {
return stackitem.Null{} return stackitem.Null{}
} }
// Clear removes all slot variables from the reference counter. // ClearRefs removes all slot variables from the reference counter.
func (s slot) Clear(refs *refCounter) { func (s slot) ClearRefs(refs *refCounter) {
for _, item := range s { for _, item := range s {
refs.Remove(item) refs.Remove(item)
} }

View file

@ -13,8 +13,9 @@ func TestSlot_Get(t *testing.T) {
var s slot var s slot
require.Panics(t, func() { s.Size() }) require.Panics(t, func() { s.Size() })
s.init(3) s.init(3, rc)
require.Equal(t, 3, s.Size()) require.Equal(t, 3, s.Size())
require.Equal(t, 3, int(*rc))
// Null is the default // Null is the default
item := s.Get(2) item := s.Get(2)
@ -22,4 +23,5 @@ func TestSlot_Get(t *testing.T) {
s.Set(1, stackitem.NewBigInteger(big.NewInt(42)), rc) s.Set(1, stackitem.NewBigInteger(big.NewInt(42)), rc)
require.Equal(t, stackitem.NewBigInteger(big.NewInt(42)), s.Get(1)) require.Equal(t, stackitem.NewBigInteger(big.NewInt(42)), s.Get(1))
require.Equal(t, 3, int(*rc))
} }

View file

@ -616,7 +616,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
if parameter[0] == 0 { if parameter[0] == 0 {
panic("zero argument") panic("zero argument")
} }
ctx.static.init(int(parameter[0])) ctx.static.init(int(parameter[0]), &v.refs)
case opcode.INITSLOT: case opcode.INITSLOT:
if ctx.local != nil || ctx.arguments != nil { if ctx.local != nil || ctx.arguments != nil {
@ -626,11 +626,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
panic("zero argument") panic("zero argument")
} }
if parameter[0] > 0 { if parameter[0] > 0 {
ctx.local.init(int(parameter[0])) ctx.local.init(int(parameter[0]), &v.refs)
} }
if parameter[1] > 0 { if parameter[1] > 0 {
sz := int(parameter[1]) sz := int(parameter[1])
ctx.arguments.init(sz) ctx.arguments.init(sz, &v.refs)
for i := 0; i < sz; i++ { for i := 0; i < sz; i++ {
ctx.arguments.Set(i, v.estack.Pop().Item(), &v.refs) ctx.arguments.Set(i, v.estack.Pop().Item(), &v.refs)
} }
@ -1250,6 +1250,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
case *stackitem.Map: case *stackitem.Map:
if i := t.Index(key.value); i >= 0 { if i := t.Index(key.value); i >= 0 {
v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value) v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value)
} else {
v.refs.Add(key.value)
} }
t.Add(key.value, item) t.Add(key.value, item)
v.refs.Add(item) v.refs.Add(item)
@ -1312,7 +1314,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
index := t.Index(key.Item()) index := t.Index(key.Item())
// NEO 2.0 doesn't error on missing key. // NEO 2.0 doesn't error on missing key.
if index >= 0 { if index >= 0 {
v.refs.Remove(t.Value().([]stackitem.MapElement)[index].Value) elems := t.Value().([]stackitem.MapElement)
v.refs.Remove(elems[index].Key)
v.refs.Remove(elems[index].Value)
t.Drop(index) t.Drop(index)
} }
default: default:
@ -1333,8 +1337,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
} }
t.Clear() t.Clear()
case *stackitem.Map: case *stackitem.Map:
for i := range t.Value().([]stackitem.MapElement) { elems := t.Value().([]stackitem.MapElement)
v.refs.Remove(t.Value().([]stackitem.MapElement)[i].Value) for i := range elems {
v.refs.Remove(elems[i].Key)
v.refs.Remove(elems[i].Value)
} }
t.Clear() t.Clear()
default: default:
@ -1576,14 +1582,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
func (v *VM) unloadContext(ctx *Context) { func (v *VM) unloadContext(ctx *Context) {
if ctx.local != nil { if ctx.local != nil {
ctx.local.Clear(&v.refs) ctx.local.ClearRefs(&v.refs)
} }
if ctx.arguments != nil { if ctx.arguments != nil {
ctx.arguments.Clear(&v.refs) ctx.arguments.ClearRefs(&v.refs)
} }
currCtx := v.Context() currCtx := v.Context()
if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static {
ctx.static.Clear(&v.refs) ctx.static.ClearRefs(&v.refs)
} }
} }

View file

@ -387,21 +387,21 @@ func TestStackLimit(t *testing.T) {
inst opcode.Opcode inst opcode.Opcode
size int size int
}{ }{
{opcode.PUSH2, 1}, {opcode.PUSH2, 2}, // 1 from INITSSLOT and 1 for integer 2
{opcode.NEWARRAY, 3}, // array + 2 items {opcode.NEWARRAY, 4}, // array + 2 items
{opcode.STSFLD0, 3}, {opcode.STSFLD0, 3},
{opcode.LDSFLD0, 4}, {opcode.LDSFLD0, 4},
{opcode.NEWMAP, 5}, {opcode.NEWMAP, 5},
{opcode.DUP, 6}, {opcode.DUP, 6},
{opcode.PUSH2, 7}, {opcode.PUSH2, 7},
{opcode.LDSFLD0, 8}, {opcode.LDSFLD0, 8},
{opcode.SETITEM, 6}, // -3 items and 1 new element in map {opcode.SETITEM, 7}, // -3 items and 1 new kv pair in map
{opcode.DUP, 7}, {opcode.DUP, 8},
{opcode.PUSH2, 8}, {opcode.PUSH2, 9},
{opcode.LDSFLD0, 9}, {opcode.LDSFLD0, 10},
{opcode.SETITEM, 6}, // -3 items and no new elements in map {opcode.SETITEM, 7}, // -3 items and no new elements in map
{opcode.DUP, 7}, {opcode.DUP, 8},
{opcode.PUSH2, 8}, {opcode.PUSH2, 9},
{opcode.REMOVE, 5}, // as we have right after NEWMAP {opcode.REMOVE, 5}, // as we have right after NEWMAP
{opcode.DROP, 4}, // DROP map with no elements {opcode.DROP, 4}, // DROP map with no elements
} }
@ -1402,7 +1402,7 @@ func TestSETITEMBigMapBad(t *testing.T) {
// 2. SETITEM each of them to a map. // 2. SETITEM each of them to a map.
// 3. Replace each of them with a scalar value. // 3. Replace each of them with a scalar value.
func TestSETITEMMapStackLimit(t *testing.T) { func TestSETITEMMapStackLimit(t *testing.T) {
size := MaxStackSize/2 - 3 size := MaxStackSize/2 - 4
m := stackitem.NewMap() m := stackitem.NewMap()
m.Add(stackitem.NewBigInteger(big.NewInt(1)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT))) m.Add(stackitem.NewBigInteger(big.NewInt(1)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT)))
m.Add(stackitem.NewBigInteger(big.NewInt(2)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT))) m.Add(stackitem.NewBigInteger(big.NewInt(2)), stackitem.NewArray(makeArrayOfType(size, stackitem.BooleanT)))
@ -2036,8 +2036,8 @@ func TestPACKMAP_UNPACK_PACKMAP_MaxSize(t *testing.T) {
} }
vm.estack.PushVal(len(elements)) vm.estack.PushVal(len(elements))
runVM(t, vm) runVM(t, vm)
// check reference counter = 1+1+1024 // check reference counter = 1+1+1024*2
assert.Equal(t, 1+1+len(elements), int(vm.refs)) assert.Equal(t, 1+1+len(elements)*2, int(vm.refs))
assert.Equal(t, 2, vm.estack.Len()) assert.Equal(t, 2, vm.estack.Len())
m := vm.estack.Peek(0).value.(*stackitem.Map).Value().([]stackitem.MapElement) m := vm.estack.Peek(0).value.(*stackitem.Map).Value().([]stackitem.MapElement)
assert.Equal(t, len(elements), len(m)) assert.Equal(t, len(elements), len(m))