diff --git a/pkg/vm/slot.go b/pkg/vm/slot.go index 4544ab8dd..661e796d8 100644 --- a/pkg/vm/slot.go +++ b/pkg/vm/slot.go @@ -23,13 +23,8 @@ func (s slot) Set(i int, item stackitem.Item, refs *refCounter) { if s[i] == item { return } - old := s[i] + refs.Remove(s[i]) s[i] = item - if old != nil { - refs.Remove(old) - } else { - *refs-- // Not really existing, but counted Null element. - } refs.Add(item) } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 801d68846..44f905cbb 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -1588,7 +1588,7 @@ func (v *VM) unloadContext(ctx *Context) { ctx.arguments.ClearRefs(&v.refs) } currCtx := v.Context() - if ctx.static != nil && currCtx != nil && ctx.static != currCtx.static { + if ctx.static != nil && (currCtx == nil || ctx.static != currCtx.static) { ctx.static.ClearRefs(&v.refs) } } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 0e88a14e4..e3bd59256 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -2718,6 +2718,35 @@ func TestNestedStructEquals(t *testing.T) { checkVMFailed(t, vm) } +func TestRemoveReferrer(t *testing.T) { + h := "560110c34a10c36058cf4540" // #2501 + prog, err := hex.DecodeString(h) + require.NoError(t, err) + vm := load(prog) + require.NoError(t, vm.StepInto()) // INITSSLOT + assert.Equal(t, 1, int(vm.refs)) + require.NoError(t, vm.StepInto()) // PUSH0 + assert.Equal(t, 2, int(vm.refs)) + require.NoError(t, vm.StepInto()) // NEWARRAY + assert.Equal(t, 2, int(vm.refs)) + require.NoError(t, vm.StepInto()) // DUP + assert.Equal(t, 3, int(vm.refs)) + require.NoError(t, vm.StepInto()) // PUSH0 + assert.Equal(t, 4, int(vm.refs)) + require.NoError(t, vm.StepInto()) // NEWARRAY + assert.Equal(t, 4, int(vm.refs)) + require.NoError(t, vm.StepInto()) // STSFLD0 + assert.Equal(t, 3, int(vm.refs)) + require.NoError(t, vm.StepInto()) // LDSFLD0 + assert.Equal(t, 4, int(vm.refs)) + require.NoError(t, vm.StepInto()) // APPEND + assert.Equal(t, 3, int(vm.refs)) + require.NoError(t, vm.StepInto()) // DROP + assert.Equal(t, 1, int(vm.refs)) + require.NoError(t, vm.StepInto()) // RET + assert.Equal(t, 0, int(vm.refs)) +} + func makeProgram(opcodes ...opcode.Opcode) []byte { prog := make([]byte, len(opcodes)+1) // RET for i := 0; i < len(opcodes); i++ {