rpc: restrict (*Client).TraverseIterator with single RPC call

Do not unwrap the whole set of iterator values even on demand.
This commit is contained in:
Anna Shaleva 2022-07-06 17:08:53 +03:00
parent fad061f3d9
commit 9bdd8151af
2 changed files with 17 additions and 28 deletions

View file

@ -1150,36 +1150,30 @@ func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) {
// TraverseIterator returns a set of iterator values (maxItemsCount at max) for // TraverseIterator returns a set of iterator values (maxItemsCount at max) for
// the specified iterator and session. If result contains no elements, then either // the specified iterator and session. If result contains no elements, then either
// Iterator has no elements or session was expired and terminated by the server. // Iterator has no elements or session was expired and terminated by the server.
// If maxItemsCount is non-positive, then the full set of iterator values will be // If maxItemsCount is non-positive, then config.DefaultMaxIteratorResultItems
// returned using several `traverseiterator` calls if needed. Note that iterator // iterator values will be returned using single `traverseiterator` call.
// session lifetime is restricted by the RPC-server configuration and is being // Note that iterator session lifetime is restricted by the RPC-server
// reset each time iterator is accessed. If session won't be accessed within session // configuration and is being reset each time iterator is accessed. If session
// expiration time, then it will be terminated by the RPC-server automatically. // won't be accessed within session expiration time, then it will be terminated
// by the RPC-server automatically.
func (c *Client) TraverseIterator(sessionID, iteratorID uuid.UUID, maxItemsCount int) ([]stackitem.Item, error) { func (c *Client) TraverseIterator(sessionID, iteratorID uuid.UUID, maxItemsCount int) ([]stackitem.Item, error) {
var traverseAll bool
if maxItemsCount <= 0 { if maxItemsCount <= 0 {
maxItemsCount = config.DefaultMaxIteratorResultItems maxItemsCount = config.DefaultMaxIteratorResultItems
traverseAll = true
} }
var ( var (
result []stackitem.Item
params = request.NewRawParams(sessionID.String(), iteratorID.String(), maxItemsCount) params = request.NewRawParams(sessionID.String(), iteratorID.String(), maxItemsCount)
resp []json.RawMessage
) )
for { if err := c.performRequest("traverseiterator", params, &resp); err != nil {
var resp []json.RawMessage return nil, err
if err := c.performRequest("traverseiterator", params, &resp); err != nil { }
return nil, err result := make([]stackitem.Item, len(resp))
} for i, iBytes := range resp {
for i, iBytes := range resp { itm, err := stackitem.FromJSONWithTypes(iBytes)
itm, err := stackitem.FromJSONWithTypes(iBytes) if err != nil {
if err != nil { return nil, fmt.Errorf("failed to unmarshal %d-th iterator value: %w", i, err)
return nil, fmt.Errorf("failed to unmarshal %d-th iterator value: %w", i, err)
}
result = append(result, itm)
}
if len(resp) < maxItemsCount || !traverseAll {
break
} }
result[i] = itm
} }
return result, nil return result, nil

View file

@ -1163,12 +1163,7 @@ func TestClient_IteratorSessions(t *testing.T) {
set, err := c.TraverseIterator(sID, iID, -1) set, err := c.TraverseIterator(sID, iID, -1)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, storageItemsCount, len(set)) require.Equal(t, config.DefaultMaxIteratorResultItems, len(set))
// No more items should be left.
set, err = c.TraverseIterator(sID, iID, -1)
require.NoError(t, err)
require.Equal(t, 0, len(set))
}) })
t.Run("traverse, concurrent access", func(t *testing.T) { t.Run("traverse, concurrent access", func(t *testing.T) {