diff --git a/pool/tree/circuitbreaker.go b/pool/tree/circuitbreaker.go new file mode 100644 index 00000000..83d11c42 --- /dev/null +++ b/pool/tree/circuitbreaker.go @@ -0,0 +1,83 @@ +package tree + +import ( + "context" + "errors" + "sync" + "time" +) + +const ( + defaultThreshold = 10 + defaultBreakDuration = 10 * time.Second +) + +type ( + circuitBreaker struct { + breakDuration time.Duration + threshold int + + mu sync.Mutex + state map[string]state + } + + state struct { + counter int + breakTimestamp time.Time + } + + dialer interface { + dial(context.Context) error + endpoint() string + } +) + +var ErrCBClosed = errors.New("circuit breaker is closed") + +func NewCircuitBreaker(breakDuration time.Duration, threshold int) *circuitBreaker { + if threshold == 0 { + threshold = defaultThreshold + } + if breakDuration == 0 { + breakDuration = defaultBreakDuration + } + return &circuitBreaker{ + breakDuration: breakDuration, + threshold: threshold, + state: make(map[string]state), + } +} + +func (c *circuitBreaker) Dial(ctx context.Context, cli dialer) error { + c.mu.Lock() + defer c.mu.Unlock() + + endpoint := cli.endpoint() + + if _, ok := c.state[endpoint]; !ok { + c.state[endpoint] = state{} + } + + s := c.state[endpoint] + defer func() { + c.state[endpoint] = s + }() + + if time.Since(s.breakTimestamp) < c.breakDuration { + return ErrCBClosed + } + + err := cli.dial(ctx) + if err == nil { + s.counter = 0 + return nil + } + + s.counter++ + if s.counter >= c.threshold { + s.counter = c.threshold + s.breakTimestamp = time.Now() + } + + return err +} diff --git a/pool/tree/circuitbreaker_test.go b/pool/tree/circuitbreaker_test.go new file mode 100644 index 00000000..f507655e --- /dev/null +++ b/pool/tree/circuitbreaker_test.go @@ -0,0 +1,60 @@ +package tree + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type ( + testDialer struct { + err error + addr string + } +) + +func (d *testDialer) dial(_ context.Context) error { + return d.err +} + +func (d *testDialer) endpoint() string { + return d.addr +} + +func TestCircuitBreaker(t *testing.T) { + ctx := context.Background() + remoteErr := errors.New("service is being synchronized") + d := &testDialer{ + err: remoteErr, + addr: "addr", + } + + // Hit threshold + cb := NewCircuitBreaker(1*time.Second, 10) + for i := 0; i < 10; i++ { + err := cb.Dial(ctx, d) + require.ErrorIs(t, err, remoteErr) + } + + // Immediate request should return circuit breaker error + d.err = nil + require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed) + + // Request after breakDuration should be ok + time.Sleep(1 * time.Second) + require.NoError(t, cb.Dial(ctx, d)) + + // Then hit threshold once again + d.err = remoteErr + for i := 0; i < 10; i++ { + err := cb.Dial(ctx, d) + require.ErrorIs(t, err, remoteErr) + } + + // Immediate request should return circuit breaker error + d.err = nil + require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed) +} diff --git a/pool/tree/pool.go b/pool/tree/pool.go index c82e2697..b0c8af63 100644 --- a/pool/tree/pool.go +++ b/pool/tree/pool.go @@ -117,6 +117,9 @@ type Pool struct { // * retry in case of request failure (see Pool.requestWithRetry) // startIndices will be used if netMapInfoSource is not set startIndices [2]int + + // circuit breaker for dial operations + cb *circuitBreaker } type innerPool struct { @@ -248,6 +251,7 @@ func NewPool(options InitParameters) (*Pool, error) { methods: methods, netMapInfoSource: options.netMapInfoSource, clientMap: make(map[uint64]client), + cb: NewCircuitBreaker(0, 0), } if options.netMapInfoSource == nil { @@ -284,7 +288,7 @@ func (p *Pool) Dial(ctx context.Context) error { clients := make([]client, len(nodes)) for j, node := range nodes { clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout) - if err := clients[j].dial(ctx); err != nil { + if err := p.cb.Dial(ctx, clients[j]); err != nil { p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err)) continue } @@ -1055,7 +1059,7 @@ func (p *Pool) getNewTreeClient(ctx context.Context, node netmap.NodeInfo) (*tre } newTreeCl := newTreeClient(addr.URIAddr(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout) - if err = newTreeCl.dial(ctx); err != nil { + if err = p.cb.Dial(ctx, newTreeCl); err != nil { p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", addr.URIAddr()), zap.Error(err)) // We have to close connection here after failed `dial()`.