forked from TrueCloudLab/neoneo-go
vm: implement reference counter
It is convenient to have all reference-counting logic in a separate struct.
This commit is contained in:
parent
81cbf183af
commit
af2abedd86
5 changed files with 116 additions and 75 deletions
62
pkg/vm/ref_counter.go
Normal file
62
pkg/vm/ref_counter.go
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
package vm
|
||||||
|
|
||||||
|
// refCounter represents reference counter for the VM.
|
||||||
|
type refCounter struct {
|
||||||
|
items map[StackItem]int
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRefCounter() *refCounter {
|
||||||
|
return &refCounter{
|
||||||
|
items: make(map[StackItem]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds an item to the reference counter.
|
||||||
|
func (r *refCounter) Add(item StackItem) {
|
||||||
|
r.size++
|
||||||
|
|
||||||
|
switch item.(type) {
|
||||||
|
case *ArrayItem, *StructItem, *MapItem:
|
||||||
|
if r.items[item]++; r.items[item] > 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t := item.(type) {
|
||||||
|
case *ArrayItem, *StructItem:
|
||||||
|
for _, it := range item.Value().([]StackItem) {
|
||||||
|
r.Add(it)
|
||||||
|
}
|
||||||
|
case *MapItem:
|
||||||
|
for i := range t.value {
|
||||||
|
r.Add(t.value[i].Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes item from the reference counter.
|
||||||
|
func (r *refCounter) Remove(item StackItem) {
|
||||||
|
r.size--
|
||||||
|
|
||||||
|
switch item.(type) {
|
||||||
|
case *ArrayItem, *StructItem, *MapItem:
|
||||||
|
if r.items[item] > 1 {
|
||||||
|
r.items[item]--
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(r.items, item)
|
||||||
|
|
||||||
|
switch t := item.(type) {
|
||||||
|
case *ArrayItem, *StructItem:
|
||||||
|
for _, it := range item.Value().([]StackItem) {
|
||||||
|
r.Remove(it)
|
||||||
|
}
|
||||||
|
case *MapItem:
|
||||||
|
for i := range t.value {
|
||||||
|
r.Remove(t.value[i].Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
32
pkg/vm/ref_counter_test.go
Normal file
32
pkg/vm/ref_counter_test.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package vm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRefCounter_Add(t *testing.T) {
|
||||||
|
r := newRefCounter()
|
||||||
|
|
||||||
|
require.Equal(t, 0, r.size)
|
||||||
|
|
||||||
|
r.Add(NullItem{})
|
||||||
|
require.Equal(t, 1, r.size)
|
||||||
|
|
||||||
|
r.Add(NullItem{})
|
||||||
|
require.Equal(t, 2, r.size) // count scalar items twice
|
||||||
|
|
||||||
|
arr := NewArrayItem([]StackItem{NewByteArrayItem([]byte{1}), NewBoolItem(false)})
|
||||||
|
r.Add(arr)
|
||||||
|
require.Equal(t, 5, r.size) // array + 2 elements
|
||||||
|
|
||||||
|
r.Add(arr)
|
||||||
|
require.Equal(t, 6, r.size) // count only array
|
||||||
|
|
||||||
|
r.Remove(arr)
|
||||||
|
require.Equal(t, 5, r.size)
|
||||||
|
|
||||||
|
r.Remove(arr)
|
||||||
|
require.Equal(t, 2, r.size)
|
||||||
|
}
|
|
@ -125,9 +125,7 @@ type Stack struct {
|
||||||
top Element
|
top Element
|
||||||
name string
|
name string
|
||||||
len int
|
len int
|
||||||
|
refs *refCounter
|
||||||
itemCount map[StackItem]int
|
|
||||||
size *int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStack returns a new stack name by the given name.
|
// NewStack returns a new stack name by the given name.
|
||||||
|
@ -138,8 +136,7 @@ func NewStack(n string) *Stack {
|
||||||
s.top.next = &s.top
|
s.top.next = &s.top
|
||||||
s.top.prev = &s.top
|
s.top.prev = &s.top
|
||||||
s.len = 0
|
s.len = 0
|
||||||
s.itemCount = make(map[StackItem]int)
|
s.refs = newRefCounter()
|
||||||
s.size = new(int)
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,58 +168,11 @@ func (s *Stack) insert(e, at *Element) *Element {
|
||||||
e.stack = s
|
e.stack = s
|
||||||
s.len++
|
s.len++
|
||||||
|
|
||||||
s.updateSizeAdd(e.value)
|
s.refs.Add(e.value)
|
||||||
|
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stack) updateSizeAdd(item StackItem) {
|
|
||||||
*s.size++
|
|
||||||
|
|
||||||
switch item.(type) {
|
|
||||||
case *ArrayItem, *StructItem, *MapItem:
|
|
||||||
if s.itemCount[item]++; s.itemCount[item] > 1 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch t := item.(type) {
|
|
||||||
case *ArrayItem, *StructItem:
|
|
||||||
for _, it := range item.Value().([]StackItem) {
|
|
||||||
s.updateSizeAdd(it)
|
|
||||||
}
|
|
||||||
case *MapItem:
|
|
||||||
for i := range t.value {
|
|
||||||
s.updateSizeAdd(t.value[i].Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stack) updateSizeRemove(item StackItem) {
|
|
||||||
*s.size--
|
|
||||||
|
|
||||||
switch item.(type) {
|
|
||||||
case *ArrayItem, *StructItem, *MapItem:
|
|
||||||
if s.itemCount[item] > 1 {
|
|
||||||
s.itemCount[item]--
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(s.itemCount, item)
|
|
||||||
|
|
||||||
switch t := item.(type) {
|
|
||||||
case *ArrayItem, *StructItem:
|
|
||||||
for _, it := range item.Value().([]StackItem) {
|
|
||||||
s.updateSizeRemove(it)
|
|
||||||
}
|
|
||||||
case *MapItem:
|
|
||||||
for i := range t.value {
|
|
||||||
s.updateSizeRemove(t.value[i].Value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertAt inserts the given item (n) deep on the stack.
|
// InsertAt inserts the given item (n) deep on the stack.
|
||||||
// Be very careful using it and _always_ check both e and n before invocation
|
// Be very careful using it and _always_ check both e and n before invocation
|
||||||
// as it will silently do wrong things otherwise.
|
// as it will silently do wrong things otherwise.
|
||||||
|
@ -300,7 +250,7 @@ func (s *Stack) Remove(e *Element) *Element {
|
||||||
e.stack = nil
|
e.stack = nil
|
||||||
s.len--
|
s.len--
|
||||||
|
|
||||||
s.updateSizeRemove(e.value)
|
s.refs.Remove(e.value)
|
||||||
|
|
||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
33
pkg/vm/vm.go
33
pkg/vm/vm.go
|
@ -78,8 +78,7 @@ type VM struct {
|
||||||
// Hash to verify in CHECKSIG/CHECKMULTISIG.
|
// Hash to verify in CHECKSIG/CHECKMULTISIG.
|
||||||
checkhash []byte
|
checkhash []byte
|
||||||
|
|
||||||
itemCount map[StackItem]int
|
refs *refCounter
|
||||||
size int
|
|
||||||
|
|
||||||
gasConsumed util.Fixed8
|
gasConsumed util.Fixed8
|
||||||
gasLimit util.Fixed8
|
gasLimit util.Fixed8
|
||||||
|
@ -94,8 +93,7 @@ func New() *VM {
|
||||||
getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage.
|
getInterop: make([]InteropGetterFunc, 0, 3), // 3 functions is typical for our default usage.
|
||||||
state: haltState,
|
state: haltState,
|
||||||
istack: NewStack("invocation"),
|
istack: NewStack("invocation"),
|
||||||
|
refs: newRefCounter(),
|
||||||
itemCount: make(map[StackItem]int),
|
|
||||||
keys: make(map[string]*keys.PublicKey),
|
keys: make(map[string]*keys.PublicKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,8 +106,7 @@ func New() *VM {
|
||||||
|
|
||||||
func (v *VM) newItemStack(n string) *Stack {
|
func (v *VM) newItemStack(n string) *Stack {
|
||||||
s := NewStack(n)
|
s := NewStack(n)
|
||||||
s.size = &v.size
|
s.refs = v.refs
|
||||||
s.itemCount = v.itemCount
|
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -499,7 +496,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
||||||
if errRecover := recover(); errRecover != nil {
|
if errRecover := recover(); errRecover != nil {
|
||||||
v.state = faultState
|
v.state = faultState
|
||||||
err = newError(ctx.ip, op, errRecover)
|
err = newError(ctx.ip, op, errRecover)
|
||||||
} else if v.size > MaxStackSize {
|
} else if v.refs.size > MaxStackSize {
|
||||||
v.state = faultState
|
v.state = faultState
|
||||||
err = newError(ctx.ip, op, "stack is too big")
|
err = newError(ctx.ip, op, "stack is too big")
|
||||||
}
|
}
|
||||||
|
@ -955,7 +952,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
||||||
panic("APPEND: not of underlying type Array")
|
panic("APPEND: not of underlying type Array")
|
||||||
}
|
}
|
||||||
|
|
||||||
v.estack.updateSizeAdd(val)
|
v.refs.Add(val)
|
||||||
|
|
||||||
case opcode.PACK:
|
case opcode.PACK:
|
||||||
n := int(v.estack.Pop().BigInt().Int64())
|
n := int(v.estack.Pop().BigInt().Int64())
|
||||||
|
@ -1024,17 +1021,17 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
||||||
if index < 0 || index >= len(arr) {
|
if index < 0 || index >= len(arr) {
|
||||||
panic("SETITEM: invalid index")
|
panic("SETITEM: invalid index")
|
||||||
}
|
}
|
||||||
v.estack.updateSizeRemove(arr[index])
|
v.refs.Remove(arr[index])
|
||||||
arr[index] = item
|
arr[index] = item
|
||||||
v.estack.updateSizeAdd(arr[index])
|
v.refs.Add(arr[index])
|
||||||
case *MapItem:
|
case *MapItem:
|
||||||
if t.Has(key.value) {
|
if t.Has(key.value) {
|
||||||
v.estack.updateSizeRemove(item)
|
v.refs.Remove(item)
|
||||||
} else if len(t.value) >= MaxArraySize {
|
} else if len(t.value) >= MaxArraySize {
|
||||||
panic("too big map")
|
panic("too big map")
|
||||||
}
|
}
|
||||||
t.Add(key.value, item)
|
t.Add(key.value, item)
|
||||||
v.estack.updateSizeAdd(item)
|
v.refs.Add(item)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("SETITEM: invalid item type %s", t))
|
panic(fmt.Sprintf("SETITEM: invalid item type %s", t))
|
||||||
|
@ -1059,7 +1056,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
||||||
if k < 0 || k >= len(a) {
|
if k < 0 || k >= len(a) {
|
||||||
panic("REMOVE: invalid index")
|
panic("REMOVE: invalid index")
|
||||||
}
|
}
|
||||||
v.estack.updateSizeRemove(a[k])
|
v.refs.Remove(a[k])
|
||||||
a = append(a[:k], a[k+1:]...)
|
a = append(a[:k], a[k+1:]...)
|
||||||
t.value = a
|
t.value = a
|
||||||
case *StructItem:
|
case *StructItem:
|
||||||
|
@ -1068,14 +1065,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
||||||
if k < 0 || k >= len(a) {
|
if k < 0 || k >= len(a) {
|
||||||
panic("REMOVE: invalid index")
|
panic("REMOVE: invalid index")
|
||||||
}
|
}
|
||||||
v.estack.updateSizeRemove(a[k])
|
v.refs.Remove(a[k])
|
||||||
a = append(a[:k], a[k+1:]...)
|
a = append(a[:k], a[k+1:]...)
|
||||||
t.value = a
|
t.value = a
|
||||||
case *MapItem:
|
case *MapItem:
|
||||||
index := t.Index(key.Item())
|
index := t.Index(key.Item())
|
||||||
// NEO 2.0 doesn't error on missing key.
|
// NEO 2.0 doesn't error on missing key.
|
||||||
if index >= 0 {
|
if index >= 0 {
|
||||||
v.estack.updateSizeRemove(t.value[index].Value)
|
v.refs.Remove(t.value[index].Value)
|
||||||
t.Drop(index)
|
t.Drop(index)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -1087,17 +1084,17 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro
|
||||||
switch t := elem.value.(type) {
|
switch t := elem.value.(type) {
|
||||||
case *ArrayItem:
|
case *ArrayItem:
|
||||||
for _, item := range t.value {
|
for _, item := range t.value {
|
||||||
v.estack.updateSizeRemove(item)
|
v.refs.Remove(item)
|
||||||
}
|
}
|
||||||
t.value = t.value[:0]
|
t.value = t.value[:0]
|
||||||
case *StructItem:
|
case *StructItem:
|
||||||
for _, item := range t.value {
|
for _, item := range t.value {
|
||||||
v.estack.updateSizeRemove(item)
|
v.refs.Remove(item)
|
||||||
}
|
}
|
||||||
t.value = t.value[:0]
|
t.value = t.value[:0]
|
||||||
case *MapItem:
|
case *MapItem:
|
||||||
for i := range t.value {
|
for i := range t.value {
|
||||||
v.estack.updateSizeRemove(t.value[i].Value)
|
v.refs.Remove(t.value[i].Value)
|
||||||
}
|
}
|
||||||
t.value = t.value[:0]
|
t.value = t.value[:0]
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -423,7 +423,7 @@ func TestStackLimit(t *testing.T) {
|
||||||
vm := load(makeProgram(prog...))
|
vm := load(makeProgram(prog...))
|
||||||
for i := range expected {
|
for i := range expected {
|
||||||
require.NoError(t, vm.Step())
|
require.NoError(t, vm.Step())
|
||||||
require.Equal(t, expected[i].size, vm.size)
|
require.Equal(t, expected[i].size, vm.refs.size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1980,7 +1980,7 @@ func testCLEARITEMS(t *testing.T, item StackItem) {
|
||||||
v.estack.PushVal(item)
|
v.estack.PushVal(item)
|
||||||
runVM(t, v)
|
runVM(t, v)
|
||||||
require.Equal(t, 2, v.estack.Len())
|
require.Equal(t, 2, v.estack.Len())
|
||||||
require.EqualValues(t, 2, v.size) // empty collection + it's size
|
require.EqualValues(t, 2, v.refs.size) // empty collection + it's size
|
||||||
require.EqualValues(t, 0, v.estack.Pop().BigInt().Int64())
|
require.EqualValues(t, 0, v.estack.Pop().BigInt().Int64())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue