forked from TrueCloudLab/frostfs-sdk-go
[#339] pool/tree: Make circuit breaker more generic
Signed-off-by: Alex Vanin <a.vanin@yadro.com>
This commit is contained in:
parent
c8d71c450a
commit
f78fb6dcb0
3 changed files with 51 additions and 71 deletions
|
@ -1,73 +1,54 @@
|
||||||
package tree
|
package tree
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
defaultThreshold = 10
|
|
||||||
defaultBreakDuration = 10 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
circuitBreaker struct {
|
circuitBreaker struct {
|
||||||
breakDuration time.Duration
|
breakDuration time.Duration
|
||||||
threshold int
|
threshold int
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
state map[string]state
|
state map[uint64]state
|
||||||
}
|
}
|
||||||
|
|
||||||
state struct {
|
state struct {
|
||||||
counter int
|
counter int
|
||||||
breakTimestamp time.Time
|
breakTimestamp time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer interface {
|
|
||||||
dial(context.Context) error
|
|
||||||
endpoint() string
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrCBClosed = errors.New("circuit breaker is closed")
|
var ErrCBClosed = errors.New("circuit breaker is closed")
|
||||||
|
|
||||||
func NewCircuitBreaker(breakDuration time.Duration, threshold int) *circuitBreaker {
|
func NewCircuitBreaker(breakDuration time.Duration, threshold int) *circuitBreaker {
|
||||||
if threshold == 0 {
|
|
||||||
threshold = defaultThreshold
|
|
||||||
}
|
|
||||||
if breakDuration == 0 {
|
|
||||||
breakDuration = defaultBreakDuration
|
|
||||||
}
|
|
||||||
return &circuitBreaker{
|
return &circuitBreaker{
|
||||||
breakDuration: breakDuration,
|
breakDuration: breakDuration,
|
||||||
threshold: threshold,
|
threshold: threshold,
|
||||||
state: make(map[string]state),
|
state: make(map[uint64]state),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *circuitBreaker) Dial(ctx context.Context, cli dialer) error {
|
func (c *circuitBreaker) Do(id uint64, f func() error) error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
endpoint := cli.endpoint()
|
if _, ok := c.state[id]; !ok {
|
||||||
|
c.state[id] = state{}
|
||||||
if _, ok := c.state[endpoint]; !ok {
|
|
||||||
c.state[endpoint] = state{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s := c.state[endpoint]
|
s := c.state[id]
|
||||||
defer func() {
|
defer func() {
|
||||||
c.state[endpoint] = s
|
c.state[id] = s
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if time.Since(s.breakTimestamp) < c.breakDuration {
|
if time.Since(s.breakTimestamp) < c.breakDuration {
|
||||||
return ErrCBClosed
|
return ErrCBClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
err := cli.dial(ctx)
|
err := f()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s.counter = 0
|
s.counter = 0
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package tree
|
package tree
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -9,52 +8,34 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
|
||||||
testDialer struct {
|
|
||||||
err error
|
|
||||||
addr string
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d *testDialer) dial(_ context.Context) error {
|
|
||||||
return d.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *testDialer) endpoint() string {
|
|
||||||
return d.addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCircuitBreaker(t *testing.T) {
|
func TestCircuitBreaker(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
remoteErr := errors.New("service is being synchronized")
|
remoteErr := errors.New("service is being synchronized")
|
||||||
d := &testDialer{
|
breakDuration := 1 * time.Second
|
||||||
err: remoteErr,
|
threshold := 10
|
||||||
addr: "addr",
|
cb := NewCircuitBreaker(breakDuration, threshold)
|
||||||
}
|
|
||||||
|
|
||||||
// Hit threshold
|
// Hit threshold
|
||||||
cb := NewCircuitBreaker(1*time.Second, 10)
|
for i := 0; i < threshold; i++ {
|
||||||
for i := 0; i < 10; i++ {
|
err := cb.Do(1, func() error { return remoteErr })
|
||||||
err := cb.Dial(ctx, d)
|
|
||||||
require.ErrorIs(t, err, remoteErr)
|
require.ErrorIs(t, err, remoteErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Different client should not be affected by threshold
|
||||||
|
require.NoError(t, cb.Do(2, func() error { return nil }))
|
||||||
|
|
||||||
// Immediate request should return circuit breaker error
|
// Immediate request should return circuit breaker error
|
||||||
d.err = nil
|
require.ErrorIs(t, cb.Do(1, func() error { return nil }), ErrCBClosed)
|
||||||
require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed)
|
|
||||||
|
|
||||||
// Request after breakDuration should be ok
|
// Request after breakDuration should be ok
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(breakDuration)
|
||||||
require.NoError(t, cb.Dial(ctx, d))
|
require.NoError(t, cb.Do(1, func() error { return nil }))
|
||||||
|
|
||||||
// Then hit threshold once again
|
// Try hitting threshold one more time after break duration
|
||||||
d.err = remoteErr
|
for i := 0; i < threshold; i++ {
|
||||||
for i := 0; i < 10; i++ {
|
err := cb.Do(1, func() error { return remoteErr })
|
||||||
err := cb.Dial(ctx, d)
|
|
||||||
require.ErrorIs(t, err, remoteErr)
|
require.ErrorIs(t, err, remoteErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Immediate request should return circuit breaker error
|
// Immediate request should return circuit breaker error
|
||||||
d.err = nil
|
require.ErrorIs(t, cb.Do(1, func() error { return nil }), ErrCBClosed)
|
||||||
require.ErrorIs(t, cb.Dial(ctx, d), ErrCBClosed)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,8 @@ const (
|
||||||
defaultHealthcheckTimeout = 4 * time.Second
|
defaultHealthcheckTimeout = 4 * time.Second
|
||||||
defaultDialTimeout = 5 * time.Second
|
defaultDialTimeout = 5 * time.Second
|
||||||
defaultStreamTimeout = 10 * time.Second
|
defaultStreamTimeout = 10 * time.Second
|
||||||
|
defaultCircuitBreakerDuration = 10 * time.Second
|
||||||
|
defaultCircuitBreakerTreshold = 10
|
||||||
)
|
)
|
||||||
|
|
||||||
// SubTreeSort defines an order of nodes returned from GetSubTree RPC.
|
// SubTreeSort defines an order of nodes returned from GetSubTree RPC.
|
||||||
|
@ -76,6 +78,8 @@ type InitParameters struct {
|
||||||
dialOptions []grpc.DialOption
|
dialOptions []grpc.DialOption
|
||||||
maxRequestAttempts int
|
maxRequestAttempts int
|
||||||
netMapInfoSource NetMapInfoSource
|
netMapInfoSource NetMapInfoSource
|
||||||
|
circuitBreakerThreshold int
|
||||||
|
circuitBreakerDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetMapInfoSource interface {
|
type NetMapInfoSource interface {
|
||||||
|
@ -117,8 +121,7 @@ type Pool struct {
|
||||||
// * retry in case of request failure (see Pool.requestWithRetry)
|
// * retry in case of request failure (see Pool.requestWithRetry)
|
||||||
// startIndices will be used if netMapInfoSource is not set
|
// startIndices will be used if netMapInfoSource is not set
|
||||||
startIndices [2]int
|
startIndices [2]int
|
||||||
|
// circuit breaker for dial operations when netmap is being used
|
||||||
// circuit breaker for dial operations
|
|
||||||
cb *circuitBreaker
|
cb *circuitBreaker
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,7 +254,10 @@ func NewPool(options InitParameters) (*Pool, error) {
|
||||||
methods: methods,
|
methods: methods,
|
||||||
netMapInfoSource: options.netMapInfoSource,
|
netMapInfoSource: options.netMapInfoSource,
|
||||||
clientMap: make(map[uint64]client),
|
clientMap: make(map[uint64]client),
|
||||||
cb: NewCircuitBreaker(0, 0),
|
cb: NewCircuitBreaker(
|
||||||
|
options.circuitBreakerDuration,
|
||||||
|
options.circuitBreakerThreshold,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.netMapInfoSource == nil {
|
if options.netMapInfoSource == nil {
|
||||||
|
@ -288,7 +294,7 @@ func (p *Pool) Dial(ctx context.Context) error {
|
||||||
clients := make([]client, len(nodes))
|
clients := make([]client, len(nodes))
|
||||||
for j, node := range nodes {
|
for j, node := range nodes {
|
||||||
clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
|
clients[j] = newTreeClient(node.Address(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
|
||||||
if err := p.cb.Dial(ctx, clients[j]); err != nil {
|
if err := clients[j].dial(ctx); err != nil {
|
||||||
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err))
|
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", node.Address()), zap.Error(err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -793,6 +799,15 @@ func fillDefaultInitParams(params *InitParameters) {
|
||||||
if params.maxRequestAttempts <= 0 {
|
if params.maxRequestAttempts <= 0 {
|
||||||
params.maxRequestAttempts = len(params.nodeParams)
|
params.maxRequestAttempts = len(params.nodeParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if params.circuitBreakerDuration <= 0 {
|
||||||
|
params.circuitBreakerDuration = defaultCircuitBreakerDuration
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.circuitBreakerThreshold <= 0 {
|
||||||
|
params.circuitBreakerThreshold = defaultCircuitBreakerTreshold
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pool) log(level zapcore.Level, msg string, fields ...zap.Field) {
|
func (p *Pool) log(level zapcore.Level, msg string, fields ...zap.Field) {
|
||||||
|
@ -988,7 +1003,10 @@ LOOP:
|
||||||
|
|
||||||
treeCl, ok := p.getClientFromMap(cnrNode.Hash())
|
treeCl, ok := p.getClientFromMap(cnrNode.Hash())
|
||||||
if !ok {
|
if !ok {
|
||||||
|
err = p.cb.Do(cnrNode.Hash(), func() error {
|
||||||
treeCl, err = p.getNewTreeClient(ctx, cnrNode)
|
treeCl, err = p.getNewTreeClient(ctx, cnrNode)
|
||||||
|
return err
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
finErr = finalError(finErr, err)
|
finErr = finalError(finErr, err)
|
||||||
p.log(zap.DebugLevel, "failed to create tree client", zap.String("request_id", reqID), zap.Int("remaining attempts", attempts))
|
p.log(zap.DebugLevel, "failed to create tree client", zap.String("request_id", reqID), zap.Int("remaining attempts", attempts))
|
||||||
|
@ -1059,7 +1077,7 @@ func (p *Pool) getNewTreeClient(ctx context.Context, node netmap.NodeInfo) (*tre
|
||||||
}
|
}
|
||||||
|
|
||||||
newTreeCl := newTreeClient(addr.URIAddr(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
|
newTreeCl := newTreeClient(addr.URIAddr(), p.dialOptions, p.nodeDialTimeout, p.streamTimeout)
|
||||||
if err = p.cb.Dial(ctx, newTreeCl); err != nil {
|
if err = newTreeCl.dial(ctx); err != nil {
|
||||||
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", addr.URIAddr()), zap.Error(err))
|
p.log(zap.WarnLevel, "failed to dial tree client", zap.String("address", addr.URIAddr()), zap.Error(err))
|
||||||
|
|
||||||
// We have to close connection here after failed `dial()`.
|
// We have to close connection here after failed `dial()`.
|
||||||
|
|
Loading…
Add table
Reference in a new issue