health: improve periodic polling of checks

The API for periodic health checks is repetitive, with a distinct
function for polling a checker to each kind of updater. It also gives
the user no control over the lifetime of the polling goroutines nor
which context is passed into the checker.

Replace the existing PeriodicXYZChecker functions with a single Poll
function which composes an Updater with a Checker. Its context parameter
is passed into the checker and also controls when the polling loop
terminates. To guard against health checks failing closed (ostensibly
healthy) when the polling loop is terminated, the updater is forcefully
updated to an error status, overriding any configured threshold.

Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
Cory Snider 2023-10-27 16:57:31 -04:00
parent a1b49d3d17
commit f2cbfe2402
3 changed files with 128 additions and 75 deletions

View file

@ -3,6 +3,7 @@ package health
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
@ -104,7 +105,7 @@ func (tu *thresholdUpdater) Check(context.Context) error {
tu.mu.Lock() tu.mu.Lock()
defer tu.mu.Unlock() defer tu.mu.Unlock()
if tu.count >= tu.threshold { if tu.count >= tu.threshold || errors.As(tu.status, new(pollingTerminatedErr)) {
return tu.status return tu.status
} }
@ -128,38 +129,35 @@ func (tu *thresholdUpdater) Update(status error) {
// NewThresholdStatusUpdater returns a new thresholdUpdater // NewThresholdStatusUpdater returns a new thresholdUpdater
func NewThresholdStatusUpdater(t int) Updater { func NewThresholdStatusUpdater(t int) Updater {
if t > 0 {
return &thresholdUpdater{threshold: t} return &thresholdUpdater{threshold: t}
}
return NewStatusUpdater()
} }
// PeriodicChecker wraps an updater to provide a periodic checker type pollingTerminatedErr struct{ Err error }
func PeriodicChecker(check Checker, period time.Duration) Checker {
u := NewStatusUpdater()
go func() {
t := time.NewTicker(period)
defer t.Stop()
for {
<-t.C
u.Update(check.Check(context.Background()))
}
}()
return u func (e pollingTerminatedErr) Error() string {
return fmt.Sprintf("health: check is not polled: %v", e.Err)
} }
// PeriodicThresholdChecker wraps an updater to provide a periodic checker that func (e pollingTerminatedErr) Unwrap() error { return e.Err }
// uses a threshold before it changes status
func PeriodicThresholdChecker(check Checker, period time.Duration, threshold int) Checker { // Poll periodically polls the checker c at interval and updates the updater u
tu := NewThresholdStatusUpdater(threshold) // with the result. The checker is called with ctx as the context. When ctx is
go func() { // done, Poll updates the updater with ctx.Err() and returns.
t := time.NewTicker(period) func Poll(ctx context.Context, u Updater, c Checker, interval time.Duration) {
t := time.NewTicker(interval)
defer t.Stop() defer t.Stop()
for { for {
<-t.C select {
tu.Update(check.Check(context.Background())) case <-ctx.Done():
u.Update(pollingTerminatedErr{Err: ctx.Err()})
return
case <-t.C:
u.Update(c.Check(ctx))
}
} }
}()
return tu
} }
// CheckStatus returns a map with all the current health check errors // CheckStatus returns a map with all the current health check errors
@ -215,30 +213,6 @@ func RegisterFunc(name string, check CheckFunc) {
DefaultRegistry.RegisterFunc(name, check) DefaultRegistry.RegisterFunc(name, check)
} }
// RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker
// from an arbitrary func(context.Context) error.
func (registry *Registry) RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) {
registry.Register(name, PeriodicChecker(check, period))
}
// RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker
// in the default registry from an arbitrary func() error.
func RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) {
DefaultRegistry.RegisterPeriodicFunc(name, period, check)
}
// RegisterPeriodicThresholdFunc allows the convenience of registering a
// PeriodicChecker from an arbitrary func() error.
func (registry *Registry) RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) {
registry.Register(name, PeriodicThresholdChecker(check, period, threshold))
}
// RegisterPeriodicThresholdFunc allows the convenience of registering a
// PeriodicChecker in the default registry from an arbitrary func() error.
func RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) {
DefaultRegistry.RegisterPeriodicThresholdFunc(name, period, threshold, check)
}
// StatusHandler returns a JSON blob with all the currently registered Health Checks // StatusHandler returns a JSON blob with all the currently registered Health Checks
// and their corresponding status. // and their corresponding status.
// Returns 503 if any Error status exists, 200 otherwise // Returns 503 if any Error status exists, 200 otherwise

View file

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
) )
// TestReturns200IfThereAreNoChecks ensures that the result code of the health // TestReturns200IfThereAreNoChecks ensures that the result code of the health
@ -106,3 +107,87 @@ func TestHealthHandler(t *testing.T) {
updater.Update(nil) updater.Update(nil)
checkUp(t, "when server is back up") // now we should be back up. checkUp(t, "when server is back up") // now we should be back up.
} }
func TestThresholdStatusUpdater(t *testing.T) {
u := NewThresholdStatusUpdater(3)
assertCheckOK := func() {
t.Helper()
if err := u.Check(context.Background()); err != nil {
t.Errorf("u.Check() = %v; want nil", err)
}
}
assertCheckErr := func(expected string) {
t.Helper()
if err := u.Check(context.Background()); err == nil || err.Error() != expected {
t.Errorf("u.Check() = %v; want %v", err, expected)
}
}
// Updater should report healthy until the threshold is reached.
for i := 1; i <= 3; i++ {
assertCheckOK()
u.Update(fmt.Errorf("fake error %d", i))
}
assertCheckErr("fake error 3")
// The threshold should reset after one successful update.
u.Update(nil)
assertCheckOK()
u.Update(errors.New("first errored update after reset"))
assertCheckOK()
u.Update(nil)
// pollingTerminatedErr should bypass the threshold.
pte := pollingTerminatedErr{Err: errors.New("womp womp")}
u.Update(pte)
assertCheckErr(pte.Error())
}
func TestPoll(t *testing.T) {
type ContextKey struct{}
for _, threshold := range []int{0, 10} {
t.Run(fmt.Sprintf("threshold=%d", threshold), func(t *testing.T) {
ctx, cancel := context.WithCancel(context.WithValue(context.Background(), ContextKey{}, t.Name()))
defer cancel()
checkerCalled := make(chan struct{})
checker := CheckFunc(func(ctx context.Context) error {
if v, ok := ctx.Value(ContextKey{}).(string); !ok || v != t.Name() {
t.Errorf("unexpected context passed into checker: got context with value %q, want %q", v, t.Name())
}
select {
case <-checkerCalled:
default:
close(checkerCalled)
}
return nil
})
updater := NewThresholdStatusUpdater(threshold)
pollReturned := make(chan struct{})
go func() {
Poll(ctx, updater, checker, 1*time.Millisecond)
close(pollReturned)
}()
select {
case <-checkerCalled:
case <-time.After(1 * time.Second):
t.Error("checker has not been polled")
}
cancel()
select {
case <-pollReturned:
case <-time.After(1 * time.Second):
t.Error("poll has not returned after context was canceled")
}
if err := updater.Check(context.Background()); !errors.Is(err, context.Canceled) {
t.Errorf("updater.Check() = %v; want %v", err, context.Canceled)
}
})
}
}

View file

@ -348,22 +348,20 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
interval = defaultCheckInterval interval = defaultCheckInterval
} }
storageDriverCheck := func(context.Context) error { storageDriverCheck := health.CheckFunc(func(ctx context.Context) error {
_, err := app.driver.Stat(app, "/") // "/" should always exist _, err := app.driver.Stat(ctx, "/") // "/" should always exist
if _, ok := err.(storagedriver.PathNotFoundError); ok { if _, ok := err.(storagedriver.PathNotFoundError); ok {
err = nil // pass this through, backend is responding, but this path doesn't exist. err = nil // pass this through, backend is responding, but this path doesn't exist.
} }
if err != nil { if err != nil {
dcontext.GetLogger(app).Errorf("storage driver health check: %v", err) dcontext.GetLogger(ctx).Errorf("storage driver health check: %v", err)
} }
return err return err
} })
if app.Config.Health.StorageDriver.Threshold != 0 { updater := health.NewThresholdStatusUpdater(app.Config.Health.StorageDriver.Threshold)
healthRegistry.RegisterPeriodicThresholdFunc("storagedriver_"+app.Config.Storage.Type(), interval, app.Config.Health.StorageDriver.Threshold, storageDriverCheck) healthRegistry.Register("storagedriver_"+app.Config.Storage.Type(), updater)
} else { go health.Poll(app, updater, storageDriverCheck, interval)
healthRegistry.RegisterPeriodicFunc("storagedriver_"+app.Config.Storage.Type(), interval, storageDriverCheck)
}
} }
for _, fileChecker := range app.Config.Health.FileCheckers { for _, fileChecker := range app.Config.Health.FileCheckers {
@ -372,7 +370,9 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
interval = defaultCheckInterval interval = defaultCheckInterval
} }
dcontext.GetLogger(app).Infof("configuring file health check path=%s, interval=%d", fileChecker.File, interval/time.Second) dcontext.GetLogger(app).Infof("configuring file health check path=%s, interval=%d", fileChecker.File, interval/time.Second)
healthRegistry.Register(fileChecker.File, health.PeriodicChecker(checks.FileChecker(fileChecker.File), interval)) u := health.NewStatusUpdater()
healthRegistry.Register(fileChecker.File, u)
go health.Poll(app, u, checks.FileChecker(fileChecker.File), interval)
} }
for _, httpChecker := range app.Config.Health.HTTPCheckers { for _, httpChecker := range app.Config.Health.HTTPCheckers {
@ -388,13 +388,10 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
checker := checks.HTTPChecker(httpChecker.URI, statusCode, httpChecker.Timeout, httpChecker.Headers) checker := checks.HTTPChecker(httpChecker.URI, statusCode, httpChecker.Timeout, httpChecker.Headers)
if httpChecker.Threshold != 0 {
dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold) dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold)
healthRegistry.Register(httpChecker.URI, health.PeriodicThresholdChecker(checker, interval, httpChecker.Threshold)) updater := health.NewThresholdStatusUpdater(httpChecker.Threshold)
} else { healthRegistry.Register(httpChecker.URI, updater)
dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d", httpChecker.URI, interval/time.Second) go health.Poll(app, updater, checker, interval)
healthRegistry.Register(httpChecker.URI, health.PeriodicChecker(checker, interval))
}
} }
for _, tcpChecker := range app.Config.Health.TCPCheckers { for _, tcpChecker := range app.Config.Health.TCPCheckers {
@ -405,13 +402,10 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
checker := checks.TCPChecker(tcpChecker.Addr, tcpChecker.Timeout) checker := checks.TCPChecker(tcpChecker.Addr, tcpChecker.Timeout)
if tcpChecker.Threshold != 0 {
dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d, threshold=%d", tcpChecker.Addr, interval/time.Second, tcpChecker.Threshold) dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d, threshold=%d", tcpChecker.Addr, interval/time.Second, tcpChecker.Threshold)
healthRegistry.Register(tcpChecker.Addr, health.PeriodicThresholdChecker(checker, interval, tcpChecker.Threshold)) updater := health.NewThresholdStatusUpdater(tcpChecker.Threshold)
} else { healthRegistry.Register(tcpChecker.Addr, updater)
dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d", tcpChecker.Addr, interval/time.Second) go health.Poll(app, updater, checker, interval)
healthRegistry.Register(tcpChecker.Addr, health.PeriodicChecker(checker, interval))
}
} }
} }