From f78fb6dcb064fd2e6521a54b42b088d82560ebb0 Mon Sep 17 00:00:00 2001 From: Alex Vanin Date: Tue, 4 Mar 2025 12:17:20 +0300 Subject: [PATCH] [#339] pool/tree: Make circuit breaker more generic Signed-off-by: Alex Vanin --- pool/tree/circuitbreaker.go | 35 ++++++----------------- pool/tree/circuitbreaker_test.go | 49 ++++++++++---------------------- pool/tree/pool.go | 38 ++++++++++++++++++------- 3 files changed, 51 insertions(+), 71 deletions(-) diff --git a/pool/tree/circuitbreaker.go b/pool/tree/circuitbreaker.go index 83d11c4..c494ddd 100644 --- a/pool/tree/circuitbreaker.go +++ b/pool/tree/circuitbreaker.go @@ -1,73 +1,54 @@ package tree import ( - "context" "errors" "sync" "time" ) -const ( - defaultThreshold = 10 - defaultBreakDuration = 10 * time.Second -) - type ( circuitBreaker struct { breakDuration time.Duration threshold int mu sync.Mutex - state map[string]state + state map[uint64]state } state struct { counter int breakTimestamp time.Time } - - dialer interface { - dial(context.Context) error - endpoint() string - } ) var ErrCBClosed = errors.New("circuit breaker is closed") func NewCircuitBreaker(breakDuration time.Duration, threshold int) *circuitBreaker { - if threshold == 0 { - threshold = defaultThreshold - } - if breakDuration == 0 { - breakDuration = defaultBreakDuration - } return &circuitBreaker{ breakDuration: breakDuration, threshold: threshold, - state: make(map[string]state), + state: make(map[uint64]state), } } -func (c *circuitBreaker) Dial(ctx context.Context, cli dialer) error { +func (c *circuitBreaker) Do(id uint64, f func() error) error { c.mu.Lock() defer c.mu.Unlock() - endpoint := cli.endpoint() - - if _, ok := c.state[endpoint]; !ok { - c.state[endpoint] = state{} + if _, ok := c.state[id]; !ok { + c.state[id] = state{} } - s := c.state[endpoint] + s := c.state[id] defer func() { - c.state[endpoint] = s + c.state[id] = s }() if time.Since(s.breakTimestamp) < c.breakDuration { return ErrCBClosed } - err := cli.dial(ctx) + err := f() if err == nil { s.counter = 0 return nil diff --git a/pool/tree/circuitbreaker_test.go b/pool/tree/circuitbreaker_test.go index f507655..4f7974c 100644 --- a/pool/tree/circuitbreaker_test.go +++ b/pool/tree/circuitbreaker_test.go @@ -1,7 +1,6 @@ package tree import ( - "context" "errors" "testing" "time" @@ -9,52 +8,34 @@ import ( "github.com/stretchr/testify/require" ) -type ( - testDialer struct { - err error - addr string - } -) - -func (d *testDialer) dial(_ context.Context) error { - return d.err -} - -func (d *testDialer) endpoint() string { - return d.addr -} - func TestCircuitBreaker(t *testing.T) { - ctx := context.Background() remoteErr := errors.New("service is being synchronized") - d := &testDialer{ - err: remoteErr, - addr: "addr", - } + breakDuration := 1 * time.Second + threshold := 10 + cb := NewCircuitBreaker(breakDuration, threshold) // Hit threshold - cb := NewCircuitBreaker(1*time.Second, 10) - for i := 0; i < 10; i++ { - err := cb.Dial(ctx, d) + for i := 0; i < threshold; i++ { + err := cb.Do(1, func() error { return remoteErr }) require.ErrorIs(t, err, remoteErr) } + // Different client should not be affected by threshold + require.NoError(t, cb.Do(2, func() error { return nil })) + // Immediate request should return circuit breaker error - d.err = nil - require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed) + require.ErrorIs(t, cb.Do(1, func() error { return nil }), ErrCBClosed) // Request after breakDuration should be ok - time.Sleep(1 * time.Second) - require.NoError(t, cb.Dial(ctx, d)) + time.Sleep(breakDuration) + require.NoError(t, cb.Do(1, func() error { return nil })) - // Then hit threshold once again - d.err = remoteErr - for i := 0; i < 10; i++ { - err := cb.Dial(ctx, d) + // Try hitting threshold one more time after break duration + for i := 0; i < threshold; i++ { + err := cb.Do(1, func() error { return remoteErr }) require.ErrorIs(t, err, remoteErr) } // Immediate request should return circuit breaker error - d.err = nil - require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed) + require.ErrorIs(t, cb.Do(1, func() error { return nil }), ErrCBClosed) } diff --git a/pool/tree/pool.go b/pool/tree/pool.go index b0c8af6..5c7fbce 100644 --- a/pool/tree/pool.go +++ b/pool/tree/pool.go @@ -24,10 +24,12 @@ import ( ) const ( - defaultRebalanceInterval = 15 * time.Second - defaultHealthcheckTimeout = 4 * time.Second - defaultDialTimeout = 5 * time.Second - defaultStreamTimeout = 10 * time.Second + defaultRebalanceInterval = 15 * time.Second + defaultHealthcheckTimeout = 4 * time.Second + defaultDialTimeout = 5 * time.Second + defaultStreamTimeout = 10 * time.Second + defaultCircuitBreakerDuration = 10 * time.Second + defaultCircuitBreakerTreshold = 10 ) // SubTreeSort defines an order of nodes returned from GetSubTree RPC. @@ -76,6 +78,8 @@ type InitParameters struct { dialOptions []grpc.DialOption maxRequestAttempts int netMapInfoSource NetMapInfoSource + circuitBreakerThreshold int + circuitBreakerDuration time.Duration } type NetMapInfoSource interface { @@ -117,8 +121,7 @@ type Pool struct { // * retry in case of request failure (see Pool.requestWithRetry) // startIndices will be used if netMapInfoSource is not set startIndices [2]int - - // circuit breaker for dial operations + // circuit breaker for dial operations when netmap is being used cb *circuitBreaker } @@ -251,7 +254,10 @@ func NewPool(options InitParameters) (*Pool, error) { methods: methods, netMapInfoSource: options.netMapInfoSource, clientMap: make(map[uint64]client), - cb: NewCircuitBreaker(0, 0), + cb: NewCircuitBreaker( + options.circuitBreakerDuration, + options.circuitBreakerThreshold, + ), } if options.netMapInfoSource == nil { @@ -288,7 +294,7 @@ func (p *Pool) Dial(ctx context.Context) error { clients := make([]client, len(nodes)) for j, node := range nodes { clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout) - if err := p.cb.Dial(ctx, clients[j]); err != nil { + if err := clients[j].dial(ctx); err != nil { p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err)) continue } @@ -793,6 +799,15 @@ func fillDefaultInitParams(params *InitParameters) { if params.maxRequestAttempts <= 0 { params.maxRequestAttempts = len(params.nodeParams) } + + if params.circuitBreakerDuration <= 0 { + params.circuitBreakerDuration = defaultCircuitBreakerDuration + + } + + if params.circuitBreakerThreshold <= 0 { + params.circuitBreakerThreshold = defaultCircuitBreakerTreshold + } } func (p *Pool) log(level zapcore.Level, msg string, fields ...zap.Field) { @@ -988,7 +1003,10 @@ LOOP: treeCl, ok := p.getClientFromMap(cnrNode.Hash()) if !ok { - treeCl, err = p.getNewTreeClient(ctx, cnrNode) + err = p.cb.Do(cnrNode.Hash(), func() error { + treeCl, err = p.getNewTreeClient(ctx, cnrNode) + return err + }) if err != nil { finErr = finalError(finErr, err) p.log(zap.DebugLevel, "failed to create tree client", zap.String("request_id", reqID), zap.Int("remaining attempts", attempts)) @@ -1059,7 +1077,7 @@ func (p *Pool) getNewTreeClient(ctx context.Context, node netmap.NodeInfo) (*tre } newTreeCl := newTreeClient(addr.URIAddr(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout) - if err = p.cb.Dial(ctx, newTreeCl); err != nil { + if err = newTreeCl.dial(ctx); err != nil { p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", addr.URIAddr()), zap.Error(err)) // We have to close connection here after failed `dial()`.