diff --git a/pool/tree/pool.go b/pool/tree/pool.go index ab6e98a1..c82e2697 100644 --- a/pool/tree/pool.go +++ b/pool/tree/pool.go @@ -414,12 +414,19 @@ func (p *Pool) GetNodes(ctx context.Context, prm GetNodesParams) ([]*tree.GetNod // // Must be initialized using Pool.GetSubTree, any other usage is unsafe. type SubTreeReader struct { - cli *rpcapi.GetSubTreeResponseReader + cli *rpcapi.GetSubTreeResponseReader + probe *tree.GetSubTreeResponseBody } // Read reads another list of the subtree nodes. func (x *SubTreeReader) Read(buf []*tree.GetSubTreeResponseBody) (int, error) { - for i := range buf { + i := 0 + if x.probe != nil && len(buf) != 0 { + buf[0] = x.probe + x.probe = nil + i = 1 + } + for ; i < len(buf); i++ { var resp tree.GetSubTreeResponse err := x.cli.Read(&resp) if err == io.EOF { @@ -436,6 +443,10 @@ func (x *SubTreeReader) Read(buf []*tree.GetSubTreeResponseBody) (int, error) { // ReadAll reads all nodes subtree nodes. func (x *SubTreeReader) ReadAll() ([]*tree.GetSubTreeResponseBody, error) { var res []*tree.GetSubTreeResponseBody + if x.probe != nil { + res = append(res, x.probe) + x.probe = nil + } for { var resp tree.GetSubTreeResponse err := x.cli.Read(&resp) @@ -452,6 +463,12 @@ func (x *SubTreeReader) ReadAll() ([]*tree.GetSubTreeResponseBody, error) { // Next gets the next node from subtree. func (x *SubTreeReader) Next() (*tree.GetSubTreeResponseBody, error) { + if x.probe != nil { + res := x.probe + x.probe = nil + return res, nil + } + var resp tree.GetSubTreeResponse err := x.cli.Read(&resp) if err == io.EOF { @@ -495,16 +512,24 @@ func (p *Pool) GetSubTree(ctx context.Context, prm GetSubTreeParams) (*SubTreeRe } var cli *rpcapi.GetSubTreeResponseReader + var probeBody *tree.GetSubTreeResponseBody err := p.requestWithRetry(ctx, prm.CID, func(client *rpcclient.Client) (inErr error) { cli, inErr = rpcapi.GetSubTree(client, request, rpcclient.WithContext(ctx)) - return handleError("failed to get sub tree client", inErr) + if inErr != nil { + return handleError("failed to get sub tree client", inErr) + } + + probe := &tree.GetSubTreeResponse{} + inErr = cli.Read(probe) + probeBody = probe.GetBody() + return handleError("failed to get first resp from sub tree client", inErr) }) p.methods[methodGetSubTree].IncRequests(time.Since(start)) if err != nil { return nil, err } - return &SubTreeReader{cli: cli}, nil + return &SubTreeReader{cli: cli, probe: probeBody}, nil } // AddNode invokes eponymous method from TreeServiceClient. diff --git a/pool/tree/pool_server_test.go b/pool/tree/pool_server_test.go index edba93f4..014e4544 100644 --- a/pool/tree/pool_server_test.go +++ b/pool/tree/pool_server_test.go @@ -28,6 +28,9 @@ type mockTreeServer struct { healthy bool addCounter int + + getSubTreeError error + getSubTreeCounter int } type mockNetmapSource struct { @@ -91,7 +94,8 @@ func (m *mockTreeServer) GetNodeByPath(context.Context, *tree.GetNodeByPathReque } func (m *mockTreeServer) GetSubTree(*tree.GetSubTreeRequest, tree.TreeService_GetSubTreeServer) error { - panic("implement me") + m.getSubTreeCounter++ + return m.getSubTreeError } func (m *mockTreeServer) TreeList(context.Context, *tree.TreeListRequest) (*tree.TreeListResponse, error) { @@ -235,3 +239,32 @@ func TestConnectionLeak(t *testing.T) { // not more than 1 extra goroutine is created due to async operations require.LessOrEqual(t, runtime.NumGoroutine()-routinesBefore, 1) } + +func TestStreamRetry(t *testing.T) { + const ( + numberOfNodes = 4 + placementPolicy = "REP 2" + ) + + // Initialize gRPC servers and create pool with netmap source + treePool, servers, _ := preparePoolWithNetmapSource(t, numberOfNodes, placementPolicy) + for i := range servers { + servers[i].getSubTreeError = errors.New("tree not found") + } + defer func() { + for i := range servers { + servers[i].Stop() + } + }() + + cnr := cidtest.ID() + ctx := context.Background() + + _, err := treePool.GetSubTree(ctx, GetSubTreeParams{CID: cnr}) + require.Error(t, err) + + for i := range servers { + // check we retried every available node in the pool + require.Equal(t, 1, servers[i].getSubTreeCounter) + } +}