[#339] pool/tree: Add circuit breaker

Signed-off-by: Alex Vanin <a.vanin@yadro.com>
This commit is contained in:
Alexey Vanin 2025-03-03 17:06:04 +03:00
parent 2b8329e026
commit c8d71c450a
3 changed files with 149 additions and 2 deletions

View file

@ -0,0 +1,83 @@
package tree
import (
"context"
"errors"
"sync"
"time"
)
const (
defaultThreshold = 10
defaultBreakDuration = 10 * time.Second
)
type (
circuitBreaker struct {
breakDuration time.Duration
threshold int
mu sync.Mutex
state map[string]state
}
state struct {
counter int
breakTimestamp time.Time
}
dialer interface {
dial(context.Context) error
endpoint() string
}
)
var ErrCBClosed = errors.New("circuit breaker is closed")
func NewCircuitBreaker(breakDuration time.Duration, threshold int) *circuitBreaker {
if threshold == 0 {
threshold = defaultThreshold
}
if breakDuration == 0 {
breakDuration = defaultBreakDuration
}
return &circuitBreaker{
breakDuration: breakDuration,
threshold: threshold,
state: make(map[string]state),
}
}
func (c *circuitBreaker) Dial(ctx context.Context, cli dialer) error {
c.mu.Lock()
defer c.mu.Unlock()
endpoint := cli.endpoint()
if _, ok := c.state[endpoint]; !ok {
c.state[endpoint] = state{}
}
s := c.state[endpoint]
defer func() {
c.state[endpoint] = s
}()
if time.Since(s.breakTimestamp) < c.breakDuration {
return ErrCBClosed
}
err := cli.dial(ctx)
if err == nil {
s.counter = 0
return nil
}
s.counter++
if s.counter >= c.threshold {
s.counter = c.threshold
s.breakTimestamp = time.Now()
}
return err
}

View file

@ -0,0 +1,60 @@
package tree
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type (
testDialer struct {
err error
addr string
}
)
func (d *testDialer) dial(_ context.Context) error {
return d.err
}
func (d *testDialer) endpoint() string {
return d.addr
}
func TestCircuitBreaker(t *testing.T) {
ctx := context.Background()
remoteErr := errors.New("service is being synchronized")
d := &testDialer{
err: remoteErr,
addr: "addr",
}
// Hit threshold
cb := NewCircuitBreaker(1*time.Second, 10)
for i := 0; i < 10; i++ {
err := cb.Dial(ctx, d)
require.ErrorIs(t, err, remoteErr)
}
// Immediate request should return circuit breaker error
d.err = nil
require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed)
// Request after breakDuration should be ok
time.Sleep(1 * time.Second)
require.NoError(t, cb.Dial(ctx, d))
// Then hit threshold once again
d.err = remoteErr
for i := 0; i < 10; i++ {
err := cb.Dial(ctx, d)
require.ErrorIs(t, err, remoteErr)
}
// Immediate request should return circuit breaker error
d.err = nil
require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed)
}

View file

@ -117,6 +117,9 @@ type Pool struct {
// * retry in case of request failure (see Pool.requestWithRetry)
// startIndices will be used if netMapInfoSource is not set
startIndices [2]int
// circuit breaker for dial operations
cb *circuitBreaker
}
type innerPool struct {
@ -248,6 +251,7 @@ func NewPool(options InitParameters) (*Pool, error) {
methods: methods,
netMapInfoSource: options.netMapInfoSource,
clientMap: make(map[uint64]client),
cb: NewCircuitBreaker(0, 0),
}
if options.netMapInfoSource == nil {
@ -284,7 +288,7 @@ func (p *Pool) Dial(ctx context.Context) error {
clients := make([]client, len(nodes))
for j, node := range nodes {
clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
if err := clients[j].dial(ctx); err != nil {
if err := p.cb.Dial(ctx, clients[j]); err != nil {
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err))
continue
}
@ -1055,7 +1059,7 @@ func (p *Pool) getNewTreeClient(ctx context.Context, node netmap.NodeInfo) (*tre
}
newTreeCl := newTreeClient(addr.URIAddr(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
if err = newTreeCl.dial(ctx); err != nil {
if err = p.cb.Dial(ctx, newTreeCl); err != nil {
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", addr.URIAddr()), zap.Error(err))
// We have to close connection here after failed `dial()`.