diff --git a/pkg/vm/stack.go b/pkg/vm/stack.go index cc8dabfb2..af8f2cecc 100644 --- a/pkg/vm/stack.go +++ b/pkg/vm/stack.go @@ -2,6 +2,7 @@ package vm import ( "encoding/json" + "errors" "fmt" "math/big" @@ -371,6 +372,51 @@ func (s *Stack) IterBack(f func(*Element)) { } } +// Swap swaps two elements on the stack without popping and pushing them. +func (s *Stack) Swap(n1, n2 int) error { + if n1 < 0 || n2 < 0 { + return errors.New("negative index") + } + if n1 >= s.len || n2 >= s.len { + return errors.New("too big index") + } + if n1 == n2 { + return nil + } + a := s.Peek(n1) + b := s.Peek(n2) + a.value, b.value = b.value, a.value + return nil +} + +// Roll brings an item with the given index to the top of the stack, moving all +// the other elements down accordingly. It does all of that without popping and +// pushing elements. +func (s *Stack) Roll(n int) error { + if n < 0 { + return errors.New("negative index") + } + if n >= s.len { + return errors.New("too big index") + } + if n == 0 { + return nil + } + top := s.Peek(0) + e := s.Peek(n) + + e.prev.next = e.next + e.next.prev = e.prev + + top.prev = e + e.next = top + + e.prev = &s.top + s.top.next = e + + return nil +} + // popSigElements pops keys or signatures from the stack as needed for // CHECKMULTISIG. func (s *Stack) popSigElements() ([][]byte, error) { diff --git a/pkg/vm/stack_test.go b/pkg/vm/stack_test.go index c14a84414..dc0fe8bd8 100644 --- a/pkg/vm/stack_test.go +++ b/pkg/vm/stack_test.go @@ -226,22 +226,76 @@ func TestSwapElemValues(t *testing.T) { s.PushVal(2) s.PushVal(4) - a := s.Peek(0) - b := s.Peek(1) - - // [ 4 ] -> a - // [ 2 ] -> b - - aval := a.value - bval := b.value - a.value = bval - b.value = aval - - // [ 2 ] -> a - // [ 4 ] -> b - + assert.NoError(t, s.Swap(0, 1)) assert.Equal(t, int64(2), s.Pop().BigInt().Int64()) assert.Equal(t, int64(4), s.Pop().BigInt().Int64()) + + s.PushVal(1) + s.PushVal(2) + s.PushVal(3) + s.PushVal(4) + + assert.NoError(t, s.Swap(1, 3)) + assert.Equal(t, int64(4), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(1), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(2), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(3), s.Pop().BigInt().Int64()) + + s.PushVal(1) + s.PushVal(2) + s.PushVal(3) + s.PushVal(4) + + assert.Error(t, s.Swap(-1, 0)) + assert.Error(t, s.Swap(0, -3)) + assert.Error(t, s.Swap(0, 4)) + assert.Error(t, s.Swap(5, 0)) + + assert.NoError(t, s.Swap(1, 1)) + assert.Equal(t, int64(4), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(3), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(2), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(1), s.Pop().BigInt().Int64()) +} + +func TestRoll(t *testing.T) { + s := NewStack("test") + + s.PushVal(1) + s.PushVal(2) + s.PushVal(3) + s.PushVal(4) + + assert.NoError(t, s.Roll(2)) + assert.Equal(t, int64(2), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(4), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(3), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(1), s.Pop().BigInt().Int64()) + + s.PushVal(1) + s.PushVal(2) + s.PushVal(3) + s.PushVal(4) + + assert.NoError(t, s.Roll(3)) + assert.Equal(t, int64(1), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(4), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(3), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(2), s.Pop().BigInt().Int64()) + + s.PushVal(1) + s.PushVal(2) + s.PushVal(3) + s.PushVal(4) + + assert.Error(t, s.Roll(-1)) + assert.Error(t, s.Roll(4)) + + assert.NoError(t, s.Roll(0)) + assert.Equal(t, int64(4), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(3), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(2), s.Pop().BigInt().Int64()) + assert.Equal(t, int64(1), s.Pop().BigInt().Int64()) } func TestPopSigElements(t *testing.T) { diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 5d784e83e..4c57ea3bc 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -510,10 +510,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro v.estack.Push(v.estack.Dup(0)) case opcode.SWAP: - a := v.estack.Pop() - b := v.estack.Pop() - v.estack.Push(a) - v.estack.Push(b) + err := v.estack.Swap(1, 0) + if err != nil { + panic(err.Error()) + } case opcode.TUCK: a := v.estack.Dup(0) @@ -587,18 +587,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.XSWAP: n := int(v.estack.Pop().BigInt().Int64()) - if n < 0 { - panic("XSWAP: invalid length") - } - - // Swap values of elements instead of reordering stack elements. - if n > 0 { - a := v.estack.Peek(n) - b := v.estack.Peek(0) - aval := a.value - bval := b.value - a.value = bval - b.value = aval + err := v.estack.Swap(n, 0) + if err != nil { + panic(err.Error()) } case opcode.XTUCK: @@ -616,11 +607,10 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro v.estack.InsertAt(a, n) case opcode.ROT: - e := v.estack.RemoveAt(2) - if e == nil { - panic("no top-level element found") + err := v.estack.Roll(2) + if err != nil { + panic(err.Error()) } - v.estack.Push(e) case opcode.DEPTH: v.estack.PushVal(v.estack.Len()) @@ -651,15 +641,9 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.ROLL: n := int(v.estack.Pop().BigInt().Int64()) - if n < 0 { - panic("negative stack item returned") - } - if n > 0 { - e := v.estack.RemoveAt(n) - if e == nil { - panic("bad index") - } - v.estack.Push(e) + err := v.estack.Roll(n) + if err != nil { + panic(err.Error()) } case opcode.DROP: diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 8a512a0ef..838b1faf4 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -1518,6 +1518,40 @@ func TestROTGood(t *testing.T) { assert.Equal(t, makeStackItem(2), vm.estack.Pop().value) } +func TestROLLBad1(t *testing.T) { + prog := makeProgram(opcode.ROLL) + vm := load(prog) + vm.estack.PushVal(1) + vm.estack.PushVal(-1) + checkVMFailed(t, vm) +} + +func TestROLLBad2(t *testing.T) { + prog := makeProgram(opcode.ROLL) + vm := load(prog) + vm.estack.PushVal(1) + vm.estack.PushVal(2) + vm.estack.PushVal(3) + vm.estack.PushVal(3) + checkVMFailed(t, vm) +} + +func TestROLLGood(t *testing.T) { + prog := makeProgram(opcode.ROLL) + vm := load(prog) + vm.estack.PushVal(1) + vm.estack.PushVal(2) + vm.estack.PushVal(3) + vm.estack.PushVal(4) + vm.estack.PushVal(1) + runVM(t, vm) + assert.Equal(t, 4, vm.estack.Len()) + assert.Equal(t, makeStackItem(3), vm.estack.Pop().value) + assert.Equal(t, makeStackItem(4), vm.estack.Pop().value) + assert.Equal(t, makeStackItem(2), vm.estack.Pop().value) + assert.Equal(t, makeStackItem(1), vm.estack.Pop().value) +} + func TestXTUCKbadNoitem(t *testing.T) { prog := makeProgram(opcode.XTUCK) vm := load(prog) @@ -2409,6 +2443,68 @@ func TestCHECKMULTISIGGood(t *testing.T) { assert.Equal(t, true, vm.estack.Pop().Bool()) } +func TestSWAPGood(t *testing.T) { + prog := makeProgram(opcode.SWAP) + vm := load(prog) + vm.estack.PushVal(2) + vm.estack.PushVal(4) + runVM(t, vm) + assert.Equal(t, 2, vm.estack.Len()) + assert.Equal(t, int64(2), vm.estack.Pop().BigInt().Int64()) + assert.Equal(t, int64(4), vm.estack.Pop().BigInt().Int64()) +} + +func TestSWAPBad1(t *testing.T) { + prog := makeProgram(opcode.SWAP) + vm := load(prog) + vm.estack.PushVal(4) + checkVMFailed(t, vm) +} + +func TestSWAPBad2(t *testing.T) { + prog := makeProgram(opcode.SWAP) + vm := load(prog) + checkVMFailed(t, vm) +} + +func TestXSWAPGood(t *testing.T) { + prog := makeProgram(opcode.XSWAP) + vm := load(prog) + vm.estack.PushVal(1) + vm.estack.PushVal(2) + vm.estack.PushVal(3) + vm.estack.PushVal(4) + vm.estack.PushVal(5) + vm.estack.PushVal(3) + runVM(t, vm) + assert.Equal(t, 5, vm.estack.Len()) + assert.Equal(t, int64(2), vm.estack.Pop().BigInt().Int64()) + assert.Equal(t, int64(4), vm.estack.Pop().BigInt().Int64()) + assert.Equal(t, int64(3), vm.estack.Pop().BigInt().Int64()) + assert.Equal(t, int64(5), vm.estack.Pop().BigInt().Int64()) + assert.Equal(t, int64(1), vm.estack.Pop().BigInt().Int64()) +} + +func TestXSWAPBad1(t *testing.T) { + prog := makeProgram(opcode.XSWAP) + vm := load(prog) + vm.estack.PushVal(1) + vm.estack.PushVal(2) + vm.estack.PushVal(-1) + checkVMFailed(t, vm) +} + +func TestXSWAPBad2(t *testing.T) { + prog := makeProgram(opcode.XSWAP) + vm := load(prog) + vm.estack.PushVal(1) + vm.estack.PushVal(2) + vm.estack.PushVal(3) + vm.estack.PushVal(4) + vm.estack.PushVal(4) + checkVMFailed(t, vm) +} + func makeProgram(opcodes ...opcode.Opcode) []byte { prog := make([]byte, len(opcodes)+1) // RET for i := 0; i < len(opcodes); i++ {