diff --git a/pkg/rpcclient/unwrap/unwrap.go b/pkg/rpcclient/unwrap/unwrap.go index 9ca61b571..7276adc39 100644 --- a/pkg/rpcclient/unwrap/unwrap.go +++ b/pkg/rpcclient/unwrap/unwrap.go @@ -167,12 +167,9 @@ func SessionIterator(r *result.Invoke, err error) (uuid.UUID, result.Iterator, e if err != nil { return uuid.UUID{}, result.Iterator{}, err } - if t := itm.Type(); t != stackitem.InteropT { - return uuid.UUID{}, result.Iterator{}, fmt.Errorf("expected InteropInterface, got %s", t) - } - iter, ok := itm.Value().(result.Iterator) - if !ok { - return uuid.UUID{}, result.Iterator{}, errors.New("the item is InteropInterface, but not an Iterator") + iter, err := itemToSessionIterator(itm) + if err != nil { + return uuid.UUID{}, result.Iterator{}, err } if (r.Session == uuid.UUID{}) && iter.ID != nil { return uuid.UUID{}, result.Iterator{}, ErrNoSessionID @@ -180,6 +177,54 @@ func SessionIterator(r *result.Invoke, err error) (uuid.UUID, result.Iterator, e return r.Session, iter, nil } +// ArrayAndSessionIterator expects correct execution (HALT state) with one or two stack +// items returned. If there is 1 item, it must be an array. If there is a second item, +// it must be an iterator. This is exactly the result of smartcontract.CreateCallAndPrefetchIteratorScript. +// Sessions must be enabled on the RPC server for this to function correctly. +func ArrayAndSessionIterator(r *result.Invoke, err error) ([]stackitem.Item, uuid.UUID, result.Iterator, error) { + if err := checkResOK(r, err); err != nil { + return nil, uuid.UUID{}, result.Iterator{}, err + } + if len(r.Stack) == 0 { + return nil, uuid.UUID{}, result.Iterator{}, errors.New("result stack is empty") + } + if len(r.Stack) != 1 && len(r.Stack) != 2 { + return nil, uuid.UUID{}, result.Iterator{}, fmt.Errorf("expected 1 or 2 result items, got %d", len(r.Stack)) + } + + // Unwrap array. + itm := r.Stack[0] + arr, ok := itm.Value().([]stackitem.Item) + if !ok { + return nil, uuid.UUID{}, result.Iterator{}, errors.New("not an array") + } + + // Check whether iterator exists and unwrap it. + if len(r.Stack) == 1 { + return arr, uuid.UUID{}, result.Iterator{}, nil + } + + iter, err := itemToSessionIterator(r.Stack[1]) + if err != nil { + return nil, uuid.UUID{}, result.Iterator{}, err + } + if (r.Session == uuid.UUID{}) { + return nil, uuid.UUID{}, result.Iterator{}, ErrNoSessionID + } + return arr, r.Session, iter, nil +} + +func itemToSessionIterator(itm stackitem.Item) (result.Iterator, error) { + if t := itm.Type(); t != stackitem.InteropT { + return result.Iterator{}, fmt.Errorf("expected InteropInterface, got %s", t) + } + iter, ok := itm.Value().(result.Iterator) + if !ok { + return result.Iterator{}, errors.New("the item is InteropInterface, but not an Iterator") + } + return iter, nil +} + // Array expects correct execution (HALT state) with a single array stack item // returned. This item is returned to the caller. Notice that this function can // be used for structures as well since they're also represented as slices of diff --git a/pkg/rpcclient/unwrap/unwrap_test.go b/pkg/rpcclient/unwrap/unwrap_test.go index cab22c23b..2014d4db1 100644 --- a/pkg/rpcclient/unwrap/unwrap_test.go +++ b/pkg/rpcclient/unwrap/unwrap_test.go @@ -50,6 +50,10 @@ func TestStdErrors(t *testing.T) { _, _, err = SessionIterator(r, err) return nil, err }, + func(r *result.Invoke, err error) (any, error) { + _, _, _, err = ArrayAndSessionIterator(r, err) + return nil, err + }, func(r *result.Invoke, err error) (any, error) { return Array(r, err) }, @@ -97,6 +101,12 @@ func TestStdErrors(t *testing.T) { require.Error(t, err) } }) + t.Run("HALT state with empty stack", func(t *testing.T) { + for _, f := range funcs { + _, err := f(&result.Invoke{State: "HALT"}, nil) + require.Error(t, err) + } + }) t.Run("multiple return values", func(t *testing.T) { for _, f := range funcs { _, err := f(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42), stackitem.Make(42)}}, nil) @@ -256,6 +266,39 @@ func TestSessionIterator(t *testing.T) { require.Equal(t, iter, ri) } +func TestArraySessionIterator(t *testing.T) { + _, _, _, err := ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil) + require.Error(t, err) + + _, _, _, err = ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.NewInterop(42)}}, nil) + require.Error(t, err) + + arr := stackitem.NewArray([]stackitem.Item{stackitem.Make(42)}) + ra, rs, ri, err := ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{arr}}, nil) + require.NoError(t, err) + require.Equal(t, arr.Value(), ra) + require.Empty(t, rs) + require.Empty(t, ri) + + _, _, _, err = ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{arr, stackitem.NewInterop(42)}}, nil) + require.Error(t, err) + + iid := uuid.New() + iter := result.Iterator{ID: &iid} + _, _, _, err = ArrayAndSessionIterator(&result.Invoke{State: "HALT", Stack: []stackitem.Item{arr, stackitem.NewInterop(iter)}}, nil) + require.ErrorIs(t, err, ErrNoSessionID) + + sid := uuid.New() + _, rs, ri, err = ArrayAndSessionIterator(&result.Invoke{Session: sid, State: "HALT", Stack: []stackitem.Item{arr, stackitem.NewInterop(iter)}}, nil) + require.NoError(t, err) + require.Equal(t, arr.Value(), ra) + require.Equal(t, sid, rs) + require.Equal(t, iter, ri) + + _, _, _, err = ArrayAndSessionIterator(&result.Invoke{Session: sid, State: "HALT", Stack: []stackitem.Item{arr, stackitem.NewInterop(iter), stackitem.Make(42)}}, nil) + require.Error(t, err) +} + func TestArray(t *testing.T) { _, err := Array(&result.Invoke{State: "HALT", Stack: []stackitem.Item{stackitem.Make(42)}}, nil) require.Error(t, err) diff --git a/pkg/smartcontract/entry.go b/pkg/smartcontract/entry.go index 99caed055..3b6ee7682 100644 --- a/pkg/smartcontract/entry.go +++ b/pkg/smartcontract/entry.go @@ -20,6 +20,61 @@ import ( // for interactions with RPC server that have iterator sessions disabled. func CreateCallAndUnwrapIteratorScript(contract util.Uint160, operation string, maxIteratorResultItems int, params ...any) ([]byte, error) { script := io.NewBufBinWriter() + jmpIfNotOffset, jmpIfMaxReachedOffset := emitCallAndUnwrapIteratorScript(script, contract, operation, maxIteratorResultItems, params...) + + // End of the program: push the result on stack and return. + loadResultOffset := script.Len() + emit.Opcodes(script.BinWriter, opcode.NIP, // Remove iterator from the 1-st cell of estack + opcode.NIP) // Remove maxIteratorResultItems from the 1-st cell of estack, so that only resulting array is left on estack. + if err := script.Err; err != nil { + return nil, fmt.Errorf("emitting iterator unwrapper script: %w", err) + } + + // Fill in JMPIFNOT instruction parameter. + bytes := script.Bytes() + bytes[jmpIfNotOffset+1] = uint8(loadResultOffset - jmpIfNotOffset) // +1 is for JMPIFNOT itself; offset is relative to JMPIFNOT position. + // Fill in jmpIfMaxReachedOffset instruction parameter. + bytes[jmpIfMaxReachedOffset+1] = uint8(loadResultOffset - jmpIfMaxReachedOffset) // +1 is for JMPIF itself; offset is relative to JMPIF position. + return bytes, nil +} + +// CreateCallAndPrefetchIteratorScript creates a script that calls 'operation' method +// of the 'contract' with the specified arguments. This method is expected to return +// an array of the first iterator items (up to maxIteratorResultItems, which cannot exceed VM limits) +// and, optionally, an iterator that then is traversed (using iterator.Next). +// The result of the script is an array containing extracted value elements and an iterator, if it can contain more items. +// If the iterator is present, it lies on top of the stack. +// Note, however, that if an iterator is returned, the number of remaining items can still be 0. +// This script should only be used for interactions with RPC server that have iterator sessions enabled. +func CreateCallAndPrefetchIteratorScript(contract util.Uint160, operation string, maxIteratorResultItems int, params ...any) ([]byte, error) { + script := io.NewBufBinWriter() + jmpIfNotOffset, jmpIfMaxReachedOffset := emitCallAndUnwrapIteratorScript(script, contract, operation, maxIteratorResultItems, params...) + + // 1st possibility: jump here when the maximum number of items was reached. + retainIteratorOffset := script.Len() + emit.Opcodes(script.BinWriter, opcode.ROT, // Put maxIteratorResultItems from the 2-nd cell of estack, to the top + opcode.DROP, // ... and then drop it. + opcode.SWAP, // Put the iterator on top of the stack. + opcode.RET) + + // 2nd possibility: jump here when the iterator has no more items. + loadResultOffset := script.Len() + emit.Opcodes(script.BinWriter, opcode.ROT, // Put maxIteratorResultItems from the 2-nd cell of estack, to the top + opcode.DROP, // ... and then drop it. + opcode.NIP) // Drop iterator as the 1-st cell on the stack. + if err := script.Err; err != nil { + return nil, fmt.Errorf("emitting iterator unwrapper script: %w", err) + } + + // Fill in JMPIFNOT instruction parameter. + bytes := script.Bytes() + bytes[jmpIfNotOffset+1] = uint8(loadResultOffset - jmpIfNotOffset) // +1 is for JMPIFNOT itself; offset is relative to JMPIFNOT position. + // Fill in jmpIfMaxReachedOffset instruction parameter. + bytes[jmpIfMaxReachedOffset+1] = uint8(retainIteratorOffset - jmpIfMaxReachedOffset) // +1 is for JMPIF itself; offset is relative to JMPIF position. + return bytes, nil +} + +func emitCallAndUnwrapIteratorScript(script *io.BufBinWriter, contract util.Uint160, operation string, maxIteratorResultItems int, params ...any) (int, int) { emit.Int(script.BinWriter, int64(maxIteratorResultItems)) emit.AppCall(script.BinWriter, contract, operation, callflag.All, params...) // The System.Contract.Call itself, it will push Iterator on estack. emit.Opcodes(script.BinWriter, opcode.NEWARRAY0) // Push new empty array to estack. This array will store iterator's elements. @@ -51,21 +106,7 @@ func CreateCallAndUnwrapIteratorScript(contract util.Uint160, operation string, []byte{ uint8(iteratorTraverseCycleStartOffset - jmpOffset), // jump to iteratorTraverseCycleStartOffset; offset is relative to JMP position. }) - - // End of the program: push the result on stack and return. - loadResultOffset := script.Len() - emit.Opcodes(script.BinWriter, opcode.NIP, // Remove iterator from the 1-st cell of estack - opcode.NIP) // Remove maxIteratorResultItems from the 1-st cell of estack, so that only resulting array is left on estack. - if err := script.Err; err != nil { - return nil, fmt.Errorf("emitting iterator unwrapper script: %w", err) - } - - // Fill in JMPIFNOT instruction parameter. - bytes := script.Bytes() - bytes[jmpIfNotOffset+1] = uint8(loadResultOffset - jmpIfNotOffset) // +1 is for JMPIFNOT itself; offset is relative to JMPIFNOT position. - // Fill in jmpIfMaxReachedOffset instruction parameter. - bytes[jmpIfMaxReachedOffset+1] = uint8(loadResultOffset - jmpIfMaxReachedOffset) // +1 is for JMPIF itself; offset is relative to JMPIF position. - return bytes, nil + return jmpIfNotOffset, jmpIfMaxReachedOffset } // CreateCallScript returns a script that calls contract's method with diff --git a/pkg/vm/iterator_test.go b/pkg/vm/iterator_test.go new file mode 100644 index 000000000..0908bee26 --- /dev/null +++ b/pkg/vm/iterator_test.go @@ -0,0 +1,121 @@ +package vm + +import ( + "fmt" + "math/big" + "testing" + + "github.com/nspcc-dev/neo-go/internal/random" + "github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/stretchr/testify/require" +) + +type arrayIterator struct { + index int + values []stackitem.Item +} + +func TestCreateCallAndUnwrapIteratorScript(t *testing.T) { + ctrHash := random.Uint160() + ctrMethod := "mymethod" + param := stackitem.NewBigInteger(big.NewInt(42)) + + const totalItems = 8 + values := make([]stackitem.Item, totalItems) + for i := range values { + values[i] = stackitem.NewBigInteger(big.NewInt(int64(i))) + } + + checkStack := func(t *testing.T, script []byte, index int, prefetch bool) { + v := load(script) + it := &arrayIterator{index: -1, values: values} + v.SyscallHandler = func(v *VM, id uint32) error { + switch id { + case interopnames.ToID([]byte(interopnames.SystemContractCall)): + require.Equal(t, ctrHash.BytesBE(), v.Estack().Pop().Value()) + require.Equal(t, []byte(ctrMethod), v.Estack().Pop().Value()) + require.Equal(t, big.NewInt(int64(callflag.All)), v.Estack().Pop().Value()) + require.Equal(t, []stackitem.Item{param}, v.Estack().Pop().Value()) + v.Estack().PushItem(stackitem.NewInterop(it)) + case interopnames.ToID([]byte(interopnames.SystemIteratorNext)): + require.Equal(t, it, v.Estack().Pop().Value()) + it.index++ + v.Estack().PushVal(it.index < len(it.values)) + case interopnames.ToID([]byte(interopnames.SystemIteratorValue)): + require.Equal(t, it, v.Estack().Pop().Value()) + v.Estack().PushVal(it.values[it.index]) + default: + return fmt.Errorf("unexpected syscall: %d", id) + } + return nil + } + require.NoError(t, v.Run()) + + if prefetch && index <= len(values) { + require.Equal(t, 2, v.Estack().Len()) + + it, ok := v.Estack().Pop().Interop().Value().(*arrayIterator) + require.True(t, ok) + require.Equal(t, index-1, it.index) + require.Equal(t, values[:index], v.Estack().Pop().Array()) + return + } + if len(values) < index { + index = len(values) + } + require.Equal(t, 1, v.Estack().Len()) + require.Equal(t, values[:index], v.Estack().Pop().Array()) + } + + t.Run("truncate", func(t *testing.T) { + t.Run("zero", func(t *testing.T) { + const index = 0 + script, err := smartcontract.CreateCallAndUnwrapIteratorScript(ctrHash, ctrMethod, index, param) + require.NoError(t, err) + + // The behaviour is a bit unexpected, but not a problem (why would anyone fetch 0 items). + // Let's have test, to make it obvious. + checkStack(t, script, index+1, false) + }) + t.Run("all", func(t *testing.T) { + const index = totalItems + 1 + script, err := smartcontract.CreateCallAndUnwrapIteratorScript(ctrHash, ctrMethod, index, param) + require.NoError(t, err) + + checkStack(t, script, index, false) + }) + t.Run("partial", func(t *testing.T) { + const index = totalItems / 2 + script, err := smartcontract.CreateCallAndUnwrapIteratorScript(ctrHash, ctrMethod, index, param) + require.NoError(t, err) + + checkStack(t, script, index, false) + }) + }) + t.Run("prefetch", func(t *testing.T) { + t.Run("zero", func(t *testing.T) { + const index = 0 + script, err := smartcontract.CreateCallAndPrefetchIteratorScript(ctrHash, ctrMethod, index, param) + require.NoError(t, err) + + checkStack(t, script, index+1, true) + }) + t.Run("all", func(t *testing.T) { + const index = totalItems + 1 // +1 to test with iterator dropped + script, err := smartcontract.CreateCallAndPrefetchIteratorScript(ctrHash, ctrMethod, index, param) + require.NoError(t, err) + + checkStack(t, script, index, true) + }) + t.Run("partial", func(t *testing.T) { + const index = totalItems / 2 + script, err := smartcontract.CreateCallAndPrefetchIteratorScript(ctrHash, ctrMethod, index, param) + require.NoError(t, err) + + checkStack(t, script, index, true) + }) + }) +}