diff --git a/pkg/vm/slot.go b/pkg/vm/slot.go index 1f4107ee2..18a65ac85 100644 --- a/pkg/vm/slot.go +++ b/pkg/vm/slot.go @@ -45,3 +45,10 @@ func (s *Slot) Get(i int) stackitem.Item { // Size returns slot size. func (s *Slot) Size() int { return len(s.storage) } + +// Clear removes all slot variables from reference counter. +func (s *Slot) Clear() { + for _, item := range s.storage { + s.refs.Remove(item) + } +} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 627ed480c..cd2e0e74a 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1262,9 +1262,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro } case opcode.RET: - v.istack.Pop() + oldCtx := v.istack.Pop().Value().(*Context) oldEstack := v.estack + v.unloadContext(oldCtx) if v.istack.Len() == 0 { v.state = haltState break @@ -1372,6 +1373,19 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro return } +func (v *VM) unloadContext(ctx *Context) { + if ctx.local != nil { + ctx.local.Clear() + } + if ctx.arguments != nil { + ctx.arguments.Clear() + } + currCtx := v.Context() + if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { + ctx.static.Clear() + } +} + // getJumpCondition performs opcode specific comparison of a and b func getJumpCondition(op opcode.Opcode, a, b *big.Int) bool { cmp := a.Cmp(b) diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 8075e94e6..ca2c3cdca 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -1142,6 +1142,46 @@ func getTestFuncForVM(prog []byte, result interface{}, args ...interface{}) func return getCustomTestFuncForVM(prog, f, args...) } +func makeRETProgram(t *testing.T, argCount, localCount int) []byte { + require.True(t, argCount+localCount <= 255) + + fProg := []opcode.Opcode{opcode.INITSLOT, opcode.Opcode(localCount), opcode.Opcode(argCount)} + for i := 0; i < localCount; i++ { + fProg = append(fProg, opcode.PUSH8, opcode.STLOC, opcode.Opcode(i)) + } + fProg = append(fProg, opcode.RET) + + offset := uint32(len(fProg) + 5) + param := make([]byte, 4) + binary.LittleEndian.PutUint32(param, offset) + + ops := []opcode.Opcode{ + opcode.INITSSLOT, 0x01, + opcode.PUSHA, 11, 0, 0, 0, + opcode.STSFLD0, + opcode.JMPL, opcode.Opcode(param[0]), opcode.Opcode(param[1]), opcode.Opcode(param[2]), opcode.Opcode(param[3]), + } + ops = append(ops, fProg...) + + // execute func multiple times to ensure total reference count is less than max + callCount := MaxStackSize/(argCount+localCount) + 1 + args := make([]opcode.Opcode, argCount) + for i := range args { + args[i] = opcode.PUSH7 + } + for i := 0; i < callCount; i++ { + ops = append(ops, args...) + ops = append(ops, opcode.LDSFLD0, opcode.CALLA) + } + return makeProgram(ops...) +} + +func TestRETReferenceClear(t *testing.T) { + // 42 is a canary + t.Run("Argument", getTestFuncForVM(makeRETProgram(t, 100, 0), 42, 42)) + t.Run("Local", getTestFuncForVM(makeRETProgram(t, 0, 100), 42, 42)) +} + func TestNOTEQUALByteArray(t *testing.T) { prog := makeProgram(opcode.NOTEQUAL) t.Run("True", getTestFuncForVM(prog, true, []byte{1, 2}, []byte{0, 1, 2}))