diff --git a/pkg/rpcclient/actor/actor_test.go b/pkg/rpcclient/actor/actor_test.go index 46cdd3906..32e9eca00 100644 --- a/pkg/rpcclient/actor/actor_test.go +++ b/pkg/rpcclient/actor/actor_test.go @@ -4,11 +4,13 @@ import ( "errors" "testing" + "github.com/google/uuid" "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/wallet" "github.com/stretchr/testify/require" ) @@ -44,7 +46,12 @@ func (r *RPCClient) GetVersion() (*result.Version, error) { func (r *RPCClient) SendRawTransaction(tx *transaction.Transaction) (util.Uint256, error) { return r.hash, r.err } - +func (r *RPCClient) TerminateSession(sessionID uuid.UUID) (bool, error) { + return false, nil // Just a stub, unused by actor. +} +func (r *RPCClient) TraverseIterator(sessionID, iteratorID uuid.UUID, maxItemsCount int) ([]stackitem.Item, error) { + return nil, nil // Just a stub, unused by actor. +} func testRPCAndAccount(t *testing.T) (*RPCClient, *wallet.Account) { client := &RPCClient{ version: &result.Version{ diff --git a/pkg/rpcclient/invoker/invoker.go b/pkg/rpcclient/invoker/invoker.go index 16a13c972..266d09687 100644 --- a/pkg/rpcclient/invoker/invoker.go +++ b/pkg/rpcclient/invoker/invoker.go @@ -1,17 +1,35 @@ package invoker import ( + "errors" "fmt" + "github.com/google/uuid" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) +// DefaultIteratorResultItems is the default number of results to +// request from the iterator. Typically it's the same as server's +// MaxIteratorResultItems, but different servers can have different +// settings. +const DefaultIteratorResultItems = 100 + +// RPCSessions is a set of RPC methods needed to retrieve values from the +// session-based iterators. +type RPCSessions interface { + TerminateSession(sessionID uuid.UUID) (bool, error) + TraverseIterator(sessionID, iteratorID uuid.UUID, maxItemsCount int) ([]stackitem.Item, error) +} + // RPCInvoke is a set of RPC methods needed to execute things at the current // blockchain height. type RPCInvoke interface { + RPCSessions + InvokeContractVerify(contract util.Uint160, params []smartcontract.Parameter, signers []transaction.Signer, witnesses ...transaction.Witness) (*result.Invoke, error) InvokeFunction(contract util.Uint160, operation string, params []smartcontract.Parameter, signers []transaction.Signer) (*result.Invoke, error) InvokeScript(script []byte, signers []transaction.Signer) (*result.Invoke, error) @@ -20,6 +38,8 @@ type RPCInvoke interface { // RPCInvokeHistoric is a set of RPC methods needed to execute things at some // fixed point in blockchain's life. type RPCInvokeHistoric interface { + RPCSessions + InvokeContractVerifyAtBlock(blockHash util.Uint256, contract util.Uint160, params []smartcontract.Parameter, signers []transaction.Signer, witnesses ...transaction.Witness) (*result.Invoke, error) InvokeContractVerifyAtHeight(height uint32, contract util.Uint160, params []smartcontract.Parameter, signers []transaction.Signer, witnesses ...transaction.Witness) (*result.Invoke, error) InvokeContractVerifyWithState(stateroot util.Uint256, contract util.Uint160, params []smartcontract.Parameter, signers []transaction.Signer, witnesses ...transaction.Witness) (*result.Invoke, error) @@ -117,6 +137,14 @@ func (h *historicConverter) InvokeContractVerify(contract util.Uint160, params [ panic("uninitialized historicConverter") } +func (h *historicConverter) TerminateSession(sessionID uuid.UUID) (bool, error) { + return h.client.TerminateSession(sessionID) +} + +func (h *historicConverter) TraverseIterator(sessionID, iteratorID uuid.UUID, maxItemsCount int) ([]stackitem.Item, error) { + return h.client.TraverseIterator(sessionID, iteratorID, maxItemsCount) +} + // Call invokes a method of the contract with the given parameters (and // Invoker-specific list of signers) and returns the result as is. func (v *Invoker) Call(contract util.Uint160, operation string, params ...interface{}) (*result.Invoke, error) { @@ -157,3 +185,48 @@ func (v *Invoker) Verify(contract util.Uint160, witnesses []transaction.Witness, func (v *Invoker) Run(script []byte) (*result.Invoke, error) { return v.client.InvokeScript(script, v.signers) } + +// TerminateSession closes the given session, returning an error if anything +// goes wrong. +func (v *Invoker) TerminateSession(sessionID uuid.UUID) error { + return termSession(v.client, sessionID) +} + +func termSession(rpc RPCSessions, sessionID uuid.UUID) error { + r, err := rpc.TerminateSession(sessionID) + if err != nil { + return err + } + if !r { + return errors.New("terminatesession returned false") + } + return nil +} + +// TraverseIterator allows to retrieve the next batch of items from the given +// iterator in the given session (previously returned from Call or Run). It works +// both with session-backed iterators and expanded ones (which one you have +// depends on the RPC server). It can change the state of the iterator in the +// process. If num <= 0 then DefaultIteratorResultItems number of elements is +// requested. If result contains no elements, then either Iterator has no +// elements or session was expired and terminated by the server. +func (v *Invoker) TraverseIterator(sessionID uuid.UUID, iterator *result.Iterator, num int) ([]stackitem.Item, error) { + return iterateNext(v.client, sessionID, iterator, num) +} + +func iterateNext(rpc RPCSessions, sessionID uuid.UUID, iterator *result.Iterator, num int) ([]stackitem.Item, error) { + if num <= 0 { + num = DefaultIteratorResultItems + } + + if iterator.ID != nil { + return rpc.TraverseIterator(sessionID, *iterator.ID, num) + } + if num > len(iterator.Values) { + num = len(iterator.Values) + } + items := iterator.Values[:num] + iterator.Values = iterator.Values[num:] + + return items, nil +} diff --git a/pkg/rpcclient/invoker/invoker_test.go b/pkg/rpcclient/invoker/invoker_test.go index a5e59af3e..e36f1f3bb 100644 --- a/pkg/rpcclient/invoker/invoker_test.go +++ b/pkg/rpcclient/invoker/invoker_test.go @@ -1,17 +1,22 @@ package invoker import ( + "errors" "testing" + "github.com/google/uuid" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/neorpc/result" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/require" ) type rpcInv struct { resInv *result.Invoke + resTrm bool + resItm []stackitem.Item err error } @@ -51,10 +56,16 @@ func (r *rpcInv) InvokeScriptAtHeight(height uint32, script []byte, signers []tr func (r *rpcInv) InvokeScriptWithState(stateroot util.Uint256, script []byte, signers []transaction.Signer) (*result.Invoke, error) { return r.resInv, r.err } +func (r *rpcInv) TerminateSession(sessionID uuid.UUID) (bool, error) { + return r.resTrm, r.err +} +func (r *rpcInv) TraverseIterator(sessionID, iteratorID uuid.UUID, maxItemsCount int) ([]stackitem.Item, error) { + return r.resItm, r.err +} func TestInvoker(t *testing.T) { resExp := &result.Invoke{State: "HALT"} - ri := &rpcInv{resExp, nil} + ri := &rpcInv{resExp, true, nil, nil} testInv := func(t *testing.T, inv *Invoker) { res, err := inv.Call(util.Uint160{}, "method") @@ -112,4 +123,50 @@ func TestInvoker(t *testing.T) { require.Panics(t, func() { _, _ = inv.Verify(util.Uint160{}, nil, "param") }) require.Panics(t, func() { _, _ = inv.Run([]byte{1}) }) }) + t.Run("terminate session", func(t *testing.T) { + for _, inv := range []*Invoker{New(ri, nil), NewHistoricAtBlock(util.Uint256{}, ri, nil)} { + ri.err = errors.New("") + require.Error(t, inv.TerminateSession(uuid.UUID{})) + ri.err = nil + ri.resTrm = false + require.Error(t, inv.TerminateSession(uuid.UUID{})) + ri.resTrm = true + require.NoError(t, inv.TerminateSession(uuid.UUID{})) + } + }) + t.Run("traverse iterator", func(t *testing.T) { + for _, inv := range []*Invoker{New(ri, nil), NewHistoricAtBlock(util.Uint256{}, ri, nil)} { + res, err := inv.TraverseIterator(uuid.UUID{}, &result.Iterator{ + Values: []stackitem.Item{stackitem.Make(42)}, + }, 0) + require.NoError(t, err) + require.Equal(t, []stackitem.Item{stackitem.Make(42)}, res) + + res, err = inv.TraverseIterator(uuid.UUID{}, &result.Iterator{ + Values: []stackitem.Item{stackitem.Make(42)}, + }, 1) + require.NoError(t, err) + require.Equal(t, []stackitem.Item{stackitem.Make(42)}, res) + + res, err = inv.TraverseIterator(uuid.UUID{}, &result.Iterator{ + Values: []stackitem.Item{stackitem.Make(42)}, + }, 2) + require.NoError(t, err) + require.Equal(t, []stackitem.Item{stackitem.Make(42)}, res) + + ri.err = errors.New("") + _, err = inv.TraverseIterator(uuid.UUID{}, &result.Iterator{ + ID: &uuid.UUID{}, + }, 2) + require.Error(t, err) + + ri.err = nil + ri.resItm = []stackitem.Item{stackitem.Make(42)} + res, err = inv.TraverseIterator(uuid.UUID{}, &result.Iterator{ + ID: &uuid.UUID{}, + }, 2) + require.NoError(t, err) + require.Equal(t, []stackitem.Item{stackitem.Make(42)}, res) + } + }) }