diff --git a/pkg/vm/context.go b/pkg/vm/context.go index 1eaba0ce0..3cbace4b7 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -32,9 +32,9 @@ type Context struct { // Evaluation stack pointer. estack *Stack - static *Slot - local *Slot - arguments *Slot + static *slot + local slot + arguments slot // Exception context stack. tryStack Stack @@ -277,16 +277,19 @@ func (c *Context) DumpStaticSlot() string { // DumpLocalSlot returns json formatted representation of the given slot. func (c *Context) DumpLocalSlot() string { - return dumpSlot(c.local) + return dumpSlot(&c.local) } // DumpArgumentsSlot returns json formatted representation of the given slot. func (c *Context) DumpArgumentsSlot() string { - return dumpSlot(c.arguments) + return dumpSlot(&c.arguments) } // dumpSlot returns json formatted representation of the given slot. -func dumpSlot(s *Slot) string { +func dumpSlot(s *slot) string { + if s == nil || *s == nil { + return "[]" + } b, _ := json.MarshalIndent(s, "", " ") return string(b) } diff --git a/pkg/vm/json_test.go b/pkg/vm/json_test.go index 5d7f6acf5..bd15eba66 100644 --- a/pkg/vm/json_test.go +++ b/pkg/vm/json_test.go @@ -239,8 +239,8 @@ func compareStacks(t *testing.T, expected []vmUTStackItem, actual *Stack) { compareItemArrays(t, expected, actual.Len(), func(i int) stackitem.Item { return actual.Peek(i).Item() }) } -func compareSlots(t *testing.T, expected []vmUTStackItem, actual *Slot) { - if actual.storage == nil && len(expected) == 0 { +func compareSlots(t *testing.T, expected []vmUTStackItem, actual *slot) { + if (actual == nil || *actual == nil) && len(expected) == 0 { return } require.NotNil(t, actual) diff --git a/pkg/vm/opcodebench_test.go b/pkg/vm/opcodebench_test.go index 1657d4151..465a9943e 100644 --- a/pkg/vm/opcodebench_test.go +++ b/pkg/vm/opcodebench_test.go @@ -59,8 +59,8 @@ func opParamSlotsPushVM(op opcode.Opcode, param []byte, sslot int, slotloc int, v.Context().static.init(sslot) } if slotloc != 0 && slotarg != 0 { - v.Context().local = v.newSlot(slotloc) - v.Context().arguments = v.newSlot(slotarg) + v.Context().local.init(slotloc) + v.Context().arguments.init(slotarg) } for i := range items { item, ok := items[i].(stackitem.Item) diff --git a/pkg/vm/slot.go b/pkg/vm/slot.go index 634891f18..132ee220c 100644 --- a/pkg/vm/slot.go +++ b/pkg/vm/slot.go @@ -6,75 +6,58 @@ import ( "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) -// Slot is a fixed-size slice of stack items. -type Slot struct { - storage []stackitem.Item - refs *refCounter -} - -// newSlot returns new slot with the provided reference counter. -func newSlot(refs *refCounter) *Slot { - return &Slot{ - refs: refs, - } -} +// slot is a fixed-size slice of stack items. +type slot []stackitem.Item // init sets static slot size to n. It is intended to be used only by INITSSLOT. -func (s *Slot) init(n int) { - if s.storage != nil { +func (s *slot) init(n int) { + if *s != nil { panic("already initialized") } - s.storage = make([]stackitem.Item, n) -} - -func (v *VM) newSlot(n int) *Slot { - s := newSlot(&v.refs) - s.init(n) - return s + *s = make([]stackitem.Item, n) } // Set sets i-th storage slot. -func (s *Slot) Set(i int, item stackitem.Item) { - if s.storage[i] == item { +func (s slot) Set(i int, item stackitem.Item, refs *refCounter) { + if s[i] == item { return } - old := s.storage[i] - s.storage[i] = item + old := s[i] + s[i] = item if old != nil { - s.refs.Remove(old) + refs.Remove(old) } - s.refs.Add(item) + refs.Add(item) } // Get returns item contained in i-th slot. -func (s *Slot) Get(i int) stackitem.Item { - if item := s.storage[i]; item != nil { +func (s slot) Get(i int) stackitem.Item { + if item := s[i]; item != nil { return item } return stackitem.Null{} } // Clear removes all slot variables from reference counter. -func (s *Slot) Clear() { - for _, item := range s.storage { - s.refs.Remove(item) +func (s slot) Clear(refs *refCounter) { + for _, item := range s { + refs.Remove(item) } } // Size returns slot size. -func (s *Slot) Size() int { - if s.storage == nil { +func (s slot) Size() int { + if s == nil { panic("not initialized") } - return len(s.storage) + return len(s) } // MarshalJSON implements JSON marshalling interface. -func (s *Slot) MarshalJSON() ([]byte, error) { - items := s.storage - arr := make([]json.RawMessage, len(items)) - for i := range items { - data, err := stackitem.ToJSONWithTypes(items[i]) +func (s slot) MarshalJSON() ([]byte, error) { + arr := make([]json.RawMessage, len(s)) + for i := range s { + data, err := stackitem.ToJSONWithTypes(s[i]) if err == nil { arr[i] = data } diff --git a/pkg/vm/slot_test.go b/pkg/vm/slot_test.go index 434464476..212470a9f 100644 --- a/pkg/vm/slot_test.go +++ b/pkg/vm/slot_test.go @@ -9,8 +9,8 @@ import ( ) func TestSlot_Get(t *testing.T) { - s := newSlot(newRefCounter()) - require.NotNil(t, s) + rc := newRefCounter() + var s slot require.Panics(t, func() { s.Size() }) s.init(3) @@ -20,6 +20,6 @@ func TestSlot_Get(t *testing.T) { item := s.Get(2) require.Equal(t, stackitem.Null{}, item) - s.Set(1, stackitem.NewBigInteger(big.NewInt(42))) + s.Set(1, stackitem.NewBigInteger(big.NewInt(42)), rc) require.Equal(t, stackitem.NewBigInteger(big.NewInt(42)), s.Get(1)) } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 960b0e4f7..07c6739cc 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -310,13 +310,15 @@ func (v *VM) LoadNEFMethod(exe *nef.File, caller util.Uint160, hash util.Uint160 // It should be used for calling from native contracts. func (v *VM) loadScriptWithCallingHash(b []byte, exe *nef.File, caller util.Uint160, hash util.Uint160, f callflag.CallFlag, rvcount int, offset int) { + var sl slot + v.checkInvocationStackSize() ctx := NewContextWithParams(b, rvcount, offset) v.estack = newStack("evaluation", &v.refs) ctx.estack = v.estack initStack(&ctx.tryStack, "exception", nil) ctx.callFlag = f - ctx.static = newSlot(&v.refs) + ctx.static = &sl ctx.scriptHash = hash ctx.callingScriptHash = caller ctx.NEF = exe @@ -615,13 +617,13 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro panic("zero argument") } if parameter[0] > 0 { - ctx.local = v.newSlot(int(parameter[0])) + ctx.local.init(int(parameter[0])) } if parameter[1] > 0 { sz := int(parameter[1]) - ctx.arguments = v.newSlot(sz) + ctx.arguments.init(sz) for i := 0; i < sz; i++ { - ctx.arguments.Set(i, v.estack.Pop().Item()) + ctx.arguments.Set(i, v.estack.Pop().Item(), &v.refs) } } @@ -635,11 +637,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.STSFLD0, opcode.STSFLD1, opcode.STSFLD2, opcode.STSFLD3, opcode.STSFLD4, opcode.STSFLD5, opcode.STSFLD6: item := v.estack.Pop().Item() - ctx.static.Set(int(op-opcode.STSFLD0), item) + ctx.static.Set(int(op-opcode.STSFLD0), item, &v.refs) case opcode.STSFLD: item := v.estack.Pop().Item() - ctx.static.Set(int(parameter[0]), item) + ctx.static.Set(int(parameter[0]), item, &v.refs) case opcode.LDLOC0, opcode.LDLOC1, opcode.LDLOC2, opcode.LDLOC3, opcode.LDLOC4, opcode.LDLOC5, opcode.LDLOC6: item := ctx.local.Get(int(op - opcode.LDLOC0)) @@ -651,11 +653,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.STLOC0, opcode.STLOC1, opcode.STLOC2, opcode.STLOC3, opcode.STLOC4, opcode.STLOC5, opcode.STLOC6: item := v.estack.Pop().Item() - ctx.local.Set(int(op-opcode.STLOC0), item) + ctx.local.Set(int(op-opcode.STLOC0), item, &v.refs) case opcode.STLOC: item := v.estack.Pop().Item() - ctx.local.Set(int(parameter[0]), item) + ctx.local.Set(int(parameter[0]), item, &v.refs) case opcode.LDARG0, opcode.LDARG1, opcode.LDARG2, opcode.LDARG3, opcode.LDARG4, opcode.LDARG5, opcode.LDARG6: item := ctx.arguments.Get(int(op - opcode.LDARG0)) @@ -667,11 +669,11 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.STARG0, opcode.STARG1, opcode.STARG2, opcode.STARG3, opcode.STARG4, opcode.STARG5, opcode.STARG6: item := v.estack.Pop().Item() - ctx.arguments.Set(int(op-opcode.STARG0), item) + ctx.arguments.Set(int(op-opcode.STARG0), item, &v.refs) case opcode.STARG: item := v.estack.Pop().Item() - ctx.arguments.Set(int(parameter[0]), item) + ctx.arguments.Set(int(parameter[0]), item, &v.refs) case opcode.NEWBUFFER: n := toInt(v.estack.Pop().BigInt()) @@ -1527,14 +1529,14 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro func (v *VM) unloadContext(ctx *Context) { if ctx.local != nil { - ctx.local.Clear() + ctx.local.Clear(&v.refs) } if ctx.arguments != nil { - ctx.arguments.Clear() + ctx.arguments.Clear(&v.refs) } currCtx := v.Context() if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { - ctx.static.Clear() + ctx.static.Clear(&v.refs) } }