diff --git a/health/api/api_test.go b/health/api/api_test.go index 66aa6b419..f93d696f3 100644 --- a/health/api/api_test.go +++ b/health/api/api_test.go @@ -1,6 +1,7 @@ package api import ( + "context" "net/http" "net/http/httptest" "testing" @@ -59,7 +60,7 @@ func TestPOSTDownHandlerChangeStatus(t *testing.T) { t.Errorf("Did not get a 200.") } - if len(health.CheckStatus()) != 1 { + if len(health.CheckStatus(context.Background())) != 1 { t.Errorf("DownHandler didn't add an error check.") } } @@ -80,7 +81,7 @@ func TestPOSTUpHandlerChangeStatus(t *testing.T) { t.Errorf("Did not get a 200.") } - if len(health.CheckStatus()) != 0 { + if len(health.CheckStatus(context.Background())) != 0 { t.Errorf("UpHandler didn't remove the error check.") } } diff --git a/health/checks/checks.go b/health/checks/checks.go index 846ca2015..9382db32e 100644 --- a/health/checks/checks.go +++ b/health/checks/checks.go @@ -1,13 +1,13 @@ package checks import ( + "context" "errors" "fmt" "net" "net/http" "os" "path/filepath" - "strconv" "time" "github.com/distribution/distribution/v3/health" @@ -16,7 +16,7 @@ import ( // FileChecker checks the existence of a file and returns an error // if the file exists. func FileChecker(f string) health.Checker { - return health.CheckFunc(func() error { + return health.CheckFunc(func(context.Context) error { absoluteFilePath, err := filepath.Abs(f) if err != nil { return fmt.Errorf("failed to get absolute path for %q: %v", f, err) @@ -36,13 +36,13 @@ func FileChecker(f string) health.Checker { // HTTPChecker does a HEAD request and verifies that the HTTP status code // returned matches statusCode. func HTTPChecker(r string, statusCode int, timeout time.Duration, headers http.Header) health.Checker { - return health.CheckFunc(func() error { + return health.CheckFunc(func(ctx context.Context) error { client := http.Client{ Timeout: timeout, } - req, err := http.NewRequest(http.MethodHead, r, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodHead, r, nil) if err != nil { - return errors.New("error creating request: " + r) + return fmt.Errorf("%v: error creating request: %w", r, err) } for headerName, headerValues := range headers { for _, headerValue := range headerValues { @@ -51,11 +51,11 @@ func HTTPChecker(r string, statusCode int, timeout time.Duration, headers http.H } response, err := client.Do(req) if err != nil { - return errors.New("error while checking: " + r) + return fmt.Errorf("%v: error while checking: %w", r, err) } defer response.Body.Close() if response.StatusCode != statusCode { - return errors.New("downstream service returned unexpected status: " + strconv.Itoa(response.StatusCode)) + return fmt.Errorf("%v: downstream service returned unexpected status: %d", r, response.StatusCode) } return nil }) @@ -63,10 +63,11 @@ func HTTPChecker(r string, statusCode int, timeout time.Duration, headers http.H // TCPChecker attempts to open a TCP connection. func TCPChecker(addr string, timeout time.Duration) health.Checker { - return health.CheckFunc(func() error { - conn, err := net.DialTimeout("tcp", addr, timeout) + return health.CheckFunc(func(ctx context.Context) error { + d := net.Dialer{Timeout: timeout} + conn, err := d.DialContext(ctx, "tcp", addr) if err != nil { - return errors.New("connection to " + addr + " failed") + return fmt.Errorf("%v: connection failed: %w", addr, err) } conn.Close() return nil diff --git a/health/checks/checks_test.go b/health/checks/checks_test.go index 6b6dd14fa..60332f180 100644 --- a/health/checks/checks_test.go +++ b/health/checks/checks_test.go @@ -1,25 +1,26 @@ package checks import ( + "context" "testing" ) func TestFileChecker(t *testing.T) { - if err := FileChecker("/tmp").Check(); err == nil { + if err := FileChecker("/tmp").Check(context.Background()); err == nil { t.Errorf("/tmp was expected as exists") } - if err := FileChecker("NoSuchFileFromMoon").Check(); err != nil { + if err := FileChecker("NoSuchFileFromMoon").Check(context.Background()); err != nil { t.Errorf("NoSuchFileFromMoon was expected as not exists, error:%v", err) } } func TestHTTPChecker(t *testing.T) { - if err := HTTPChecker("https://www.google.cybertron", 200, 0, nil).Check(); err == nil { + if err := HTTPChecker("https://www.google.cybertron", 200, 0, nil).Check(context.Background()); err == nil { t.Errorf("Google on Cybertron was expected as not exists") } - if err := HTTPChecker("https://www.google.pt", 200, 0, nil).Check(); err != nil { + if err := HTTPChecker("https://www.google.pt", 200, 0, nil).Check(context.Background()); err != nil { t.Errorf("Google at Portugal was expected as exists, error:%v", err) } } diff --git a/health/health.go b/health/health.go index 67873f597..89cd7c4f3 100644 --- a/health/health.go +++ b/health/health.go @@ -1,6 +1,7 @@ package health import ( + "context" "encoding/json" "fmt" "net/http" @@ -35,17 +36,17 @@ var DefaultRegistry *Registry // Checker is the interface for a Health Checker type Checker interface { // Check returns nil if the service is okay. - Check() error + Check(context.Context) error } // CheckFunc is a convenience type to create functions that implement // the Checker interface -type CheckFunc func() error +type CheckFunc func(context.Context) error // Check Implements the Checker interface to allow for any func() error method // to be passed as a Checker -func (cf CheckFunc) Check() error { - return cf() +func (cf CheckFunc) Check(ctx context.Context) error { + return cf(ctx) } // Updater implements a health check that is explicitly set. @@ -66,7 +67,7 @@ type updater struct { } // Check implements the Checker interface -func (u *updater) Check() error { +func (u *updater) Check(context.Context) error { u.mu.Lock() defer u.mu.Unlock() @@ -99,7 +100,7 @@ type thresholdUpdater struct { } // Check implements the Checker interface -func (tu *thresholdUpdater) Check() error { +func (tu *thresholdUpdater) Check(context.Context) error { tu.mu.Lock() defer tu.mu.Unlock() @@ -138,7 +139,7 @@ func PeriodicChecker(check Checker, period time.Duration) Checker { defer t.Stop() for { <-t.C - u.Update(check.Check()) + u.Update(check.Check(context.Background())) } }() @@ -154,7 +155,7 @@ func PeriodicThresholdChecker(check Checker, period time.Duration, threshold int defer t.Stop() for { <-t.C - tu.Update(check.Check()) + tu.Update(check.Check(context.Background())) } }() @@ -162,12 +163,12 @@ func PeriodicThresholdChecker(check Checker, period time.Duration, threshold int } // CheckStatus returns a map with all the current health check errors -func (registry *Registry) CheckStatus() map[string]string { // TODO(stevvooe) this needs a proper type +func (registry *Registry) CheckStatus(ctx context.Context) map[string]string { // TODO(stevvooe) this needs a proper type registry.mu.RLock() defer registry.mu.RUnlock() statusKeys := make(map[string]string) for k, v := range registry.registeredChecks { - err := v.Check() + err := v.Check(ctx) if err != nil { statusKeys[k] = err.Error() } @@ -178,8 +179,8 @@ func (registry *Registry) CheckStatus() map[string]string { // TODO(stevvooe) th // CheckStatus returns a map with all the current health check errors from the // default registry. -func CheckStatus() map[string]string { - return DefaultRegistry.CheckStatus() +func CheckStatus(ctx context.Context) map[string]string { + return DefaultRegistry.CheckStatus(ctx) } // Register associates the checker with the provided name. @@ -203,19 +204,19 @@ func Register(name string, check Checker) { } // RegisterFunc allows the convenience of registering a checker directly from -// an arbitrary func() error. -func (registry *Registry) RegisterFunc(name string, check func() error) { - registry.Register(name, CheckFunc(check)) +// an arbitrary func(context.Context) error. +func (registry *Registry) RegisterFunc(name string, check CheckFunc) { + registry.Register(name, check) } // RegisterFunc allows the convenience of registering a checker in the default -// registry directly from an arbitrary func() error. -func RegisterFunc(name string, check func() error) { +// registry directly from an arbitrary func(context.Context) error. +func RegisterFunc(name string, check CheckFunc) { DefaultRegistry.RegisterFunc(name, check) } // RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker -// from an arbitrary func() error. +// from an arbitrary func(context.Context) error. func (registry *Registry) RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) { registry.Register(name, PeriodicChecker(check, period)) } @@ -243,7 +244,7 @@ func RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold // Returns 503 if any Error status exists, 200 otherwise func StatusHandler(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { - checks := CheckStatus() + checks := CheckStatus(r.Context()) status := http.StatusOK // If there is an error, return 503 @@ -263,7 +264,7 @@ func StatusHandler(w http.ResponseWriter, r *http.Request) { // disable a web application when the health checks fail. func Handler(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - checks := CheckStatus() + checks := CheckStatus(r.Context()) if len(checks) != 0 { // NOTE(milosgajdos): disable errcheck as the error is // accessible via /debug/health diff --git a/health/health_test.go b/health/health_test.go index 80b5ebed6..c877ed7fe 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -1,6 +1,7 @@ package health import ( + "context" "errors" "fmt" "net/http" @@ -36,7 +37,7 @@ func TestReturns503IfThereAreErrorChecks(t *testing.T) { } // Create a manual error - Register("some_check", CheckFunc(func() error { + Register("some_check", CheckFunc(func(context.Context) error { return errors.New("This Check did not succeed") })) diff --git a/registry/handlers/app.go b/registry/handlers/app.go index fb8e9dd29..0e4d528b8 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -348,7 +348,7 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { interval = defaultCheckInterval } - storageDriverCheck := func() error { + storageDriverCheck := func(context.Context) error { _, err := app.driver.Stat(app, "/") // "/" should always exist if _, ok := err.(storagedriver.PathNotFoundError); ok { err = nil // pass this through, backend is responding, but this path doesn't exist. diff --git a/registry/handlers/health_test.go b/registry/handlers/health_test.go index dc706b870..079ebda51 100644 --- a/registry/handlers/health_test.go +++ b/registry/handlers/health_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" "time" @@ -48,7 +49,7 @@ func TestFileHealthCheck(t *testing.T) { // Wait for health check to happen <-time.After(2 * interval) - status := healthRegistry.CheckStatus() + status := healthRegistry.CheckStatus(ctx) if len(status) != 1 { t.Fatal("expected 1 item in health check results") } @@ -59,7 +60,7 @@ func TestFileHealthCheck(t *testing.T) { os.Remove(tmpfile.Name()) <-time.After(2 * interval) - if len(healthRegistry.CheckStatus()) != 0 { + if len(healthRegistry.CheckStatus(ctx)) != 0 { t.Fatal("expected 0 items in health check results") } } @@ -112,7 +113,7 @@ func TestTCPHealthCheck(t *testing.T) { // Wait for health check to happen <-time.After(2 * interval) - if len(healthRegistry.CheckStatus()) != 0 { + if len(healthRegistry.CheckStatus(ctx)) != 0 { t.Fatal("expected 0 items in health check results") } @@ -120,11 +121,11 @@ func TestTCPHealthCheck(t *testing.T) { <-time.After(2 * interval) // Health check should now fail - status := healthRegistry.CheckStatus() + status := healthRegistry.CheckStatus(ctx) if len(status) != 1 { t.Fatal("expected 1 item in health check results") } - if status[addrStr] != "connection to "+addrStr+" failed" { + if !strings.Contains(status[addrStr], "connection failed") { t.Fatal(`did not get "connection failed" result for health check`) } } @@ -174,7 +175,7 @@ func TestHTTPHealthCheck(t *testing.T) { for i := 0; ; i++ { <-time.After(interval) - status := healthRegistry.CheckStatus() + status := healthRegistry.CheckStatus(ctx) if i < threshold-1 { // definitely shouldn't have hit the threshold yet @@ -191,7 +192,7 @@ func TestHTTPHealthCheck(t *testing.T) { if len(status) != 1 { t.Fatal("expected 1 item in health check results") } - if status[checkedServer.URL] != "downstream service returned unexpected status: 500" { + if !strings.Contains(status[checkedServer.URL], "downstream service returned unexpected status: 500") { t.Fatal("did not get expected result for health check") } @@ -203,7 +204,7 @@ func TestHTTPHealthCheck(t *testing.T) { <-time.After(2 * interval) - if len(healthRegistry.CheckStatus()) != 0 { + if len(healthRegistry.CheckStatus(ctx)) != 0 { t.Fatal("expected 0 items in health check results") } }