diff --git a/health/api/api_test.go b/health/api/api_test.go index 66aa6b419..f93d696f3 100644 --- a/health/api/api_test.go +++ b/health/api/api_test.go @@ -1,6 +1,7 @@ package api import ( + "context" "net/http" "net/http/httptest" "testing" @@ -59,7 +60,7 @@ func TestPOSTDownHandlerChangeStatus(t *testing.T) { t.Errorf("Did not get a 200.") } - if len(health.CheckStatus()) != 1 { + if len(health.CheckStatus(context.Background())) != 1 { t.Errorf("DownHandler didn't add an error check.") } } @@ -80,7 +81,7 @@ func TestPOSTUpHandlerChangeStatus(t *testing.T) { t.Errorf("Did not get a 200.") } - if len(health.CheckStatus()) != 0 { + if len(health.CheckStatus(context.Background())) != 0 { t.Errorf("UpHandler didn't remove the error check.") } } diff --git a/health/checks/checks.go b/health/checks/checks.go index 846ca2015..9382db32e 100644 --- a/health/checks/checks.go +++ b/health/checks/checks.go @@ -1,13 +1,13 @@ package checks import ( + "context" "errors" "fmt" "net" "net/http" "os" "path/filepath" - "strconv" "time" "github.com/distribution/distribution/v3/health" @@ -16,7 +16,7 @@ import ( // FileChecker checks the existence of a file and returns an error // if the file exists. func FileChecker(f string) health.Checker { - return health.CheckFunc(func() error { + return health.CheckFunc(func(context.Context) error { absoluteFilePath, err := filepath.Abs(f) if err != nil { return fmt.Errorf("failed to get absolute path for %q: %v", f, err) @@ -36,13 +36,13 @@ func FileChecker(f string) health.Checker { // HTTPChecker does a HEAD request and verifies that the HTTP status code // returned matches statusCode. func HTTPChecker(r string, statusCode int, timeout time.Duration, headers http.Header) health.Checker { - return health.CheckFunc(func() error { + return health.CheckFunc(func(ctx context.Context) error { client := http.Client{ Timeout: timeout, } - req, err := http.NewRequest(http.MethodHead, r, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodHead, r, nil) if err != nil { - return errors.New("error creating request: " + r) + return fmt.Errorf("%v: error creating request: %w", r, err) } for headerName, headerValues := range headers { for _, headerValue := range headerValues { @@ -51,11 +51,11 @@ func HTTPChecker(r string, statusCode int, timeout time.Duration, headers http.H } response, err := client.Do(req) if err != nil { - return errors.New("error while checking: " + r) + return fmt.Errorf("%v: error while checking: %w", r, err) } defer response.Body.Close() if response.StatusCode != statusCode { - return errors.New("downstream service returned unexpected status: " + strconv.Itoa(response.StatusCode)) + return fmt.Errorf("%v: downstream service returned unexpected status: %d", r, response.StatusCode) } return nil }) @@ -63,10 +63,11 @@ func HTTPChecker(r string, statusCode int, timeout time.Duration, headers http.H // TCPChecker attempts to open a TCP connection. func TCPChecker(addr string, timeout time.Duration) health.Checker { - return health.CheckFunc(func() error { - conn, err := net.DialTimeout("tcp", addr, timeout) + return health.CheckFunc(func(ctx context.Context) error { + d := net.Dialer{Timeout: timeout} + conn, err := d.DialContext(ctx, "tcp", addr) if err != nil { - return errors.New("connection to " + addr + " failed") + return fmt.Errorf("%v: connection failed: %w", addr, err) } conn.Close() return nil diff --git a/health/checks/checks_test.go b/health/checks/checks_test.go index 6b6dd14fa..60332f180 100644 --- a/health/checks/checks_test.go +++ b/health/checks/checks_test.go @@ -1,25 +1,26 @@ package checks import ( + "context" "testing" ) func TestFileChecker(t *testing.T) { - if err := FileChecker("/tmp").Check(); err == nil { + if err := FileChecker("/tmp").Check(context.Background()); err == nil { t.Errorf("/tmp was expected as exists") } - if err := FileChecker("NoSuchFileFromMoon").Check(); err != nil { + if err := FileChecker("NoSuchFileFromMoon").Check(context.Background()); err != nil { t.Errorf("NoSuchFileFromMoon was expected as not exists, error:%v", err) } } func TestHTTPChecker(t *testing.T) { - if err := HTTPChecker("https://www.google.cybertron", 200, 0, nil).Check(); err == nil { + if err := HTTPChecker("https://www.google.cybertron", 200, 0, nil).Check(context.Background()); err == nil { t.Errorf("Google on Cybertron was expected as not exists") } - if err := HTTPChecker("https://www.google.pt", 200, 0, nil).Check(); err != nil { + if err := HTTPChecker("https://www.google.pt", 200, 0, nil).Check(context.Background()); err != nil { t.Errorf("Google at Portugal was expected as exists, error:%v", err) } } diff --git a/health/health.go b/health/health.go index 57a714b66..fff87323d 100644 --- a/health/health.go +++ b/health/health.go @@ -1,7 +1,9 @@ package health import ( + "context" "encoding/json" + "errors" "fmt" "net/http" "sync" @@ -35,17 +37,17 @@ var DefaultRegistry *Registry // Checker is the interface for a Health Checker type Checker interface { // Check returns nil if the service is okay. - Check() error + Check(context.Context) error } // CheckFunc is a convenience type to create functions that implement // the Checker interface -type CheckFunc func() error +type CheckFunc func(context.Context) error // Check Implements the Checker interface to allow for any func() error method // to be passed as a Checker -func (cf CheckFunc) Check() error { - return cf() +func (cf CheckFunc) Check(ctx context.Context) error { + return cf(ctx) } // Updater implements a health check that is explicitly set. @@ -66,7 +68,7 @@ type updater struct { } // Check implements the Checker interface -func (u *updater) Check() error { +func (u *updater) Check(context.Context) error { u.mu.Lock() defer u.mu.Unlock() @@ -99,11 +101,11 @@ type thresholdUpdater struct { } // Check implements the Checker interface -func (tu *thresholdUpdater) Check() error { +func (tu *thresholdUpdater) Check(context.Context) error { tu.mu.Lock() defer tu.mu.Unlock() - if tu.count >= tu.threshold { + if tu.count >= tu.threshold || errors.As(tu.status, new(pollingTerminatedErr)) { return tu.status } @@ -127,47 +129,44 @@ func (tu *thresholdUpdater) Update(status error) { // NewThresholdStatusUpdater returns a new thresholdUpdater func NewThresholdStatusUpdater(t int) Updater { - return &thresholdUpdater{threshold: t} + if t > 0 { + return &thresholdUpdater{threshold: t} + } + return NewStatusUpdater() } -// PeriodicChecker wraps an updater to provide a periodic checker -func PeriodicChecker(check Checker, period time.Duration) Checker { - u := NewStatusUpdater() - go func() { - t := time.NewTicker(period) - defer t.Stop() - for { - <-t.C - u.Update(check.Check()) - } - }() +type pollingTerminatedErr struct{ Err error } - return u +func (e pollingTerminatedErr) Error() string { + return fmt.Sprintf("health: check is not polled: %v", e.Err) } -// PeriodicThresholdChecker wraps an updater to provide a periodic checker that -// uses a threshold before it changes status -func PeriodicThresholdChecker(check Checker, period time.Duration, threshold int) Checker { - tu := NewThresholdStatusUpdater(threshold) - go func() { - t := time.NewTicker(period) - defer t.Stop() - for { - <-t.C - tu.Update(check.Check()) - } - }() +func (e pollingTerminatedErr) Unwrap() error { return e.Err } - return tu +// Poll periodically polls the checker c at interval and updates the updater u +// with the result. The checker is called with ctx as the context. When ctx is +// done, Poll updates the updater with ctx.Err() and returns. +func Poll(ctx context.Context, u Updater, c Checker, interval time.Duration) { + t := time.NewTicker(interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + u.Update(pollingTerminatedErr{Err: ctx.Err()}) + return + case <-t.C: + u.Update(c.Check(ctx)) + } + } } // CheckStatus returns a map with all the current health check errors -func (registry *Registry) CheckStatus() map[string]string { // TODO(stevvooe) this needs a proper type +func (registry *Registry) CheckStatus(ctx context.Context) map[string]string { // TODO(stevvooe) this needs a proper type registry.mu.RLock() defer registry.mu.RUnlock() statusKeys := make(map[string]string) for k, v := range registry.registeredChecks { - err := v.Check() + err := v.Check(ctx) if err != nil { statusKeys[k] = err.Error() } @@ -178,8 +177,8 @@ func (registry *Registry) CheckStatus() map[string]string { // TODO(stevvooe) th // CheckStatus returns a map with all the current health check errors from the // default registry. -func CheckStatus() map[string]string { - return DefaultRegistry.CheckStatus() +func CheckStatus(ctx context.Context) map[string]string { + return DefaultRegistry.CheckStatus(ctx) } // Register associates the checker with the provided name. @@ -203,47 +202,23 @@ func Register(name string, check Checker) { } // RegisterFunc allows the convenience of registering a checker directly from -// an arbitrary func() error. -func (registry *Registry) RegisterFunc(name string, check func() error) { - registry.Register(name, CheckFunc(check)) +// an arbitrary func(context.Context) error. +func (registry *Registry) RegisterFunc(name string, check CheckFunc) { + registry.Register(name, check) } // RegisterFunc allows the convenience of registering a checker in the default -// registry directly from an arbitrary func() error. -func RegisterFunc(name string, check func() error) { +// registry directly from an arbitrary func(context.Context) error. +func RegisterFunc(name string, check CheckFunc) { DefaultRegistry.RegisterFunc(name, check) } -// RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker -// from an arbitrary func() error. -func (registry *Registry) RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) { - registry.Register(name, PeriodicChecker(check, period)) -} - -// RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker -// in the default registry from an arbitrary func() error. -func RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) { - DefaultRegistry.RegisterPeriodicFunc(name, period, check) -} - -// RegisterPeriodicThresholdFunc allows the convenience of registering a -// PeriodicChecker from an arbitrary func() error. -func (registry *Registry) RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) { - registry.Register(name, PeriodicThresholdChecker(check, period, threshold)) -} - -// RegisterPeriodicThresholdFunc allows the convenience of registering a -// PeriodicChecker in the default registry from an arbitrary func() error. -func RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) { - DefaultRegistry.RegisterPeriodicThresholdFunc(name, period, threshold, check) -} - // StatusHandler returns a JSON blob with all the currently registered Health Checks // and their corresponding status. // Returns 503 if any Error status exists, 200 otherwise func StatusHandler(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { - checks := CheckStatus() + checks := CheckStatus(r.Context()) status := http.StatusOK // If there is an error, return 503 @@ -263,7 +238,7 @@ func StatusHandler(w http.ResponseWriter, r *http.Request) { // disable a web application when the health checks fail. func Handler(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - checks := CheckStatus() + checks := CheckStatus(r.Context()) if len(checks) != 0 { // NOTE(milosgajdos): disable errcheck as the error is // accessible via /debug/health @@ -282,7 +257,7 @@ func Handler(handler http.Handler) http.Handler { func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks map[string]string) { p, err := json.Marshal(checks) if err != nil { - dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status: %v", err) + dcontext.GetLogger(r.Context()).Errorf("error serializing health status: %v", err) p, err = json.Marshal(struct { ServerError string `json:"server_error"` }{ @@ -291,7 +266,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m status = http.StatusInternalServerError if err != nil { - dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status failure message: %v", err) + dcontext.GetLogger(r.Context()).Errorf("error serializing health status failure message: %v", err) return } } @@ -300,7 +275,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m w.Header().Set("Content-Length", fmt.Sprint(len(p))) w.WriteHeader(status) if _, err := w.Write(p); err != nil { - dcontext.GetLogger(dcontext.Background()).Errorf("error writing health status response body: %v", err) + dcontext.GetLogger(r.Context()).Errorf("error writing health status response body: %v", err) } } diff --git a/health/health_test.go b/health/health_test.go index 80b5ebed6..ac87c8634 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -1,11 +1,13 @@ package health import ( + "context" "errors" "fmt" "net/http" "net/http/httptest" "testing" + "time" ) // TestReturns200IfThereAreNoChecks ensures that the result code of the health @@ -36,7 +38,7 @@ func TestReturns503IfThereAreErrorChecks(t *testing.T) { } // Create a manual error - Register("some_check", CheckFunc(func() error { + Register("some_check", CheckFunc(func(context.Context) error { return errors.New("This Check did not succeed") })) @@ -105,3 +107,87 @@ func TestHealthHandler(t *testing.T) { updater.Update(nil) checkUp(t, "when server is back up") // now we should be back up. } + +func TestThresholdStatusUpdater(t *testing.T) { + u := NewThresholdStatusUpdater(3) + + assertCheckOK := func() { + t.Helper() + if err := u.Check(context.Background()); err != nil { + t.Errorf("u.Check() = %v; want nil", err) + } + } + + assertCheckErr := func(expected string) { + t.Helper() + if err := u.Check(context.Background()); err == nil || err.Error() != expected { + t.Errorf("u.Check() = %v; want %v", err, expected) + } + } + + // Updater should report healthy until the threshold is reached. + for i := 1; i <= 3; i++ { + assertCheckOK() + u.Update(fmt.Errorf("fake error %d", i)) + } + assertCheckErr("fake error 3") + + // The threshold should reset after one successful update. + u.Update(nil) + assertCheckOK() + u.Update(errors.New("first errored update after reset")) + assertCheckOK() + u.Update(nil) + + // pollingTerminatedErr should bypass the threshold. + pte := pollingTerminatedErr{Err: errors.New("womp womp")} + u.Update(pte) + assertCheckErr(pte.Error()) +} + +func TestPoll(t *testing.T) { + type ContextKey struct{} + for _, threshold := range []int{0, 10} { + t.Run(fmt.Sprintf("threshold=%d", threshold), func(t *testing.T) { + ctx, cancel := context.WithCancel(context.WithValue(context.Background(), ContextKey{}, t.Name())) + defer cancel() + checkerCalled := make(chan struct{}) + checker := CheckFunc(func(ctx context.Context) error { + if v, ok := ctx.Value(ContextKey{}).(string); !ok || v != t.Name() { + t.Errorf("unexpected context passed into checker: got context with value %q, want %q", v, t.Name()) + } + select { + case <-checkerCalled: + default: + close(checkerCalled) + } + return nil + }) + + updater := NewThresholdStatusUpdater(threshold) + pollReturned := make(chan struct{}) + go func() { + Poll(ctx, updater, checker, 1*time.Millisecond) + close(pollReturned) + }() + + select { + case <-checkerCalled: + case <-time.After(1 * time.Second): + t.Error("checker has not been polled") + } + + cancel() + + select { + case <-pollReturned: + case <-time.After(1 * time.Second): + t.Error("poll has not returned after context was canceled") + } + + if err := updater.Check(context.Background()); !errors.Is(err, context.Canceled) { + t.Errorf("updater.Check() = %v; want %v", err, context.Canceled) + } + }) + } +} diff --git a/registry/handlers/app.go b/registry/handlers/app.go index fb8e9dd29..6866eb823 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -348,22 +348,20 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { interval = defaultCheckInterval } - storageDriverCheck := func() error { - _, err := app.driver.Stat(app, "/") // "/" should always exist + storageDriverCheck := health.CheckFunc(func(ctx context.Context) error { + _, err := app.driver.Stat(ctx, "/") // "/" should always exist if _, ok := err.(storagedriver.PathNotFoundError); ok { err = nil // pass this through, backend is responding, but this path doesn't exist. } if err != nil { - dcontext.GetLogger(app).Errorf("storage driver health check: %v", err) + dcontext.GetLogger(ctx).Errorf("storage driver health check: %v", err) } return err - } + }) - if app.Config.Health.StorageDriver.Threshold != 0 { - healthRegistry.RegisterPeriodicThresholdFunc("storagedriver_"+app.Config.Storage.Type(), interval, app.Config.Health.StorageDriver.Threshold, storageDriverCheck) - } else { - healthRegistry.RegisterPeriodicFunc("storagedriver_"+app.Config.Storage.Type(), interval, storageDriverCheck) - } + updater := health.NewThresholdStatusUpdater(app.Config.Health.StorageDriver.Threshold) + healthRegistry.Register("storagedriver_"+app.Config.Storage.Type(), updater) + go health.Poll(app, updater, storageDriverCheck, interval) } for _, fileChecker := range app.Config.Health.FileCheckers { @@ -372,7 +370,9 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { interval = defaultCheckInterval } dcontext.GetLogger(app).Infof("configuring file health check path=%s, interval=%d", fileChecker.File, interval/time.Second) - healthRegistry.Register(fileChecker.File, health.PeriodicChecker(checks.FileChecker(fileChecker.File), interval)) + u := health.NewStatusUpdater() + healthRegistry.Register(fileChecker.File, u) + go health.Poll(app, u, checks.FileChecker(fileChecker.File), interval) } for _, httpChecker := range app.Config.Health.HTTPCheckers { @@ -388,13 +388,10 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { checker := checks.HTTPChecker(httpChecker.URI, statusCode, httpChecker.Timeout, httpChecker.Headers) - if httpChecker.Threshold != 0 { - dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold) - healthRegistry.Register(httpChecker.URI, health.PeriodicThresholdChecker(checker, interval, httpChecker.Threshold)) - } else { - dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d", httpChecker.URI, interval/time.Second) - healthRegistry.Register(httpChecker.URI, health.PeriodicChecker(checker, interval)) - } + dcontext.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold) + updater := health.NewThresholdStatusUpdater(httpChecker.Threshold) + healthRegistry.Register(httpChecker.URI, updater) + go health.Poll(app, updater, checker, interval) } for _, tcpChecker := range app.Config.Health.TCPCheckers { @@ -405,13 +402,10 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { checker := checks.TCPChecker(tcpChecker.Addr, tcpChecker.Timeout) - if tcpChecker.Threshold != 0 { - dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d, threshold=%d", tcpChecker.Addr, interval/time.Second, tcpChecker.Threshold) - healthRegistry.Register(tcpChecker.Addr, health.PeriodicThresholdChecker(checker, interval, tcpChecker.Threshold)) - } else { - dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d", tcpChecker.Addr, interval/time.Second) - healthRegistry.Register(tcpChecker.Addr, health.PeriodicChecker(checker, interval)) - } + dcontext.GetLogger(app).Infof("configuring TCP health check addr=%s, interval=%d, threshold=%d", tcpChecker.Addr, interval/time.Second, tcpChecker.Threshold) + updater := health.NewThresholdStatusUpdater(tcpChecker.Threshold) + healthRegistry.Register(tcpChecker.Addr, updater) + go health.Poll(app, updater, checker, interval) } } diff --git a/registry/handlers/health_test.go b/registry/handlers/health_test.go index dc706b870..079ebda51 100644 --- a/registry/handlers/health_test.go +++ b/registry/handlers/health_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" "time" @@ -48,7 +49,7 @@ func TestFileHealthCheck(t *testing.T) { // Wait for health check to happen <-time.After(2 * interval) - status := healthRegistry.CheckStatus() + status := healthRegistry.CheckStatus(ctx) if len(status) != 1 { t.Fatal("expected 1 item in health check results") } @@ -59,7 +60,7 @@ func TestFileHealthCheck(t *testing.T) { os.Remove(tmpfile.Name()) <-time.After(2 * interval) - if len(healthRegistry.CheckStatus()) != 0 { + if len(healthRegistry.CheckStatus(ctx)) != 0 { t.Fatal("expected 0 items in health check results") } } @@ -112,7 +113,7 @@ func TestTCPHealthCheck(t *testing.T) { // Wait for health check to happen <-time.After(2 * interval) - if len(healthRegistry.CheckStatus()) != 0 { + if len(healthRegistry.CheckStatus(ctx)) != 0 { t.Fatal("expected 0 items in health check results") } @@ -120,11 +121,11 @@ func TestTCPHealthCheck(t *testing.T) { <-time.After(2 * interval) // Health check should now fail - status := healthRegistry.CheckStatus() + status := healthRegistry.CheckStatus(ctx) if len(status) != 1 { t.Fatal("expected 1 item in health check results") } - if status[addrStr] != "connection to "+addrStr+" failed" { + if !strings.Contains(status[addrStr], "connection failed") { t.Fatal(`did not get "connection failed" result for health check`) } } @@ -174,7 +175,7 @@ func TestHTTPHealthCheck(t *testing.T) { for i := 0; ; i++ { <-time.After(interval) - status := healthRegistry.CheckStatus() + status := healthRegistry.CheckStatus(ctx) if i < threshold-1 { // definitely shouldn't have hit the threshold yet @@ -191,7 +192,7 @@ func TestHTTPHealthCheck(t *testing.T) { if len(status) != 1 { t.Fatal("expected 1 item in health check results") } - if status[checkedServer.URL] != "downstream service returned unexpected status: 500" { + if !strings.Contains(status[checkedServer.URL], "downstream service returned unexpected status: 500") { t.Fatal("did not get expected result for health check") } @@ -203,7 +204,7 @@ func TestHTTPHealthCheck(t *testing.T) { <-time.After(2 * interval) - if len(healthRegistry.CheckStatus()) != 0 { + if len(healthRegistry.CheckStatus(ctx)) != 0 { t.Fatal("expected 0 items in health check results") } }