[#339] pool/tree: Make circuit breaker more generic

Signed-off-by: Alex Vanin <a.vanin@yadro.com>
This commit is contained in:
Alexey Vanin 2025-03-04 12:17:20 +03:00
parent c8d71c450a
commit f78fb6dcb0
3 changed files with 51 additions and 71 deletions

View file

@ -1,73 +1,54 @@
package tree package tree
import ( import (
"context"
"errors" "errors"
"sync" "sync"
"time" "time"
) )
const (
defaultThreshold = 10
defaultBreakDuration = 10 * time.Second
)
type ( type (
circuitBreaker struct { circuitBreaker struct {
breakDuration time.Duration breakDuration time.Duration
threshold int threshold int
mu sync.Mutex mu sync.Mutex
state map[string]state state map[uint64]state
} }
state struct { state struct {
counter int counter int
breakTimestamp time.Time breakTimestamp time.Time
} }
dialer interface {
dial(context.Context) error
endpoint() string
}
) )
var ErrCBClosed = errors.New("circuit breaker is closed") var ErrCBClosed = errors.New("circuit breaker is closed")
func NewCircuitBreaker(breakDuration time.Duration, threshold int) *circuitBreaker { func NewCircuitBreaker(breakDuration time.Duration, threshold int) *circuitBreaker {
if threshold == 0 {
threshold = defaultThreshold
}
if breakDuration == 0 {
breakDuration = defaultBreakDuration
}
return &circuitBreaker{ return &circuitBreaker{
breakDuration: breakDuration, breakDuration: breakDuration,
threshold: threshold, threshold: threshold,
state: make(map[string]state), state: make(map[uint64]state),
} }
} }
func (c *circuitBreaker) Dial(ctx context.Context, cli dialer) error { func (c *circuitBreaker) Do(id uint64, f func() error) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
endpoint := cli.endpoint() if _, ok := c.state[id]; !ok {
c.state[id] = state{}
if _, ok := c.state[endpoint]; !ok {
c.state[endpoint] = state{}
} }
s := c.state[endpoint] s := c.state[id]
defer func() { defer func() {
c.state[endpoint] = s c.state[id] = s
}() }()
if time.Since(s.breakTimestamp) < c.breakDuration { if time.Since(s.breakTimestamp) < c.breakDuration {
return ErrCBClosed return ErrCBClosed
} }
err := cli.dial(ctx) err := f()
if err == nil { if err == nil {
s.counter = 0 s.counter = 0
return nil return nil

View file

@ -1,7 +1,6 @@
package tree package tree
import ( import (
"context"
"errors" "errors"
"testing" "testing"
"time" "time"
@ -9,52 +8,34 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type (
testDialer struct {
err error
addr string
}
)
func (d *testDialer) dial(_ context.Context) error {
return d.err
}
func (d *testDialer) endpoint() string {
return d.addr
}
func TestCircuitBreaker(t *testing.T) { func TestCircuitBreaker(t *testing.T) {
ctx := context.Background()
remoteErr := errors.New("service is being synchronized") remoteErr := errors.New("service is being synchronized")
d := &testDialer{ breakDuration := 1 * time.Second
err: remoteErr, threshold := 10
addr: "addr", cb := NewCircuitBreaker(breakDuration, threshold)
}
// Hit threshold // Hit threshold
cb := NewCircuitBreaker(1*time.Second, 10) for i := 0; i < threshold; i++ {
for i := 0; i < 10; i++ { err := cb.Do(1, func() error { return remoteErr })
err := cb.Dial(ctx, d)
require.ErrorIs(t, err, remoteErr) require.ErrorIs(t, err, remoteErr)
} }
// Different client should not be affected by threshold
require.NoError(t, cb.Do(2, func() error { return nil }))
// Immediate request should return circuit breaker error // Immediate request should return circuit breaker error
d.err = nil require.ErrorIs(t, cb.Do(1, func() error { return nil }), ErrCBClosed)
require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed)
// Request after breakDuration should be ok // Request after breakDuration should be ok
time.Sleep(1 * time.Second) time.Sleep(breakDuration)
require.NoError(t, cb.Dial(ctx, d)) require.NoError(t, cb.Do(1, func() error { return nil }))
// Then hit threshold once again // Try hitting threshold one more time after break duration
d.err = remoteErr for i := 0; i < threshold; i++ {
for i := 0; i < 10; i++ { err := cb.Do(1, func() error { return remoteErr })
err := cb.Dial(ctx, d)
require.ErrorIs(t, err, remoteErr) require.ErrorIs(t, err, remoteErr)
} }
// Immediate request should return circuit breaker error // Immediate request should return circuit breaker error
d.err = nil require.ErrorIs(t, cb.Do(1, func() error { return nil }), ErrCBClosed)
require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed)
} }

View file

@ -28,6 +28,8 @@ const (
defaultHealthcheckTimeout = 4 * time.Second defaultHealthcheckTimeout = 4 * time.Second
defaultDialTimeout = 5 * time.Second defaultDialTimeout = 5 * time.Second
defaultStreamTimeout = 10 * time.Second defaultStreamTimeout = 10 * time.Second
defaultCircuitBreakerDuration = 10 * time.Second
defaultCircuitBreakerTreshold = 10
) )
// SubTreeSort defines an order of nodes returned from GetSubTree RPC. // SubTreeSort defines an order of nodes returned from GetSubTree RPC.
@ -76,6 +78,8 @@ type InitParameters struct {
dialOptions []grpc.DialOption dialOptions []grpc.DialOption
maxRequestAttempts int maxRequestAttempts int
netMapInfoSource NetMapInfoSource netMapInfoSource NetMapInfoSource
circuitBreakerThreshold int
circuitBreakerDuration time.Duration
} }
type NetMapInfoSource interface { type NetMapInfoSource interface {
@ -117,8 +121,7 @@ type Pool struct {
// * retry in case of request failure (see Pool.requestWithRetry) // * retry in case of request failure (see Pool.requestWithRetry)
// startIndices will be used if netMapInfoSource is not set // startIndices will be used if netMapInfoSource is not set
startIndices [2]int startIndices [2]int
// circuit breaker for dial operations when netmap is being used
// circuit breaker for dial operations
cb *circuitBreaker cb *circuitBreaker
} }
@ -251,7 +254,10 @@ func NewPool(options InitParameters) (*Pool, error) {
methods: methods, methods: methods,
netMapInfoSource: options.netMapInfoSource, netMapInfoSource: options.netMapInfoSource,
clientMap: make(map[uint64]client), clientMap: make(map[uint64]client),
cb: NewCircuitBreaker(0, 0), cb: NewCircuitBreaker(
options.circuitBreakerDuration,
options.circuitBreakerThreshold,
),
} }
if options.netMapInfoSource == nil { if options.netMapInfoSource == nil {
@ -288,7 +294,7 @@ func (p *Pool) Dial(ctx context.Context) error {
clients := make([]client, len(nodes)) clients := make([]client, len(nodes))
for j, node := range nodes { for j, node := range nodes {
clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout) clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
if err := p.cb.Dial(ctx, clients[j]); err != nil { if err := clients[j].dial(ctx); err != nil {
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err)) p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err))
continue continue
} }
@ -793,6 +799,15 @@ func fillDefaultInitParams(params *InitParameters) {
if params.maxRequestAttempts <= 0 { if params.maxRequestAttempts <= 0 {
params.maxRequestAttempts = len(params.nodeParams) params.maxRequestAttempts = len(params.nodeParams)
} }
if params.circuitBreakerDuration <= 0 {
params.circuitBreakerDuration = defaultCircuitBreakerDuration
}
if params.circuitBreakerThreshold <= 0 {
params.circuitBreakerThreshold = defaultCircuitBreakerTreshold
}
} }
func (p *Pool) log(level zapcore.Level, msg string, fields ...zap.Field) { func (p *Pool) log(level zapcore.Level, msg string, fields ...zap.Field) {
@ -988,7 +1003,10 @@ LOOP:
treeCl, ok := p.getClientFromMap(cnrNode.Hash()) treeCl, ok := p.getClientFromMap(cnrNode.Hash())
if !ok { if !ok {
err = p.cb.Do(cnrNode.Hash(), func() error {
treeCl, err = p.getNewTreeClient(ctx, cnrNode) treeCl, err = p.getNewTreeClient(ctx, cnrNode)
return err
})
if err != nil { if err != nil {
finErr = finalError(finErr, err) finErr = finalError(finErr, err)
p.log(zap.DebugLevel, "failed to create tree client", zap.String("request_id", reqID), zap.Int("remaining attempts", attempts)) p.log(zap.DebugLevel, "failed to create tree client", zap.String("request_id", reqID), zap.Int("remaining attempts", attempts))
@ -1059,7 +1077,7 @@ func (p *Pool) getNewTreeClient(ctx context.Context, node netmap.NodeInfo) (*tre
} }
newTreeCl := newTreeClient(addr.URIAddr(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout) newTreeCl := newTreeClient(addr.URIAddr(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
if err = p.cb.Dial(ctx, newTreeCl); err != nil { if err = newTreeCl.dial(ctx); err != nil {
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", addr.URIAddr()), zap.Error(err)) p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", addr.URIAddr()), zap.Error(err))
// We have to close connection here after failed `dial()`. // We have to close connection here after failed `dial()`.