diff --git a/pkg/vm/stack.go b/pkg/vm/stack.go index a08882d8d..17c6e8143 100644 --- a/pkg/vm/stack.go +++ b/pkg/vm/stack.go @@ -84,6 +84,18 @@ func (e *Element) Bytes() []byte { return e.value.Value().([]byte) } +// Array attempts to get the underlying value of the element as an array of +// other items. Will panic if the item type is different which will be caught +// by the VM. +func (e *Element) Array() []StackItem { + switch t := e.value.(type) { + case *ArrayItem: + return t.value + default: + panic("element is not an array") + } +} + // Stack represents a Stack backed by a double linked list. type Stack struct { top Element diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 8109e8067..ebdb50cab 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -633,6 +633,46 @@ func TestRIGHTBadLen(t *testing.T) { assert.Equal(t, true, vm.state.HasFlag(faultState)) } +func TestPACKBadLen(t *testing.T) { + prog := makeProgram(PACK) + vm := load(prog) + vm.estack.PushVal(1) + vm.Run() + assert.Equal(t, true, vm.state.HasFlag(faultState)) +} + +func TestPACKGoodZeroLen(t *testing.T) { + prog := makeProgram(PACK) + vm := load(prog) + vm.estack.PushVal(0) + vm.Run() + assert.Equal(t, false, vm.state.HasFlag(faultState)) + assert.Equal(t, 1, vm.estack.Len()) + assert.Equal(t, []StackItem{}, vm.estack.Peek(0).Array()) +} + +func TestPACKGood(t *testing.T) { + prog := makeProgram(PACK) + elements := []int{55, 34, 42} + vm := load(prog) + // canary + vm.estack.PushVal(1) + for i := len(elements) - 1; i >= 0; i-- { + vm.estack.PushVal(elements[i]) + } + vm.estack.PushVal(len(elements)) + vm.Run() + assert.Equal(t, false, vm.state.HasFlag(faultState)) + assert.Equal(t, 2, vm.estack.Len()) + a := vm.estack.Peek(0).Array() + assert.Equal(t, len(elements), len(a)) + for i := 0; i < len(elements); i++ { + e := a[i].Value().(*big.Int) + assert.Equal(t, int64(elements[i]), e.Int64()) + } + assert.Equal(t, int64(1), vm.estack.Peek(1).BigInt().Int64()) +} + func makeProgram(opcodes ...Instruction) []byte { prog := make([]byte, len(opcodes)+1) // RET for i := 0; i < len(opcodes); i++ {