forked from TrueCloudLab/neoneo-go
vm: implement reference counter
It is convenient to have all reference-counting logic in a separate struct.
This commit is contained in:
parent
81cbf183af
commit
af2abedd86
5 changed files with 116 additions and 75 deletions
62
pkg/vm/ref_counter.go
Normal file
62
pkg/vm/ref_counter.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package vm
|
||||
|
||||
// refCounter represents reference counter for the VM.
|
||||
type refCounter struct {
|
||||
items map[StackItem]int
|
||||
size int
|
||||
}
|
||||
|
||||
func newRefCounter() *refCounter {
|
||||
return &refCounter{
|
||||
items: make(map[StackItem]int),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds an item to the reference counter.
|
||||
func (r *refCounter) Add(item StackItem) {
|
||||
r.size++
|
||||
|
||||
switch item.(type) {
|
||||
case *ArrayItem, *StructItem, *MapItem:
|
||||
if r.items[item]++; r.items[item] > 1 {
|
||||
return
|
||||
}
|
||||
|
||||
switch t := item.(type) {
|
||||
case *ArrayItem, *StructItem:
|
||||
for _, it := range item.Value().([]StackItem) {
|
||||
r.Add(it)
|
||||
}
|
||||
case *MapItem:
|
||||
for i := range t.value {
|
||||
r.Add(t.value[i].Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove removes item from the reference counter.
|
||||
func (r *refCounter) Remove(item StackItem) {
|
||||
r.size--
|
||||
|
||||
switch item.(type) {
|
||||
case *ArrayItem, *StructItem, *MapItem:
|
||||
if r.items[item] > 1 {
|
||||
r.items[item]--
|
||||
return
|
||||
}
|
||||
|
||||
delete(r.items, item)
|
||||
|
||||
switch t := item.(type) {
|
||||
case *ArrayItem, *StructItem:
|
||||
for _, it := range item.Value().([]StackItem) {
|
||||
r.Remove(it)
|
||||
}
|
||||
case *MapItem:
|
||||
for i := range t.value {
|
||||
r.Remove(t.value[i].Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
32
pkg/vm/ref_counter_test.go
Normal file
32
pkg/vm/ref_counter_test.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package vm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRefCounter_Add(t *testing.T) {
|
||||
r := newRefCounter()
|
||||
|
||||
require.Equal(t, 0, r.size)
|
||||
|
||||
r.Add(NullItem{})
|
||||
require.Equal(t, 1, r.size)
|
||||
|
||||
r.Add(NullItem{})
|
||||
require.Equal(t, 2, r.size) // count scalar items twice
|
||||
|
||||
arr := NewArrayItem([]StackItem{NewByteArrayItem([]byte{1}), NewBoolItem(false)})
|
||||
r.Add(arr)
|
||||
require.Equal(t, 5, r.size) // array + 2 elements
|
||||
|
||||
r.Add(arr)
|
||||
require.Equal(t, 6, r.size) // count only array
|
||||
|
||||
r.Remove(arr)
|
||||
require.Equal(t, 5, r.size)
|
||||
|
||||
r.Remove(arr)
|
||||
require.Equal(t, 2, r.size)
|
||||
}
|
|
@ -125,9 +125,7 @@ type Stack struct {
|
|||
top Element
|
||||
name string
|
||||
len int
|
||||
|
||||
itemCount map[StackItem]int
|
||||
size *int
|
||||
refs *refCounter
|
||||
}
|
||||
|
||||
// NewStack returns a new stack name by the given name.
|
||||
|
@ -138,8 +136,7 @@ func NewStack(n string) *Stack {
|
|||
s.top.next = &s.top
|
||||
s.top.prev = &s.top
|
||||
s.len = 0
|
||||
s.itemCount = make(map[StackItem]int)
|
||||
s.size = new(int)
|
||||
s.refs = newRefCounter()
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -171,58 +168,11 @@ func (s *Stack) insert(e, at *Element) *Element {
|
|||
e.stack = s
|
||||
s.len++
|
||||
|
||||
s.updateSizeAdd(e.value)
|
||||
s.refs.Add(e.value)
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
func (s *Stack) updateSizeAdd(item StackItem) {
|
||||
*s.size++
|
||||
|
||||
switch item.(type) {
|
||||
case *ArrayItem, *StructItem, *MapItem:
|
||||
if s.itemCount[item]++; s.itemCount[item] > 1 {
|
||||
return
|
||||
}
|
||||
|
||||
switch t := item.(type) {
|
||||
case *ArrayItem, *StructItem:
|
||||
for _, it := range item.Value().([]StackItem) {
|
||||
s.updateSizeAdd(it)
|
||||
}
|
||||
case *MapItem:
|
||||
for i := range t.value {
|
||||
s.updateSizeAdd(t.value[i].Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stack) updateSizeRemove(item StackItem) {
|
||||
*s.size--
|
||||
|
||||
switch item.(type) {
|
||||
case *ArrayItem, *StructItem, *MapItem:
|
||||
if s.itemCount[item] > 1 {
|
||||
s.itemCount[item]--
|
||||
return
|
||||
}
|
||||
|
||||
delete(s.itemCount, item)
|
||||
|
||||
switch t := item.(type) {
|
||||
case *ArrayItem, *StructItem:
|
||||
for _, it := range item.Value().([]StackItem) {
|
||||
s.updateSizeRemove(it)
|
||||
}
|
||||
case *MapItem:
|
||||
for i := range t.value {
|
||||
s.updateSizeRemove(t.value[i].Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InsertAt inserts the given item (n) deep on the stack.
|
||||
// Be very careful using it and _always_ check both e and n before invocation
|
||||
// as it will silently do wrong things otherwise.
|
||||
|
@ -300,7 +250,7 @@ func (s *Stack) Remove(e *Element) *Element {
|
|||
e.stack = nil
|
||||
s.len--
|
||||
|
||||
s.updateSizeRemove(e.value)
|
||||
s.refs.Remove(e.value)
|
||||
|
||||
return e
|
||||
}
|
||||
|
|
35
pkg/vm/vm.go
35
pkg/vm/vm.go
|
@ -78,8 +78,7 @@ type VM struct {
|
|||
// Hash to verify in CHECKSIG/CHECKMULTISIG.
|
||||
checkhash []byte
|
||||
|
||||
itemCount map[StackItem]int
|
||||
size int
|
||||
refs *refCounter
|
||||
|
||||
gasConsumed util.Fixed8
|
||||
gasLimit util.Fixed8
|
||||
|
@ -94,9 +93,8 @@ func New() *VM {
|
|||
getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage.
|
||||
state: haltState,
|
||||
istack: NewStack("invocation"),
|
||||
|
||||
itemCount: make(map[StackItem]int),
|
||||
keys: make(map[string]*keys.PublicKey),
|
||||
refs: newRefCounter(),
|
||||
keys: make(map[string]*keys.PublicKey),
|
||||
}
|
||||
|
||||
vm.estack = vm.newItemStack("evaluation")
|
||||
|
@ -108,8 +106,7 @@ func New() *VM {
|
|||
|
||||
func (v *VM) newItemStack(n string) *Stack {
|
||||
s := NewStack(n)
|
||||
s.size = &v.size
|
||||
s.itemCount = v.itemCount
|
||||
s.refs = v.refs
|
||||
|
||||
return s
|
||||
}
|
||||
|
@ -499,7 +496,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
|||
if errRecover := recover(); errRecover != nil {
|
||||
v.state = faultState
|
||||
err = newError(ctx.ip, op, errRecover)
|
||||
} else if v.size > MaxStackSize {
|
||||
} else if v.refs.size > MaxStackSize {
|
||||
v.state = faultState
|
||||
err = newError(ctx.ip, op, "stack is too big")
|
||||
}
|
||||
|
@ -955,7 +952,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
|||
panic("APPEND: not of underlying type Array")
|
||||
}
|
||||
|
||||
v.estack.updateSizeAdd(val)
|
||||
v.refs.Add(val)
|
||||
|
||||
case opcode.PACK:
|
||||
n := int(v.estack.Pop().BigInt().Int64())
|
||||
|
@ -1024,17 +1021,17 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
|||
if index < 0 || index >= len(arr) {
|
||||
panic("SETITEM: invalid index")
|
||||
}
|
||||
v.estack.updateSizeRemove(arr[index])
|
||||
v.refs.Remove(arr[index])
|
||||
arr[index] = item
|
||||
v.estack.updateSizeAdd(arr[index])
|
||||
v.refs.Add(arr[index])
|
||||
case *MapItem:
|
||||
if t.Has(key.value) {
|
||||
v.estack.updateSizeRemove(item)
|
||||
v.refs.Remove(item)
|
||||
} else if len(t.value) >= MaxArraySize {
|
||||
panic("too big map")
|
||||
}
|
||||
t.Add(key.value, item)
|
||||
v.estack.updateSizeAdd(item)
|
||||
v.refs.Add(item)
|
||||
|
||||
default:
|
||||
panic(fmt.Sprintf("SETITEM: invalid item type %s", t))
|
||||
|
@ -1059,7 +1056,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
|||
if k < 0 || k >= len(a) {
|
||||
panic("REMOVE: invalid index")
|
||||
}
|
||||
v.estack.updateSizeRemove(a[k])
|
||||
v.refs.Remove(a[k])
|
||||
a = append(a[:k], a[k+1:]...)
|
||||
t.value = a
|
||||
case *StructItem:
|
||||
|
@ -1068,14 +1065,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
|||
if k < 0 || k >= len(a) {
|
||||
panic("REMOVE: invalid index")
|
||||
}
|
||||
v.estack.updateSizeRemove(a[k])
|
||||
v.refs.Remove(a[k])
|
||||
a = append(a[:k], a[k+1:]...)
|
||||
t.value = a
|
||||
case *MapItem:
|
||||
index := t.Index(key.Item())
|
||||
// NEO 2.0 doesn't error on missing key.
|
||||
if index >= 0 {
|
||||
v.estack.updateSizeRemove(t.value[index].Value)
|
||||
v.refs.Remove(t.value[index].Value)
|
||||
t.Drop(index)
|
||||
}
|
||||
default:
|
||||
|
@ -1087,17 +1084,17 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
|||
switch t := elem.value.(type) {
|
||||
case *ArrayItem:
|
||||
for _, item := range t.value {
|
||||
v.estack.updateSizeRemove(item)
|
||||
v.refs.Remove(item)
|
||||
}
|
||||
t.value = t.value[:0]
|
||||
case *StructItem:
|
||||
for _, item := range t.value {
|
||||
v.estack.updateSizeRemove(item)
|
||||
v.refs.Remove(item)
|
||||
}
|
||||
t.value = t.value[:0]
|
||||
case *MapItem:
|
||||
for i := range t.value {
|
||||
v.estack.updateSizeRemove(t.value[i].Value)
|
||||
v.refs.Remove(t.value[i].Value)
|
||||
}
|
||||
t.value = t.value[:0]
|
||||
default:
|
||||
|
|
|
@ -423,7 +423,7 @@ func TestStackLimit(t *testing.T) {
|
|||
vm := load(makeProgram(prog...))
|
||||
for i := range expected {
|
||||
require.NoError(t, vm.Step())
|
||||
require.Equal(t, expected[i].size, vm.size)
|
||||
require.Equal(t, expected[i].size, vm.refs.size)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1980,7 +1980,7 @@ func testCLEARITEMS(t *testing.T, item StackItem) {
|
|||
v.estack.PushVal(item)
|
||||
runVM(t, v)
|
||||
require.Equal(t, 2, v.estack.Len())
|
||||
require.EqualValues(t, 2, v.size) // empty collection + it's size
|
||||
require.EqualValues(t, 2, v.refs.size) // empty collection + it's size
|
||||
require.EqualValues(t, 0, v.estack.Pop().BigInt().Int64())
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue