Merge pull request #3274 from fyfyrchik/partial-iterator
Allow to partially unwrap session iterator
This commit is contained in:
commit
0d3f749b5e
4 changed files with 271 additions and 21 deletions
|
@ -167,12 +167,9 @@ func SessionIterator(r *result.Invoke, err error) (uuid.UUID, result.Iterator, e
|
|||
if err != nil {
|
||||
return uuid.UUID{}, result.Iterator{}, err
|
||||
}
|
||||
if t := itm.Type(); t != stackitem.InteropT {
|
||||
return uuid.UUID{}, result.Iterator{}, fmt.Errorf("expected InteropInterface, got %s", t)
|
||||
}
|
||||
iter, ok := itm.Value().(result.Iterator)
|
||||
if !ok {
|
||||
return uuid.UUID{}, result.Iterator{}, errors.New("the item is InteropInterface, but not an Iterator")
|
||||
iter, err := itemToSessionIterator(itm)
|
||||
if err != nil {
|
||||
return uuid.UUID{}, result.Iterator{}, err
|
||||
}
|
||||
if (r.Session == uuid.UUID{}) && iter.ID != nil {
|
||||
return uuid.UUID{}, result.Iterator{}, ErrNoSessionID
|
||||
|
@ -180,6 +177,54 @@ func SessionIterator(r *result.Invoke, err error) (uuid.UUID, result.Iterator, e
|
|||
return r.Session, iter, nil
|
||||
}
|
||||
|
||||
// ArrayAndSessionIterator expects correct execution (HALT state) with one or two stack
|
||||
// items returned. If there is 1 item, it must be an array. If there is a second item,
|
||||
// it must be an iterator. This is exactly the result of smartcontract.CreateCallAndPrefetchIteratorScript.
|
||||
// Sessions must be enabled on the RPC server for this to function correctly.
|
||||
func ArrayAndSessionIterator(r *result.Invoke, err error) ([]stackitem.Item, uuid.UUID, result.Iterator, error) {
|
||||
if err := checkResOK(r, err); err != nil {
|
||||
return nil, uuid.UUID{}, result.Iterator{}, err
|
||||
}
|
||||
if len(r.Stack) == 0 {
|
||||
return nil, uuid.UUID{}, result.Iterator{}, errors.New("result stack is empty")
|
||||
}
|
||||
if len(r.Stack) != 1 && len(r.Stack) != 2 {
|
||||
return nil, uuid.UUID{}, result.Iterator{}, fmt.Errorf("expected 1 or 2 result items, got %d", len(r.Stack))
|
||||
}
|
||||
|
||||
// Unwrap array.
|
||||
itm := r.Stack[0]
|
||||
arr, ok := itm.Value().([]stackitem.Item)
|
||||
if !ok {
|
||||
return nil, uuid.UUID{}, result.Iterator{}, errors.New("not an array")
|
||||
}
|
||||
|
||||
// Check whether iterator exists and unwrap it.
|
||||
if len(r.Stack) == 1 {
|
||||
return arr, uuid.UUID{}, result.Iterator{}, nil
|
||||
}
|
||||
|
||||
iter, err := itemToSessionIterator(r.Stack[1])
|
||||
if err != nil {
|
||||
return nil, uuid.UUID{}, result.Iterator{}, err
|
||||
}
|
||||
if (r.Session == uuid.UUID{}) {
|
||||
return nil, uuid.UUID{}, result.Iterator{}, ErrNoSessionID
|
||||
}
|
||||
return arr, r.Session, iter, nil
|
||||
}
|
||||
|
||||
func itemToSessionIterator(itm stackitem.Item) (result.Iterator, error) {
|
||||
if t := itm.Type(); t != stackitem.InteropT {
|
||||
return result.Iterator{}, fmt.Errorf("expected InteropInterface, got %s", t)
|
||||
}
|
||||
iter, ok := itm.Value().(result.Iterator)
|
||||
if !ok {
|
||||
return result.Iterator{}, errors.New("the item is InteropInterface, but not an Iterator")
|
||||
}
|
||||
return iter, nil
|
||||
}
|
||||
|
||||
// Array expects correct execution (HALT state) with a single array stack item
|
||||
// returned. This item is returned to the caller. Notice that this function can
|
||||
// be used for structures as well since they're also represented as slices of
|
||||
|
|
|
@ -50,6 +50,10 @@ func TestStdErrors(t *testing.T) {
|
|||
_, _, err = SessionIterator(r, err)
|
||||
return nil, err
|
||||
},
|
||||
func(r *result.Invoke, err error) (any, error) {
|
||||
_, _, _, err = ArrayAndSessionIterator(r, err)
|
||||
return nil, err
|
||||
},
|
||||
func(r *result.Invoke, err error) (any, error) {
|
||||
return Array(r, err)
|
||||
},
|
||||
|
@ -97,6 +101,12 @@ func TestStdErrors(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
t.Run("HALT state with empty stack", func(t *testing.T) {
|
||||
for _, f := range funcs {
|
||||
_, err := f(&result.Invoke{State: "HALT"}, nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
t.Run("multiple return values", func(t *testing.T) {
|
||||
for _, f := range funcs {
|
||||
_, err := f(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42), stackitem.Make(42)}}, nil)
|
||||
|
@ -256,6 +266,39 @@ func TestSessionIterator(t *testing.T) {
|
|||
require.Equal(t, iter, ri)
|
||||
}
|
||||
|
||||
func TestArraySessionIterator(t *testing.T) {
|
||||
_, _, _, err := ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil)
|
||||
require.Error(t, err)
|
||||
|
||||
_, _, _, err = ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.NewInterop(42)}}, nil)
|
||||
require.Error(t, err)
|
||||
|
||||
arr := stackitem.NewArray([]stackitem.Item{stackitem.Make(42)})
|
||||
ra, rs, ri, err := ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{arr}}, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, arr.Value(), ra)
|
||||
require.Empty(t, rs)
|
||||
require.Empty(t, ri)
|
||||
|
||||
_, _, _, err = ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{arr, stackitem.NewInterop(42)}}, nil)
|
||||
require.Error(t, err)
|
||||
|
||||
iid := uuid.New()
|
||||
iter := result.Iterator{ID: &iid}
|
||||
_, _, _, err = ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{arr, stackitem.NewInterop(iter)}}, nil)
|
||||
require.ErrorIs(t, err, ErrNoSessionID)
|
||||
|
||||
sid := uuid.New()
|
||||
_, rs, ri, err = ArrayAndSessionIterator(&result.Invoke{Session: sid, State: "HALT", Stack: []stackitem.Item{arr, stackitem.NewInterop(iter)}}, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, arr.Value(), ra)
|
||||
require.Equal(t, sid, rs)
|
||||
require.Equal(t, iter, ri)
|
||||
|
||||
_, _, _, err = ArrayAndSessionIterator(&result.Invoke{Session: sid, State: "HALT", Stack: []stackitem.Item{arr, stackitem.NewInterop(iter), stackitem.Make(42)}}, nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestArray(t *testing.T) {
|
||||
_, err := Array(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil)
|
||||
require.Error(t, err)
|
||||
|
|
|
@ -20,6 +20,61 @@ import (
|
|||
// for interactions with RPC server that have iterator sessions disabled.
|
||||
func CreateCallAndUnwrapIteratorScript(contract util.Uint160, operation string, maxIteratorResultItems int, params ...any) ([]byte, error) {
|
||||
script := io.NewBufBinWriter()
|
||||
jmpIfNotOffset, jmpIfMaxReachedOffset := emitCallAndUnwrapIteratorScript(script, contract, operation, maxIteratorResultItems, params...)
|
||||
|
||||
// End of the program: push the result on stack and return.
|
||||
loadResultOffset := script.Len()
|
||||
emit.Opcodes(script.BinWriter, opcode.NIP, // Remove iterator from the 1-st cell of estack
|
||||
opcode.NIP) // Remove maxIteratorResultItems from the 1-st cell of estack, so that only resulting array is left on estack.
|
||||
if err := script.Err; err != nil {
|
||||
return nil, fmt.Errorf("emitting iterator unwrapper script: %w", err)
|
||||
}
|
||||
|
||||
// Fill in JMPIFNOT instruction parameter.
|
||||
bytes := script.Bytes()
|
||||
bytes[jmpIfNotOffset+1] = uint8(loadResultOffset - jmpIfNotOffset) // +1 is for JMPIFNOT itself; offset is relative to JMPIFNOT position.
|
||||
// Fill in jmpIfMaxReachedOffset instruction parameter.
|
||||
bytes[jmpIfMaxReachedOffset+1] = uint8(loadResultOffset - jmpIfMaxReachedOffset) // +1 is for JMPIF itself; offset is relative to JMPIF position.
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
// CreateCallAndPrefetchIteratorScript creates a script that calls 'operation' method
|
||||
// of the 'contract' with the specified arguments. This method is expected to return
|
||||
// an array of the first iterator items (up to maxIteratorResultItems, which cannot exceed VM limits)
|
||||
// and, optionally, an iterator that then is traversed (using iterator.Next).
|
||||
// The result of the script is an array containing extracted value elements and an iterator, if it can contain more items.
|
||||
// If the iterator is present, it lies on top of the stack.
|
||||
// Note, however, that if an iterator is returned, the number of remaining items can still be 0.
|
||||
// This script should only be used for interactions with RPC server that have iterator sessions enabled.
|
||||
func CreateCallAndPrefetchIteratorScript(contract util.Uint160, operation string, maxIteratorResultItems int, params ...any) ([]byte, error) {
|
||||
script := io.NewBufBinWriter()
|
||||
jmpIfNotOffset, jmpIfMaxReachedOffset := emitCallAndUnwrapIteratorScript(script, contract, operation, maxIteratorResultItems, params...)
|
||||
|
||||
// 1st possibility: jump here when the maximum number of items was reached.
|
||||
retainIteratorOffset := script.Len()
|
||||
emit.Opcodes(script.BinWriter, opcode.ROT, // Put maxIteratorResultItems from the 2-nd cell of estack, to the top
|
||||
opcode.DROP, // ... and then drop it.
|
||||
opcode.SWAP, // Put the iterator on top of the stack.
|
||||
opcode.RET)
|
||||
|
||||
// 2nd possibility: jump here when the iterator has no more items.
|
||||
loadResultOffset := script.Len()
|
||||
emit.Opcodes(script.BinWriter, opcode.ROT, // Put maxIteratorResultItems from the 2-nd cell of estack, to the top
|
||||
opcode.DROP, // ... and then drop it.
|
||||
opcode.NIP) // Drop iterator as the 1-st cell on the stack.
|
||||
if err := script.Err; err != nil {
|
||||
return nil, fmt.Errorf("emitting iterator unwrapper script: %w", err)
|
||||
}
|
||||
|
||||
// Fill in JMPIFNOT instruction parameter.
|
||||
bytes := script.Bytes()
|
||||
bytes[jmpIfNotOffset+1] = uint8(loadResultOffset - jmpIfNotOffset) // +1 is for JMPIFNOT itself; offset is relative to JMPIFNOT position.
|
||||
// Fill in jmpIfMaxReachedOffset instruction parameter.
|
||||
bytes[jmpIfMaxReachedOffset+1] = uint8(retainIteratorOffset - jmpIfMaxReachedOffset) // +1 is for JMPIF itself; offset is relative to JMPIF position.
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
func emitCallAndUnwrapIteratorScript(script *io.BufBinWriter, contract util.Uint160, operation string, maxIteratorResultItems int, params ...any) (int, int) {
|
||||
emit.Int(script.BinWriter, int64(maxIteratorResultItems))
|
||||
emit.AppCall(script.BinWriter, contract, operation, callflag.All, params...) // The System.Contract.Call itself, it will push Iterator on estack.
|
||||
emit.Opcodes(script.BinWriter, opcode.NEWARRAY0) // Push new empty array to estack. This array will store iterator's elements.
|
||||
|
@ -51,21 +106,7 @@ func CreateCallAndUnwrapIteratorScript(contract util.Uint160, operation string,
|
|||
[]byte{
|
||||
uint8(iteratorTraverseCycleStartOffset - jmpOffset), // jump to iteratorTraverseCycleStartOffset; offset is relative to JMP position.
|
||||
})
|
||||
|
||||
// End of the program: push the result on stack and return.
|
||||
loadResultOffset := script.Len()
|
||||
emit.Opcodes(script.BinWriter, opcode.NIP, // Remove iterator from the 1-st cell of estack
|
||||
opcode.NIP) // Remove maxIteratorResultItems from the 1-st cell of estack, so that only resulting array is left on estack.
|
||||
if err := script.Err; err != nil {
|
||||
return nil, fmt.Errorf("emitting iterator unwrapper script: %w", err)
|
||||
}
|
||||
|
||||
// Fill in JMPIFNOT instruction parameter.
|
||||
bytes := script.Bytes()
|
||||
bytes[jmpIfNotOffset+1] = uint8(loadResultOffset - jmpIfNotOffset) // +1 is for JMPIFNOT itself; offset is relative to JMPIFNOT position.
|
||||
// Fill in jmpIfMaxReachedOffset instruction parameter.
|
||||
bytes[jmpIfMaxReachedOffset+1] = uint8(loadResultOffset - jmpIfMaxReachedOffset) // +1 is for JMPIF itself; offset is relative to JMPIF position.
|
||||
return bytes, nil
|
||||
return jmpIfNotOffset, jmpIfMaxReachedOffset
|
||||
}
|
||||
|
||||
// CreateCallScript returns a script that calls contract's method with
|
||||
|
|
121
pkg/vm/iterator_test.go
Normal file
121
pkg/vm/iterator_test.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package vm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/internal/random"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type arrayIterator struct {
|
||||
index int
|
||||
values []stackitem.Item
|
||||
}
|
||||
|
||||
func TestCreateCallAndUnwrapIteratorScript(t *testing.T) {
|
||||
ctrHash := random.Uint160()
|
||||
ctrMethod := "mymethod"
|
||||
param := stackitem.NewBigInteger(big.NewInt(42))
|
||||
|
||||
const totalItems = 8
|
||||
values := make([]stackitem.Item, totalItems)
|
||||
for i := range values {
|
||||
values[i] = stackitem.NewBigInteger(big.NewInt(int64(i)))
|
||||
}
|
||||
|
||||
checkStack := func(t *testing.T, script []byte, index int, prefetch bool) {
|
||||
v := load(script)
|
||||
it := &arrayIterator{index: -1, values: values}
|
||||
v.SyscallHandler = func(v *VM, id uint32) error {
|
||||
switch id {
|
||||
case interopnames.ToID([]byte(interopnames.SystemContractCall)):
|
||||
require.Equal(t, ctrHash.BytesBE(), v.Estack().Pop().Value())
|
||||
require.Equal(t, []byte(ctrMethod), v.Estack().Pop().Value())
|
||||
require.Equal(t, big.NewInt(int64(callflag.All)), v.Estack().Pop().Value())
|
||||
require.Equal(t, []stackitem.Item{param}, v.Estack().Pop().Value())
|
||||
v.Estack().PushItem(stackitem.NewInterop(it))
|
||||
case interopnames.ToID([]byte(interopnames.SystemIteratorNext)):
|
||||
require.Equal(t, it, v.Estack().Pop().Value())
|
||||
it.index++
|
||||
v.Estack().PushVal(it.index < len(it.values))
|
||||
case interopnames.ToID([]byte(interopnames.SystemIteratorValue)):
|
||||
require.Equal(t, it, v.Estack().Pop().Value())
|
||||
v.Estack().PushVal(it.values[it.index])
|
||||
default:
|
||||
return fmt.Errorf("unexpected syscall: %d", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
require.NoError(t, v.Run())
|
||||
|
||||
if prefetch && index <= len(values) {
|
||||
require.Equal(t, 2, v.Estack().Len())
|
||||
|
||||
it, ok := v.Estack().Pop().Interop().Value().(*arrayIterator)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, index-1, it.index)
|
||||
require.Equal(t, values[:index], v.Estack().Pop().Array())
|
||||
return
|
||||
}
|
||||
if len(values) < index {
|
||||
index = len(values)
|
||||
}
|
||||
require.Equal(t, 1, v.Estack().Len())
|
||||
require.Equal(t, values[:index], v.Estack().Pop().Array())
|
||||
}
|
||||
|
||||
t.Run("truncate", func(t *testing.T) {
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
const index = 0
|
||||
script, err := smartcontract.CreateCallAndUnwrapIteratorScript(ctrHash, ctrMethod, index, param)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The behaviour is a bit unexpected, but not a problem (why would anyone fetch 0 items).
|
||||
// Let's have test, to make it obvious.
|
||||
checkStack(t, script, index+1, false)
|
||||
})
|
||||
t.Run("all", func(t *testing.T) {
|
||||
const index = totalItems + 1
|
||||
script, err := smartcontract.CreateCallAndUnwrapIteratorScript(ctrHash, ctrMethod, index, param)
|
||||
require.NoError(t, err)
|
||||
|
||||
checkStack(t, script, index, false)
|
||||
})
|
||||
t.Run("partial", func(t *testing.T) {
|
||||
const index = totalItems / 2
|
||||
script, err := smartcontract.CreateCallAndUnwrapIteratorScript(ctrHash, ctrMethod, index, param)
|
||||
require.NoError(t, err)
|
||||
|
||||
checkStack(t, script, index, false)
|
||||
})
|
||||
})
|
||||
t.Run("prefetch", func(t *testing.T) {
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
const index = 0
|
||||
script, err := smartcontract.CreateCallAndPrefetchIteratorScript(ctrHash, ctrMethod, index, param)
|
||||
require.NoError(t, err)
|
||||
|
||||
checkStack(t, script, index+1, true)
|
||||
})
|
||||
t.Run("all", func(t *testing.T) {
|
||||
const index = totalItems + 1 // +1 to test with iterator dropped
|
||||
script, err := smartcontract.CreateCallAndPrefetchIteratorScript(ctrHash, ctrMethod, index, param)
|
||||
require.NoError(t, err)
|
||||
|
||||
checkStack(t, script, index, true)
|
||||
})
|
||||
t.Run("partial", func(t *testing.T) {
|
||||
const index = totalItems / 2
|
||||
script, err := smartcontract.CreateCallAndPrefetchIteratorScript(ctrHash, ctrMethod, index, param)
|
||||
require.NoError(t, err)
|
||||
|
||||
checkStack(t, script, index, true)
|
||||
})
|
||||
})
|
||||
}
|
Loading…
Reference in a new issue