From b45b6d18b8248c4a79c27621e1e83054894b248e Mon Sep 17 00:00:00 2001
From: Cory Snider <csnider@mirantis.com>
Date: Fri, 27 Oct 2023 17:33:55 -0400
Subject: [PATCH 1/2] storage/driver: plumb contexts into factories

...and driver constructors when applicable.

Signed-off-by: Cory Snider <csnider@mirantis.com>
---
 registry/handlers/api_test.go                |  2 +-
 registry/handlers/app.go                     |  2 +-
 registry/root.go                             | 12 ++++++------
 registry/storage/driver/azure/azure.go       |  6 +++---
 registry/storage/driver/azure/azure_test.go  |  2 +-
 registry/storage/driver/factory/factory.go   |  7 ++++---
 registry/storage/driver/filesystem/driver.go |  2 +-
 registry/storage/driver/gcs/gcs.go           | 11 +++++------
 registry/storage/driver/inmemory/driver.go   |  2 +-
 registry/storage/driver/s3-aws/s3.go         | 10 +++++-----
 registry/storage/driver/s3-aws/s3_test.go    |  2 +-
 11 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/registry/handlers/api_test.go b/registry/handlers/api_test.go
index ae0c67cf6..7f2606e2f 100644
--- a/registry/handlers/api_test.go
+++ b/registry/handlers/api_test.go
@@ -1377,7 +1377,7 @@ const (
 	repositoryWithGenericStorageError = "genericstorageerr"
 )
 
-func (factory *storageManifestErrDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
+func (factory *storageManifestErrDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
 	// Initialize the mock driver
 	errGenericStorage := errors.New("generic storage error")
 	return &mockErrorDriver{
diff --git a/registry/handlers/app.go b/registry/handlers/app.go
index 8efdaf85e..53367277e 100644
--- a/registry/handlers/app.go
+++ b/registry/handlers/app.go
@@ -116,7 +116,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
 	storageParams["useragent"] = fmt.Sprintf("distribution/%s %s", version.Version, runtime.Version())
 
 	var err error
-	app.driver, err = factory.Create(config.Storage.Type(), storageParams)
+	app.driver, err = factory.Create(app, config.Storage.Type(), storageParams)
 	if err != nil {
 		// TODO(stevvooe): Move the creation of a service into a protected
 		// method, where this is created lazily. Its status can be queried via
diff --git a/registry/root.go b/registry/root.go
index 51c8c7ffb..fb370ca80 100644
--- a/registry/root.go
+++ b/registry/root.go
@@ -53,12 +53,6 @@ var GCCmd = &cobra.Command{
 			os.Exit(1)
 		}
 
-		driver, err := factory.Create(config.Storage.Type(), config.Storage.Parameters())
-		if err != nil {
-			fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err)
-			os.Exit(1)
-		}
-
 		ctx := dcontext.Background()
 		ctx, err = configureLogging(ctx, config)
 		if err != nil {
@@ -66,6 +60,12 @@ var GCCmd = &cobra.Command{
 			os.Exit(1)
 		}
 
+		driver, err := factory.Create(ctx, config.Storage.Type(), config.Storage.Parameters())
+		if err != nil {
+			fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err)
+			os.Exit(1)
+		}
+
 		registry, err := storage.NewRegistry(ctx, driver)
 		if err != nil {
 			fmt.Fprintf(os.Stderr, "failed to construct registry: %v", err)
diff --git a/registry/storage/driver/azure/azure.go b/registry/storage/driver/azure/azure.go
index 585c8b432..837104af9 100644
--- a/registry/storage/driver/azure/azure.go
+++ b/registry/storage/driver/azure/azure.go
@@ -46,16 +46,16 @@ func init() {
 
 type azureDriverFactory struct{}
 
-func (factory *azureDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
+func (factory *azureDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
 	params, err := NewParameters(parameters)
 	if err != nil {
 		return nil, err
 	}
-	return New(params)
+	return New(ctx, params)
 }
 
 // New constructs a new Driver from parameters
-func New(params *Parameters) (*Driver, error) {
+func New(ctx context.Context, params *Parameters) (*Driver, error) {
 	azClient, err := newAzureClient(params)
 	if err != nil {
 		return nil, err
diff --git a/registry/storage/driver/azure/azure_test.go b/registry/storage/driver/azure/azure_test.go
index 945cc0a90..3cdb3f1c1 100644
--- a/registry/storage/driver/azure/azure_test.go
+++ b/registry/storage/driver/azure/azure_test.go
@@ -68,7 +68,7 @@ func init() {
 		if err != nil {
 			return nil, err
 		}
-		return New(params)
+		return New(context.Background(), params)
 	}
 
 	// Skip Azure storage driver tests if environment variable parameters are not provided
diff --git a/registry/storage/driver/factory/factory.go b/registry/storage/driver/factory/factory.go
index 3a4b57ce5..f52684b76 100644
--- a/registry/storage/driver/factory/factory.go
+++ b/registry/storage/driver/factory/factory.go
@@ -1,6 +1,7 @@
 package factory
 
 import (
+	"context"
 	"fmt"
 
 	storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
@@ -23,7 +24,7 @@ type StorageDriverFactory interface {
 	// Create returns a new storagedriver.StorageDriver with the given parameters
 	// Parameters will vary by driver and may be ignored
 	// Each parameter key must only consist of lowercase letters and numbers
-	Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error)
+	Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error)
 }
 
 // Register makes a storage driver available by the provided name.
@@ -46,12 +47,12 @@ func Register(name string, factory StorageDriverFactory) {
 // parameters. To use a driver, the StorageDriverFactory must first be
 // registered with the given name. If no drivers are found, an
 // InvalidStorageDriverError is returned
-func Create(name string, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
+func Create(ctx context.Context, name string, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
 	driverFactory, ok := driverFactories[name]
 	if !ok {
 		return nil, InvalidStorageDriverError{name}
 	}
-	return driverFactory.Create(parameters)
+	return driverFactory.Create(ctx, parameters)
 }
 
 // InvalidStorageDriverError records an attempt to construct an unregistered storage driver
diff --git a/registry/storage/driver/filesystem/driver.go b/registry/storage/driver/filesystem/driver.go
index 23033268f..d5514f0ba 100644
--- a/registry/storage/driver/filesystem/driver.go
+++ b/registry/storage/driver/filesystem/driver.go
@@ -40,7 +40,7 @@ func init() {
 // filesystemDriverFactory implements the factory.StorageDriverFactory interface
 type filesystemDriverFactory struct{}
 
-func (factory *filesystemDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
+func (factory *filesystemDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
 	return FromParameters(parameters)
 }
 
diff --git a/registry/storage/driver/gcs/gcs.go b/registry/storage/driver/gcs/gcs.go
index 5b276d65f..79c03c708 100644
--- a/registry/storage/driver/gcs/gcs.go
+++ b/registry/storage/driver/gcs/gcs.go
@@ -87,8 +87,8 @@ func init() {
 type gcsDriverFactory struct{}
 
 // Create StorageDriver from parameters
-func (factory *gcsDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
-	return FromParameters(parameters)
+func (factory *gcsDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
+	return FromParameters(ctx, parameters)
 }
 
 var _ storagedriver.StorageDriver = &driver{}
@@ -118,8 +118,7 @@ type baseEmbed struct {
 // FromParameters constructs a new Driver with a given parameters map
 // Required parameters:
 // - bucket
-func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
-	ctx := context.TODO()
+func FromParameters(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
 	bucket, ok := parameters["bucket"]
 	if !ok || fmt.Sprint(bucket) == "" {
 		return nil, fmt.Errorf("No bucket parameter provided")
@@ -229,11 +228,11 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
 		gcs:            gcs,
 	}
 
-	return New(params)
+	return New(ctx, params)
 }
 
 // New constructs a new driver
-func New(params driverParameters) (storagedriver.StorageDriver, error) {
+func New(ctx context.Context, params driverParameters) (storagedriver.StorageDriver, error) {
 	rootDirectory := strings.Trim(params.rootDirectory, "/")
 	if rootDirectory != "" {
 		rootDirectory += "/"
diff --git a/registry/storage/driver/inmemory/driver.go b/registry/storage/driver/inmemory/driver.go
index 4c00ca404..97c3bcde5 100644
--- a/registry/storage/driver/inmemory/driver.go
+++ b/registry/storage/driver/inmemory/driver.go
@@ -21,7 +21,7 @@ func init() {
 // inMemoryDriverFacotry implements the factory.StorageDriverFactory interface.
 type inMemoryDriverFactory struct{}
 
-func (factory *inMemoryDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
+func (factory *inMemoryDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
 	return New(), nil
 }
 
diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go
index a6428a918..3d9cb920d 100644
--- a/registry/storage/driver/s3-aws/s3.go
+++ b/registry/storage/driver/s3-aws/s3.go
@@ -150,8 +150,8 @@ func init() {
 // s3DriverFactory implements the factory.StorageDriverFactory interface
 type s3DriverFactory struct{}
 
-func (factory *s3DriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
-	return FromParameters(parameters)
+func (factory *s3DriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
+	return FromParameters(ctx, parameters)
 }
 
 var _ storagedriver.StorageDriver = &driver{}
@@ -189,7 +189,7 @@ type Driver struct {
 // - region
 // - bucket
 // - encrypt
-func FromParameters(parameters map[string]interface{}) (*Driver, error) {
+func FromParameters(ctx context.Context, parameters map[string]interface{}) (*Driver, error) {
 	// Providing no values for these is valid in case the user is authenticating
 	// with an IAM on an ec2 instance (in which case the instance credentials will
 	// be summoned when GetAuth is called)
@@ -468,7 +468,7 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
 		getS3LogLevelFromParam(parameters["loglevel"]),
 	}
 
-	return New(params)
+	return New(ctx, params)
 }
 
 func getS3LogLevelFromParam(param interface{}) aws.LogLevelType {
@@ -529,7 +529,7 @@ func getParameterAsInt64(parameters map[string]interface{}, name string, default
 
 // New constructs a new Driver with the given AWS credentials, region, encryption flag, and
 // bucketName
-func New(params DriverParameters) (*Driver, error) {
+func New(ctx context.Context, params DriverParameters) (*Driver, error) {
 	if !params.V4Auth &&
 		(params.RegionEndpoint == "" ||
 			strings.Contains(params.RegionEndpoint, "s3.amazonaws.com")) {
diff --git a/registry/storage/driver/s3-aws/s3_test.go b/registry/storage/driver/s3-aws/s3_test.go
index ebd9d2a2b..7bfae2aa0 100644
--- a/registry/storage/driver/s3-aws/s3_test.go
+++ b/registry/storage/driver/s3-aws/s3_test.go
@@ -141,7 +141,7 @@ func init() {
 			getS3LogLevelFromParam(logLevel),
 		}
 
-		return New(parameters)
+		return New(context.Background(), parameters)
 	}
 
 	// Skip S3 storage driver tests if environment variable parameters are not provided

From b4dc4f3474e445910148f4a70103fd7186460b49 Mon Sep 17 00:00:00 2001
From: Cory Snider <csnider@mirantis.com>
Date: Fri, 27 Oct 2023 17:46:09 -0400
Subject: [PATCH 2/2] storage/driver: plumb contexts into middlewares

Signed-off-by: Cory Snider <csnider@mirantis.com>
---
 registry/handlers/app.go                      |  6 +--
 .../middleware/cloudfront/middleware.go       | 12 +++--
 .../middleware/cloudfront/middleware_test.go  |  5 ++-
 .../driver/middleware/cloudfront/s3filter.go  | 31 ++++++++-----
 .../middleware/cloudfront/s3filter_test.go    | 45 +++++++++++--------
 .../driver/middleware/redirect/middleware.go  |  2 +-
 .../middleware/redirect/middleware_test.go    | 12 ++---
 .../driver/middleware/storagemiddleware.go    |  7 +--
 8 files changed, 72 insertions(+), 48 deletions(-)

diff --git a/registry/handlers/app.go b/registry/handlers/app.go
index 53367277e..5f017b700 100644
--- a/registry/handlers/app.go
+++ b/registry/handlers/app.go
@@ -148,7 +148,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
 
 	startUploadPurger(app, app.driver, dcontext.GetLogger(app), purgeConfig)
 
-	app.driver, err = applyStorageMiddleware(app.driver, config.Middleware["storage"])
+	app.driver, err = applyStorageMiddleware(app, app.driver, config.Middleware["storage"])
 	if err != nil {
 		panic(err)
 	}
@@ -938,9 +938,9 @@ func applyRepoMiddleware(ctx context.Context, repository distribution.Repository
 }
 
 // applyStorageMiddleware wraps a storage driver with the configured middlewares
-func applyStorageMiddleware(driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) {
+func applyStorageMiddleware(ctx context.Context, driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) {
 	for _, mw := range middlewares {
-		smw, err := storagemiddleware.Get(mw.Name, mw.Options, driver)
+		smw, err := storagemiddleware.Get(ctx, mw.Name, mw.Options, driver)
 		if err != nil {
 			return nil, fmt.Errorf("unable to configure storage middleware (%s): %v", mw.Name, err)
 		}
diff --git a/registry/storage/driver/middleware/cloudfront/middleware.go b/registry/storage/driver/middleware/cloudfront/middleware.go
index 5c2c09955..741d618ed 100644
--- a/registry/storage/driver/middleware/cloudfront/middleware.go
+++ b/registry/storage/driver/middleware/cloudfront/middleware.go
@@ -48,7 +48,7 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
 //     default value. "aws", only aws IP goes to S3 directly. "awsregion", only
 //     regions listed in awsregion options goes to S3 directly
 //   - awsregion: a comma separated string of AWS regions.
-func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
+func newCloudFrontStorageMiddleware(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
 	// parse baseurl
 	base, ok := options["baseurl"]
 	if !ok {
@@ -157,7 +157,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
 			case "", "none":
 				awsIPs = nil
 			case "aws":
-				awsIPs = newAWSIPs(ipRangesURL, updateFrequency, nil)
+				awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, nil)
+				if err != nil {
+					return nil, err
+				}
 			case "awsregion":
 				var awsRegion []string
 				if i, ok := options["awsregion"]; ok {
@@ -165,7 +168,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
 						for _, awsRegions := range strings.Split(regions, ",") {
 							awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
 						}
-						awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion)
+						awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, awsRegion)
+						if err != nil {
+							return nil, err
+						}
 					} else {
 						return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
 					}
diff --git a/registry/storage/driver/middleware/cloudfront/middleware_test.go b/registry/storage/driver/middleware/cloudfront/middleware_test.go
index 4e75c6439..67ec077fd 100644
--- a/registry/storage/driver/middleware/cloudfront/middleware_test.go
+++ b/registry/storage/driver/middleware/cloudfront/middleware_test.go
@@ -1,6 +1,7 @@
 package middleware
 
 import (
+	"context"
 	"os"
 	"testing"
 
@@ -15,7 +16,7 @@ var _ = check.Suite(&MiddlewareSuite{})
 
 func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
 	options := make(map[string]interface{})
-	_, err := newCloudFrontStorageMiddleware(nil, options)
+	_, err := newCloudFrontStorageMiddleware(context.Background(), nil, options)
 	c.Assert(err, check.ErrorMatches, "no baseurl provided")
 }
 
@@ -48,7 +49,7 @@ pZeMRablbPQdp8/1NyIwimq1VlG0ohQ4P6qhW7E09ZMC
 	defer os.Remove(file.Name())
 	options["privatekey"] = file.Name()
 	options["keypairid"] = "test"
-	storageDriver, err := newCloudFrontStorageMiddleware(nil, options)
+	storageDriver, err := newCloudFrontStorageMiddleware(context.Background(), nil, options)
 	if err != nil {
 		t.Fatal(err)
 	}
diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go
index 25aafd043..75158c91d 100644
--- a/registry/storage/driver/middleware/cloudfront/s3filter.go
+++ b/registry/storage/driver/middleware/cloudfront/s3filter.go
@@ -3,6 +3,7 @@ package middleware
 import (
 	"context"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net"
@@ -23,18 +24,21 @@ const (
 
 // newAWSIPs returns a New awsIP object.
 // If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
-func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs {
+func newAWSIPs(ctx context.Context, host string, updateFrequency time.Duration, awsRegion []string) (*awsIPs, error) {
 	ips := &awsIPs{
 		host:            host,
 		updateFrequency: updateFrequency,
 		awsRegion:       awsRegion,
 		updaterStopChan: make(chan bool),
 	}
-	if err := ips.tryUpdate(); err != nil {
-		dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP")
+	if err := ips.tryUpdate(ctx); err != nil {
+		if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+			return nil, err
+		}
+		dcontext.GetLogger(ctx).WithError(err).Warn("failed to update AWS IP")
 	}
 	go ips.updater()
-	return ips
+	return ips, nil
 }
 
 // awsIPs tracks a list of AWS ips, filtered by awsRegion
@@ -61,9 +65,13 @@ type prefixEntry struct {
 	Service    string `json:"service"`
 }
 
-func fetchAWSIPs(url string) (awsIPResponse, error) {
+func fetchAWSIPs(ctx context.Context, url string) (awsIPResponse, error) {
 	var response awsIPResponse
-	resp, err := http.Get(url)
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+	if err != nil {
+		return response, err
+	}
+	resp, err := http.DefaultClient.Do(req)
 	if err != nil {
 		return response, err
 	}
@@ -83,8 +91,8 @@ func fetchAWSIPs(url string) (awsIPResponse, error) {
 
 // tryUpdate attempts to download the new set of ip addresses.
 // tryUpdate must be thread safe with contains
-func (s *awsIPs) tryUpdate() error {
-	response, err := fetchAWSIPs(s.host)
+func (s *awsIPs) tryUpdate(ctx context.Context) error {
+	response, err := fetchAWSIPs(ctx, s.host)
 	if err != nil {
 		return err
 	}
@@ -135,17 +143,18 @@ func (s *awsIPs) tryUpdate() error {
 // This function is meant to be run in a background goroutine.
 // It will periodically update the ips from aws.
 func (s *awsIPs) updater() {
+	ctx := context.TODO()
 	defer close(s.updaterStopChan)
 	for {
 		time.Sleep(s.updateFrequency)
 		select {
 		case <-s.updaterStopChan:
-			dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal")
+			dcontext.GetLogger(ctx).Info("aws ip updater received stop signal")
 			return
 		default:
-			err := s.tryUpdate()
+			err := s.tryUpdate(ctx)
 			if err != nil {
-				dcontext.GetLogger(context.Background()).WithError(err).Error("git  AWS IP")
+				dcontext.GetLogger(ctx).WithError(err).Error("git  AWS IP")
 			}
 		}
 	}
diff --git a/registry/storage/driver/middleware/cloudfront/s3filter_test.go b/registry/storage/driver/middleware/cloudfront/s3filter_test.go
index 81ef6aa8a..c347d0354 100644
--- a/registry/storage/driver/middleware/cloudfront/s3filter_test.go
+++ b/registry/storage/driver/middleware/cloudfront/s3filter_test.go
@@ -62,7 +62,7 @@ func TestS3TryUpdate(t *testing.T) {
 	})
 	defer server.Close()
 
-	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
+	ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil)
 
 	assertEqual(t, 1, len(ips.ipv4))
 	assertEqual(t, 0, len(ips.ipv6))
@@ -77,8 +77,9 @@ func TestMatchIPV6(t *testing.T) {
 	})
 	defer server.Close()
 
-	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
-	ips.tryUpdate()
+	ctx := context.Background()
+	ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
+	ips.tryUpdate(ctx)
 	assertEqual(t, true, ips.contains(net.ParseIP("ff00::")))
 	assertEqual(t, 1, len(ips.ipv6))
 	assertEqual(t, 0, len(ips.ipv4))
@@ -93,8 +94,9 @@ func TestMatchIPV4(t *testing.T) {
 	})
 	defer server.Close()
 
-	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
-	ips.tryUpdate()
+	ctx := context.Background()
+	ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
+	ips.tryUpdate(ctx)
 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -112,8 +114,9 @@ func TestMatchIPV4_2(t *testing.T) {
 	})
 	defer server.Close()
 
-	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
-	ips.tryUpdate()
+	ctx := context.Background()
+	ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
+	ips.tryUpdate(ctx)
 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -131,8 +134,9 @@ func TestMatchIPV4WithRegionMatched(t *testing.T) {
 	})
 	defer server.Close()
 
-	ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"})
-	ips.tryUpdate()
+	ctx := context.Background()
+	ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-east-1"})
+	ips.tryUpdate(ctx)
 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -150,8 +154,9 @@ func TestMatchIPV4WithRegionMatch_2(t *testing.T) {
 	})
 	defer server.Close()
 
-	ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
-	ips.tryUpdate()
+	ctx := context.Background()
+	ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
+	ips.tryUpdate(ctx)
 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
 	assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -169,8 +174,9 @@ func TestMatchIPV4WithRegionNotMatched(t *testing.T) {
 	})
 	defer server.Close()
 
-	ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"})
-	ips.tryUpdate()
+	ctx := context.Background()
+	ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2"})
+	ips.tryUpdate(ctx)
 	assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0")))
 	assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1")))
 	assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -187,8 +193,9 @@ func TestInvalidData(t *testing.T) {
 	})
 	defer server.Close()
 
-	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
-	ips.tryUpdate()
+	ctx := context.Background()
+	ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
+	ips.tryUpdate(ctx)
 	assertEqual(t, 1, len(ips.ipv4))
 }
 
@@ -205,7 +212,7 @@ func TestInvalidNetworkType(t *testing.T) {
 	})
 	defer server.Close()
 
-	ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
+	ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil)
 	assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
 	assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4))))  // netv4 networks
 	assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
@@ -226,7 +233,7 @@ func TestParsing(t *testing.T) {
 	t.Parallel()
 	server := httptest.NewServer(rawMockHandler)
 	defer server.Close()
-	schema, err := fetchAWSIPs(server.URL)
+	schema, err := fetchAWSIPs(context.Background(), server.URL)
 
 	assertEqual(t, nil, err)
 	assertEqual(t, 1, len(schema.Prefixes))
@@ -253,7 +260,7 @@ func TestUpdateCalledRegularly(t *testing.T) {
 			rw.Write([]byte("ok"))
 		}))
 	defer server.Close()
-	newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil)
+	newAWSIPs(context.Background(), fmt.Sprintf("%s/", server.URL), time.Second, nil)
 	time.Sleep(time.Second*4 + time.Millisecond*500)
 	if updateCount < 4 {
 		t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount)
@@ -384,7 +391,7 @@ func BenchmarkContainsRandom(b *testing.B) {
 }
 
 func BenchmarkContainsProd(b *testing.B) {
-	ips := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil)
+	ips, _ := newAWSIPs(context.Background(), defaultIPRangesURL, defaultUpdateFrequency, nil)
 	ipv4 := make([][]byte, b.N)
 	ipv6 := make([][]byte, b.N)
 	for i := 0; i < b.N; i++ {
diff --git a/registry/storage/driver/middleware/redirect/middleware.go b/registry/storage/driver/middleware/redirect/middleware.go
index 9e1b303ea..8976d8689 100644
--- a/registry/storage/driver/middleware/redirect/middleware.go
+++ b/registry/storage/driver/middleware/redirect/middleware.go
@@ -19,7 +19,7 @@ type redirectStorageMiddleware struct {
 
 var _ storagedriver.StorageDriver = &redirectStorageMiddleware{}
 
-func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
+func newRedirectStorageMiddleware(ctx context.Context, sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
 	o, ok := options["baseurl"]
 	if !ok {
 		return nil, fmt.Errorf("no baseurl provided")
diff --git a/registry/storage/driver/middleware/redirect/middleware_test.go b/registry/storage/driver/middleware/redirect/middleware_test.go
index 30d0bb192..2c22dec36 100644
--- a/registry/storage/driver/middleware/redirect/middleware_test.go
+++ b/registry/storage/driver/middleware/redirect/middleware_test.go
@@ -15,21 +15,21 @@ var _ = check.Suite(&MiddlewareSuite{})
 
 func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
 	options := make(map[string]interface{})
-	_, err := newRedirectStorageMiddleware(nil, options)
+	_, err := newRedirectStorageMiddleware(context.Background(), nil, options)
 	c.Assert(err, check.ErrorMatches, "no baseurl provided")
 }
 
 func (s *MiddlewareSuite) TestMissingScheme(c *check.C) {
 	options := make(map[string]interface{})
 	options["baseurl"] = "example.com"
-	_, err := newRedirectStorageMiddleware(nil, options)
+	_, err := newRedirectStorageMiddleware(context.Background(), nil, options)
 	c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl")
 }
 
 func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
 	options := make(map[string]interface{})
 	options["baseurl"] = "https://example.com:5443"
-	middleware, err := newRedirectStorageMiddleware(nil, options)
+	middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options)
 	c.Assert(err, check.Equals, nil)
 
 	m, ok := middleware.(*redirectStorageMiddleware)
@@ -45,7 +45,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
 func (s *MiddlewareSuite) TestHTTP(c *check.C) {
 	options := make(map[string]interface{})
 	options["baseurl"] = "http://example.com"
-	middleware, err := newRedirectStorageMiddleware(nil, options)
+	middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options)
 	c.Assert(err, check.Equals, nil)
 
 	m, ok := middleware.(*redirectStorageMiddleware)
@@ -62,7 +62,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
 	// basePath: end with no slash
 	options := make(map[string]interface{})
 	options["baseurl"] = "https://example.com/path"
-	middleware, err := newRedirectStorageMiddleware(nil, options)
+	middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options)
 	c.Assert(err, check.Equals, nil)
 
 	m, ok := middleware.(*redirectStorageMiddleware)
@@ -82,7 +82,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
 
 	// basePath: end with slash
 	options["baseurl"] = "https://example.com/path/"
-	middleware, err = newRedirectStorageMiddleware(nil, options)
+	middleware, err = newRedirectStorageMiddleware(context.Background(), nil, options)
 	c.Assert(err, check.Equals, nil)
 
 	m, ok = middleware.(*redirectStorageMiddleware)
diff --git a/registry/storage/driver/middleware/storagemiddleware.go b/registry/storage/driver/middleware/storagemiddleware.go
index bde645dc5..d2c37741e 100644
--- a/registry/storage/driver/middleware/storagemiddleware.go
+++ b/registry/storage/driver/middleware/storagemiddleware.go
@@ -1,6 +1,7 @@
 package storagemiddleware
 
 import (
+	"context"
 	"fmt"
 
 	storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
@@ -8,7 +9,7 @@ import (
 
 // InitFunc is the type of a StorageMiddleware factory function and is
 // used to register the constructor for different StorageMiddleware backends.
-type InitFunc func(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error)
+type InitFunc func(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error)
 
 var storageMiddlewares map[string]InitFunc
 
@@ -28,10 +29,10 @@ func Register(name string, initFunc InitFunc) error {
 }
 
 // Get constructs a StorageMiddleware with the given options using the named backend.
-func Get(name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) {
+func Get(ctx context.Context, name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) {
 	if storageMiddlewares != nil {
 		if initFunc, exists := storageMiddlewares[name]; exists {
-			return initFunc(storageDriver, options)
+			return initFunc(ctx, storageDriver, options)
 		}
 	}