diff --git a/registry/handlers/api_test.go b/registry/handlers/api_test.go index ae0c67cf6..7f2606e2f 100644 --- a/registry/handlers/api_test.go +++ b/registry/handlers/api_test.go @@ -1377,7 +1377,7 @@ const ( repositoryWithGenericStorageError = "genericstorageerr" ) -func (factory *storageManifestErrDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { +func (factory *storageManifestErrDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { // Initialize the mock driver errGenericStorage := errors.New("generic storage error") return &mockErrorDriver{ diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 8efdaf85e..5f017b700 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -116,7 +116,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App { storageParams["useragent"] = fmt.Sprintf("distribution/%s %s", version.Version, runtime.Version()) var err error - app.driver, err = factory.Create(config.Storage.Type(), storageParams) + app.driver, err = factory.Create(app, config.Storage.Type(), storageParams) if err != nil { // TODO(stevvooe): Move the creation of a service into a protected // method, where this is created lazily. Its status can be queried via @@ -148,7 +148,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App { startUploadPurger(app, app.driver, dcontext.GetLogger(app), purgeConfig) - app.driver, err = applyStorageMiddleware(app.driver, config.Middleware["storage"]) + app.driver, err = applyStorageMiddleware(app, app.driver, config.Middleware["storage"]) if err != nil { panic(err) } @@ -938,9 +938,9 @@ func applyRepoMiddleware(ctx context.Context, repository distribution.Repository } // applyStorageMiddleware wraps a storage driver with the configured middlewares -func applyStorageMiddleware(driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) { +func applyStorageMiddleware(ctx context.Context, driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) { for _, mw := range middlewares { - smw, err := storagemiddleware.Get(mw.Name, mw.Options, driver) + smw, err := storagemiddleware.Get(ctx, mw.Name, mw.Options, driver) if err != nil { return nil, fmt.Errorf("unable to configure storage middleware (%s): %v", mw.Name, err) } diff --git a/registry/root.go b/registry/root.go index 51c8c7ffb..fb370ca80 100644 --- a/registry/root.go +++ b/registry/root.go @@ -53,12 +53,6 @@ var GCCmd = &cobra.Command{ os.Exit(1) } - driver, err := factory.Create(config.Storage.Type(), config.Storage.Parameters()) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err) - os.Exit(1) - } - ctx := dcontext.Background() ctx, err = configureLogging(ctx, config) if err != nil { @@ -66,6 +60,12 @@ var GCCmd = &cobra.Command{ os.Exit(1) } + driver, err := factory.Create(ctx, config.Storage.Type(), config.Storage.Parameters()) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err) + os.Exit(1) + } + registry, err := storage.NewRegistry(ctx, driver) if err != nil { fmt.Fprintf(os.Stderr, "failed to construct registry: %v", err) diff --git a/registry/storage/driver/azure/azure.go b/registry/storage/driver/azure/azure.go index 585c8b432..837104af9 100644 --- a/registry/storage/driver/azure/azure.go +++ b/registry/storage/driver/azure/azure.go @@ -46,16 +46,16 @@ func init() { type azureDriverFactory struct{} -func (factory *azureDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { +func (factory *azureDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { params, err := NewParameters(parameters) if err != nil { return nil, err } - return New(params) + return New(ctx, params) } // New constructs a new Driver from parameters -func New(params *Parameters) (*Driver, error) { +func New(ctx context.Context, params *Parameters) (*Driver, error) { azClient, err := newAzureClient(params) if err != nil { return nil, err diff --git a/registry/storage/driver/azure/azure_test.go b/registry/storage/driver/azure/azure_test.go index 945cc0a90..3cdb3f1c1 100644 --- a/registry/storage/driver/azure/azure_test.go +++ b/registry/storage/driver/azure/azure_test.go @@ -68,7 +68,7 @@ func init() { if err != nil { return nil, err } - return New(params) + return New(context.Background(), params) } // Skip Azure storage driver tests if environment variable parameters are not provided diff --git a/registry/storage/driver/factory/factory.go b/registry/storage/driver/factory/factory.go index 3a4b57ce5..f52684b76 100644 --- a/registry/storage/driver/factory/factory.go +++ b/registry/storage/driver/factory/factory.go @@ -1,6 +1,7 @@ package factory import ( + "context" "fmt" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" @@ -23,7 +24,7 @@ type StorageDriverFactory interface { // Create returns a new storagedriver.StorageDriver with the given parameters // Parameters will vary by driver and may be ignored // Each parameter key must only consist of lowercase letters and numbers - Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) + Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) } // Register makes a storage driver available by the provided name. @@ -46,12 +47,12 @@ func Register(name string, factory StorageDriverFactory) { // parameters. To use a driver, the StorageDriverFactory must first be // registered with the given name. If no drivers are found, an // InvalidStorageDriverError is returned -func Create(name string, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { +func Create(ctx context.Context, name string, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { driverFactory, ok := driverFactories[name] if !ok { return nil, InvalidStorageDriverError{name} } - return driverFactory.Create(parameters) + return driverFactory.Create(ctx, parameters) } // InvalidStorageDriverError records an attempt to construct an unregistered storage driver diff --git a/registry/storage/driver/filesystem/driver.go b/registry/storage/driver/filesystem/driver.go index 23033268f..d5514f0ba 100644 --- a/registry/storage/driver/filesystem/driver.go +++ b/registry/storage/driver/filesystem/driver.go @@ -40,7 +40,7 @@ func init() { // filesystemDriverFactory implements the factory.StorageDriverFactory interface type filesystemDriverFactory struct{} -func (factory *filesystemDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { +func (factory *filesystemDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { return FromParameters(parameters) } diff --git a/registry/storage/driver/gcs/gcs.go b/registry/storage/driver/gcs/gcs.go index 5b276d65f..79c03c708 100644 --- a/registry/storage/driver/gcs/gcs.go +++ b/registry/storage/driver/gcs/gcs.go @@ -87,8 +87,8 @@ func init() { type gcsDriverFactory struct{} // Create StorageDriver from parameters -func (factory *gcsDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { - return FromParameters(parameters) +func (factory *gcsDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { + return FromParameters(ctx, parameters) } var _ storagedriver.StorageDriver = &driver{} @@ -118,8 +118,7 @@ type baseEmbed struct { // FromParameters constructs a new Driver with a given parameters map // Required parameters: // - bucket -func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { - ctx := context.TODO() +func FromParameters(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { bucket, ok := parameters["bucket"] if !ok || fmt.Sprint(bucket) == "" { return nil, fmt.Errorf("No bucket parameter provided") @@ -229,11 +228,11 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri gcs: gcs, } - return New(params) + return New(ctx, params) } // New constructs a new driver -func New(params driverParameters) (storagedriver.StorageDriver, error) { +func New(ctx context.Context, params driverParameters) (storagedriver.StorageDriver, error) { rootDirectory := strings.Trim(params.rootDirectory, "/") if rootDirectory != "" { rootDirectory += "/" diff --git a/registry/storage/driver/inmemory/driver.go b/registry/storage/driver/inmemory/driver.go index 4c00ca404..97c3bcde5 100644 --- a/registry/storage/driver/inmemory/driver.go +++ b/registry/storage/driver/inmemory/driver.go @@ -21,7 +21,7 @@ func init() { // inMemoryDriverFacotry implements the factory.StorageDriverFactory interface. type inMemoryDriverFactory struct{} -func (factory *inMemoryDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { +func (factory *inMemoryDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { return New(), nil } diff --git a/registry/storage/driver/middleware/cloudfront/middleware.go b/registry/storage/driver/middleware/cloudfront/middleware.go index 5c2c09955..741d618ed 100644 --- a/registry/storage/driver/middleware/cloudfront/middleware.go +++ b/registry/storage/driver/middleware/cloudfront/middleware.go @@ -48,7 +48,7 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{} // default value. "aws", only aws IP goes to S3 directly. "awsregion", only // regions listed in awsregion options goes to S3 directly // - awsregion: a comma separated string of AWS regions. -func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { +func newCloudFrontStorageMiddleware(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { // parse baseurl base, ok := options["baseurl"] if !ok { @@ -157,7 +157,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o case "", "none": awsIPs = nil case "aws": - awsIPs = newAWSIPs(ipRangesURL, updateFrequency, nil) + awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, nil) + if err != nil { + return nil, err + } case "awsregion": var awsRegion []string if i, ok := options["awsregion"]; ok { @@ -165,7 +168,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o for _, awsRegions := range strings.Split(regions, ",") { awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions))) } - awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion) + awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, awsRegion) + if err != nil { + return nil, err + } } else { return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions") } diff --git a/registry/storage/driver/middleware/cloudfront/middleware_test.go b/registry/storage/driver/middleware/cloudfront/middleware_test.go index 4e75c6439..67ec077fd 100644 --- a/registry/storage/driver/middleware/cloudfront/middleware_test.go +++ b/registry/storage/driver/middleware/cloudfront/middleware_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "os" "testing" @@ -15,7 +16,7 @@ var _ = check.Suite(&MiddlewareSuite{}) func (s *MiddlewareSuite) TestNoConfig(c *check.C) { options := make(map[string]interface{}) - _, err := newCloudFrontStorageMiddleware(nil, options) + _, err := newCloudFrontStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.ErrorMatches, "no baseurl provided") } @@ -48,7 +49,7 @@ pZeMRablbPQdp8/1NyIwimq1VlG0ohQ4P6qhW7E09ZMC defer os.Remove(file.Name()) options["privatekey"] = file.Name() options["keypairid"] = "test" - storageDriver, err := newCloudFrontStorageMiddleware(nil, options) + storageDriver, err := newCloudFrontStorageMiddleware(context.Background(), nil, options) if err != nil { t.Fatal(err) } diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go index 25aafd043..75158c91d 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter.go @@ -3,6 +3,7 @@ package middleware import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -23,18 +24,21 @@ const ( // newAWSIPs returns a New awsIP object. // If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified -func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs { +func newAWSIPs(ctx context.Context, host string, updateFrequency time.Duration, awsRegion []string) (*awsIPs, error) { ips := &awsIPs{ host: host, updateFrequency: updateFrequency, awsRegion: awsRegion, updaterStopChan: make(chan bool), } - if err := ips.tryUpdate(); err != nil { - dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP") + if err := ips.tryUpdate(ctx); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + dcontext.GetLogger(ctx).WithError(err).Warn("failed to update AWS IP") } go ips.updater() - return ips + return ips, nil } // awsIPs tracks a list of AWS ips, filtered by awsRegion @@ -61,9 +65,13 @@ type prefixEntry struct { Service string `json:"service"` } -func fetchAWSIPs(url string) (awsIPResponse, error) { +func fetchAWSIPs(ctx context.Context, url string) (awsIPResponse, error) { var response awsIPResponse - resp, err := http.Get(url) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return response, err + } + resp, err := http.DefaultClient.Do(req) if err != nil { return response, err } @@ -83,8 +91,8 @@ func fetchAWSIPs(url string) (awsIPResponse, error) { // tryUpdate attempts to download the new set of ip addresses. // tryUpdate must be thread safe with contains -func (s *awsIPs) tryUpdate() error { - response, err := fetchAWSIPs(s.host) +func (s *awsIPs) tryUpdate(ctx context.Context) error { + response, err := fetchAWSIPs(ctx, s.host) if err != nil { return err } @@ -135,17 +143,18 @@ func (s *awsIPs) tryUpdate() error { // This function is meant to be run in a background goroutine. // It will periodically update the ips from aws. func (s *awsIPs) updater() { + ctx := context.TODO() defer close(s.updaterStopChan) for { time.Sleep(s.updateFrequency) select { case <-s.updaterStopChan: - dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal") + dcontext.GetLogger(ctx).Info("aws ip updater received stop signal") return default: - err := s.tryUpdate() + err := s.tryUpdate(ctx) if err != nil { - dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP") + dcontext.GetLogger(ctx).WithError(err).Error("git AWS IP") } } } diff --git a/registry/storage/driver/middleware/cloudfront/s3filter_test.go b/registry/storage/driver/middleware/cloudfront/s3filter_test.go index 81ef6aa8a..c347d0354 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter_test.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter_test.go @@ -62,7 +62,7 @@ func TestS3TryUpdate(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil) assertEqual(t, 1, len(ips.ipv4)) assertEqual(t, 0, len(ips.ipv6)) @@ -77,8 +77,9 @@ func TestMatchIPV6(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("ff00::"))) assertEqual(t, 1, len(ips.ipv6)) assertEqual(t, 0, len(ips.ipv4)) @@ -93,8 +94,9 @@ func TestMatchIPV4(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -112,8 +114,9 @@ func TestMatchIPV4_2(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -131,8 +134,9 @@ func TestMatchIPV4WithRegionMatched(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"}) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-east-1"}) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -150,8 +154,9 @@ func TestMatchIPV4WithRegionMatch_2(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"}) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"}) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -169,8 +174,9 @@ func TestMatchIPV4WithRegionNotMatched(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"}) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2"}) + ips.tryUpdate(ctx) assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -187,8 +193,9 @@ func TestInvalidData(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil) + ips.tryUpdate(ctx) assertEqual(t, 1, len(ips.ipv4)) } @@ -205,7 +212,7 @@ func TestInvalidNetworkType(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil) assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks @@ -226,7 +233,7 @@ func TestParsing(t *testing.T) { t.Parallel() server := httptest.NewServer(rawMockHandler) defer server.Close() - schema, err := fetchAWSIPs(server.URL) + schema, err := fetchAWSIPs(context.Background(), server.URL) assertEqual(t, nil, err) assertEqual(t, 1, len(schema.Prefixes)) @@ -253,7 +260,7 @@ func TestUpdateCalledRegularly(t *testing.T) { rw.Write([]byte("ok")) })) defer server.Close() - newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil) + newAWSIPs(context.Background(), fmt.Sprintf("%s/", server.URL), time.Second, nil) time.Sleep(time.Second*4 + time.Millisecond*500) if updateCount < 4 { t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount) @@ -384,7 +391,7 @@ func BenchmarkContainsRandom(b *testing.B) { } func BenchmarkContainsProd(b *testing.B) { - ips := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil) + ips, _ := newAWSIPs(context.Background(), defaultIPRangesURL, defaultUpdateFrequency, nil) ipv4 := make([][]byte, b.N) ipv6 := make([][]byte, b.N) for i := 0; i < b.N; i++ { diff --git a/registry/storage/driver/middleware/redirect/middleware.go b/registry/storage/driver/middleware/redirect/middleware.go index 9e1b303ea..8976d8689 100644 --- a/registry/storage/driver/middleware/redirect/middleware.go +++ b/registry/storage/driver/middleware/redirect/middleware.go @@ -19,7 +19,7 @@ type redirectStorageMiddleware struct { var _ storagedriver.StorageDriver = &redirectStorageMiddleware{} -func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { +func newRedirectStorageMiddleware(ctx context.Context, sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { o, ok := options["baseurl"] if !ok { return nil, fmt.Errorf("no baseurl provided") diff --git a/registry/storage/driver/middleware/redirect/middleware_test.go b/registry/storage/driver/middleware/redirect/middleware_test.go index 30d0bb192..2c22dec36 100644 --- a/registry/storage/driver/middleware/redirect/middleware_test.go +++ b/registry/storage/driver/middleware/redirect/middleware_test.go @@ -15,21 +15,21 @@ var _ = check.Suite(&MiddlewareSuite{}) func (s *MiddlewareSuite) TestNoConfig(c *check.C) { options := make(map[string]interface{}) - _, err := newRedirectStorageMiddleware(nil, options) + _, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.ErrorMatches, "no baseurl provided") } func (s *MiddlewareSuite) TestMissingScheme(c *check.C) { options := make(map[string]interface{}) options["baseurl"] = "example.com" - _, err := newRedirectStorageMiddleware(nil, options) + _, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl") } func (s *MiddlewareSuite) TestHttpsPort(c *check.C) { options := make(map[string]interface{}) options["baseurl"] = "https://example.com:5443" - middleware, err := newRedirectStorageMiddleware(nil, options) + middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.Equals, nil) m, ok := middleware.(*redirectStorageMiddleware) @@ -45,7 +45,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) { func (s *MiddlewareSuite) TestHTTP(c *check.C) { options := make(map[string]interface{}) options["baseurl"] = "http://example.com" - middleware, err := newRedirectStorageMiddleware(nil, options) + middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.Equals, nil) m, ok := middleware.(*redirectStorageMiddleware) @@ -62,7 +62,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) { // basePath: end with no slash options := make(map[string]interface{}) options["baseurl"] = "https://example.com/path" - middleware, err := newRedirectStorageMiddleware(nil, options) + middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.Equals, nil) m, ok := middleware.(*redirectStorageMiddleware) @@ -82,7 +82,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) { // basePath: end with slash options["baseurl"] = "https://example.com/path/" - middleware, err = newRedirectStorageMiddleware(nil, options) + middleware, err = newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.Equals, nil) m, ok = middleware.(*redirectStorageMiddleware) diff --git a/registry/storage/driver/middleware/storagemiddleware.go b/registry/storage/driver/middleware/storagemiddleware.go index bde645dc5..d2c37741e 100644 --- a/registry/storage/driver/middleware/storagemiddleware.go +++ b/registry/storage/driver/middleware/storagemiddleware.go @@ -1,6 +1,7 @@ package storagemiddleware import ( + "context" "fmt" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" @@ -8,7 +9,7 @@ import ( // InitFunc is the type of a StorageMiddleware factory function and is // used to register the constructor for different StorageMiddleware backends. -type InitFunc func(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) +type InitFunc func(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) var storageMiddlewares map[string]InitFunc @@ -28,10 +29,10 @@ func Register(name string, initFunc InitFunc) error { } // Get constructs a StorageMiddleware with the given options using the named backend. -func Get(name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) { +func Get(ctx context.Context, name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) { if storageMiddlewares != nil { if initFunc, exists := storageMiddlewares[name]; exists { - return initFunc(storageDriver, options) + return initFunc(ctx, storageDriver, options) } } diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go index a6428a918..3d9cb920d 100644 --- a/registry/storage/driver/s3-aws/s3.go +++ b/registry/storage/driver/s3-aws/s3.go @@ -150,8 +150,8 @@ func init() { // s3DriverFactory implements the factory.StorageDriverFactory interface type s3DriverFactory struct{} -func (factory *s3DriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { - return FromParameters(parameters) +func (factory *s3DriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { + return FromParameters(ctx, parameters) } var _ storagedriver.StorageDriver = &driver{} @@ -189,7 +189,7 @@ type Driver struct { // - region // - bucket // - encrypt -func FromParameters(parameters map[string]interface{}) (*Driver, error) { +func FromParameters(ctx context.Context, parameters map[string]interface{}) (*Driver, error) { // Providing no values for these is valid in case the user is authenticating // with an IAM on an ec2 instance (in which case the instance credentials will // be summoned when GetAuth is called) @@ -468,7 +468,7 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { getS3LogLevelFromParam(parameters["loglevel"]), } - return New(params) + return New(ctx, params) } func getS3LogLevelFromParam(param interface{}) aws.LogLevelType { @@ -529,7 +529,7 @@ func getParameterAsInt64(parameters map[string]interface{}, name string, default // New constructs a new Driver with the given AWS credentials, region, encryption flag, and // bucketName -func New(params DriverParameters) (*Driver, error) { +func New(ctx context.Context, params DriverParameters) (*Driver, error) { if !params.V4Auth && (params.RegionEndpoint == "" || strings.Contains(params.RegionEndpoint, "s3.amazonaws.com")) { diff --git a/registry/storage/driver/s3-aws/s3_test.go b/registry/storage/driver/s3-aws/s3_test.go index ebd9d2a2b..7bfae2aa0 100644 --- a/registry/storage/driver/s3-aws/s3_test.go +++ b/registry/storage/driver/s3-aws/s3_test.go @@ -141,7 +141,7 @@ func init() { getS3LogLevelFromParam(logLevel), } - return New(parameters) + return New(context.Background(), parameters) } // Skip S3 storage driver tests if environment variable parameters are not provided