vm: implement reference counter

It is convenient to have all reference-counting logic
in a separate struct.
This commit is contained in:
Evgenii Stratonikov 2020-05-12 16:05:10 +03:00
parent 81cbf183af
commit af2abedd86
5 changed files with 116 additions and 75 deletions

62
pkg/vm/ref_counter.go Normal file
View file

@ -0,0 +1,62 @@
package vm
// refCounter represents reference counter for the VM.
type refCounter struct {
items map[StackItem]int
size int
}
func newRefCounter() *refCounter {
return &refCounter{
items: make(map[StackItem]int),
}
}
// Add adds an item to the reference counter.
func (r *refCounter) Add(item StackItem) {
r.size++
switch item.(type) {
case *ArrayItem, *StructItem, *MapItem:
if r.items[item]++; r.items[item] > 1 {
return
}
switch t := item.(type) {
case *ArrayItem, *StructItem:
for _, it := range item.Value().([]StackItem) {
r.Add(it)
}
case *MapItem:
for i := range t.value {
r.Add(t.value[i].Value)
}
}
}
}
// Remove removes item from the reference counter.
func (r *refCounter) Remove(item StackItem) {
r.size--
switch item.(type) {
case *ArrayItem, *StructItem, *MapItem:
if r.items[item] > 1 {
r.items[item]--
return
}
delete(r.items, item)
switch t := item.(type) {
case *ArrayItem, *StructItem:
for _, it := range item.Value().([]StackItem) {
r.Remove(it)
}
case *MapItem:
for i := range t.value {
r.Remove(t.value[i].Value)
}
}
}
}

View file

@ -0,0 +1,32 @@
package vm
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestRefCounter_Add(t *testing.T) {
r := newRefCounter()
require.Equal(t, 0, r.size)
r.Add(NullItem{})
require.Equal(t, 1, r.size)
r.Add(NullItem{})
require.Equal(t, 2, r.size) // count scalar items twice
arr := NewArrayItem([]StackItem{NewByteArrayItem([]byte{1}), NewBoolItem(false)})
r.Add(arr)
require.Equal(t, 5, r.size) // array + 2 elements
r.Add(arr)
require.Equal(t, 6, r.size) // count only array
r.Remove(arr)
require.Equal(t, 5, r.size)
r.Remove(arr)
require.Equal(t, 2, r.size)
}

View file

@ -125,9 +125,7 @@ type Stack struct {
top Element
name string
len int
itemCount map[StackItem]int
size *int
refs *refCounter
}
// NewStack returns a new stack name by the given name.
@ -138,8 +136,7 @@ func NewStack(n string) *Stack {
s.top.next = &s.top
s.top.prev = &s.top
s.len = 0
s.itemCount = make(map[StackItem]int)
s.size = new(int)
s.refs = newRefCounter()
return s
}
@ -171,58 +168,11 @@ func (s *Stack) insert(e, at *Element) *Element {
e.stack = s
s.len++
s.updateSizeAdd(e.value)
s.refs.Add(e.value)
return e
}
func (s *Stack) updateSizeAdd(item StackItem) {
*s.size++
switch item.(type) {
case *ArrayItem, *StructItem, *MapItem:
if s.itemCount[item]++; s.itemCount[item] > 1 {
return
}
switch t := item.(type) {
case *ArrayItem, *StructItem:
for _, it := range item.Value().([]StackItem) {
s.updateSizeAdd(it)
}
case *MapItem:
for i := range t.value {
s.updateSizeAdd(t.value[i].Value)
}
}
}
}
func (s *Stack) updateSizeRemove(item StackItem) {
*s.size--
switch item.(type) {
case *ArrayItem, *StructItem, *MapItem:
if s.itemCount[item] > 1 {
s.itemCount[item]--
return
}
delete(s.itemCount, item)
switch t := item.(type) {
case *ArrayItem, *StructItem:
for _, it := range item.Value().([]StackItem) {
s.updateSizeRemove(it)
}
case *MapItem:
for i := range t.value {
s.updateSizeRemove(t.value[i].Value)
}
}
}
}
// InsertAt inserts the given item (n) deep on the stack.
// Be very careful using it and _always_ check both e and n before invocation
// as it will silently do wrong things otherwise.
@ -300,7 +250,7 @@ func (s *Stack) Remove(e *Element) *Element {
e.stack = nil
s.len--
s.updateSizeRemove(e.value)
s.refs.Remove(e.value)
return e
}

View file

@ -78,8 +78,7 @@ type VM struct {
// Hash to verify in CHECKSIG/CHECKMULTISIG.
checkhash []byte
itemCount map[StackItem]int
size int
refs *refCounter
gasConsumed util.Fixed8
gasLimit util.Fixed8
@ -94,8 +93,7 @@ func New() *VM {
getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage.
state: haltState,
istack: NewStack("invocation"),
itemCount: make(map[StackItem]int),
refs: newRefCounter(),
keys: make(map[string]*keys.PublicKey),
}
@ -108,8 +106,7 @@ func New() *VM {
func (v *VM) newItemStack(n string) *Stack {
s := NewStack(n)
s.size = &v.size
s.itemCount = v.itemCount
s.refs = v.refs
return s
}
@ -499,7 +496,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
if errRecover := recover(); errRecover != nil {
v.state = faultState
err = newError(ctx.ip, op, errRecover)
} else if v.size > MaxStackSize {
} else if v.refs.size > MaxStackSize {
v.state = faultState
err = newError(ctx.ip, op, "stack is too big")
}
@ -955,7 +952,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
panic("APPEND: not of underlying type Array")
}
v.estack.updateSizeAdd(val)
v.refs.Add(val)
case opcode.PACK:
n := int(v.estack.Pop().BigInt().Int64())
@ -1024,17 +1021,17 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
if index < 0 || index >= len(arr) {
panic("SETITEM: invalid index")
}
v.estack.updateSizeRemove(arr[index])
v.refs.Remove(arr[index])
arr[index] = item
v.estack.updateSizeAdd(arr[index])
v.refs.Add(arr[index])
case *MapItem:
if t.Has(key.value) {
v.estack.updateSizeRemove(item)
v.refs.Remove(item)
} else if len(t.value) >= MaxArraySize {
panic("too big map")
}
t.Add(key.value, item)
v.estack.updateSizeAdd(item)
v.refs.Add(item)
default:
panic(fmt.Sprintf("SETITEM: invalid item type %s", t))
@ -1059,7 +1056,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
if k < 0 || k >= len(a) {
panic("REMOVE: invalid index")
}
v.estack.updateSizeRemove(a[k])
v.refs.Remove(a[k])
a = append(a[:k], a[k+1:]...)
t.value = a
case *StructItem:
@ -1068,14 +1065,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
if k < 0 || k >= len(a) {
panic("REMOVE: invalid index")
}
v.estack.updateSizeRemove(a[k])
v.refs.Remove(a[k])
a = append(a[:k], a[k+1:]...)
t.value = a
case *MapItem:
index := t.Index(key.Item())
// NEO 2.0 doesn't error on missing key.
if index >= 0 {
v.estack.updateSizeRemove(t.value[index].Value)
v.refs.Remove(t.value[index].Value)
t.Drop(index)
}
default:
@ -1087,17 +1084,17 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
switch t := elem.value.(type) {
case *ArrayItem:
for _, item := range t.value {
v.estack.updateSizeRemove(item)
v.refs.Remove(item)
}
t.value = t.value[:0]
case *StructItem:
for _, item := range t.value {
v.estack.updateSizeRemove(item)
v.refs.Remove(item)
}
t.value = t.value[:0]
case *MapItem:
for i := range t.value {
v.estack.updateSizeRemove(t.value[i].Value)
v.refs.Remove(t.value[i].Value)
}
t.value = t.value[:0]
default:

View file

@ -423,7 +423,7 @@ func TestStackLimit(t *testing.T) {
vm := load(makeProgram(prog...))
for i := range expected {
require.NoError(t, vm.Step())
require.Equal(t, expected[i].size, vm.size)
require.Equal(t, expected[i].size, vm.refs.size)
}
}
@ -1980,7 +1980,7 @@ func testCLEARITEMS(t *testing.T, item StackItem) {
v.estack.PushVal(item)
runVM(t, v)
require.Equal(t, 2, v.estack.Len())
require.EqualValues(t, 2, v.size) // empty collection + it's size
require.EqualValues(t, 2, v.refs.size) // empty collection + it's size
require.EqualValues(t, 0, v.estack.Pop().BigInt().Int64())
}