diff --git a/pool/pool.go b/pool/pool.go index d795af8..631227b 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -1934,6 +1934,13 @@ type resCreateSession struct { // // See pool package overview to get some examples. type Pool struct { + manager *connectionManager + logger *zap.Logger + + maxObjectSize uint64 +} + +type connectionManager struct { innerPools []*innerPool key *ecdsa.PrivateKey cancel context.CancelFunc @@ -1943,8 +1950,6 @@ type Pool struct { rebalanceParams rebalanceParameters clientBuilder clientBuilder logger *zap.Logger - - maxObjectSize uint64 } type innerPool struct { @@ -1966,8 +1971,8 @@ const ( defaultBufferMaxSizeForPut = 3 * 1024 * 1024 // 3 MB ) -// NewPool creates connection pool using parameters. -func NewPool(options InitParameters) (*Pool, error) { +// newConnectionManager creates connection pool using parameters. +func newConnectionManager(options InitParameters) (*connectionManager, error) { if options.key == nil { return nil, fmt.Errorf("missed required parameter 'Key'") } @@ -1984,7 +1989,7 @@ func NewPool(options InitParameters) (*Pool, error) { fillDefaultInitParams(&options, cache) - pool := &Pool{ + manager := &connectionManager{ key: options.key, cache: cache, logger: options.logger, @@ -1998,6 +2003,21 @@ func NewPool(options InitParameters) (*Pool, error) { clientBuilder: options.clientBuilder, } + return manager, nil +} + +// NewPool creates cnnectionManager using parameters. +func NewPool(options InitParameters) (*Pool, error) { + manager, err := newConnectionManager(options) + if err != nil { + return nil, err + } + + pool := &Pool{ + manager: manager, + logger: options.logger, + } + return pool, nil } @@ -2010,28 +2030,42 @@ func NewPool(options InitParameters) (*Pool, error) { // // See also InitParameters.SetClientRebalanceInterval. func (p *Pool) Dial(ctx context.Context) error { - inner := make([]*innerPool, len(p.rebalanceParams.nodesParams)) + err := p.manager.dial(ctx) + if err != nil { + return err + } + + ni, err := p.NetworkInfo(ctx) + if err != nil { + return fmt.Errorf("get network info for max object size: %w", err) + } + p.maxObjectSize = ni.MaxObjectSize() + return nil +} + +func (cm *connectionManager) dial(ctx context.Context) error { + inner := make([]*innerPool, len(cm.rebalanceParams.nodesParams)) var atLeastOneHealthy bool - for i, params := range p.rebalanceParams.nodesParams { + for i, params := range cm.rebalanceParams.nodesParams { clients := make([]client, len(params.weights)) for j, addr := range params.addresses { - clients[j] = p.clientBuilder(addr) + clients[j] = cm.clientBuilder(addr) if err := clients[j].dial(ctx); err != nil { - p.log(zap.WarnLevel, "failed to build client", zap.String("address", addr), zap.Error(err)) + cm.log(zap.WarnLevel, "failed to build client", zap.String("address", addr), zap.Error(err)) continue } var st session.Object - err := initSessionForDuration(ctx, &st, clients[j], p.rebalanceParams.sessionExpirationDuration, *p.key, false) + err := initSessionForDuration(ctx, &st, clients[j], cm.rebalanceParams.sessionExpirationDuration, *cm.key, false) if err != nil { clients[j].setUnhealthy() - p.log(zap.WarnLevel, "failed to create frostfs session token for client", + cm.log(zap.WarnLevel, "failed to create frostfs session token for client", zap.String("address", addr), zap.Error(err)) continue } - _ = p.cache.Put(formCacheKey(addr, p.key, false), st) + _ = cm.cache.Put(formCacheKey(addr, cm.key, false), st) atLeastOneHealthy = true } source := rand.NewSource(time.Now().UnixNano()) @@ -2048,26 +2082,20 @@ func (p *Pool) Dial(ctx context.Context) error { } ctx, cancel := context.WithCancel(ctx) - p.cancel = cancel - p.closedCh = make(chan struct{}) - p.innerPools = inner + cm.cancel = cancel + cm.closedCh = make(chan struct{}) + cm.innerPools = inner - ni, err := p.NetworkInfo(ctx) - if err != nil { - return fmt.Errorf("get network info for max object size: %w", err) - } - p.maxObjectSize = ni.MaxObjectSize() - - go p.startRebalance(ctx) + go cm.startRebalance(ctx) return nil } -func (p *Pool) log(level zapcore.Level, msg string, fields ...zap.Field) { - if p.logger == nil { +func (cm *connectionManager) log(level zapcore.Level, msg string, fields ...zap.Field) { + if cm.logger == nil { return } - p.logger.Log(level, msg, fields...) + cm.logger.Log(level, msg, fields...) } func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { @@ -2154,47 +2182,47 @@ func adjustNodeParams(nodeParams []NodeParam) ([]*nodesParam, error) { } // startRebalance runs loop to monitor connection healthy status. -func (p *Pool) startRebalance(ctx context.Context) { - ticker := time.NewTicker(p.rebalanceParams.clientRebalanceInterval) +func (cm *connectionManager) startRebalance(ctx context.Context) { + ticker := time.NewTicker(cm.rebalanceParams.clientRebalanceInterval) defer ticker.Stop() - buffers := make([][]float64, len(p.rebalanceParams.nodesParams)) - for i, params := range p.rebalanceParams.nodesParams { + buffers := make([][]float64, len(cm.rebalanceParams.nodesParams)) + for i, params := range cm.rebalanceParams.nodesParams { buffers[i] = make([]float64, len(params.weights)) } for { select { case <-ctx.Done(): - close(p.closedCh) + close(cm.closedCh) return case <-ticker.C: - p.updateNodesHealth(ctx, buffers) - ticker.Reset(p.rebalanceParams.clientRebalanceInterval) + cm.updateNodesHealth(ctx, buffers) + ticker.Reset(cm.rebalanceParams.clientRebalanceInterval) } } } -func (p *Pool) updateNodesHealth(ctx context.Context, buffers [][]float64) { +func (cm *connectionManager) updateNodesHealth(ctx context.Context, buffers [][]float64) { wg := sync.WaitGroup{} - for i, inner := range p.innerPools { + for i, inner := range cm.innerPools { wg.Add(1) bufferWeights := buffers[i] go func(i int, _ *innerPool) { defer wg.Done() - p.updateInnerNodesHealth(ctx, i, bufferWeights) + cm.updateInnerNodesHealth(ctx, i, bufferWeights) }(i, inner) } wg.Wait() } -func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights []float64) { - if i > len(p.innerPools)-1 { +func (cm *connectionManager) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights []float64) { + if i > len(cm.innerPools)-1 { return } - pool := p.innerPools[i] - options := p.rebalanceParams + pool := cm.innerPools[i] + options := cm.rebalanceParams healthyChanged := new(atomic.Bool) wg := sync.WaitGroup{} @@ -2213,7 +2241,7 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights bufferWeights[j] = options.nodesParams[i].weights[j] } else { bufferWeights[j] = 0 - p.cache.DeleteByPrefix(cli.address()) + cm.cache.DeleteByPrefix(cli.address()) } if changed { @@ -2222,7 +2250,7 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights fields = append(fields, zap.String("reason", err.Error())) } - p.log(zap.DebugLevel, "health has changed", fields...) + cm.log(zap.DebugLevel, "health has changed", fields...) healthyChanged.Store(true) } }(j, cli) @@ -2290,8 +2318,8 @@ func adjustWeights(weights []float64) []float64 { return adjusted } -func (p *Pool) connection() (client, error) { - for _, inner := range p.innerPools { +func (cm *connectionManager) connection() (client, error) { + for _, inner := range cm.innerPools { cp, err := inner.connection() if err == nil { return cp, nil @@ -2333,13 +2361,13 @@ func formCacheKey(address string, key *ecdsa.PrivateKey, clientCut bool) string return address + stype + k.String() } -func (p *Pool) checkSessionTokenErr(err error, address string) bool { +func (cm *connectionManager) checkSessionTokenErr(err error, address string) bool { if err == nil { return false } if sdkClient.IsErrSessionNotFound(err) || sdkClient.IsErrSessionExpired(err) { - p.cache.DeleteByPrefix(address) + cm.cache.DeleteByPrefix(address) return true } @@ -2412,8 +2440,8 @@ type callContext struct { sessionClientCut bool } -func (p *Pool) initCallContext(ctx *callContext, cfg prmCommon, prmCtx prmContext) error { - cp, err := p.connection() +func (cm *connectionManager) initCallContext(ctx *callContext, cfg prmCommon, prmCtx prmContext) error { + cp, err := cm.connection() if err != nil { return err } @@ -2421,7 +2449,7 @@ func (p *Pool) initCallContext(ctx *callContext, cfg prmCommon, prmCtx prmContex ctx.key = cfg.key if ctx.key == nil { // use pool key if caller didn't specify its own - ctx.key = p.key + ctx.key = cm.key } ctx.endpoint = cp.address() @@ -2445,19 +2473,19 @@ func (p *Pool) initCallContext(ctx *callContext, cfg prmCommon, prmCtx prmContex // opens new session or uses cached one. // Must be called only on initialized callContext with set sessionTarget. -func (p *Pool) openDefaultSession(ctx context.Context, cc *callContext) error { +func (cm *connectionManager) openDefaultSession(ctx context.Context, cc *callContext) error { cacheKey := formCacheKey(cc.endpoint, cc.key, cc.sessionClientCut) - tok, ok := p.cache.Get(cacheKey) + tok, ok := cm.cache.Get(cacheKey) if !ok { // init new session - err := initSessionForDuration(ctx, &tok, cc.client, p.stokenDuration, *cc.key, cc.sessionClientCut) + err := initSessionForDuration(ctx, &tok, cc.client, cm.stokenDuration, *cc.key, cc.sessionClientCut) if err != nil { return fmt.Errorf("session API client: %w", err) } // cache the opened session - p.cache.Put(cacheKey, tok) + cm.cache.Put(cacheKey, tok) } tok.ForVerb(cc.sessionVerb) @@ -2479,26 +2507,26 @@ func (p *Pool) openDefaultSession(ctx context.Context, cc *callContext) error { // opens default session (if sessionDefault is set), and calls f. If f returns // session-related error then cached token is removed. -func (p *Pool) call(ctx context.Context, cc *callContext, f func() error) error { +func (cm *connectionManager) call(ctx context.Context, cc *callContext, f func() error) error { var err error if cc.sessionDefault { - err = p.openDefaultSession(ctx, cc) + err = cm.openDefaultSession(ctx, cc) if err != nil { return fmt.Errorf("open default session: %w", err) } } err = f() - _ = p.checkSessionTokenErr(err, cc.endpoint) + _ = cm.checkSessionTokenErr(err, cc.endpoint) return err } // fillAppropriateKey use pool key if caller didn't specify its own. -func (p *Pool) fillAppropriateKey(prm *prmCommon) { +func (cm *connectionManager) fillAppropriateKey(prm *prmCommon) { if prm.key == nil { - prm.key = p.key + prm.key = cm.key } } @@ -2522,16 +2550,16 @@ func (p *Pool) PatchObject(ctx context.Context, prm PrmObjectPatch) (ResPatchObj prmCtx.useVerb(session.VerbObjectPatch) prmCtx.useContainer(prm.addr.Container()) - p.fillAppropriateKey(&prm.prmCommon) + p.manager.fillAppropriateKey(&prm.prmCommon) var ctxCall callContext - if err := p.initCallContext(&ctxCall, prm.prmCommon, prmCtx); err != nil { + if err := p.manager.initCallContext(&ctxCall, prm.prmCommon, prmCtx); err != nil { return ResPatchObject{}, fmt.Errorf("init call context: %w", err) } if ctxCall.sessionDefault { ctxCall.sessionTarget = prm.UseSession - if err := p.openDefaultSession(ctx, &ctxCall); err != nil { + if err := p.manager.openDefaultSession(ctx, &ctxCall); err != nil { return ResPatchObject{}, fmt.Errorf("open default session: %w", err) } } @@ -2539,7 +2567,7 @@ func (p *Pool) PatchObject(ctx context.Context, prm PrmObjectPatch) (ResPatchObj res, err := ctxCall.client.objectPatch(ctx, prm) if err != nil { // removes session token from cache in case of token error - p.checkSessionTokenErr(err, ctxCall.endpoint) + p.manager.checkSessionTokenErr(err, ctxCall.endpoint) return ResPatchObject{}, fmt.Errorf("init patching on API client %s: %w", ctxCall.endpoint, err) } @@ -2557,24 +2585,24 @@ func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (ResPutObject, e prmCtx.useVerb(session.VerbObjectPut) prmCtx.useContainer(cnr) - p.fillAppropriateKey(&prm.prmCommon) + p.manager.fillAppropriateKey(&prm.prmCommon) var ctxCall callContext ctxCall.sessionClientCut = prm.clientCut - if err := p.initCallContext(&ctxCall, prm.prmCommon, prmCtx); err != nil { + if err := p.manager.initCallContext(&ctxCall, prm.prmCommon, prmCtx); err != nil { return ResPutObject{}, fmt.Errorf("init call context: %w", err) } if ctxCall.sessionDefault { ctxCall.sessionTarget = prm.UseSession - if err := p.openDefaultSession(ctx, &ctxCall); err != nil { + if err := p.manager.openDefaultSession(ctx, &ctxCall); err != nil { return ResPutObject{}, fmt.Errorf("open default session: %w", err) } } if prm.clientCut { var ni netmap.NetworkInfo - ni.SetCurrentEpoch(p.cache.Epoch()) + ni.SetCurrentEpoch(p.manager.cache.Epoch()) ni.SetMaxObjectSize(p.maxObjectSize) // we want to use initial max object size in PayloadSizeLimiter prm.setNetworkInfo(ni) } @@ -2582,7 +2610,7 @@ func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (ResPutObject, e res, err := ctxCall.client.objectPut(ctx, prm) if err != nil { // removes session token from cache in case of token error - p.checkSessionTokenErr(err, ctxCall.endpoint) + p.manager.checkSessionTokenErr(err, ctxCall.endpoint) return ResPutObject{}, fmt.Errorf("init writing on API client %s: %w", ctxCall.endpoint, err) } @@ -2614,17 +2642,17 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error { } } - p.fillAppropriateKey(&prm.prmCommon) + p.manager.fillAppropriateKey(&prm.prmCommon) var cc callContext cc.sessionTarget = prm.UseSession - err := p.initCallContext(&cc, prm.prmCommon, prmCtx) + err := p.manager.initCallContext(&cc, prm.prmCommon, prmCtx) if err != nil { return err } - return p.call(ctx, &cc, func() error { + return p.manager.call(ctx, &cc, func() error { if err = cc.client.objectDelete(ctx, prm); err != nil { return fmt.Errorf("remove object via client %s: %w", cc.endpoint, err) } @@ -2663,19 +2691,19 @@ type ResGetObject struct { // // Main return value MUST NOT be processed on an erroneous return. func (p *Pool) GetObject(ctx context.Context, prm PrmObjectGet) (ResGetObject, error) { - p.fillAppropriateKey(&prm.prmCommon) + p.manager.fillAppropriateKey(&prm.prmCommon) var cc callContext cc.sessionTarget = prm.UseSession var res ResGetObject - err := p.initCallContext(&cc, prm.prmCommon, prmContext{}) + err := p.manager.initCallContext(&cc, prm.prmCommon, prmContext{}) if err != nil { return res, err } - return res, p.call(ctx, &cc, func() error { + return res, p.manager.call(ctx, &cc, func() error { res, err = cc.client.objectGet(ctx, prm) if err != nil { return fmt.Errorf("get object via client %s: %w", cc.endpoint, err) @@ -2688,19 +2716,19 @@ func (p *Pool) GetObject(ctx context.Context, prm PrmObjectGet) (ResGetObject, e // // Main return value MUST NOT be processed on an erroneous return. func (p *Pool) HeadObject(ctx context.Context, prm PrmObjectHead) (object.Object, error) { - p.fillAppropriateKey(&prm.prmCommon) + p.manager.fillAppropriateKey(&prm.prmCommon) var cc callContext cc.sessionTarget = prm.UseSession var obj object.Object - err := p.initCallContext(&cc, prm.prmCommon, prmContext{}) + err := p.manager.initCallContext(&cc, prm.prmCommon, prmContext{}) if err != nil { return obj, err } - return obj, p.call(ctx, &cc, func() error { + return obj, p.manager.call(ctx, &cc, func() error { obj, err = cc.client.objectHead(ctx, prm) if err != nil { return fmt.Errorf("head object via client %s: %w", cc.endpoint, err) @@ -2739,19 +2767,19 @@ func (x *ResObjectRange) Close() error { // // Main return value MUST NOT be processed on an erroneous return. func (p *Pool) ObjectRange(ctx context.Context, prm PrmObjectRange) (ResObjectRange, error) { - p.fillAppropriateKey(&prm.prmCommon) + p.manager.fillAppropriateKey(&prm.prmCommon) var cc callContext cc.sessionTarget = prm.UseSession var res ResObjectRange - err := p.initCallContext(&cc, prm.prmCommon, prmContext{}) + err := p.manager.initCallContext(&cc, prm.prmCommon, prmContext{}) if err != nil { return res, err } - return res, p.call(ctx, &cc, func() error { + return res, p.manager.call(ctx, &cc, func() error { res, err = cc.client.objectRange(ctx, prm) if err != nil { return fmt.Errorf("object range via client %s: %w", cc.endpoint, err) @@ -2810,19 +2838,19 @@ func (x *ResObjectSearch) Close() { // // Main return value MUST NOT be processed on an erroneous return. func (p *Pool) SearchObjects(ctx context.Context, prm PrmObjectSearch) (ResObjectSearch, error) { - p.fillAppropriateKey(&prm.prmCommon) + p.manager.fillAppropriateKey(&prm.prmCommon) var cc callContext cc.sessionTarget = prm.UseSession var res ResObjectSearch - err := p.initCallContext(&cc, prm.prmCommon, prmContext{}) + err := p.manager.initCallContext(&cc, prm.prmCommon, prmContext{}) if err != nil { return res, err } - return res, p.call(ctx, &cc, func() error { + return res, p.manager.call(ctx, &cc, func() error { res, err = cc.client.objectSearch(ctx, prm) if err != nil { return fmt.Errorf("search object via client %s: %w", cc.endpoint, err) @@ -2842,7 +2870,7 @@ func (p *Pool) SearchObjects(ctx context.Context, prm PrmObjectSearch) (ResObjec // // Main return value MUST NOT be processed on an erroneous return. func (p *Pool) PutContainer(ctx context.Context, prm PrmContainerPut) (cid.ID, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return cid.ID{}, err } @@ -2859,7 +2887,7 @@ func (p *Pool) PutContainer(ctx context.Context, prm PrmContainerPut) (cid.ID, e // // Main return value MUST NOT be processed on an erroneous return. func (p *Pool) GetContainer(ctx context.Context, prm PrmContainerGet) (container.Container, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return container.Container{}, err } @@ -2874,7 +2902,7 @@ func (p *Pool) GetContainer(ctx context.Context, prm PrmContainerGet) (container // ListContainers requests identifiers of the account-owned containers. func (p *Pool) ListContainers(ctx context.Context, prm PrmContainerList) ([]cid.ID, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return nil, err } @@ -2896,7 +2924,7 @@ func (p *Pool) ListContainers(ctx context.Context, prm PrmContainerList) ([]cid. // // Success can be verified by reading by identifier (see GetContainer). func (p *Pool) DeleteContainer(ctx context.Context, prm PrmContainerDelete) error { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return err } @@ -2911,7 +2939,7 @@ func (p *Pool) DeleteContainer(ctx context.Context, prm PrmContainerDelete) erro // AddAPEChain sends a request to set APE chain rules for a target (basically, for a container). func (p *Pool) AddAPEChain(ctx context.Context, prm PrmAddAPEChain) error { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return err } @@ -2926,7 +2954,7 @@ func (p *Pool) AddAPEChain(ctx context.Context, prm PrmAddAPEChain) error { // RemoveAPEChain sends a request to remove APE chain rules for a target. func (p *Pool) RemoveAPEChain(ctx context.Context, prm PrmRemoveAPEChain) error { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return err } @@ -2941,7 +2969,7 @@ func (p *Pool) RemoveAPEChain(ctx context.Context, prm PrmRemoveAPEChain) error // ListAPEChains sends a request to list APE chains rules for a target. func (p *Pool) ListAPEChains(ctx context.Context, prm PrmListAPEChains) ([]ape.Chain, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return nil, err } @@ -2958,7 +2986,7 @@ func (p *Pool) ListAPEChains(ctx context.Context, prm PrmListAPEChains) ([]ape.C // // Main return value MUST NOT be processed on an erroneous return. func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (accounting.Decimal, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return accounting.Decimal{}, err } @@ -2973,8 +3001,12 @@ func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (accounting.Decim // Statistic returns connection statistics. func (p Pool) Statistic() Statistic { + return p.manager.Statistic() +} + +func (cm connectionManager) Statistic() Statistic { stat := Statistic{} - for _, inner := range p.innerPools { + for _, inner := range cm.innerPools { nodes := make([]string, 0, len(inner.clients)) inner.lock.RLock() for _, cl := range inner.clients { @@ -3042,7 +3074,7 @@ func waitFor(ctx context.Context, params *WaitParams, condition func(context.Con // // Main return value MUST NOT be processed on an erroneous return. func (p *Pool) NetworkInfo(ctx context.Context) (netmap.NetworkInfo, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return netmap.NetworkInfo{}, err } @@ -3059,7 +3091,7 @@ func (p *Pool) NetworkInfo(ctx context.Context) (netmap.NetworkInfo, error) { // // Main return value MUST NOT be processed on an erroneous return. func (p *Pool) NetMapSnapshot(ctx context.Context) (netmap.NetMap, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return netmap.NetMap{}, err } @@ -3074,11 +3106,15 @@ func (p *Pool) NetMapSnapshot(ctx context.Context) (netmap.NetMap, error) { // Close closes the Pool and releases all the associated resources. func (p *Pool) Close() { - p.cancel() - <-p.closedCh + p.manager.close() +} + +func (cm *connectionManager) close() { + cm.cancel() + <-cm.closedCh // close all clients - for _, pools := range p.innerPools { + for _, pools := range cm.innerPools { for _, cli := range pools.clients { _ = cli.close() } diff --git a/pool/pool_test.go b/pool/pool_test.go index 1362654..09d2bc6 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -104,11 +104,11 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { expectedAuthKey := frostfsecdsa.PublicKey(clientKeys[1].PublicKey) condition := func() bool { - cp, err := clientPool.connection() + cp, err := clientPool.manager.connection() if err != nil { return false } - st, _ := clientPool.cache.Get(formCacheKey(cp.address(), clientPool.key, false)) + st, _ := clientPool.manager.cache.Get(formCacheKey(cp.address(), clientPool.manager.key, false)) return st.AssertAuthKey(&expectedAuthKey) } require.Never(t, condition, 900*time.Millisecond, 100*time.Millisecond) @@ -141,9 +141,9 @@ func TestOneNode(t *testing.T) { require.NoError(t, err) t.Cleanup(pool.Close) - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) + st, _ := pool.manager.cache.Get(formCacheKey(cp.address(), pool.manager.key, false)) expectedAuthKey := frostfsecdsa.PublicKey(key1.PublicKey) require.True(t, st.AssertAuthKey(&expectedAuthKey)) } @@ -171,9 +171,9 @@ func TestTwoNodes(t *testing.T) { require.NoError(t, err) t.Cleanup(pool.Close) - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) + st, _ := pool.manager.cache.Get(formCacheKey(cp.address(), pool.manager.key, false)) require.True(t, assertAuthKeyForAny(st, clientKeys)) } @@ -226,9 +226,9 @@ func TestOneOfTwoFailed(t *testing.T) { time.Sleep(2 * time.Second) for range 5 { - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) + st, _ := pool.manager.cache.Get(formCacheKey(cp.address(), pool.manager.key, false)) require.True(t, assertAuthKeyForAny(st, clientKeys)) } } @@ -369,7 +369,7 @@ func TestUpdateNodesHealth(t *testing.T) { tc.prepareCli(cli) p, log := newPool(t, cli) - p.updateNodesHealth(ctx, [][]float64{{1}}) + p.manager.updateNodesHealth(ctx, [][]float64{{1}}) changed := tc.wasHealthy != tc.willHealthy require.Equalf(t, tc.willHealthy, cli.isHealthy(), "healthy status should be: %v", tc.willHealthy) @@ -385,19 +385,20 @@ func newPool(t *testing.T, cli *mockClient) (*Pool, *observer.ObservedLogs) { require.NoError(t, err) return &Pool{ - innerPools: []*innerPool{{ - sampler: newSampler([]float64{1}, rand.NewSource(0)), - clients: []client{cli}, - }}, - cache: cache, - key: newPrivateKey(t), - closedCh: make(chan struct{}), - rebalanceParams: rebalanceParameters{ - nodesParams: []*nodesParam{{1, []string{"peer0"}, []float64{1}}}, - nodeRequestTimeout: time.Second, - clientRebalanceInterval: 200 * time.Millisecond, - }, - logger: log, + manager: &connectionManager{ + innerPools: []*innerPool{{ + sampler: newSampler([]float64{1}, rand.NewSource(0)), + clients: []client{cli}, + }}, + cache: cache, + key: newPrivateKey(t), + closedCh: make(chan struct{}), + rebalanceParams: rebalanceParameters{ + nodesParams: []*nodesParam{{1, []string{"peer0"}, []float64{1}}}, + nodeRequestTimeout: time.Second, + clientRebalanceInterval: 200 * time.Millisecond, + }, + logger: log}, }, observedLog } @@ -435,7 +436,7 @@ func TestTwoFailed(t *testing.T) { time.Sleep(2 * time.Second) - _, err = pool.connection() + _, err = pool.manager.connection() require.Error(t, err) require.Contains(t, err.Error(), "no healthy") } @@ -469,9 +470,9 @@ func TestSessionCache(t *testing.T) { t.Cleanup(pool.Close) // cache must contain session token - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) + st, _ := pool.manager.cache.Get(formCacheKey(cp.address(), pool.manager.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) var prm PrmObjectGet @@ -482,9 +483,9 @@ func TestSessionCache(t *testing.T) { require.Error(t, err) // cache must not contain session token - cp, err = pool.connection() + cp, err = pool.manager.connection() require.NoError(t, err) - _, ok := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) + _, ok := pool.manager.cache.Get(formCacheKey(cp.address(), pool.manager.key, false)) require.False(t, ok) var prm2 PrmObjectPut @@ -494,9 +495,9 @@ func TestSessionCache(t *testing.T) { require.NoError(t, err) // cache must contain session token - cp, err = pool.connection() + cp, err = pool.manager.connection() require.NoError(t, err) - st, _ = pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) + st, _ = pool.manager.cache.Get(formCacheKey(cp.address(), pool.manager.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) } @@ -538,17 +539,17 @@ func TestPriority(t *testing.T) { expectedAuthKey1 := frostfsecdsa.PublicKey(clientKeys[0].PublicKey) firstNode := func() bool { - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) + st, _ := pool.manager.cache.Get(formCacheKey(cp.address(), pool.manager.key, false)) return st.AssertAuthKey(&expectedAuthKey1) } expectedAuthKey2 := frostfsecdsa.PublicKey(clientKeys[1].PublicKey) secondNode := func() bool { - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) + st, _ := pool.manager.cache.Get(formCacheKey(cp.address(), pool.manager.key, false)) return st.AssertAuthKey(&expectedAuthKey2) } require.Never(t, secondNode, time.Second, 200*time.Millisecond) @@ -583,9 +584,9 @@ func TestSessionCacheWithKey(t *testing.T) { require.NoError(t, err) // cache must contain session token - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) - st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) + st, _ := pool.manager.cache.Get(formCacheKey(cp.address(), pool.manager.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) var prm PrmObjectDelete @@ -595,7 +596,7 @@ func TestSessionCacheWithKey(t *testing.T) { err = pool.DeleteObject(ctx, prm) require.NoError(t, err) - st, _ = pool.cache.Get(formCacheKey(cp.address(), anonKey, false)) + st, _ = pool.manager.cache.Get(formCacheKey(cp.address(), anonKey, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) } @@ -636,10 +637,10 @@ func TestSessionTokenOwner(t *testing.T) { cc.sessionTarget = func(tok session.Object) { tkn = tok } - err = p.initCallContext(&cc, prm, prmCtx) + err = p.manager.initCallContext(&cc, prm, prmCtx) require.NoError(t, err) - err = p.openDefaultSession(ctx, &cc) + err = p.manager.openDefaultSession(ctx, &cc) require.NoError(t, err) require.True(t, tkn.VerifySignature()) require.True(t, tkn.Issuer().Equals(anonOwner)) @@ -922,14 +923,14 @@ func TestSwitchAfterErrorThreshold(t *testing.T) { t.Cleanup(pool.Close) for range errorThreshold { - conn, err := pool.connection() + conn, err := pool.manager.connection() require.NoError(t, err) require.Equal(t, nodes[0].address, conn.address()) _, err = conn.objectGet(ctx, PrmObjectGet{}) require.Error(t, err) } - conn, err := pool.connection() + conn, err := pool.manager.connection() require.NoError(t, err) require.Equal(t, nodes[1].address, conn.address()) _, err = conn.objectGet(ctx, PrmObjectGet{}) diff --git a/pool/sampler_test.go b/pool/sampler_test.go index ab06e0f..5577313 100644 --- a/pool/sampler_test.go +++ b/pool/sampler_test.go @@ -59,7 +59,7 @@ func TestHealthyReweight(t *testing.T) { sampler: newSampler(weights, rand.NewSource(0)), clients: []client{client1, client2}, } - p := &Pool{ + cm := &connectionManager{ innerPools: []*innerPool{inner}, cache: cache, key: newPrivateKey(t), @@ -67,14 +67,14 @@ func TestHealthyReweight(t *testing.T) { } // check getting first node connection before rebalance happened - connection0, err := p.connection() + connection0, err := cm.connection() require.NoError(t, err) mock0 := connection0.(*mockClient) require.Equal(t, names[0], mock0.address()) - p.updateInnerNodesHealth(context.TODO(), 0, buffer) + cm.updateInnerNodesHealth(context.TODO(), 0, buffer) - connection1, err := p.connection() + connection1, err := cm.connection() require.NoError(t, err) mock1 := connection1.(*mockClient) require.Equal(t, names[1], mock1.address()) @@ -84,10 +84,10 @@ func TestHealthyReweight(t *testing.T) { inner.clients[0] = newMockClient(names[0], *newPrivateKey(t)) inner.lock.Unlock() - p.updateInnerNodesHealth(context.TODO(), 0, buffer) + cm.updateInnerNodesHealth(context.TODO(), 0, buffer) inner.sampler = newSampler(weights, rand.NewSource(0)) - connection0, err = p.connection() + connection0, err = cm.connection() require.NoError(t, err) mock0 = connection0.(*mockClient) require.Equal(t, names[0], mock0.address()) @@ -108,12 +108,12 @@ func TestHealthyNoReweight(t *testing.T) { newMockClient(names[1], *newPrivateKey(t)), }, } - p := &Pool{ + cm := &connectionManager{ innerPools: []*innerPool{inner}, rebalanceParams: rebalanceParameters{nodesParams: []*nodesParam{{weights: weights}}}, } - p.updateInnerNodesHealth(context.TODO(), 0, buffer) + cm.updateInnerNodesHealth(context.TODO(), 0, buffer) inner.lock.RLock() defer inner.lock.RUnlock()