diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 83e987d85..91c56e762 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -430,6 +430,14 @@ func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) { } } +// Shutdown close the underlying registry +func (app *App) Shutdown() error { + if r, ok := app.registry.(proxy.Closer); ok { + return r.Close() + } + return nil +} + // register a handler with the application, by route name. The handler will be // passed through the application filters and context will be constructed at // request time. diff --git a/registry/proxy/proxyregistry.go b/registry/proxy/proxyregistry.go index 33dcc4afa..55c8f4beb 100644 --- a/registry/proxy/proxyregistry.go +++ b/registry/proxy/proxyregistry.go @@ -211,6 +211,15 @@ func (pr *proxyingRegistry) BlobStatter() distribution.BlobStatter { return pr.embedded.BlobStatter() } +type Closer interface { + // Close release all resources used by this object + Close() error +} + +func (pr *proxyingRegistry) Close() error { + return pr.scheduler.Stop() +} + // authChallenger encapsulates a request to the upstream to establish credential challenges type authChallenger interface { tryEstablishChallenges(context.Context) error diff --git a/registry/proxy/scheduler/scheduler.go b/registry/proxy/scheduler/scheduler.go index ed1d9d419..78366c2ff 100644 --- a/registry/proxy/scheduler/scheduler.go +++ b/registry/proxy/scheduler/scheduler.go @@ -206,12 +206,13 @@ func (ttles *TTLExpirationScheduler) startTimer(entry *schedulerEntry, ttl time. } // Stop stops the scheduler. -func (ttles *TTLExpirationScheduler) Stop() { +func (ttles *TTLExpirationScheduler) Stop() error { ttles.Lock() defer ttles.Unlock() - if err := ttles.writeState(); err != nil { - dcontext.GetLogger(ttles.ctx).Errorf("Error writing scheduler state: %s", err) + err := ttles.writeState() + if err != nil { + err = fmt.Errorf("error writing scheduler state: %w", err) } for _, entry := range ttles.entries { @@ -221,6 +222,7 @@ func (ttles *TTLExpirationScheduler) Stop() { close(ttles.doneChan) ttles.saveTimer.Stop() ttles.stopped = true + return err } func (ttles *TTLExpirationScheduler) writeState() error { diff --git a/registry/proxy/scheduler/scheduler_test.go b/registry/proxy/scheduler/scheduler_test.go index 38fa0f580..fb0869c2d 100644 --- a/registry/proxy/scheduler/scheduler_test.go +++ b/registry/proxy/scheduler/scheduler_test.go @@ -136,7 +136,12 @@ func TestRestoreOld(t *testing.T) { if err != nil { t.Fatalf("Error starting ttlExpirationScheduler: %s", err) } - defer s.Stop() + defer func(s *TTLExpirationScheduler) { + err := s.Stop() + if err != nil { + t.Fatalf("Error stopping ttlExpirationScheduler: %s", err) + } + }(s) wg.Wait() mu.Lock() @@ -177,7 +182,10 @@ func TestStopRestore(t *testing.T) { // Start and stop before all operations complete // state will be written to fs - s.Stop() + err = s.Stop() + if err != nil { + t.Fatalf(err.Error()) + } time.Sleep(10 * time.Millisecond) // v2 will restore state from fs diff --git a/registry/registry.go b/registry/registry.go index b7989f021..3d3bf1eb1 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "net/http" "os" @@ -312,7 +313,7 @@ func (registry *Registry) ListenAndServe() error { } // setup channel to get notified on SIGTERM signal - signal.Notify(registry.quit, syscall.SIGTERM) + signal.Notify(registry.quit, os.Interrupt, syscall.SIGTERM) serveErr := make(chan error) // Start serving in goroutine and listen for stop signal in main thread @@ -332,9 +333,13 @@ func (registry *Registry) ListenAndServe() error { } } -// Shutdown gracefully shuts down the registry's HTTP server. +// Shutdown gracefully shuts down the registry's HTTP server and application object. func (registry *Registry) Shutdown(ctx context.Context) error { - return registry.server.Shutdown(ctx) + err := registry.server.Shutdown(ctx) + if appErr := registry.app.Shutdown(); appErr != nil { + err = errors.Join(err, appErr) + } + return err } func configureDebugServer(config *configuration.Configuration) {