diff --git a/health/doc.go b/health/doc.go index 8faa32f7c..194b8a566 100644 --- a/health/doc.go +++ b/health/doc.go @@ -39,7 +39,7 @@ // // The recommended way of registering checks is using a periodic Check. // PeriodicChecks run on a certain schedule and asynchronously update the -// status of the check. This allows "CheckStatus()" to return without blocking +// status of the check. This allows CheckStatus to return without blocking // on an expensive check. // // A trivial example of a check that runs every 5 seconds and shuts down our diff --git a/health/health.go b/health/health.go index dab2794df..220282dcd 100644 --- a/health/health.go +++ b/health/health.go @@ -11,10 +11,26 @@ import ( "github.com/docker/distribution/registry/api/errcode" ) -var ( - mutex sync.RWMutex - registeredChecks = make(map[string]Checker) -) +// A Registry is a collection of checks. Most applications will use the global +// registry defined in DefaultRegistry. However, unit tests may need to create +// separate registries to isolate themselves from other tests. +type Registry struct { + mu sync.RWMutex + registeredChecks map[string]Checker +} + +// NewRegistry creates a new registry. This isn't necessary for normal use of +// the package, but may be useful for unit tests so individual tests have their +// own set of checks. +func NewRegistry() *Registry { + return &Registry{ + registeredChecks: make(map[string]Checker), + } +} + +// DefaultRegistry is the default registry where checks are registered. It is +// the registry used by the HTTP handler. +var DefaultRegistry *Registry // Checker is the interface for a Health Checker type Checker interface { @@ -144,11 +160,11 @@ func PeriodicThresholdChecker(check Checker, period time.Duration, threshold int } // CheckStatus returns a map with all the current health check errors -func CheckStatus() map[string]string { // TODO(stevvooe) this needs a proper type - mutex.RLock() - defer mutex.RUnlock() +func (registry *Registry) CheckStatus() map[string]string { // TODO(stevvooe) this needs a proper type + registry.mu.RLock() + defer registry.mu.RUnlock() statusKeys := make(map[string]string) - for k, v := range registeredChecks { + for k, v := range registry.registeredChecks { err := v.Check() if err != nil { statusKeys[k] = err.Error() @@ -158,48 +174,66 @@ func CheckStatus() map[string]string { // TODO(stevvooe) this needs a proper typ return statusKeys } -// Register associates the checker with the provided name. We allow -// overwrites to a specific check status. -func Register(name string, check Checker) { - mutex.Lock() - defer mutex.Unlock() - _, ok := registeredChecks[name] +// CheckStatus returns a map with all the current health check errors from the +// default registry. +func CheckStatus() map[string]string { + return DefaultRegistry.CheckStatus() +} + +// Register associates the checker with the provided name. +func (registry *Registry) Register(name string, check Checker) { + if registry == nil { + registry = DefaultRegistry + } + registry.mu.Lock() + defer registry.mu.Unlock() + _, ok := registry.registeredChecks[name] if ok { panic("Check already exists: " + name) } - registeredChecks[name] = check + registry.registeredChecks[name] = check } -// Unregister removes the named checker. -func Unregister(name string) { - mutex.Lock() - defer mutex.Unlock() - delete(registeredChecks, name) +// Register associates the checker with the provided name in the default +// registry. +func Register(name string, check Checker) { + DefaultRegistry.Register(name, check) } -// UnregisterAll removes all registered checkers. -func UnregisterAll() { - mutex.Lock() - defer mutex.Unlock() - registeredChecks = make(map[string]Checker) +// RegisterFunc allows the convenience of registering a checker directly from +// an arbitrary func() error. +func (registry *Registry) RegisterFunc(name string, check func() error) { + registry.Register(name, CheckFunc(check)) } -// RegisterFunc allows the convenience of registering a checker directly -// from an arbitrary func() error +// RegisterFunc allows the convenience of registering a checker in the default +// registry directly from an arbitrary func() error. func RegisterFunc(name string, check func() error) { - Register(name, CheckFunc(check)) + DefaultRegistry.RegisterFunc(name, check) } // RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker -// from an arbitrary func() error +// from an arbitrary func() error. +func (registry *Registry) RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) { + registry.Register(name, PeriodicChecker(CheckFunc(check), period)) +} + +// RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker +// in the default registry from an arbitrary func() error. func RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) { - Register(name, PeriodicChecker(CheckFunc(check), period)) + DefaultRegistry.RegisterPeriodicFunc(name, period, check) } // RegisterPeriodicThresholdFunc allows the convenience of registering a -// PeriodicChecker from an arbitrary func() error +// PeriodicChecker from an arbitrary func() error. +func (registry *Registry) RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) { + registry.Register(name, PeriodicThresholdChecker(CheckFunc(check), period, threshold)) +} + +// RegisterPeriodicThresholdFunc allows the convenience of registering a +// PeriodicChecker in the default registry from an arbitrary func() error. func RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) { - Register(name, PeriodicThresholdChecker(CheckFunc(check), period, threshold)) + DefaultRegistry.RegisterPeriodicThresholdFunc(name, period, threshold, check) } // StatusHandler returns a JSON blob with all the currently registered Health Checks @@ -265,7 +299,8 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m } } -// Registers global /debug/health api endpoint +// Registers global /debug/health api endpoint, creates default registry func init() { + DefaultRegistry = NewRegistry() http.HandleFunc("/debug/health", StatusHandler) } diff --git a/health/health_test.go b/health/health_test.go index 3228cb801..766fe159f 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -51,7 +51,7 @@ func TestReturns503IfThereAreErrorChecks(t *testing.T) { // the web application when things aren't so healthy. func TestHealthHandler(t *testing.T) { // clear out existing checks. - registeredChecks = make(map[string]Checker) + DefaultRegistry = NewRegistry() // protect an http server handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 9cf6447a6..91f4e1a37 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -234,7 +234,15 @@ func NewApp(ctx context.Context, configuration configuration.Configuration) *App // process. Because the configuration and app are tightly coupled, // implementing this properly will require a refactor. This method may panic // if called twice in the same process. -func (app *App) RegisterHealthChecks() { +func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { + if len(healthRegistries) > 1 { + panic("RegisterHealthChecks called with more than one registry") + } + healthRegistry := health.DefaultRegistry + if len(healthRegistries) == 1 { + healthRegistry = healthRegistries[0] + } + if app.Config.Health.StorageDriver.Enabled { interval := app.Config.Health.StorageDriver.Interval if interval == 0 { @@ -247,9 +255,9 @@ func (app *App) RegisterHealthChecks() { } if app.Config.Health.StorageDriver.Threshold != 0 { - health.RegisterPeriodicThresholdFunc("storagedriver_"+app.Config.Storage.Type(), interval, app.Config.Health.StorageDriver.Threshold, storageDriverCheck) + healthRegistry.RegisterPeriodicThresholdFunc("storagedriver_"+app.Config.Storage.Type(), interval, app.Config.Health.StorageDriver.Threshold, storageDriverCheck) } else { - health.RegisterPeriodicFunc("storagedriver_"+app.Config.Storage.Type(), interval, storageDriverCheck) + healthRegistry.RegisterPeriodicFunc("storagedriver_"+app.Config.Storage.Type(), interval, storageDriverCheck) } } @@ -260,10 +268,10 @@ func (app *App) RegisterHealthChecks() { } if fileChecker.Threshold != 0 { ctxu.GetLogger(app).Infof("configuring file health check path=%s, interval=%d, threshold=%d", fileChecker.File, interval/time.Second, fileChecker.Threshold) - health.Register(fileChecker.File, health.PeriodicThresholdChecker(checks.FileChecker(fileChecker.File), interval, fileChecker.Threshold)) + healthRegistry.Register(fileChecker.File, health.PeriodicThresholdChecker(checks.FileChecker(fileChecker.File), interval, fileChecker.Threshold)) } else { ctxu.GetLogger(app).Infof("configuring file health check path=%s, interval=%d", fileChecker.File, interval/time.Second) - health.Register(fileChecker.File, health.PeriodicChecker(checks.FileChecker(fileChecker.File), interval)) + healthRegistry.Register(fileChecker.File, health.PeriodicChecker(checks.FileChecker(fileChecker.File), interval)) } } @@ -274,10 +282,10 @@ func (app *App) RegisterHealthChecks() { } if httpChecker.Threshold != 0 { ctxu.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold) - health.Register(httpChecker.URI, health.PeriodicThresholdChecker(checks.HTTPChecker(httpChecker.URI), interval, httpChecker.Threshold)) + healthRegistry.Register(httpChecker.URI, health.PeriodicThresholdChecker(checks.HTTPChecker(httpChecker.URI), interval, httpChecker.Threshold)) } else { ctxu.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d", httpChecker.URI, interval/time.Second) - health.Register(httpChecker.URI, health.PeriodicChecker(checks.HTTPChecker(httpChecker.URI), interval)) + healthRegistry.Register(httpChecker.URI, health.PeriodicChecker(checks.HTTPChecker(httpChecker.URI), interval)) } } } diff --git a/registry/handlers/health_test.go b/registry/handlers/health_test.go index 38ea9b2fa..de2b71ccb 100644 --- a/registry/handlers/health_test.go +++ b/registry/handlers/health_test.go @@ -1,7 +1,6 @@ package handlers import ( - "encoding/json" "io/ioutil" "net/http" "net/http/httptest" @@ -15,9 +14,6 @@ import ( ) func TestFileHealthCheck(t *testing.T) { - // In case other tests registered checks before this one - health.UnregisterAll() - interval := time.Second tmpfile, err := ioutil.TempFile(os.TempDir(), "healthcheck") @@ -43,60 +39,29 @@ func TestFileHealthCheck(t *testing.T) { ctx := context.Background() app := NewApp(ctx, config) - app.RegisterHealthChecks() - - debugServer := httptest.NewServer(nil) + healthRegistry := health.NewRegistry() + app.RegisterHealthChecks(healthRegistry) // Wait for health check to happen <-time.After(2 * interval) - resp, err := http.Get(debugServer.URL + "/debug/health") - if err != nil { - t.Fatalf("error performing HTTP GET: %v", err) + status := healthRegistry.CheckStatus() + if len(status) != 1 { + t.Fatal("expected 1 item in health check results") } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading HTTP body: %v", err) - } - resp.Body.Close() - var decoded map[string]string - err = json.Unmarshal(body, &decoded) - if err != nil { - t.Fatalf("error unmarshaling json: %v", err) - } - if len(decoded) != 1 { - t.Fatal("expected 1 item in returned json") - } - if decoded[tmpfile.Name()] != "file exists" { + if status[tmpfile.Name()] != "file exists" { t.Fatal(`did not get "file exists" result for health check`) } os.Remove(tmpfile.Name()) <-time.After(2 * interval) - resp, err = http.Get(debugServer.URL + "/debug/health") - if err != nil { - t.Fatalf("error performing HTTP GET: %v", err) - } - body, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading HTTP body: %v", err) - } - resp.Body.Close() - var decoded2 map[string]string - err = json.Unmarshal(body, &decoded2) - if err != nil { - t.Fatalf("error unmarshaling json: %v", err) - } - if len(decoded2) != 0 { - t.Fatal("expected 0 items in returned json") + if len(healthRegistry.CheckStatus()) != 0 { + t.Fatal("expected 0 items in health check results") } } func TestHTTPHealthCheck(t *testing.T) { - // In case other tests registered checks before this one - health.UnregisterAll() - interval := time.Second threshold := 3 @@ -132,32 +97,18 @@ func TestHTTPHealthCheck(t *testing.T) { ctx := context.Background() app := NewApp(ctx, config) - app.RegisterHealthChecks() - - debugServer := httptest.NewServer(nil) + healthRegistry := health.NewRegistry() + app.RegisterHealthChecks(healthRegistry) for i := 0; ; i++ { <-time.After(interval) - resp, err := http.Get(debugServer.URL + "/debug/health") - if err != nil { - t.Fatalf("error performing HTTP GET: %v", err) - } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading HTTP body: %v", err) - } - resp.Body.Close() - var decoded map[string]string - err = json.Unmarshal(body, &decoded) - if err != nil { - t.Fatalf("error unmarshaling json: %v", err) - } + status := healthRegistry.CheckStatus() if i < threshold-1 { // definitely shouldn't have hit the threshold yet - if len(decoded) != 0 { - t.Fatal("expected 1 items in returned json") + if len(status) != 0 { + t.Fatal("expected 1 item in health check results") } continue } @@ -166,10 +117,10 @@ func TestHTTPHealthCheck(t *testing.T) { continue } - if len(decoded) != 1 { - t.Fatal("expected 1 item in returned json") + if len(status) != 1 { + t.Fatal("expected 1 item in health check results") } - if decoded[checkedServer.URL] != "downstream service returned unexpected status: 500" { + if status[checkedServer.URL] != "downstream service returned unexpected status: 500" { t.Fatal("did not get expected result for health check") } @@ -180,21 +131,8 @@ func TestHTTPHealthCheck(t *testing.T) { close(stopFailing) <-time.After(2 * interval) - resp, err := http.Get(debugServer.URL + "/debug/health") - if err != nil { - t.Fatalf("error performing HTTP GET: %v", err) - } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading HTTP body: %v", err) - } - resp.Body.Close() - var decoded map[string]string - err = json.Unmarshal(body, &decoded) - if err != nil { - t.Fatalf("error unmarshaling json: %v", err) - } - if len(decoded) != 0 { - t.Fatal("expected 0 items in returned json") + + if len(healthRegistry.CheckStatus()) != 0 { + t.Fatal("expected 0 items in health check results") } }