diff --git a/pkg/services/tree/redirect.go b/pkg/services/tree/redirect.go index ec41a60d..5bde3ae3 100644 --- a/pkg/services/tree/redirect.go +++ b/pkg/services/tree/redirect.go @@ -12,10 +12,24 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" + "google.golang.org/grpc" ) var errNoSuitableNode = errors.New("no node was found to execute the request") +func relayUnary[Req any, Resp any](ctx context.Context, s *Service, ns []netmapSDK.NodeInfo, req *Req, callback func(TreeServiceClient, context.Context, *Req, ...grpc.CallOption) (*Resp, error)) (*Resp, error) { + var resp *Resp + var outErr error + err := s.forEachNode(ctx, ns, func(c TreeServiceClient) bool { + resp, outErr = callback(c, ctx, req) + return true + }) + if err != nil { + return nil, err + } + return resp, outErr +} + // forEachNode executes callback for each node in the container until true is returned. // Returns errNoSuitableNode if there was no successful attempt to dial any node. func (s *Service) forEachNode(ctx context.Context, cntNodes []netmapSDK.NodeInfo, f func(c TreeServiceClient) bool) error { diff --git a/pkg/services/tree/service.go b/pkg/services/tree/service.go index 2cb2af29..acc2775e 100644 --- a/pkg/services/tree/service.go +++ b/pkg/services/tree/service.go @@ -122,16 +122,7 @@ func (s *Service) Add(ctx context.Context, req *AddRequest) (*AddResponse, error return nil, err } if pos < 0 { - var resp *AddResponse - var outErr error - err = s.forEachNode(ctx, ns, func(c TreeServiceClient) bool { - resp, outErr = c.Add(ctx, req) - return true - }) - if err != nil { - return nil, err - } - return resp, outErr + return relayUnary(ctx, s, ns, req, (TreeServiceClient).Add) } d := pilorama.CIDDescriptor{CID: cid, Position: pos, Size: len(ns)} @@ -174,16 +165,7 @@ func (s *Service) AddByPath(ctx context.Context, req *AddByPathRequest) (*AddByP return nil, err } if pos < 0 { - var resp *AddByPathResponse - var outErr error - err = s.forEachNode(ctx, ns, func(c TreeServiceClient) bool { - resp, outErr = c.AddByPath(ctx, req) - return true - }) - if err != nil { - return nil, err - } - return resp, outErr + return relayUnary(ctx, s, ns, req, (TreeServiceClient).AddByPath) } meta := protoToMeta(b.GetMeta()) @@ -238,16 +220,7 @@ func (s *Service) Remove(ctx context.Context, req *RemoveRequest) (*RemoveRespon return nil, err } if pos < 0 { - var resp *RemoveResponse - var outErr error - err = s.forEachNode(ctx, ns, func(c TreeServiceClient) bool { - resp, outErr = c.Remove(ctx, req) - return true - }) - if err != nil { - return nil, err - } - return resp, outErr + return relayUnary(ctx, s, ns, req, (TreeServiceClient).Remove) } if b.GetNodeId() == pilorama.RootID { @@ -291,16 +264,7 @@ func (s *Service) Move(ctx context.Context, req *MoveRequest) (*MoveResponse, er return nil, err } if pos < 0 { - var resp *MoveResponse - var outErr error - err = s.forEachNode(ctx, ns, func(c TreeServiceClient) bool { - resp, outErr = c.Move(ctx, req) - return true - }) - if err != nil { - return nil, err - } - return resp, outErr + return relayUnary(ctx, s, ns, req, (TreeServiceClient).Move) } if b.GetNodeId() == pilorama.RootID { @@ -343,16 +307,7 @@ func (s *Service) GetNodeByPath(ctx context.Context, req *GetNodeByPathRequest) return nil, err } if pos < 0 { - var resp *GetNodeByPathResponse - var outErr error - err = s.forEachNode(ctx, ns, func(c TreeServiceClient) bool { - resp, outErr = c.GetNodeByPath(ctx, req) - return true - }) - if err != nil { - return nil, err - } - return resp, outErr + return relayUnary(ctx, s, ns, req, (TreeServiceClient).GetNodeByPath) } attr := b.GetPathAttribute() @@ -763,16 +718,7 @@ func (s *Service) TreeList(ctx context.Context, req *TreeListRequest) (*TreeList return nil, err } if pos < 0 { - var resp *TreeListResponse - var outErr error - err = s.forEachNode(ctx, ns, func(c TreeServiceClient) bool { - resp, outErr = c.TreeList(ctx, req) - return outErr == nil - }) - if err != nil { - return nil, err - } - return resp, outErr + return relayUnary(ctx, s, ns, req, (TreeServiceClient).TreeList) } ids, err := s.forest.TreeList(ctx, cid)