From b4dc4f3474e445910148f4a70103fd7186460b49 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Fri, 27 Oct 2023 17:46:09 -0400 Subject: [PATCH] storage/driver: plumb contexts into middlewares Signed-off-by: Cory Snider --- registry/handlers/app.go | 6 +-- .../middleware/cloudfront/middleware.go | 12 +++-- .../middleware/cloudfront/middleware_test.go | 5 ++- .../driver/middleware/cloudfront/s3filter.go | 31 ++++++++----- .../middleware/cloudfront/s3filter_test.go | 45 +++++++++++-------- .../driver/middleware/redirect/middleware.go | 2 +- .../middleware/redirect/middleware_test.go | 12 ++--- .../driver/middleware/storagemiddleware.go | 7 +-- 8 files changed, 72 insertions(+), 48 deletions(-) diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 53367277..5f017b70 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -148,7 +148,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App { startUploadPurger(app, app.driver, dcontext.GetLogger(app), purgeConfig) - app.driver, err = applyStorageMiddleware(app.driver, config.Middleware["storage"]) + app.driver, err = applyStorageMiddleware(app, app.driver, config.Middleware["storage"]) if err != nil { panic(err) } @@ -938,9 +938,9 @@ func applyRepoMiddleware(ctx context.Context, repository distribution.Repository } // applyStorageMiddleware wraps a storage driver with the configured middlewares -func applyStorageMiddleware(driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) { +func applyStorageMiddleware(ctx context.Context, driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) { for _, mw := range middlewares { - smw, err := storagemiddleware.Get(mw.Name, mw.Options, driver) + smw, err := storagemiddleware.Get(ctx, mw.Name, mw.Options, driver) if err != nil { return nil, fmt.Errorf("unable to configure storage middleware (%s): %v", mw.Name, err) } diff --git a/registry/storage/driver/middleware/cloudfront/middleware.go b/registry/storage/driver/middleware/cloudfront/middleware.go index 5c2c0995..741d618e 100644 --- a/registry/storage/driver/middleware/cloudfront/middleware.go +++ b/registry/storage/driver/middleware/cloudfront/middleware.go @@ -48,7 +48,7 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{} // default value. "aws", only aws IP goes to S3 directly. "awsregion", only // regions listed in awsregion options goes to S3 directly // - awsregion: a comma separated string of AWS regions. -func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { +func newCloudFrontStorageMiddleware(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { // parse baseurl base, ok := options["baseurl"] if !ok { @@ -157,7 +157,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o case "", "none": awsIPs = nil case "aws": - awsIPs = newAWSIPs(ipRangesURL, updateFrequency, nil) + awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, nil) + if err != nil { + return nil, err + } case "awsregion": var awsRegion []string if i, ok := options["awsregion"]; ok { @@ -165,7 +168,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o for _, awsRegions := range strings.Split(regions, ",") { awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions))) } - awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion) + awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, awsRegion) + if err != nil { + return nil, err + } } else { return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions") } diff --git a/registry/storage/driver/middleware/cloudfront/middleware_test.go b/registry/storage/driver/middleware/cloudfront/middleware_test.go index 4e75c643..67ec077f 100644 --- a/registry/storage/driver/middleware/cloudfront/middleware_test.go +++ b/registry/storage/driver/middleware/cloudfront/middleware_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "os" "testing" @@ -15,7 +16,7 @@ var _ = check.Suite(&MiddlewareSuite{}) func (s *MiddlewareSuite) TestNoConfig(c *check.C) { options := make(map[string]interface{}) - _, err := newCloudFrontStorageMiddleware(nil, options) + _, err := newCloudFrontStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.ErrorMatches, "no baseurl provided") } @@ -48,7 +49,7 @@ pZeMRablbPQdp8/1NyIwimq1VlG0ohQ4P6qhW7E09ZMC defer os.Remove(file.Name()) options["privatekey"] = file.Name() options["keypairid"] = "test" - storageDriver, err := newCloudFrontStorageMiddleware(nil, options) + storageDriver, err := newCloudFrontStorageMiddleware(context.Background(), nil, options) if err != nil { t.Fatal(err) } diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go index 25aafd04..75158c91 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter.go @@ -3,6 +3,7 @@ package middleware import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -23,18 +24,21 @@ const ( // newAWSIPs returns a New awsIP object. // If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified -func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs { +func newAWSIPs(ctx context.Context, host string, updateFrequency time.Duration, awsRegion []string) (*awsIPs, error) { ips := &awsIPs{ host: host, updateFrequency: updateFrequency, awsRegion: awsRegion, updaterStopChan: make(chan bool), } - if err := ips.tryUpdate(); err != nil { - dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP") + if err := ips.tryUpdate(ctx); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + dcontext.GetLogger(ctx).WithError(err).Warn("failed to update AWS IP") } go ips.updater() - return ips + return ips, nil } // awsIPs tracks a list of AWS ips, filtered by awsRegion @@ -61,9 +65,13 @@ type prefixEntry struct { Service string `json:"service"` } -func fetchAWSIPs(url string) (awsIPResponse, error) { +func fetchAWSIPs(ctx context.Context, url string) (awsIPResponse, error) { var response awsIPResponse - resp, err := http.Get(url) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return response, err + } + resp, err := http.DefaultClient.Do(req) if err != nil { return response, err } @@ -83,8 +91,8 @@ func fetchAWSIPs(url string) (awsIPResponse, error) { // tryUpdate attempts to download the new set of ip addresses. // tryUpdate must be thread safe with contains -func (s *awsIPs) tryUpdate() error { - response, err := fetchAWSIPs(s.host) +func (s *awsIPs) tryUpdate(ctx context.Context) error { + response, err := fetchAWSIPs(ctx, s.host) if err != nil { return err } @@ -135,17 +143,18 @@ func (s *awsIPs) tryUpdate() error { // This function is meant to be run in a background goroutine. // It will periodically update the ips from aws. func (s *awsIPs) updater() { + ctx := context.TODO() defer close(s.updaterStopChan) for { time.Sleep(s.updateFrequency) select { case <-s.updaterStopChan: - dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal") + dcontext.GetLogger(ctx).Info("aws ip updater received stop signal") return default: - err := s.tryUpdate() + err := s.tryUpdate(ctx) if err != nil { - dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP") + dcontext.GetLogger(ctx).WithError(err).Error("git AWS IP") } } } diff --git a/registry/storage/driver/middleware/cloudfront/s3filter_test.go b/registry/storage/driver/middleware/cloudfront/s3filter_test.go index 81ef6aa8..c347d035 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter_test.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter_test.go @@ -62,7 +62,7 @@ func TestS3TryUpdate(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil) assertEqual(t, 1, len(ips.ipv4)) assertEqual(t, 0, len(ips.ipv6)) @@ -77,8 +77,9 @@ func TestMatchIPV6(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("ff00::"))) assertEqual(t, 1, len(ips.ipv6)) assertEqual(t, 0, len(ips.ipv4)) @@ -93,8 +94,9 @@ func TestMatchIPV4(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -112,8 +114,9 @@ func TestMatchIPV4_2(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -131,8 +134,9 @@ func TestMatchIPV4WithRegionMatched(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"}) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-east-1"}) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -150,8 +154,9 @@ func TestMatchIPV4WithRegionMatch_2(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"}) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"}) + ips.tryUpdate(ctx) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -169,8 +174,9 @@ func TestMatchIPV4WithRegionNotMatched(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"}) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2"}) + ips.tryUpdate(ctx) assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) @@ -187,8 +193,9 @@ func TestInvalidData(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) - ips.tryUpdate() + ctx := context.Background() + ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil) + ips.tryUpdate(ctx) assertEqual(t, 1, len(ips.ipv4)) } @@ -205,7 +212,7 @@ func TestInvalidNetworkType(t *testing.T) { }) defer server.Close() - ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) + ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil) assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks @@ -226,7 +233,7 @@ func TestParsing(t *testing.T) { t.Parallel() server := httptest.NewServer(rawMockHandler) defer server.Close() - schema, err := fetchAWSIPs(server.URL) + schema, err := fetchAWSIPs(context.Background(), server.URL) assertEqual(t, nil, err) assertEqual(t, 1, len(schema.Prefixes)) @@ -253,7 +260,7 @@ func TestUpdateCalledRegularly(t *testing.T) { rw.Write([]byte("ok")) })) defer server.Close() - newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil) + newAWSIPs(context.Background(), fmt.Sprintf("%s/", server.URL), time.Second, nil) time.Sleep(time.Second*4 + time.Millisecond*500) if updateCount < 4 { t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount) @@ -384,7 +391,7 @@ func BenchmarkContainsRandom(b *testing.B) { } func BenchmarkContainsProd(b *testing.B) { - ips := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil) + ips, _ := newAWSIPs(context.Background(), defaultIPRangesURL, defaultUpdateFrequency, nil) ipv4 := make([][]byte, b.N) ipv6 := make([][]byte, b.N) for i := 0; i < b.N; i++ { diff --git a/registry/storage/driver/middleware/redirect/middleware.go b/registry/storage/driver/middleware/redirect/middleware.go index 9e1b303e..8976d868 100644 --- a/registry/storage/driver/middleware/redirect/middleware.go +++ b/registry/storage/driver/middleware/redirect/middleware.go @@ -19,7 +19,7 @@ type redirectStorageMiddleware struct { var _ storagedriver.StorageDriver = &redirectStorageMiddleware{} -func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { +func newRedirectStorageMiddleware(ctx context.Context, sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { o, ok := options["baseurl"] if !ok { return nil, fmt.Errorf("no baseurl provided") diff --git a/registry/storage/driver/middleware/redirect/middleware_test.go b/registry/storage/driver/middleware/redirect/middleware_test.go index 30d0bb19..2c22dec3 100644 --- a/registry/storage/driver/middleware/redirect/middleware_test.go +++ b/registry/storage/driver/middleware/redirect/middleware_test.go @@ -15,21 +15,21 @@ var _ = check.Suite(&MiddlewareSuite{}) func (s *MiddlewareSuite) TestNoConfig(c *check.C) { options := make(map[string]interface{}) - _, err := newRedirectStorageMiddleware(nil, options) + _, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.ErrorMatches, "no baseurl provided") } func (s *MiddlewareSuite) TestMissingScheme(c *check.C) { options := make(map[string]interface{}) options["baseurl"] = "example.com" - _, err := newRedirectStorageMiddleware(nil, options) + _, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl") } func (s *MiddlewareSuite) TestHttpsPort(c *check.C) { options := make(map[string]interface{}) options["baseurl"] = "https://example.com:5443" - middleware, err := newRedirectStorageMiddleware(nil, options) + middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.Equals, nil) m, ok := middleware.(*redirectStorageMiddleware) @@ -45,7 +45,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) { func (s *MiddlewareSuite) TestHTTP(c *check.C) { options := make(map[string]interface{}) options["baseurl"] = "http://example.com" - middleware, err := newRedirectStorageMiddleware(nil, options) + middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.Equals, nil) m, ok := middleware.(*redirectStorageMiddleware) @@ -62,7 +62,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) { // basePath: end with no slash options := make(map[string]interface{}) options["baseurl"] = "https://example.com/path" - middleware, err := newRedirectStorageMiddleware(nil, options) + middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.Equals, nil) m, ok := middleware.(*redirectStorageMiddleware) @@ -82,7 +82,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) { // basePath: end with slash options["baseurl"] = "https://example.com/path/" - middleware, err = newRedirectStorageMiddleware(nil, options) + middleware, err = newRedirectStorageMiddleware(context.Background(), nil, options) c.Assert(err, check.Equals, nil) m, ok = middleware.(*redirectStorageMiddleware) diff --git a/registry/storage/driver/middleware/storagemiddleware.go b/registry/storage/driver/middleware/storagemiddleware.go index bde645dc..d2c37741 100644 --- a/registry/storage/driver/middleware/storagemiddleware.go +++ b/registry/storage/driver/middleware/storagemiddleware.go @@ -1,6 +1,7 @@ package storagemiddleware import ( + "context" "fmt" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" @@ -8,7 +9,7 @@ import ( // InitFunc is the type of a StorageMiddleware factory function and is // used to register the constructor for different StorageMiddleware backends. -type InitFunc func(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) +type InitFunc func(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) var storageMiddlewares map[string]InitFunc @@ -28,10 +29,10 @@ func Register(name string, initFunc InitFunc) error { } // Get constructs a StorageMiddleware with the given options using the named backend. -func Get(name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) { +func Get(ctx context.Context, name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) { if storageMiddlewares != nil { if initFunc, exists := storageMiddlewares[name]; exists { - return initFunc(storageDriver, options) + return initFunc(ctx, storageDriver, options) } }