vm: implement Neo.Iterator.* interops

This commit is contained in:
Evgenii Stratonikov 2019-11-13 15:29:27 +03:00
parent 3ff7fd5262
commit 5bc32b523a
5 changed files with 300 additions and 22 deletions

View file

@ -808,3 +808,28 @@ func (ic *interopContext) enumeratorNext(v *vm.VM) error {
func (ic *interopContext) enumeratorValue(v *vm.VM) error { func (ic *interopContext) enumeratorValue(v *vm.VM) error {
return vm.EnumeratorValue(v) return vm.EnumeratorValue(v)
} }
// iteratorConcat concatenates 2 iterators into a single one.
func (ic *interopContext) iteratorConcat(v *vm.VM) error {
return vm.IteratorConcat(v)
}
// iteratorCreate creates an iterator from array-like or map stack item.
func (ic *interopContext) iteratorCreate(v *vm.VM) error {
return vm.IteratorCreate(v)
}
// iteratorKey returns current iterator key.
func (ic *interopContext) iteratorKey(v *vm.VM) error {
return vm.IteratorKey(v)
}
// iteratorKeys returns keys of the iterator.
func (ic *interopContext) iteratorKeys(v *vm.VM) error {
return vm.IteratorKeys(v)
}
// iteratorValues returns values of the iterator.
func (ic *interopContext) iteratorValues(v *vm.VM) error {
return vm.IteratorValues(v)
}

View file

@ -157,6 +157,11 @@ var neoInterops = []interopedFunction{
{Name: "Neo.Input.GetHash", Func: (*interopContext).inputGetHash, Price: 1}, {Name: "Neo.Input.GetHash", Func: (*interopContext).inputGetHash, Price: 1},
{Name: "Neo.Input.GetIndex", Func: (*interopContext).inputGetIndex, Price: 1}, {Name: "Neo.Input.GetIndex", Func: (*interopContext).inputGetIndex, Price: 1},
{Name: "Neo.InvocationTransaction.GetScript", Func: (*interopContext).invocationTxGetScript, Price: 1}, {Name: "Neo.InvocationTransaction.GetScript", Func: (*interopContext).invocationTxGetScript, Price: 1},
{Name: "Neo.Iterator.Concat", Func: (*interopContext).iteratorConcat, Price: 1},
{Name: "Neo.Iterator.Create", Func: (*interopContext).iteratorCreate, Price: 1},
{Name: "Neo.Iterator.Key", Func: (*interopContext).iteratorKey, Price: 1},
{Name: "Neo.Iterator.Keys", Func: (*interopContext).iteratorKeys, Price: 1},
{Name: "Neo.Iterator.Values", Func: (*interopContext).iteratorValues, Price: 1},
{Name: "Neo.Output.GetAssetId", Func: (*interopContext).outputGetAssetID, Price: 1}, {Name: "Neo.Output.GetAssetId", Func: (*interopContext).outputGetAssetID, Price: 1},
{Name: "Neo.Output.GetScriptHash", Func: (*interopContext).outputGetScriptHash, Price: 1}, {Name: "Neo.Output.GetScriptHash", Func: (*interopContext).outputGetScriptHash, Price: 1},
{Name: "Neo.Output.GetValue", Func: (*interopContext).outputGetValue, Price: 1}, {Name: "Neo.Output.GetValue", Func: (*interopContext).outputGetValue, Price: 1},
@ -182,16 +187,11 @@ var neoInterops = []interopedFunction{
{Name: "Neo.Transaction.GetUnspentCoins", Func: (*interopContext).txGetUnspentCoins, Price: 200}, {Name: "Neo.Transaction.GetUnspentCoins", Func: (*interopContext).txGetUnspentCoins, Price: 200},
{Name: "Neo.Transaction.GetWitnesses", Func: (*interopContext).txGetWitnesses, Price: 200}, {Name: "Neo.Transaction.GetWitnesses", Func: (*interopContext).txGetWitnesses, Price: 200},
{Name: "Neo.Witness.GetVerificationScript", Func: (*interopContext).witnessGetVerificationScript, Price: 100}, {Name: "Neo.Witness.GetVerificationScript", Func: (*interopContext).witnessGetVerificationScript, Price: 100},
// {Name: "Neo.Iterator.Concat", Func: (*interopContext).iteratorConcat, Price: 1},
// {Name: "Neo.Iterator.Create", Func: (*interopContext).iteratorCreate, Price: 1},
// {Name: "Neo.Iterator.Key", Func: (*interopContext).iteratorKey, Price: 1},
// {Name: "Neo.Iterator.Keys", Func: (*interopContext).iteratorKeys, Price: 1},
// {Name: "Neo.Iterator.Values", Func: (*interopContext).iteratorValues, Price: 1},
// {Name: "Neo.Storage.Find", Func: (*interopContext).storageFind, Price: 1}, // {Name: "Neo.Storage.Find", Func: (*interopContext).storageFind, Price: 1},
// Aliases. // Aliases.
// {Name: "Neo.Iterator.Next", Func: (*interopContext).enumeratorNext, Price: 1}, {Name: "Neo.Iterator.Next", Func: (*interopContext).enumeratorNext, Price: 1},
// {Name: "Neo.Iterator.Value", Func: (*interopContext).enumeratorValue, Price: 1}, {Name: "Neo.Iterator.Value", Func: (*interopContext).enumeratorValue, Price: 1},
// Old compatibility APIs. // Old compatibility APIs.
{Name: "AntShares.Account.GetBalance", Func: (*interopContext).accountGetBalance, Price: 1}, {Name: "AntShares.Account.GetBalance", Func: (*interopContext).accountGetBalance, Price: 1},

View file

@ -48,6 +48,16 @@ var defaultVMInterops = []interopIDFuncPrice{
InteropFuncPrice{EnumeratorConcat, 1}}, InteropFuncPrice{EnumeratorConcat, 1}},
{InteropNameToID([]byte("Neo.Enumerator.Value")), {InteropNameToID([]byte("Neo.Enumerator.Value")),
InteropFuncPrice{EnumeratorValue, 1}}, InteropFuncPrice{EnumeratorValue, 1}},
{InteropNameToID([]byte("Neo.Iterator.Create")),
InteropFuncPrice{IteratorCreate, 1}},
{InteropNameToID([]byte("Neo.Iterator.Concat")),
InteropFuncPrice{IteratorConcat, 1}},
{InteropNameToID([]byte("Neo.Iterator.Key")),
InteropFuncPrice{IteratorKey, 1}},
{InteropNameToID([]byte("Neo.Iterator.Keys")),
InteropFuncPrice{IteratorKeys, 1}},
{InteropNameToID([]byte("Neo.Iterator.Values")),
InteropFuncPrice{IteratorValues, 1}},
} }
func getDefaultVMInterop(id uint32) *InteropFuncPrice { func getDefaultVMInterop(id uint32) *InteropFuncPrice {
@ -163,3 +173,80 @@ func EnumeratorConcat(v *VM) error {
return nil return nil
} }
// IteratorCreate handles syscall Neo.Iterator.Create.
func IteratorCreate(v *VM) error {
data := v.Estack().Pop()
var item interface{}
switch t := data.value.(type) {
case *ArrayItem, *StructItem:
item = &arrayWrapper{
index: -1,
value: t.Value().([]StackItem),
}
case *MapItem:
keys := make([]interface{}, 0, len(t.value))
for k := range t.value {
keys = append(keys, k)
}
item = &mapWrapper{
index: -1,
keys: keys,
m: t.value,
}
default:
return errors.New("non-iterable type")
}
v.Estack().Push(&Element{value: NewInteropItem(item)})
return nil
}
// IteratorConcat handles syscall Neo.Iterator.Concat.
func IteratorConcat(v *VM) error {
iop1 := v.Estack().Pop().Interop()
iter1 := iop1.value.(iterator)
iop2 := v.Estack().Pop().Interop()
iter2 := iop2.value.(iterator)
v.Estack().Push(&Element{value: NewInteropItem(
&concatIter{
current: iter1,
second: iter2,
},
)})
return nil
}
// IteratorKey handles syscall Neo.Iterator.Key.
func IteratorKey(v *VM) error {
iop := v.estack.Pop().Interop()
iter := iop.value.(iterator)
v.Estack().Push(&Element{value: iter.Key()})
return nil
}
// IteratorKeys handles syscall Neo.Iterator.Keys.
func IteratorKeys(v *VM) error {
iop := v.estack.Pop().Interop()
iter := iop.value.(iterator)
v.Estack().Push(&Element{value: NewInteropItem(
&keysWrapper{iter},
)})
return nil
}
// IteratorValues handles syscall Neo.Iterator.Values.
func IteratorValues(v *VM) error {
iop := v.estack.Pop().Interop()
iter := iop.value.(iterator)
v.Estack().Push(&Element{value: NewInteropItem(
&valuesWrapper{iter},
)})
return nil
}

View file

@ -17,6 +17,32 @@ type (
} }
) )
type (
iterator interface {
enumerator
Key() StackItem
}
mapWrapper struct {
index int
keys []interface{}
m map[interface{}]StackItem
}
concatIter struct {
current iterator
second iterator
}
keysWrapper struct {
iter iterator
}
valuesWrapper struct {
iter iterator
}
)
func (a *arrayWrapper) Next() bool { func (a *arrayWrapper) Next() bool {
if next := a.index + 1; next < len(a.value) { if next := a.index + 1; next < len(a.value) {
a.index = next a.index = next
@ -30,6 +56,10 @@ func (a *arrayWrapper) Value() StackItem {
return a.value[a.index] return a.value[a.index]
} }
func (a *arrayWrapper) Key() StackItem {
return makeStackItem(a.index)
}
func (c *concatEnum) Next() bool { func (c *concatEnum) Next() bool {
if c.current.Next() { if c.current.Next() {
return true return true
@ -42,3 +72,53 @@ func (c *concatEnum) Next() bool {
func (c *concatEnum) Value() StackItem { func (c *concatEnum) Value() StackItem {
return c.current.Value() return c.current.Value()
} }
func (i *concatIter) Next() bool {
if i.current.Next() {
return true
}
i.current = i.second
return i.second.Next()
}
func (i *concatIter) Value() StackItem {
return i.current.Value()
}
func (i *concatIter) Key() StackItem {
return i.current.Key()
}
func (m *mapWrapper) Next() bool {
if next := m.index + 1; next < len(m.keys) {
m.index = next
return true
}
return false
}
func (m *mapWrapper) Value() StackItem {
return m.m[m.keys[m.index]]
}
func (m *mapWrapper) Key() StackItem {
return makeStackItem(m.keys[m.index])
}
func (e *keysWrapper) Next() bool {
return e.iter.Next()
}
func (e *keysWrapper) Value() StackItem {
return e.iter.Key()
}
func (e *valuesWrapper) Next() bool {
return e.iter.Next()
}
func (e *valuesWrapper) Value() StackItem {
return e.iter.Value()
}

View file

@ -333,13 +333,17 @@ func TestPushData4Good(t *testing.T) {
assert.Equal(t, []byte{1, 2, 3}, vm.estack.Pop().Bytes()) assert.Equal(t, []byte{1, 2, 3}, vm.estack.Pop().Bytes())
} }
func getEnumeratorProg(n int) (prog []byte) { func getEnumeratorProg(n int, isIter bool) (prog []byte) {
prog = append(prog, byte(opcode.TOALTSTACK)) prog = append(prog, byte(opcode.TOALTSTACK))
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
prog = append(prog, byte(opcode.DUPFROMALTSTACK)) prog = append(prog, byte(opcode.DUPFROMALTSTACK))
prog = append(prog, getSyscallProg("Neo.Enumerator.Next")...) prog = append(prog, getSyscallProg("Neo.Enumerator.Next")...)
prog = append(prog, byte(opcode.DUPFROMALTSTACK)) prog = append(prog, byte(opcode.DUPFROMALTSTACK))
prog = append(prog, getSyscallProg("Neo.Enumerator.Value")...) prog = append(prog, getSyscallProg("Neo.Enumerator.Value")...)
if isIter {
prog = append(prog, byte(opcode.DUPFROMALTSTACK))
prog = append(prog, getSyscallProg("Neo.Iterator.Key")...)
}
} }
prog = append(prog, byte(opcode.DUPFROMALTSTACK)) prog = append(prog, byte(opcode.DUPFROMALTSTACK))
prog = append(prog, getSyscallProg("Neo.Enumerator.Next")...) prog = append(prog, getSyscallProg("Neo.Enumerator.Next")...)
@ -356,8 +360,9 @@ func checkEnumeratorStack(t *testing.T, vm *VM, arr []StackItem) {
} }
func testIterableCreate(t *testing.T, typ string) { func testIterableCreate(t *testing.T, typ string) {
isIter := typ == "Iterator"
prog := getSyscallProg("Neo." + typ + ".Create") prog := getSyscallProg("Neo." + typ + ".Create")
prog = append(prog, getEnumeratorProg(2)...) prog = append(prog, getEnumeratorProg(2, isIter)...)
vm := load(prog) vm := load(prog)
arr := []StackItem{ arr := []StackItem{
@ -367,22 +372,34 @@ func testIterableCreate(t *testing.T, typ string) {
vm.estack.Push(&Element{value: NewArrayItem(arr)}) vm.estack.Push(&Element{value: NewArrayItem(arr)})
runVM(t, vm) runVM(t, vm)
if isIter {
checkEnumeratorStack(t, vm, []StackItem{
makeStackItem(1), arr[1], NewBoolItem(true),
makeStackItem(0), arr[0], NewBoolItem(true),
})
} else {
checkEnumeratorStack(t, vm, []StackItem{ checkEnumeratorStack(t, vm, []StackItem{
arr[1], NewBoolItem(true), arr[1], NewBoolItem(true),
arr[0], NewBoolItem(true), arr[0], NewBoolItem(true),
}) })
}
} }
func TestEnumeratorCreate(t *testing.T) { func TestEnumeratorCreate(t *testing.T) {
testIterableCreate(t, "Enumerator") testIterableCreate(t, "Enumerator")
} }
func TestEnumeratorConcat(t *testing.T) { func TestIteratorCreate(t *testing.T) {
prog := getSyscallProg("Neo.Enumerator.Create") testIterableCreate(t, "Iterator")
}
func testIterableConcat(t *testing.T, typ string) {
isIter := typ == "Iterator"
prog := getSyscallProg("Neo." + typ + ".Create")
prog = append(prog, byte(opcode.SWAP)) prog = append(prog, byte(opcode.SWAP))
prog = append(prog, getSyscallProg("Neo.Enumerator.Create")...) prog = append(prog, getSyscallProg("Neo."+typ+".Create")...)
prog = append(prog, getSyscallProg("Neo.Enumerator.Concat")...) prog = append(prog, getSyscallProg("Neo."+typ+".Concat")...)
prog = append(prog, getEnumeratorProg(3)...) prog = append(prog, getEnumeratorProg(3, isIter)...)
vm := load(prog) vm := load(prog)
arr := []StackItem{ arr := []StackItem{
@ -394,11 +411,80 @@ func TestEnumeratorConcat(t *testing.T) {
vm.estack.Push(&Element{value: NewArrayItem(arr[1:])}) vm.estack.Push(&Element{value: NewArrayItem(arr[1:])})
runVM(t, vm) runVM(t, vm)
if isIter {
// Yes, this is how iterators are concatenated in reference VM
// https://github.com/neo-project/neo/blob/master-2.x/neo.UnitTests/UT_ConcatenatedIterator.cs#L54
checkEnumeratorStack(t, vm, []StackItem{
makeStackItem(1), arr[2], NewBoolItem(true),
makeStackItem(0), arr[1], NewBoolItem(true),
makeStackItem(0), arr[0], NewBoolItem(true),
})
} else {
checkEnumeratorStack(t, vm, []StackItem{ checkEnumeratorStack(t, vm, []StackItem{
arr[2], NewBoolItem(true), arr[2], NewBoolItem(true),
arr[1], NewBoolItem(true), arr[1], NewBoolItem(true),
arr[0], NewBoolItem(true), arr[0], NewBoolItem(true),
}) })
}
}
func TestEnumeratorConcat(t *testing.T) {
testIterableConcat(t, "Enumerator")
}
func TestIteratorConcat(t *testing.T) {
testIterableConcat(t, "Iterator")
}
func TestIteratorKeys(t *testing.T) {
prog := getSyscallProg("Neo.Iterator.Create")
prog = append(prog, getSyscallProg("Neo.Iterator.Keys")...)
prog = append(prog, byte(opcode.TOALTSTACK), byte(opcode.DUPFROMALTSTACK))
prog = append(prog, getEnumeratorProg(2, false)...)
v := load(prog)
arr := NewArrayItem([]StackItem{
NewBoolItem(false),
NewBigIntegerItem(42),
})
v.estack.PushVal(arr)
runVM(t, v)
checkEnumeratorStack(t, v, []StackItem{
NewBigIntegerItem(1), NewBoolItem(true),
NewBigIntegerItem(0), NewBoolItem(true),
})
}
func TestIteratorValues(t *testing.T) {
prog := getSyscallProg("Neo.Iterator.Create")
prog = append(prog, getSyscallProg("Neo.Iterator.Values")...)
prog = append(prog, byte(opcode.TOALTSTACK), byte(opcode.DUPFROMALTSTACK))
prog = append(prog, getEnumeratorProg(2, false)...)
v := load(prog)
m := NewMapItem()
m.Add(NewBigIntegerItem(1), NewBoolItem(false))
m.Add(NewByteArrayItem([]byte{32}), NewByteArrayItem([]byte{7}))
v.estack.PushVal(m)
runVM(t, v)
require.Equal(t, 5, v.estack.Len())
require.Equal(t, NewBoolItem(false), v.estack.Peek(0).value)
// Map values can be enumerated in any order.
i1, i2 := 1, 3
if _, ok := v.estack.Peek(i1).value.(*BoolItem); !ok {
i1, i2 = i2, i1
}
require.Equal(t, NewBoolItem(false), v.estack.Peek(i1).value)
require.Equal(t, NewByteArrayItem([]byte{7}), v.estack.Peek(i2).value)
require.Equal(t, NewBoolItem(true), v.estack.Peek(2).value)
require.Equal(t, NewBoolItem(true), v.estack.Peek(4).value)
} }
func getSyscallProg(name string) (prog []byte) { func getSyscallProg(name string) (prog []byte) {