diff --git a/pkg/vm/stack.go b/pkg/vm/stack.go index 16dda046a..af8f2cecc 100644 --- a/pkg/vm/stack.go +++ b/pkg/vm/stack.go @@ -389,6 +389,34 @@ func (s *Stack) Swap(n1, n2 int) error { return nil } +// Roll brings an item with the given index to the top of the stack, moving all +// the other elements down accordingly. It does all of that without popping and +// pushing elements. +func (s *Stack) Roll(n int) error { + if n < 0 { + return errors.New("negative index") + } + if n >= s.len { + return errors.New("too big index") + } + if n == 0 { + return nil + } + top := s.Peek(0) + e := s.Peek(n) + + e.prev.next = e.next + e.next.prev = e.prev + + top.prev = e + e.next = top + + e.prev = &s.top + s.top.next = e + + return nil +} + // popSigElements pops keys or signatures from the stack as needed for // CHECKMULTISIG. func (s *Stack) popSigElements() ([][]byte, error) { diff --git a/pkg/vm/stack_test.go b/pkg/vm/stack_test.go index 7e0244282..dc0fe8bd8 100644 --- a/pkg/vm/stack_test.go +++ b/pkg/vm/stack_test.go @@ -258,6 +258,46 @@ func TestSwapElemValues(t *testing.T) { assert.Equal(t, int64(1), s.Pop().BigInt().Int64()) } +func TestRoll(t *testing.T) { + s := NewStack("test") + + s.PushVal(1) + s.PushVal(2) + s.PushVal(3) + s.PushVal(4) + + assert.NoError(t, s.Roll(2)) + assert.Equal(t, int64(2), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(4), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(3), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(1), s.Pop().BigInt().Int64()) + + s.PushVal(1) + s.PushVal(2) + s.PushVal(3) + s.PushVal(4) + + assert.NoError(t, s.Roll(3)) + assert.Equal(t, int64(1), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(4), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(3), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(2), s.Pop().BigInt().Int64()) + + s.PushVal(1) + s.PushVal(2) + s.PushVal(3) + s.PushVal(4) + + assert.Error(t, s.Roll(-1)) + assert.Error(t, s.Roll(4)) + + assert.NoError(t, s.Roll(0)) + assert.Equal(t, int64(4), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(3), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(2), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(1), s.Pop().BigInt().Int64()) +} + func TestPopSigElements(t *testing.T) { s := NewStack("test") diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 2e8863ab3..4c57ea3bc 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -607,11 +607,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro v.estack.InsertAt(a, n) case opcode.ROT: - e := v.estack.RemoveAt(2) - if e == nil { - panic("no top-level element found") + err := v.estack.Roll(2) + if err != nil { + panic(err.Error()) } - v.estack.Push(e) case opcode.DEPTH: v.estack.PushVal(v.estack.Len()) @@ -642,15 +641,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.ROLL: n := int(v.estack.Pop().BigInt().Int64()) - if n < 0 { - panic("negative stack item returned") - } - if n > 0 { - e := v.estack.RemoveAt(n) - if e == nil { - panic("bad index") - } - v.estack.Push(e) + err := v.estack.Roll(n) + if err != nil { + panic(err.Error()) } case opcode.DROP: diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index d02e4fa10..838b1faf4 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -1518,6 +1518,40 @@ func TestROTGood(t *testing.T) { assert.Equal(t, makeStackItem(2), vm.estack.Pop().value) } +func TestROLLBad1(t *testing.T) { + prog := makeProgram(opcode.ROLL) + vm := load(prog) + vm.estack.PushVal(1) + vm.estack.PushVal(-1) + checkVMFailed(t, vm) +} + +func TestROLLBad2(t *testing.T) { + prog := makeProgram(opcode.ROLL) + vm := load(prog) + vm.estack.PushVal(1) + vm.estack.PushVal(2) + vm.estack.PushVal(3) + vm.estack.PushVal(3) + checkVMFailed(t, vm) +} + +func TestROLLGood(t *testing.T) { + prog := makeProgram(opcode.ROLL) + vm := load(prog) + vm.estack.PushVal(1) + vm.estack.PushVal(2) + vm.estack.PushVal(3) + vm.estack.PushVal(4) + vm.estack.PushVal(1) + runVM(t, vm) + assert.Equal(t, 4, vm.estack.Len()) + assert.Equal(t, makeStackItem(3), vm.estack.Pop().value) + assert.Equal(t, makeStackItem(4), vm.estack.Pop().value) + assert.Equal(t, makeStackItem(2), vm.estack.Pop().value) + assert.Equal(t, makeStackItem(1), vm.estack.Pop().value) +} + func TestXTUCKbadNoitem(t *testing.T) { prog := makeProgram(opcode.XTUCK) vm := load(prog)