diff --git a/go.mod b/go.mod index ed14604..d0fd2a5 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/klauspost/reedsolomon v1.12.1 github.com/mailru/easyjson v0.7.7 github.com/mr-tron/base58 v1.2.0 + github.com/multiformats/go-multiaddr v0.12.1 github.com/nspcc-dev/neo-go v0.106.2 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 @@ -29,13 +30,21 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/golang/snappy v0.0.1 // indirect github.com/gorilla/websocket v1.5.1 // indirect + github.com/ipfs/go-cid v0.0.7 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect - github.com/kr/pretty v0.1.0 // indirect + github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect + github.com/minio/sha256-simd v1.0.1 // indirect + github.com/multiformats/go-base32 v0.0.3 // indirect + github.com/multiformats/go-base36 v0.1.0 // indirect + github.com/multiformats/go-multibase v0.0.3 // indirect + github.com/multiformats/go-multihash v0.0.14 // indirect + github.com/multiformats/go-varint v0.0.6 // indirect github.com/nspcc-dev/go-ordered-json v0.0.0-20240301084351-0246b013f8b2 // indirect github.com/nspcc-dev/neo-go/pkg/interop v0.0.0-20240521091047-78685785716d // indirect github.com/nspcc-dev/rfc6979 v0.2.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/syndtr/goleveldb v1.0.1-0.20210305035536-64b5b1c73954 // indirect github.com/twmb/murmur3 v1.1.8 // indirect go.etcd.io/bbolt v1.3.9 // indirect @@ -46,5 +55,4 @@ require ( golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 // indirect - gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) diff --git a/go.sum b/go.sum index 56930c7..9baf10d 100644 Binary files a/go.sum and b/go.sum differ diff --git a/pkg/network/address.go b/pkg/network/address.go new file mode 100644 index 0000000..f91a886 --- /dev/null +++ b/pkg/network/address.go @@ -0,0 +1,110 @@ +package network + +import ( + "errors" + "fmt" + "net" + "net/url" + "strings" + + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/rpc/client" + "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +var errHostIsEmpty = errors.New("host is empty") + +// Address represents the FrostFS node +// network address. +type Address struct { + ma multiaddr.Multiaddr +} + +// URIAddr returns Address as a URI. +// +// Panics if host address cannot be fetched from Address. +// +// See also FromString. +func (a Address) URIAddr() string { + _, host, err := manet.DialArgs(a.ma) + if err != nil { + // the only correct way to construct Address is AddressFromString + // which makes this error appear unexpected + panic(fmt.Errorf("could not get host addr: %w", err)) + } + + if !a.IsTLSEnabled() { + return host + } + + return (&url.URL{ + Scheme: "grpcs", + Host: host, + }).String() +} + +// FromString restores Address from a string representation. +// +// Supports URIAddr, MultiAddr and HostAddr strings. +func (a *Address) FromString(s string) error { + var err error + + a.ma, err = multiaddr.NewMultiaddr(s) + if err != nil { + var ( + host string + hasTLS bool + ) + host, hasTLS, err = client.ParseURI(s) + if err != nil { + host = s + } + + s, err = multiaddrStringFromHostAddr(host) + if err == nil { + a.ma, err = multiaddr.NewMultiaddr(s) + if err == nil && hasTLS { + a.ma = a.ma.Encapsulate(tls) + } + } + } + + return err +} + +// multiaddrStringFromHostAddr converts "localhost:8080" to "/dns4/localhost/tcp/8080". +func multiaddrStringFromHostAddr(host string) (string, error) { + if len(host) == 0 { + return "", errHostIsEmpty + } + + endpoint, port, err := net.SplitHostPort(host) + if err != nil { + return "", err + } + + // Empty address in host `:8080` generates `/dns4//tcp/8080` multiaddr + // which is invalid. It could be `/tcp/8080` but this breaks + // `manet.DialArgs`. The solution is to manually parse it as 0.0.0.0 + if endpoint == "" { + return "/ip4/0.0.0.0/tcp/" + port, nil + } + + var ( + prefix = "/dns4" + addr = endpoint + ) + + if ip := net.ParseIP(endpoint); ip != nil { + addr = ip.String() + if ip.To4() == nil { + prefix = "/ip6" + } else { + prefix = "/ip4" + } + } + + const l4Protocol = "tcp" + + return strings.Join([]string{prefix, addr, l4Protocol, port}, "/"), nil +} diff --git a/pkg/network/tls.go b/pkg/network/tls.go new file mode 100644 index 0000000..8d9c15b --- /dev/null +++ b/pkg/network/tls.go @@ -0,0 +1,16 @@ +package network + +import "github.com/multiformats/go-multiaddr" + +const ( + tlsProtocolName = "tls" +) + +// tls var is used for (un)wrapping other multiaddrs around TLS multiaddr. +var tls, _ = multiaddr.NewMultiaddr("/" + tlsProtocolName) + +// IsTLSEnabled searches for wrapped TLS protocol in multiaddr. +func (a Address) IsTLSEnabled() bool { + _, err := a.ma.ValueForProtocol(multiaddr.P_TLS) + return err == nil +} diff --git a/pool/tree/pool.go b/pool/tree/pool.go index f4f9bf8..31134fc 100644 --- a/pool/tree/pool.go +++ b/pool/tree/pool.go @@ -14,6 +14,8 @@ import ( rpcclient "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/rpc/client" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/tree" cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/pkg/network" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/pool" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "go.uber.org/zap" @@ -73,6 +75,17 @@ type InitParameters struct { nodeParams []pool.NodeParam dialOptions []grpc.DialOption maxRequestAttempts int + netMapSource NetMapSource + placementPolicySource PlacementPolicySource + v2 bool +} + +type NetMapSource interface { + NetMapSnapshot(ctx context.Context) (netmap.NetMap, error) +} + +type PlacementPolicySource interface { + PlacementPolicy(ctx context.Context, cnrID cid.ID) (netmap.PlacementPolicy, error) } // Pool represents virtual connection to the FrostFS tree services network to communicate @@ -96,6 +109,11 @@ type Pool struct { streamTimeout time.Duration nodeDialTimeout time.Duration + v2 bool + netMapSource NetMapSource + policySource PlacementPolicySource + clientMap map[uint64]client + startIndicesMtx sync.RWMutex // startIndices points to the client from which the next request will be executed. // Since clients are stored in innerPool field we have to use two indices. @@ -213,9 +231,14 @@ func NewPool(options InitParameters) (*Pool, error) { return nil, fmt.Errorf("missed required parameter 'Key'") } - nodesParams, err := adjustNodeParams(options.nodeParams) - if err != nil { - return nil, err + if options.v2 { + if options.netMapSource == nil { + return nil, fmt.Errorf("missed required parameter 'NetMap source'") + } + + if options.placementPolicySource == nil { + return nil, fmt.Errorf("missed required parameter 'Placement policy source'") + } } fillDefaultInitParams(&options) @@ -230,13 +253,25 @@ func NewPool(options InitParameters) (*Pool, error) { logger: options.logger, dialOptions: options.dialOptions, rebalanceParams: rebalanceParameters{ - nodesGroup: nodesParams, nodeRequestTimeout: options.healthcheckTimeout, clientRebalanceInterval: options.clientRebalanceInterval, }, maxRequestAttempts: options.maxRequestAttempts, streamTimeout: options.nodeStreamTimeout, + nodeDialTimeout: options.nodeDialTimeout, methods: methods, + v2: options.v2, + netMapSource: options.netMapSource, + policySource: options.placementPolicySource, + clientMap: make(map[uint64]client), + } + + if !options.v2 { + nodesParams, err := adjustNodeParams(options.nodeParams) + if err != nil { + return nil, err + } + p.rebalanceParams.nodesGroup = nodesParams } return p, nil @@ -251,36 +286,39 @@ func NewPool(options InitParameters) (*Pool, error) { // // See also InitParameters.SetClientRebalanceInterval. func (p *Pool) Dial(ctx context.Context) error { - inner := make([]*innerPool, len(p.rebalanceParams.nodesGroup)) - var atLeastOneHealthy bool + if !p.v2 { + inner := make([]*innerPool, len(p.rebalanceParams.nodesGroup)) + var atLeastOneHealthy bool - for i, nodes := range p.rebalanceParams.nodesGroup { - clients := make([]client, len(nodes)) - for j, node := range nodes { - clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout) - if err := clients[j].dial(ctx); err != nil { - p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err)) - continue + for i, nodes := range p.rebalanceParams.nodesGroup { + clients := make([]client, len(nodes)) + for j, node := range nodes { + clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout) + if err := clients[j].dial(ctx); err != nil { + p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err)) + continue + } + + atLeastOneHealthy = true } - atLeastOneHealthy = true + inner[i] = &innerPool{ + clients: clients, + } } - inner[i] = &innerPool{ - clients: clients, + if !atLeastOneHealthy { + return fmt.Errorf("at least one node must be healthy") } + + ctx, cancel := context.WithCancel(ctx) + p.cancel = cancel + p.closedCh = make(chan struct{}) + p.innerPools = inner + + go p.startRebalance(ctx) } - if !atLeastOneHealthy { - return fmt.Errorf("at least one node must be healthy") - } - - ctx, cancel := context.WithCancel(ctx) - p.cancel = cancel - p.closedCh = make(chan struct{}) - p.innerPools = inner - - go p.startRebalance(ctx) return nil } @@ -334,6 +372,21 @@ func (x *InitParameters) SetMaxRequestAttempts(maxAttempts int) { x.maxRequestAttempts = maxAttempts } +// UseV2 sets flag for using net map to prioritize tree services. +func (x *InitParameters) UseV2() { + x.v2 = true +} + +// SetNetMapSource sets implementation of interface to get current net map. +func (x *InitParameters) SetNetMapSource(netMapSource NetMapSource) { + x.netMapSource = netMapSource +} + +// SetPlacementPolicySource sets implementation of interface to get container placement policy. +func (x *InitParameters) SetPlacementPolicySource(placementPolicySource PlacementPolicySource) { + x.placementPolicySource = placementPolicySource +} + // GetNodes invokes eponymous method from TreeServiceClient. // // Can return predefined errors: @@ -359,7 +412,7 @@ func (p *Pool) GetNodes(ctx context.Context, prm GetNodesParams) ([]*tree.GetNod } var resp *tree.GetNodeByPathResponse - err := p.requestWithRetry(ctx, func(client *rpcclient.Client) (inErr error) { + err := p.requestWithRetry(ctx, prm.CID, func(client *rpcclient.Client) (inErr error) { resp, inErr = rpcapi.GetNodeByPath(client, request, rpcclient.WithContext(ctx)) // Pool wants to do retry 'GetNodeByPath' request if result is empty. // Empty result is expected due to delayed tree service sync. @@ -463,7 +516,7 @@ func (p *Pool) GetSubTree(ctx context.Context, prm GetSubTreeParams) (*SubTreeRe } var cli *rpcapi.GetSubTreeResponseReader - err := p.requestWithRetry(ctx, func(client *rpcclient.Client) (inErr error) { + err := p.requestWithRetry(ctx, prm.CID, func(client *rpcclient.Client) (inErr error) { cli, inErr = rpcapi.GetSubTree(client, request, rpcclient.WithContext(ctx)) return handleError("failed to get sub tree client", inErr) }) @@ -497,7 +550,7 @@ func (p *Pool) AddNode(ctx context.Context, prm AddNodeParams) (uint64, error) { } var resp *tree.AddResponse - err := p.requestWithRetry(ctx, func(client *rpcclient.Client) (inErr error) { + err := p.requestWithRetry(ctx, prm.CID, func(client *rpcclient.Client) (inErr error) { resp, inErr = rpcapi.Add(client, request, rpcclient.WithContext(ctx)) return handleError("failed to add node", inErr) }) @@ -532,7 +585,7 @@ func (p *Pool) AddNodeByPath(ctx context.Context, prm AddNodeByPathParams) (uint } var resp *tree.AddByPathResponse - err := p.requestWithRetry(ctx, func(client *rpcclient.Client) (inErr error) { + err := p.requestWithRetry(ctx, prm.CID, func(client *rpcclient.Client) (inErr error) { resp, inErr = rpcapi.AddByPath(client, request, rpcclient.WithContext(ctx)) return handleError("failed to add node by path", inErr) }) @@ -574,7 +627,7 @@ func (p *Pool) MoveNode(ctx context.Context, prm MoveNodeParams) error { return err } - err := p.requestWithRetry(ctx, func(client *rpcclient.Client) error { + err := p.requestWithRetry(ctx, prm.CID, func(client *rpcclient.Client) error { if _, err := rpcapi.Move(client, request, rpcclient.WithContext(ctx)); err != nil { return handleError("failed to move node", err) } @@ -605,7 +658,7 @@ func (p *Pool) RemoveNode(ctx context.Context, prm RemoveNodeParams) error { return err } - err := p.requestWithRetry(ctx, func(client *rpcclient.Client) error { + err := p.requestWithRetry(ctx, prm.CID, func(client *rpcclient.Client) error { if _, err := rpcapi.Remove(client, request, rpcclient.WithContext(ctx)); err != nil { return handleError("failed to remove node", err) } @@ -631,6 +684,13 @@ func (p *Pool) Close() error { } } + for _, cl := range p.clientMap { + if closeErr := cl.close(); closeErr != nil { + p.log(zapcore.ErrorLevel, "close client connection", zap.Error(closeErr)) + err = closeErr + } + } + return err } @@ -814,7 +874,11 @@ func (p *Pool) setStartIndices(i, j int) { p.startIndicesMtx.Unlock() } -func (p *Pool) requestWithRetry(ctx context.Context, fn func(client *rpcclient.Client) error) error { +func (p *Pool) requestWithRetry(ctx context.Context, cid cid.ID, fn func(client *rpcclient.Client) error) error { + if p.v2 { + return p.requestWithRetryContainerNodes(ctx, cid, fn) + } + var ( err, finErr error cl *rpcclient.Client @@ -866,10 +930,120 @@ LOOP: return finErr } +func (p *Pool) requestWithRetryContainerNodes(ctx context.Context, cid cid.ID, fn func(client *rpcclient.Client) error) error { + var ( + err, finErr error + cl *rpcclient.Client + ) + + reqID := GetRequestID(ctx) + + netMap, err := p.netMapSource.NetMapSnapshot(ctx) + if err != nil { + return fmt.Errorf("get net map: %w", err) + } + + policy, err := p.policySource.PlacementPolicy(ctx, cid) + if err != nil { + return fmt.Errorf("get container placement policy: %w", err) + } + + cnrNodes, err := netMap.ContainerNodes(policy, cid[:]) + if err != nil { + return fmt.Errorf("get container nodes: %w", err) + } + + cnrNodes, err = netMap.PlacementVectors(cnrNodes, cid[:]) + if err != nil { + return fmt.Errorf("get placement vectors: %w", err) + } + + attempts := p.maxRequestAttempts + +LOOP: + for _, cnrNodeGroup := range cnrNodes { + for _, cnrNode := range cnrNodeGroup { + if attempts == 0 { + break LOOP + } + attempts-- + + treeCl, ok := p.clientMap[cnrNode.Hash()] + if !ok { + treeCl, err = p.getNewTreeClient(ctx, cnrNode) + if err != nil { + finErr = finalError(finErr, err) + p.log(zap.DebugLevel, "failed to create tree client", zap.String("request_id", reqID), zap.Int("remaining attempts", attempts)) + continue + } + + p.clientMap[cnrNode.Hash()] = treeCl + } + + if cl, err = treeCl.serviceClient(); err == nil { + err = fn(cl) + } + if shouldRedial(err) { + delete(p.clientMap, cnrNode.Hash()) + } + if !shouldTryAgain(err) { + if err != nil { + err = fmt.Errorf("address %s: %w", treeCl.endpoint(), err) + } + + return err + } + + finErr = finalError(finErr, err) + p.log(zap.DebugLevel, "tree request error", zap.String("request_id", reqID), zap.Int("remaining attempts", attempts), + zap.String("address", treeCl.endpoint()), zap.Error(err)) + } + } + + return finErr +} + +func (p *Pool) getNewTreeClient(ctx context.Context, node netmap.NodeInfo) (*treeClient, error) { + var ( + treeCl *treeClient + err error + ) + + node.IterateNetworkEndpoints(func(endpoint string) bool { + var addr network.Address + if err = addr.FromString(endpoint); err != nil { + p.log(zap.WarnLevel, "can't parse endpoint", zap.String("endpoint", endpoint), zap.Error(err)) + return false + } + + newTreeCl := newTreeClient(addr.URIAddr(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout) + if err = newTreeCl.dial(ctx); err != nil { + p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", addr.URIAddr()), zap.Error(err)) + return false + } + + treeCl = newTreeCl + return true + }) + + if treeCl == nil { + return nil, fmt.Errorf("tree client wasn't initialized") + } + + return treeCl, nil +} + func shouldTryAgain(err error) bool { return !(err == nil || errors.Is(err, ErrNodeAccessDenied)) } +func shouldRedial(err error) bool { + if err == nil || errors.Is(err, ErrNodeAccessDenied) || errors.Is(err, ErrNodeNotFound) || errors.Is(err, errNodeEmptyResult) { + return false + } + return true +} + func prioErr(err error) int { switch { case err == nil: diff --git a/pool/tree/pool_test.go b/pool/tree/pool_test.go index f9f4142..ade899f 100644 --- a/pool/tree/pool_test.go +++ b/pool/tree/pool_test.go @@ -7,7 +7,12 @@ import ( rpcClient "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/rpc/client" apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status" + cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id" + cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/pool" + "git.frostfs.info/TrueCloudLab/hrw" + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) @@ -15,12 +20,14 @@ import ( type treeClientMock struct { address string err bool + used bool } func (t *treeClientMock) serviceClient() (*rpcClient.Client, error) { if t.err { return nil, errors.New("serviceClient() mock error") } + t.used = true return nil, nil } @@ -51,6 +58,26 @@ func (t *treeClientMock) close() error { return nil } +type placementPolicyMock struct { + policy netmap.PlacementPolicy +} + +func (p *placementPolicyMock) PlacementPolicy(context.Context, cid.ID) (netmap.PlacementPolicy, error) { + return p.policy, nil +} + +type netMapMock struct { + netMap netmap.NetMap + err error +} + +func (n *netMapMock) NetMapSnapshot(context.Context) (netmap.NetMap, error) { + if n.err != nil { + return netmap.NetMap{}, n.err + } + return n.netMap, nil +} + func TestHandleError(t *testing.T) { defaultError := errors.New("default error") for _, tc := range []struct { @@ -104,14 +131,14 @@ func TestRetry(t *testing.T) { } t.Run("first ok", func(t *testing.T) { - err := p.requestWithRetry(ctx, makeFn) + err := p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.NoError(t, err) checkIndicesAndReset(t, p, 0, 0) }) t.Run("first failed", func(t *testing.T) { setErrors(p, "node00") - err := p.requestWithRetry(ctx, makeFn) + err := p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.NoError(t, err) checkIndicesAndReset(t, p, 0, 1) }) @@ -119,7 +146,7 @@ func TestRetry(t *testing.T) { t.Run("all failed", func(t *testing.T) { setErrors(p, nodes[0]...) setErrors(p, nodes[1]...) - err := p.requestWithRetry(ctx, makeFn) + err := p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.Error(t, err) checkIndicesAndReset(t, p, 0, 0) }) @@ -127,13 +154,13 @@ func TestRetry(t *testing.T) { t.Run("round", func(t *testing.T) { setErrors(p, nodes[0][0], nodes[0][1]) setErrors(p, nodes[1]...) - err := p.requestWithRetry(ctx, makeFn) + err := p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.NoError(t, err) checkIndices(t, p, 0, 2) resetClientsErrors(p) setErrors(p, nodes[0][2], nodes[0][3]) - err = p.requestWithRetry(ctx, makeFn) + err = p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.NoError(t, err) checkIndicesAndReset(t, p, 0, 0) }) @@ -141,14 +168,14 @@ func TestRetry(t *testing.T) { t.Run("group switch", func(t *testing.T) { setErrors(p, nodes[0]...) setErrors(p, nodes[1][0]) - err := p.requestWithRetry(ctx, makeFn) + err := p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.NoError(t, err) checkIndicesAndReset(t, p, 1, 1) }) t.Run("group round", func(t *testing.T) { setErrors(p, nodes[0][1:]...) - err := p.requestWithRetry(ctx, makeFn) + err := p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.NoError(t, err) checkIndicesAndReset(t, p, 0, 0) }) @@ -156,7 +183,7 @@ func TestRetry(t *testing.T) { t.Run("group round switch", func(t *testing.T) { setErrors(p, nodes[0]...) p.setStartIndices(0, 1) - err := p.requestWithRetry(ctx, makeFn) + err := p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.NoError(t, err) checkIndicesAndReset(t, p, 1, 0) }) @@ -164,14 +191,14 @@ func TestRetry(t *testing.T) { t.Run("no panic group switch", func(t *testing.T) { setErrors(p, nodes[1]...) p.setStartIndices(1, 0) - err := p.requestWithRetry(ctx, makeFn) + err := p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.NoError(t, err) checkIndicesAndReset(t, p, 0, 0) }) t.Run("error empty result", func(t *testing.T) { errNodes, index := 2, 0 - err := p.requestWithRetry(ctx, func(client *rpcClient.Client) error { + err := p.requestWithRetry(ctx, cidtest.ID(), func(client *rpcClient.Client) error { if index < errNodes { index++ return errNodeEmptyResult @@ -184,7 +211,7 @@ func TestRetry(t *testing.T) { t.Run("error not found", func(t *testing.T) { errNodes, index := 2, 0 - err := p.requestWithRetry(ctx, func(client *rpcClient.Client) error { + err := p.requestWithRetry(ctx, cidtest.ID(), func(client *rpcClient.Client) error { if index < errNodes { index++ return ErrNodeNotFound @@ -197,7 +224,7 @@ func TestRetry(t *testing.T) { t.Run("error access denied", func(t *testing.T) { var index int - err := p.requestWithRetry(ctx, func(client *rpcClient.Client) error { + err := p.requestWithRetry(ctx, cidtest.ID(), func(client *rpcClient.Client) error { index++ return ErrNodeAccessDenied }) @@ -211,7 +238,7 @@ func TestRetry(t *testing.T) { p.maxRequestAttempts = 2 setErrors(p, nodes[0]...) setErrors(p, nodes[1]...) - err := p.requestWithRetry(ctx, makeFn) + err := p.requestWithRetry(ctx, cidtest.ID(), makeFn) require.Error(t, err) checkIndicesAndReset(t, p, 0, 2) p.maxRequestAttempts = oldVal @@ -273,6 +300,122 @@ func TestRebalance(t *testing.T) { }) } +func TestRetryContainerNodes(t *testing.T) { + ctx := context.Background() + nodesCount := 3 + policy := getPlacementPolicy(uint32(nodesCount)) + p := &Pool{ + logger: zaptest.NewLogger(t), + maxRequestAttempts: nodesCount, + policySource: &placementPolicyMock{policy: policy}, + v2: true, + } + + var nm netmap.NetMap + nodeKeys := make([]*keys.PrivateKey, nodesCount) + for i := 0; i < nodesCount; i++ { + key, err := keys.NewPrivateKey() + require.NoError(t, err) + nodeKeys[i] = key + nm.SetNodes(append(nm.Nodes(), getNodeInfo(key.Bytes()))) + } + p.netMapSource = &netMapMock{netMap: nm} + + cnrID := cidtest.ID() + cnrNodes, err := nm.ContainerNodes(policy, cnrID[:]) + require.NoError(t, err) + cnrNodes, err = nm.PlacementVectors(cnrNodes, cnrID[:]) + require.NoError(t, err) + require.Len(t, cnrNodes, 1) + require.Len(t, cnrNodes[0], nodesCount) + nodeKeys = reorderKeys(nodeKeys, cnrNodes) + + makeFn := func(client *rpcClient.Client) error { + return nil + } + + t.Run("first ok", func(t *testing.T) { + p.clientMap = makeClientMap(nodeKeys) + err = p.requestWithRetry(ctx, cnrID, makeFn) + require.NoError(t, err) + checkClientUsage(t, p, nodeKeys[0]) + }) + + t.Run("first failed", func(t *testing.T) { + p.clientMap = makeClientMap(nodeKeys) + setClientMapErrors(p, nodeKeys[0]) + err = p.requestWithRetry(ctx, cnrID, makeFn) + require.NoError(t, err) + checkClientUsage(t, p, nodeKeys[1]) + }) + + t.Run("first two failed", func(t *testing.T) { + p.clientMap = makeClientMap(nodeKeys) + setClientMapErrors(p, nodeKeys[0], nodeKeys[1]) + err = p.requestWithRetry(ctx, cnrID, makeFn) + require.NoError(t, err) + checkClientUsage(t, p, nodeKeys[2]) + }) + + t.Run("all failed", func(t *testing.T) { + p.clientMap = makeClientMap(nodeKeys) + setClientMapErrors(p, nodeKeys[0], nodeKeys[1], nodeKeys[2]) + err = p.requestWithRetry(ctx, cnrID, makeFn) + require.Error(t, err) + checkClientUsage(t, p) + }) + + t.Run("error empty result", func(t *testing.T) { + p.clientMap = makeClientMap(nodeKeys) + errNodes, index := 2, 0 + err = p.requestWithRetry(ctx, cnrID, func(client *rpcClient.Client) error { + if index < errNodes { + index++ + return errNodeEmptyResult + } + return nil + }) + require.NoError(t, err) + checkClientUsage(t, p, nodeKeys[:errNodes+1]...) + }) + + t.Run("error not found", func(t *testing.T) { + p.clientMap = makeClientMap(nodeKeys) + errNodes, index := 2, 0 + err = p.requestWithRetry(ctx, cnrID, func(client *rpcClient.Client) error { + if index < errNodes { + index++ + return ErrNodeNotFound + } + return nil + }) + require.NoError(t, err) + checkClientUsage(t, p, nodeKeys[:errNodes+1]...) + }) + + t.Run("error access denied", func(t *testing.T) { + p.clientMap = makeClientMap(nodeKeys) + var index int + err = p.requestWithRetry(ctx, cnrID, func(client *rpcClient.Client) error { + index++ + return ErrNodeAccessDenied + }) + require.ErrorIs(t, err, ErrNodeAccessDenied) + require.Equal(t, 1, index) + checkClientUsage(t, p, nodeKeys[0]) + }) + + t.Run("limit attempts", func(t *testing.T) { + p.clientMap = makeClientMap(nodeKeys) + p.maxRequestAttempts = 2 + setClientMapErrors(p, nodeKeys[0], nodeKeys[1]) + err = p.requestWithRetry(ctx, cnrID, makeFn) + require.Error(t, err) + checkClientUsage(t, p) + p.maxRequestAttempts = nodesCount + }) +} + func makeInnerPool(nodes [][]string) []*innerPool { res := make([]*innerPool, len(nodes)) @@ -359,3 +502,64 @@ func containsStr(list []string, item string) bool { return false } + +func makeClientMap(nodeKeys []*keys.PrivateKey) map[uint64]client { + res := make(map[uint64]client, len(nodeKeys)) + for _, key := range nodeKeys { + res[hrw.Hash(key.Bytes())] = &treeClientMock{} + } + return res +} + +func reorderKeys(nodeKeys []*keys.PrivateKey, cnrNodes [][]netmap.NodeInfo) []*keys.PrivateKey { + res := make([]*keys.PrivateKey, len(nodeKeys)) + for i := 0; i < len(cnrNodes[0]); i++ { + for j := 0; j < len(nodeKeys); j++ { + if hrw.Hash(nodeKeys[j].Bytes()) == cnrNodes[0][i].Hash() { + res[i] = nodeKeys[j] + } + } + } + return res +} + +func checkClientUsage(t *testing.T, p *Pool, nodeKeys ...*keys.PrivateKey) { + for hash, cl := range p.clientMap { + if containsHash(nodeKeys, hash) { + require.True(t, cl.(*treeClientMock).used) + } else { + require.False(t, cl.(*treeClientMock).used) + } + } +} + +func setClientMapErrors(p *Pool, nodeKeys ...*keys.PrivateKey) { + for hash, cl := range p.clientMap { + if containsHash(nodeKeys, hash) { + cl.(*treeClientMock).err = true + } + } +} + +func containsHash(list []*keys.PrivateKey, hash uint64) bool { + for i := range list { + if hrw.Hash(list[i].Bytes()) == hash { + return true + } + } + + return false +} + +func getPlacementPolicy(replicas uint32) (p netmap.PlacementPolicy) { + var r netmap.ReplicaDescriptor + r.SetNumberOfObjects(replicas) + p.AddReplicas([]netmap.ReplicaDescriptor{r}...) + return p +} + +func getNodeInfo(key []byte) netmap.NodeInfo { + var node netmap.NodeInfo + node.SetPublicKey(key) + return node +}