forked from TrueCloudLab/frostfs-sdk-go
[#339] pool/tree: Add circuit breaker
Signed-off-by: Alex Vanin <a.vanin@yadro.com>
This commit is contained in:
parent
2b8329e026
commit
c8d71c450a
3 changed files with 149 additions and 2 deletions
83
pool/tree/circuitbreaker.go
Normal file
83
pool/tree/circuitbreaker.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package tree
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultThreshold = 10
|
||||
defaultBreakDuration = 10 * time.Second
|
||||
)
|
||||
|
||||
type (
|
||||
circuitBreaker struct {
|
||||
breakDuration time.Duration
|
||||
threshold int
|
||||
|
||||
mu sync.Mutex
|
||||
state map[string]state
|
||||
}
|
||||
|
||||
state struct {
|
||||
counter int
|
||||
breakTimestamp time.Time
|
||||
}
|
||||
|
||||
dialer interface {
|
||||
dial(context.Context) error
|
||||
endpoint() string
|
||||
}
|
||||
)
|
||||
|
||||
var ErrCBClosed = errors.New("circuit breaker is closed")
|
||||
|
||||
func NewCircuitBreaker(breakDuration time.Duration, threshold int) *circuitBreaker {
|
||||
if threshold == 0 {
|
||||
threshold = defaultThreshold
|
||||
}
|
||||
if breakDuration == 0 {
|
||||
breakDuration = defaultBreakDuration
|
||||
}
|
||||
return &circuitBreaker{
|
||||
breakDuration: breakDuration,
|
||||
threshold: threshold,
|
||||
state: make(map[string]state),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *circuitBreaker) Dial(ctx context.Context, cli dialer) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
endpoint := cli.endpoint()
|
||||
|
||||
if _, ok := c.state[endpoint]; !ok {
|
||||
c.state[endpoint] = state{}
|
||||
}
|
||||
|
||||
s := c.state[endpoint]
|
||||
defer func() {
|
||||
c.state[endpoint] = s
|
||||
}()
|
||||
|
||||
if time.Since(s.breakTimestamp) < c.breakDuration {
|
||||
return ErrCBClosed
|
||||
}
|
||||
|
||||
err := cli.dial(ctx)
|
||||
if err == nil {
|
||||
s.counter = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
s.counter++
|
||||
if s.counter >= c.threshold {
|
||||
s.counter = c.threshold
|
||||
s.breakTimestamp = time.Now()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
60
pool/tree/circuitbreaker_test.go
Normal file
60
pool/tree/circuitbreaker_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package tree
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type (
|
||||
testDialer struct {
|
||||
err error
|
||||
addr string
|
||||
}
|
||||
)
|
||||
|
||||
func (d *testDialer) dial(_ context.Context) error {
|
||||
return d.err
|
||||
}
|
||||
|
||||
func (d *testDialer) endpoint() string {
|
||||
return d.addr
|
||||
}
|
||||
|
||||
func TestCircuitBreaker(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
remoteErr := errors.New("service is being synchronized")
|
||||
d := &testDialer{
|
||||
err: remoteErr,
|
||||
addr: "addr",
|
||||
}
|
||||
|
||||
// Hit threshold
|
||||
cb := NewCircuitBreaker(1*time.Second, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
err := cb.Dial(ctx, d)
|
||||
require.ErrorIs(t, err, remoteErr)
|
||||
}
|
||||
|
||||
// Immediate request should return circuit breaker error
|
||||
d.err = nil
|
||||
require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed)
|
||||
|
||||
// Request after breakDuration should be ok
|
||||
time.Sleep(1 * time.Second)
|
||||
require.NoError(t, cb.Dial(ctx, d))
|
||||
|
||||
// Then hit threshold once again
|
||||
d.err = remoteErr
|
||||
for i := 0; i < 10; i++ {
|
||||
err := cb.Dial(ctx, d)
|
||||
require.ErrorIs(t, err, remoteErr)
|
||||
}
|
||||
|
||||
// Immediate request should return circuit breaker error
|
||||
d.err = nil
|
||||
require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed)
|
||||
}
|
|
@ -117,6 +117,9 @@ type Pool struct {
|
|||
// * retry in case of request failure (see Pool.requestWithRetry)
|
||||
// startIndices will be used if netMapInfoSource is not set
|
||||
startIndices [2]int
|
||||
|
||||
// circuit breaker for dial operations
|
||||
cb *circuitBreaker
|
||||
}
|
||||
|
||||
type innerPool struct {
|
||||
|
@ -248,6 +251,7 @@ func NewPool(options InitParameters) (*Pool, error) {
|
|||
methods: methods,
|
||||
netMapInfoSource: options.netMapInfoSource,
|
||||
clientMap: make(map[uint64]client),
|
||||
cb: NewCircuitBreaker(0, 0),
|
||||
}
|
||||
|
||||
if options.netMapInfoSource == nil {
|
||||
|
@ -284,7 +288,7 @@ func (p *Pool) Dial(ctx context.Context) error {
|
|||
clients := make([]client, len(nodes))
|
||||
for j, node := range nodes {
|
||||
clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
|
||||
if err := clients[j].dial(ctx); err != nil {
|
||||
if err := p.cb.Dial(ctx, clients[j]); err != nil {
|
||||
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
@ -1055,7 +1059,7 @@ func (p *Pool) getNewTreeClient(ctx context.Context, node netmap.NodeInfo) (*tre
|
|||
}
|
||||
|
||||
newTreeCl := newTreeClient(addr.URIAddr(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
|
||||
if err = newTreeCl.dial(ctx); err != nil {
|
||||
if err = p.cb.Dial(ctx, newTreeCl); err != nil {
|
||||
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", addr.URIAddr()), zap.Error(err))
|
||||
|
||||
// We have to close connection here after failed `dial()`.
|
||||
|
|
Loading…
Add table
Reference in a new issue