diff --git a/health/health.go b/health/health.go index 89cd7c4f..fff87323 100644 --- a/health/health.go +++ b/health/health.go @@ -3,6 +3,7 @@ package health import ( "context" "encoding/json" + "errors" "fmt" "net/http" "sync" @@ -104,7 +105,7 @@ func (tu *thresholdUpdater) Check(context.Context) error { tu.mu.Lock() defer tu.mu.Unlock() - if tu.count >= tu.threshold { + if tu.count >= tu.threshold || errors.As(tu.status, new(pollingTerminatedErr)) { return tu.status } @@ -128,38 +129,35 @@ func (tu *thresholdUpdater) Update(status error) { // NewThresholdStatusUpdater returns a new thresholdUpdater func NewThresholdStatusUpdater(t int) Updater { - return &thresholdUpdater{threshold: t} + if t > 0 { + return &thresholdUpdater{threshold: t} + } + return NewStatusUpdater() } -// PeriodicChecker wraps an updater to provide a periodic checker -func PeriodicChecker(check Checker, period time.Duration) Checker { - u := NewStatusUpdater() - go func() { - t := time.NewTicker(period) - defer t.Stop() - for { - <-t.C - u.Update(check.Check(context.Background())) - } - }() +type pollingTerminatedErr struct{ Err error } - return u +func (e pollingTerminatedErr) Error() string { + return fmt.Sprintf("health: check is not polled: %v", e.Err) } -// PeriodicThresholdChecker wraps an updater to provide a periodic checker that -// uses a threshold before it changes status -func PeriodicThresholdChecker(check Checker, period time.Duration, threshold int) Checker { - tu := NewThresholdStatusUpdater(threshold) - go func() { - t := time.NewTicker(period) - defer t.Stop() - for { - <-t.C - tu.Update(check.Check(context.Background())) - } - }() +func (e pollingTerminatedErr) Unwrap() error { return e.Err } - return tu +// Poll periodically polls the checker c at interval and updates the updater u +// with the result. The checker is called with ctx as the context. When ctx is +// done, Poll updates the updater with ctx.Err() and returns. +func Poll(ctx context.Context, u Updater, c Checker, interval time.Duration) { + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + u.Update(pollingTerminatedErr{Err: ctx.Err()}) + return + case <-t.C: + u.Update(c.Check(ctx)) + } + } } // CheckStatus returns a map with all the current health check errors @@ -215,30 +213,6 @@ func RegisterFunc(name string, check CheckFunc) { DefaultRegistry.RegisterFunc(name, check) } -// RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker -// from an arbitrary func(context.Context) error. -func (registry *Registry) RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) { - registry.Register(name, PeriodicChecker(check, period)) -} - -// RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker -// in the default registry from an arbitrary func() error. -func RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) { - DefaultRegistry.RegisterPeriodicFunc(name, period, check) -} - -// RegisterPeriodicThresholdFunc allows the convenience of registering a -// PeriodicChecker from an arbitrary func() error. -func (registry *Registry) RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) { - registry.Register(name, PeriodicThresholdChecker(check, period, threshold)) -} - -// RegisterPeriodicThresholdFunc allows the convenience of registering a -// PeriodicChecker in the default registry from an arbitrary func() error. -func RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) { - DefaultRegistry.RegisterPeriodicThresholdFunc(name, period, threshold, check) -} - // StatusHandler returns a JSON blob with all the currently registered Health Checks // and their corresponding status. // Returns 503 if any Error status exists, 200 otherwise diff --git a/health/health_test.go b/health/health_test.go index c877ed7f..ac87c863 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" ) // TestReturns200IfThereAreNoChecks ensures that the result code of the health @@ -106,3 +107,87 @@ func TestHealthHandler(t *testing.T) { updater.Update(nil) checkUp(t, "when server is back up") // now we should be back up. } + +func TestThresholdStatusUpdater(t *testing.T) { + u := NewThresholdStatusUpdater(3) + + assertCheckOK := func() { + t.Helper() + if err := u.Check(context.Background()); err != nil { + t.Errorf("u.Check() = %v; want nil", err) + } + } + + assertCheckErr := func(expected string) { + t.Helper() + if err := u.Check(context.Background()); err == nil || err.Error() != expected { + t.Errorf("u.Check() = %v; want %v", err, expected) + } + } + + // Updater should report healthy until the threshold is reached. + for i := 1; i <= 3; i++ { + assertCheckOK() + u.Update(fmt.Errorf("fake error %d", i)) + } + assertCheckErr("fake error 3") + + // The threshold should reset after one successful update. + u.Update(nil) + assertCheckOK() + u.Update(errors.New("first errored update after reset")) + assertCheckOK() + u.Update(nil) + + // pollingTerminatedErr should bypass the threshold. + pte := pollingTerminatedErr{Err: errors.New("womp womp")} + u.Update(pte) + assertCheckErr(pte.Error()) +} + +func TestPoll(t *testing.T) { + type ContextKey struct{} + for _, threshold := range []int{0, 10} { + t.Run(fmt.Sprintf("threshold=%d", threshold), func(t *testing.T) { + ctx, cancel := context.WithCancel(context.WithValue(context.Background(), ContextKey{}, t.Name())) + defer cancel() + checkerCalled := make(chan struct{}) + checker := CheckFunc(func(ctx context.Context) error { + if v, ok := ctx.Value(ContextKey{}).(string); !ok || v != t.Name() { + t.Errorf("unexpected context passed into checker: got context with value %q, want %q", v, t.Name()) + } + select { + case <-checkerCalled: + default: + close(checkerCalled) + } + return nil + }) + + updater := NewThresholdStatusUpdater(threshold) + pollReturned := make(chan struct{}) + go func() { + Poll(ctx, updater, checker, 1*time.Millisecond) + close(pollReturned) + }() + + select { + case <-checkerCalled: + case <-time.After(1 * time.Second): + t.Error("checker has not been polled") + } + + cancel() + + select { + case <-pollReturned: + case <-time.After(1 * time.Second): + t.Error("poll has not returned after context was canceled") + } + + if err := updater.Check(context.Background()); !errors.Is(err, context.Canceled) { + t.Errorf("updater.Check() = %v; want %v", err, context.Canceled) + } + }) + } +} diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 0e4d528b..6866eb82 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -348,22 +348,20 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { interval = defaultCheckInterval } - storageDriverCheck := func(context.Context) error { - _, err := app.driver.Stat(app, "/") // "/" should always exist + storageDriverCheck := health.CheckFunc(func(ctx context.Context) error { + _, err := app.driver.Stat(ctx, "/") // "/" should always exist if _, ok := err.(storagedriver.PathNotFoundError); ok { err = nil // pass this through, backend is responding, but this path doesn't exist. } if err != nil { - dcontext.GetLogger(app).Errorf("storage driver health check: %v", err) + dcontext.GetLogger(ctx).Errorf("storage driver health check: %v", err) } return err - } + }) - if app.Config.Health.StorageDriver.Threshold != 0 { - healthRegistry.RegisterPeriodicThresholdFunc("storagedriver_"+app.Config.Storage.Type(), interval, app.Config.Health.StorageDriver.Threshold, storageDriverCheck) - } else { - healthRegistry.RegisterPeriodicFunc("storagedriver_"+app.Config.Storage.Type(), interval, storageDriverCheck) - } + updater := health.NewThresholdStatusUpdater(app.Config.Health.StorageDriver.Threshold) + healthRegistry.Register("storagedriver_"+app.Config.Storage.Type(), updater) + go health.Poll(app, updater, storageDriverCheck, interval) } for _, fileChecker := range app.Config.Health.FileCheckers { @@ -372,7 +370,9 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { interval = defaultCheckInterval } dcontext.GetLogger(app).Infof("configuring file health check path=%s, interval=%d", fileChecker.File, interval/time.Second) - healthRegistry.Register(fileChecker.File, health.PeriodicChecker(checks.FileChecker(fileChecker.File), interval)) + u := health.NewStatusUpdater() + healthRegistry.Register(fileChecker.File, u) + go health.Poll(app, u, checks.FileChecker(fileChecker.File), interval) } for _, httpChecker := range app.Config.Health.HTTPCheckers { @@ -388,13 +388,10 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { checker := checks.HTTPChecker(httpChecker.URI, statusCode, httpChecker.Timeout, httpChecker.Headers) - if httpChecker.Threshold != 0 { - dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold) - healthRegistry.Register(httpChecker.URI, health.PeriodicThresholdChecker(checker, interval, httpChecker.Threshold)) - } else { - dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d", httpChecker.URI, interval/time.Second) - healthRegistry.Register(httpChecker.URI, health.PeriodicChecker(checker, interval)) - } + dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold) + updater := health.NewThresholdStatusUpdater(httpChecker.Threshold) + healthRegistry.Register(httpChecker.URI, updater) + go health.Poll(app, updater, checker, interval) } for _, tcpChecker := range app.Config.Health.TCPCheckers { @@ -405,13 +402,10 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { checker := checks.TCPChecker(tcpChecker.Addr, tcpChecker.Timeout) - if tcpChecker.Threshold != 0 { - dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d, threshold=%d", tcpChecker.Addr, interval/time.Second, tcpChecker.Threshold) - healthRegistry.Register(tcpChecker.Addr, health.PeriodicThresholdChecker(checker, interval, tcpChecker.Threshold)) - } else { - dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d", tcpChecker.Addr, interval/time.Second) - healthRegistry.Register(tcpChecker.Addr, health.PeriodicChecker(checker, interval)) - } + dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d, threshold=%d", tcpChecker.Addr, interval/time.Second, tcpChecker.Threshold) + updater := health.NewThresholdStatusUpdater(tcpChecker.Threshold) + healthRegistry.Register(tcpChecker.Addr, updater) + go health.Poll(app, updater, checker, interval) } }