diff --git a/pkg/core/interop_neo.go b/pkg/core/interop_neo.go index 33407e431..e1cf39d59 100644 --- a/pkg/core/interop_neo.go +++ b/pkg/core/interop_neo.go @@ -808,3 +808,28 @@ func (ic *interopContext) enumeratorNext(v *vm.VM) error { func (ic *interopContext) enumeratorValue(v *vm.VM) error { return vm.EnumeratorValue(v) } + +// iteratorConcat concatenates 2 iterators into a single one. +func (ic *interopContext) iteratorConcat(v *vm.VM) error { + return vm.IteratorConcat(v) +} + +// iteratorCreate creates an iterator from array-like or map stack item. +func (ic *interopContext) iteratorCreate(v *vm.VM) error { + return vm.IteratorCreate(v) +} + +// iteratorKey returns current iterator key. +func (ic *interopContext) iteratorKey(v *vm.VM) error { + return vm.IteratorKey(v) +} + +// iteratorKeys returns keys of the iterator. +func (ic *interopContext) iteratorKeys(v *vm.VM) error { + return vm.IteratorKeys(v) +} + +// iteratorValues returns values of the iterator. +func (ic *interopContext) iteratorValues(v *vm.VM) error { + return vm.IteratorValues(v) +} diff --git a/pkg/core/interops.go b/pkg/core/interops.go index b7c249a91..4f76f1ea6 100644 --- a/pkg/core/interops.go +++ b/pkg/core/interops.go @@ -157,6 +157,11 @@ var neoInterops = []interopedFunction{ {Name: "Neo.Input.GetHash", Func: (*interopContext).inputGetHash, Price: 1}, {Name: "Neo.Input.GetIndex", Func: (*interopContext).inputGetIndex, Price: 1}, {Name: "Neo.InvocationTransaction.GetScript", Func: (*interopContext).invocationTxGetScript, Price: 1}, + {Name: "Neo.Iterator.Concat", Func: (*interopContext).iteratorConcat, Price: 1}, + {Name: "Neo.Iterator.Create", Func: (*interopContext).iteratorCreate, Price: 1}, + {Name: "Neo.Iterator.Key", Func: (*interopContext).iteratorKey, Price: 1}, + {Name: "Neo.Iterator.Keys", Func: (*interopContext).iteratorKeys, Price: 1}, + {Name: "Neo.Iterator.Values", Func: (*interopContext).iteratorValues, Price: 1}, {Name: "Neo.Output.GetAssetId", Func: (*interopContext).outputGetAssetID, Price: 1}, {Name: "Neo.Output.GetScriptHash", Func: (*interopContext).outputGetScriptHash, Price: 1}, {Name: "Neo.Output.GetValue", Func: (*interopContext).outputGetValue, Price: 1}, @@ -182,16 +187,11 @@ var neoInterops = []interopedFunction{ {Name: "Neo.Transaction.GetUnspentCoins", Func: (*interopContext).txGetUnspentCoins, Price: 200}, {Name: "Neo.Transaction.GetWitnesses", Func: (*interopContext).txGetWitnesses, Price: 200}, {Name: "Neo.Witness.GetVerificationScript", Func: (*interopContext).witnessGetVerificationScript, Price: 100}, - // {Name: "Neo.Iterator.Concat", Func: (*interopContext).iteratorConcat, Price: 1}, - // {Name: "Neo.Iterator.Create", Func: (*interopContext).iteratorCreate, Price: 1}, - // {Name: "Neo.Iterator.Key", Func: (*interopContext).iteratorKey, Price: 1}, - // {Name: "Neo.Iterator.Keys", Func: (*interopContext).iteratorKeys, Price: 1}, - // {Name: "Neo.Iterator.Values", Func: (*interopContext).iteratorValues, Price: 1}, // {Name: "Neo.Storage.Find", Func: (*interopContext).storageFind, Price: 1}, // Aliases. - // {Name: "Neo.Iterator.Next", Func: (*interopContext).enumeratorNext, Price: 1}, - // {Name: "Neo.Iterator.Value", Func: (*interopContext).enumeratorValue, Price: 1}, + {Name: "Neo.Iterator.Next", Func: (*interopContext).enumeratorNext, Price: 1}, + {Name: "Neo.Iterator.Value", Func: (*interopContext).enumeratorValue, Price: 1}, // Old compatibility APIs. {Name: "AntShares.Account.GetBalance", Func: (*interopContext).accountGetBalance, Price: 1}, diff --git a/pkg/vm/interop.go b/pkg/vm/interop.go index 3489a925a..419c5f5d3 100644 --- a/pkg/vm/interop.go +++ b/pkg/vm/interop.go @@ -48,6 +48,16 @@ var defaultVMInterops = []interopIDFuncPrice{ InteropFuncPrice{EnumeratorConcat, 1}}, {InteropNameToID([]byte("Neo.Enumerator.Value")), InteropFuncPrice{EnumeratorValue, 1}}, + {InteropNameToID([]byte("Neo.Iterator.Create")), + InteropFuncPrice{IteratorCreate, 1}}, + {InteropNameToID([]byte("Neo.Iterator.Concat")), + InteropFuncPrice{IteratorConcat, 1}}, + {InteropNameToID([]byte("Neo.Iterator.Key")), + InteropFuncPrice{IteratorKey, 1}}, + {InteropNameToID([]byte("Neo.Iterator.Keys")), + InteropFuncPrice{IteratorKeys, 1}}, + {InteropNameToID([]byte("Neo.Iterator.Values")), + InteropFuncPrice{IteratorValues, 1}}, } func getDefaultVMInterop(id uint32) *InteropFuncPrice { @@ -163,3 +173,80 @@ func EnumeratorConcat(v *VM) error { return nil } + +// IteratorCreate handles syscall Neo.Iterator.Create. +func IteratorCreate(v *VM) error { + data := v.Estack().Pop() + var item interface{} + switch t := data.value.(type) { + case *ArrayItem, *StructItem: + item = &arrayWrapper{ + index: -1, + value: t.Value().([]StackItem), + } + case *MapItem: + keys := make([]interface{}, 0, len(t.value)) + for k := range t.value { + keys = append(keys, k) + } + + item = &mapWrapper{ + index: -1, + keys: keys, + m: t.value, + } + default: + return errors.New("non-iterable type") + } + + v.Estack().Push(&Element{value: NewInteropItem(item)}) + return nil +} + +// IteratorConcat handles syscall Neo.Iterator.Concat. +func IteratorConcat(v *VM) error { + iop1 := v.Estack().Pop().Interop() + iter1 := iop1.value.(iterator) + iop2 := v.Estack().Pop().Interop() + iter2 := iop2.value.(iterator) + + v.Estack().Push(&Element{value: NewInteropItem( + &concatIter{ + current: iter1, + second: iter2, + }, + )}) + + return nil +} + +// IteratorKey handles syscall Neo.Iterator.Key. +func IteratorKey(v *VM) error { + iop := v.estack.Pop().Interop() + iter := iop.value.(iterator) + v.Estack().Push(&Element{value: iter.Key()}) + + return nil +} + +// IteratorKeys handles syscall Neo.Iterator.Keys. +func IteratorKeys(v *VM) error { + iop := v.estack.Pop().Interop() + iter := iop.value.(iterator) + v.Estack().Push(&Element{value: NewInteropItem( + &keysWrapper{iter}, + )}) + + return nil +} + +// IteratorValues handles syscall Neo.Iterator.Values. +func IteratorValues(v *VM) error { + iop := v.estack.Pop().Interop() + iter := iop.value.(iterator) + v.Estack().Push(&Element{value: NewInteropItem( + &valuesWrapper{iter}, + )}) + + return nil +} diff --git a/pkg/vm/interop_iterators.go b/pkg/vm/interop_iterators.go index cb1254874..62a8b507e 100644 --- a/pkg/vm/interop_iterators.go +++ b/pkg/vm/interop_iterators.go @@ -17,6 +17,32 @@ type ( } ) +type ( + iterator interface { + enumerator + Key() StackItem + } + + mapWrapper struct { + index int + keys []interface{} + m map[interface{}]StackItem + } + + concatIter struct { + current iterator + second iterator + } + + keysWrapper struct { + iter iterator + } + + valuesWrapper struct { + iter iterator + } +) + func (a *arrayWrapper) Next() bool { if next := a.index + 1; next < len(a.value) { a.index = next @@ -30,6 +56,10 @@ func (a *arrayWrapper) Value() StackItem { return a.value[a.index] } +func (a *arrayWrapper) Key() StackItem { + return makeStackItem(a.index) +} + func (c *concatEnum) Next() bool { if c.current.Next() { return true @@ -42,3 +72,53 @@ func (c *concatEnum) Next() bool { func (c *concatEnum) Value() StackItem { return c.current.Value() } + +func (i *concatIter) Next() bool { + if i.current.Next() { + return true + } + i.current = i.second + + return i.second.Next() +} + +func (i *concatIter) Value() StackItem { + return i.current.Value() +} + +func (i *concatIter) Key() StackItem { + return i.current.Key() +} + +func (m *mapWrapper) Next() bool { + if next := m.index + 1; next < len(m.keys) { + m.index = next + return true + } + + return false +} + +func (m *mapWrapper) Value() StackItem { + return m.m[m.keys[m.index]] +} + +func (m *mapWrapper) Key() StackItem { + return makeStackItem(m.keys[m.index]) +} + +func (e *keysWrapper) Next() bool { + return e.iter.Next() +} + +func (e *keysWrapper) Value() StackItem { + return e.iter.Key() +} + +func (e *valuesWrapper) Next() bool { + return e.iter.Next() +} + +func (e *valuesWrapper) Value() StackItem { + return e.iter.Value() +} diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index ae211486f..2fb8fc634 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -333,13 +333,17 @@ func TestPushData4Good(t *testing.T) { assert.Equal(t, []byte{1, 2, 3}, vm.estack.Pop().Bytes()) } -func getEnumeratorProg(n int) (prog []byte) { +func getEnumeratorProg(n int, isIter bool) (prog []byte) { prog = append(prog, byte(opcode.TOALTSTACK)) for i := 0; i < n; i++ { prog = append(prog, byte(opcode.DUPFROMALTSTACK)) prog = append(prog, getSyscallProg("Neo.Enumerator.Next")...) prog = append(prog, byte(opcode.DUPFROMALTSTACK)) prog = append(prog, getSyscallProg("Neo.Enumerator.Value")...) + if isIter { + prog = append(prog, byte(opcode.DUPFROMALTSTACK)) + prog = append(prog, getSyscallProg("Neo.Iterator.Key")...) + } } prog = append(prog, byte(opcode.DUPFROMALTSTACK)) prog = append(prog, getSyscallProg("Neo.Enumerator.Next")...) @@ -356,8 +360,9 @@ func checkEnumeratorStack(t *testing.T, vm *VM, arr []StackItem) { } func testIterableCreate(t *testing.T, typ string) { + isIter := typ == "Iterator" prog := getSyscallProg("Neo." + typ + ".Create") - prog = append(prog, getEnumeratorProg(2)...) + prog = append(prog, getEnumeratorProg(2, isIter)...) vm := load(prog) arr := []StackItem{ @@ -367,22 +372,34 @@ func testIterableCreate(t *testing.T, typ string) { vm.estack.Push(&Element{value: NewArrayItem(arr)}) runVM(t, vm) - checkEnumeratorStack(t, vm, []StackItem{ - arr[1], NewBoolItem(true), - arr[0], NewBoolItem(true), - }) + if isIter { + checkEnumeratorStack(t, vm, []StackItem{ + makeStackItem(1), arr[1], NewBoolItem(true), + makeStackItem(0), arr[0], NewBoolItem(true), + }) + } else { + checkEnumeratorStack(t, vm, []StackItem{ + arr[1], NewBoolItem(true), + arr[0], NewBoolItem(true), + }) + } } func TestEnumeratorCreate(t *testing.T) { testIterableCreate(t, "Enumerator") } -func TestEnumeratorConcat(t *testing.T) { - prog := getSyscallProg("Neo.Enumerator.Create") +func TestIteratorCreate(t *testing.T) { + testIterableCreate(t, "Iterator") +} + +func testIterableConcat(t *testing.T, typ string) { + isIter := typ == "Iterator" + prog := getSyscallProg("Neo." + typ + ".Create") prog = append(prog, byte(opcode.SWAP)) - prog = append(prog, getSyscallProg("Neo.Enumerator.Create")...) - prog = append(prog, getSyscallProg("Neo.Enumerator.Concat")...) - prog = append(prog, getEnumeratorProg(3)...) + prog = append(prog, getSyscallProg("Neo."+typ+".Create")...) + prog = append(prog, getSyscallProg("Neo."+typ+".Concat")...) + prog = append(prog, getEnumeratorProg(3, isIter)...) vm := load(prog) arr := []StackItem{ @@ -394,11 +411,80 @@ func TestEnumeratorConcat(t *testing.T) { vm.estack.Push(&Element{value: NewArrayItem(arr[1:])}) runVM(t, vm) - checkEnumeratorStack(t, vm, []StackItem{ - arr[2], NewBoolItem(true), - arr[1], NewBoolItem(true), - arr[0], NewBoolItem(true), + + if isIter { + // Yes, this is how iterators are concatenated in reference VM + // https://github.com/neo-project/neo/blob/master-2.x/neo.UnitTests/UT_ConcatenatedIterator.cs#L54 + checkEnumeratorStack(t, vm, []StackItem{ + makeStackItem(1), arr[2], NewBoolItem(true), + makeStackItem(0), arr[1], NewBoolItem(true), + makeStackItem(0), arr[0], NewBoolItem(true), + }) + } else { + checkEnumeratorStack(t, vm, []StackItem{ + arr[2], NewBoolItem(true), + arr[1], NewBoolItem(true), + arr[0], NewBoolItem(true), + }) + } +} + +func TestEnumeratorConcat(t *testing.T) { + testIterableConcat(t, "Enumerator") +} + +func TestIteratorConcat(t *testing.T) { + testIterableConcat(t, "Iterator") +} + +func TestIteratorKeys(t *testing.T) { + prog := getSyscallProg("Neo.Iterator.Create") + prog = append(prog, getSyscallProg("Neo.Iterator.Keys")...) + prog = append(prog, byte(opcode.TOALTSTACK), byte(opcode.DUPFROMALTSTACK)) + prog = append(prog, getEnumeratorProg(2, false)...) + + v := load(prog) + arr := NewArrayItem([]StackItem{ + NewBoolItem(false), + NewBigIntegerItem(42), }) + v.estack.PushVal(arr) + + runVM(t, v) + + checkEnumeratorStack(t, v, []StackItem{ + NewBigIntegerItem(1), NewBoolItem(true), + NewBigIntegerItem(0), NewBoolItem(true), + }) +} + +func TestIteratorValues(t *testing.T) { + prog := getSyscallProg("Neo.Iterator.Create") + prog = append(prog, getSyscallProg("Neo.Iterator.Values")...) + prog = append(prog, byte(opcode.TOALTSTACK), byte(opcode.DUPFROMALTSTACK)) + prog = append(prog, getEnumeratorProg(2, false)...) + + v := load(prog) + m := NewMapItem() + m.Add(NewBigIntegerItem(1), NewBoolItem(false)) + m.Add(NewByteArrayItem([]byte{32}), NewByteArrayItem([]byte{7})) + v.estack.PushVal(m) + + runVM(t, v) + require.Equal(t, 5, v.estack.Len()) + require.Equal(t, NewBoolItem(false), v.estack.Peek(0).value) + + // Map values can be enumerated in any order. + i1, i2 := 1, 3 + if _, ok := v.estack.Peek(i1).value.(*BoolItem); !ok { + i1, i2 = i2, i1 + } + + require.Equal(t, NewBoolItem(false), v.estack.Peek(i1).value) + require.Equal(t, NewByteArrayItem([]byte{7}), v.estack.Peek(i2).value) + + require.Equal(t, NewBoolItem(true), v.estack.Peek(2).value) + require.Equal(t, NewBoolItem(true), v.estack.Peek(4).value) } func getSyscallProg(name string) (prog []byte) {