diff --git a/pool/pool.go b/pool/pool.go index 5b37fcb..4cf156d 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -2023,6 +2023,21 @@ type connectionManager struct { rebalanceParams rebalanceParameters clientBuilder clientBuilder logger *zap.Logger + healthChecker *healthCheck +} + +type healthCheck struct { + cancel context.CancelFunc + closedCh chan struct{} + + clientRebalanceInterval time.Duration +} + +func newHealthCheck(clientRebalanceInterval time.Duration) *healthCheck { + var h healthCheck + h.clientRebalanceInterval = clientRebalanceInterval + h.closedCh = make(chan struct{}) + return &h } type innerPool struct { @@ -2167,15 +2182,22 @@ func (cm *connectionManager) dial(ctx context.Context) error { return fmt.Errorf("at least one node must be healthy") } - ctx, cancel := context.WithCancel(ctx) - cm.cancel = cancel - cm.closedCh = make(chan struct{}) cm.innerPools = inner - go cm.startRebalance(ctx) + cm.healthChecker = newHealthCheck(cm.rebalanceParams.clientRebalanceInterval) + cm.healthChecker.startRebalance(ctx, cm.rebalance) return nil } +func (cm *connectionManager) rebalance(ctx context.Context) { + buffers := make([][]float64, len(cm.rebalanceParams.nodesParams)) + for i, params := range cm.rebalanceParams.nodesParams { + buffers[i] = make([]float64, len(params.weights)) + } + + cm.updateNodesHealth(ctx, buffers) +} + func (cm *connectionManager) log(level zapcore.Level, msg string, fields ...zap.Field) { if cm.logger == nil { return @@ -2268,25 +2290,30 @@ func adjustNodeParams(nodeParams []NodeParam) ([]*nodesParam, error) { } // startRebalance runs loop to monitor connection healthy status. -func (cm *connectionManager) startRebalance(ctx context.Context) { - ticker := time.NewTicker(cm.rebalanceParams.clientRebalanceInterval) - defer ticker.Stop() +func (h *healthCheck) startRebalance(ctx context.Context, callback func(ctx context.Context)) { + ctx, cancel := context.WithCancel(ctx) + h.cancel = cancel - buffers := make([][]float64, len(cm.rebalanceParams.nodesParams)) - for i, params := range cm.rebalanceParams.nodesParams { - buffers[i] = make([]float64, len(params.weights)) - } + go func() { + ticker := time.NewTicker(h.clientRebalanceInterval) + defer ticker.Stop() - for { - select { - case <-ctx.Done(): - close(cm.closedCh) - return - case <-ticker.C: - cm.updateNodesHealth(ctx, buffers) - ticker.Reset(cm.rebalanceParams.clientRebalanceInterval) + for { + select { + case <-ctx.Done(): + close(h.closedCh) + return + case <-ticker.C: + callback(ctx) + ticker.Reset(h.clientRebalanceInterval) + } } - } + }() +} + +func (h *healthCheck) stopRebalance() { + h.cancel() + <-h.closedCh } func (cm *connectionManager) updateNodesHealth(ctx context.Context, buffers [][]float64) { @@ -3209,8 +3236,7 @@ func (p *Pool) Close() { } func (cm *connectionManager) close() { - cm.cancel() - <-cm.closedCh + cm.healthChecker.stopRebalance() // close all clients for _, pools := range cm.innerPools { diff --git a/pool/pool_test.go b/pool/pool_test.go index 75163fd..b063294 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -391,11 +391,10 @@ func newPool(t *testing.T, cli *mockClient) (*Pool, *observer.ObservedLogs) { sampler: newSampler([]float64{1}, rand.NewSource(0)), clients: []client{cli}, }}, - closedCh: make(chan struct{}), + healthChecker: newHealthCheck(200 * time.Millisecond), rebalanceParams: rebalanceParameters{ - nodesParams: []*nodesParam{{1, []string{"peer0"}, []float64{1}}}, - nodeRequestTimeout: time.Second, - clientRebalanceInterval: 200 * time.Millisecond, + nodesParams: []*nodesParam{{1, []string{"peer0"}, []float64{1}}}, + nodeRequestTimeout: time.Second, }, logger: log}, }, observedLog