diff --git a/pool/tree/circuitbreaker.go b/pool/tree/circuitbreaker.go index 5eadcb4b..82615a66 100644 --- a/pool/tree/circuitbreaker.go +++ b/pool/tree/circuitbreaker.go @@ -31,15 +31,16 @@ func newCircuitBreaker(breakDuration time.Duration, threshold int) *circuitBreak } } -func (cb *circuitBreaker) breakTime(id uint64) (time.Time, bool) { +func (cb *circuitBreaker) checkBreak(id uint64) error { cb.mu.RLock() - defer cb.mu.RUnlock() + s, ok := cb.state[id] + cb.mu.RUnlock() - if s, ok := cb.state[id]; ok { - return s.breakTimestamp, true + if ok && time.Since(s.breakTimestamp) < cb.breakDuration { + return ErrCBClosed } - return time.Time{}, false + return nil } func (cb *circuitBreaker) openBreak(id uint64) { @@ -48,7 +49,7 @@ func (cb *circuitBreaker) openBreak(id uint64) { delete(cb.state, id) } -func (cb *circuitBreaker) incError(id uint64, doTime time.Time) { +func (cb *circuitBreaker) incError(id uint64) { cb.mu.Lock() defer cb.mu.Unlock() @@ -57,30 +58,24 @@ func (cb *circuitBreaker) incError(id uint64, doTime time.Time) { s.counter++ if s.counter >= cb.threshold { s.counter = cb.threshold - if s.breakTimestamp.Before(doTime) { - s.breakTimestamp = doTime + if time.Since(s.breakTimestamp) >= cb.breakDuration { + s.breakTimestamp = time.Now() } } cb.state[id] = s } -func (c *circuitBreaker) Do(id uint64, f func() error) error { - breakTime, ok := c.breakTime(id) - if ok && time.Since(breakTime) < c.breakDuration { - return ErrCBClosed +func (cb *circuitBreaker) Do(id uint64, f func() error) error { + if err := cb.checkBreak(id); err != nil { + return err } - // Use this timestamp to update circuit breaker in case of an error. - // f() may be blocked for unpredictable duration, so concurrent calls - // may update time in 'incError' endlessly and circuit will never be open - doTime := time.Now() - err := f() if err == nil { - c.openBreak(id) + cb.openBreak(id) } else { - c.incError(id, doTime) + cb.incError(id) } return err diff --git a/pool/tree/circuitbreaker_test.go b/pool/tree/circuitbreaker_test.go index aefa9d69..c616d1b6 100644 --- a/pool/tree/circuitbreaker_test.go +++ b/pool/tree/circuitbreaker_test.go @@ -2,7 +2,6 @@ package tree import ( "errors" - "runtime" "testing" "time" @@ -43,9 +42,9 @@ func TestCircuitBreaker(t *testing.T) { func TestCircuitBreakerNoBlock(t *testing.T) { remoteErr := errors.New("service is being synchronized") - funcDuration := 2 * time.Second + funcDuration := 200 * time.Millisecond threshold := 100 - cb := newCircuitBreaker(1*time.Minute, threshold) + cb := newCircuitBreaker(10*funcDuration, threshold) slowFunc := func() error { time.Sleep(funcDuration) @@ -53,16 +52,17 @@ func TestCircuitBreakerNoBlock(t *testing.T) { } for i := 0; i < threshold; i++ { - // run in multiple goroutines Do function and make sure it is not + // run in multiple goroutines Do function go func() { cb.Do(1, slowFunc) }() } - // wait for one slow func duration + some delta - time.Sleep(funcDuration + 100*time.Millisecond) - runtime.Gosched() - // expect that all goroutines were not blocked by mutex in circuit breaker - // therefore all functions are done and circuit is closed - require.ErrorIs(t, cb.Do(1, func() error { return nil }), ErrCBClosed) + time.Sleep(funcDuration) + + // eventually at most after one more func duration circuit breaker will be + // closed and not blocked by slow func execution under mutex + require.Eventually(t, func() bool { + return errors.Is(cb.Do(1, func() error { return nil }), ErrCBClosed) + }, funcDuration, funcDuration/10) }