Plumb contexts into storage driver factories and middlewares (#4142)

This commit is contained in:
Wang Yan 2023-10-31 18:18:22 +08:00 committed by GitHub
commit 6814691c19
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 101 additions and 77 deletions

View file

@ -1377,7 +1377,7 @@ const (
repositoryWithGenericStorageError = "genericstorageerr" repositoryWithGenericStorageError = "genericstorageerr"
) )
func (factory *storageManifestErrDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func (factory *storageManifestErrDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
// Initialize the mock driver // Initialize the mock driver
errGenericStorage := errors.New("generic storage error") errGenericStorage := errors.New("generic storage error")
return &mockErrorDriver{ return &mockErrorDriver{

View file

@ -116,7 +116,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
storageParams["useragent"] = fmt.Sprintf("distribution/%s %s", version.Version, runtime.Version()) storageParams["useragent"] = fmt.Sprintf("distribution/%s %s", version.Version, runtime.Version())
var err error var err error
app.driver, err = factory.Create(config.Storage.Type(), storageParams) app.driver, err = factory.Create(app, config.Storage.Type(), storageParams)
if err != nil { if err != nil {
// TODO(stevvooe): Move the creation of a service into a protected // TODO(stevvooe): Move the creation of a service into a protected
// method, where this is created lazily. Its status can be queried via // method, where this is created lazily. Its status can be queried via
@ -148,7 +148,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
startUploadPurger(app, app.driver, dcontext.GetLogger(app), purgeConfig) startUploadPurger(app, app.driver, dcontext.GetLogger(app), purgeConfig)
app.driver, err = applyStorageMiddleware(app.driver, config.Middleware["storage"]) app.driver, err = applyStorageMiddleware(app, app.driver, config.Middleware["storage"])
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -938,9 +938,9 @@ func applyRepoMiddleware(ctx context.Context, repository distribution.Repository
} }
// applyStorageMiddleware wraps a storage driver with the configured middlewares // applyStorageMiddleware wraps a storage driver with the configured middlewares
func applyStorageMiddleware(driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) { func applyStorageMiddleware(ctx context.Context, driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) {
for _, mw := range middlewares { for _, mw := range middlewares {
smw, err := storagemiddleware.Get(mw.Name, mw.Options, driver) smw, err := storagemiddleware.Get(ctx, mw.Name, mw.Options, driver)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to configure storage middleware (%s): %v", mw.Name, err) return nil, fmt.Errorf("unable to configure storage middleware (%s): %v", mw.Name, err)
} }

View file

@ -53,12 +53,6 @@ var GCCmd = &cobra.Command{
os.Exit(1) os.Exit(1)
} }
driver, err := factory.Create(config.Storage.Type(), config.Storage.Parameters())
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err)
os.Exit(1)
}
ctx := dcontext.Background() ctx := dcontext.Background()
ctx, err = configureLogging(ctx, config) ctx, err = configureLogging(ctx, config)
if err != nil { if err != nil {
@ -66,6 +60,12 @@ var GCCmd = &cobra.Command{
os.Exit(1) os.Exit(1)
} }
driver, err := factory.Create(ctx, config.Storage.Type(), config.Storage.Parameters())
if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct %s driver: %v", config.Storage.Type(), err)
os.Exit(1)
}
registry, err := storage.NewRegistry(ctx, driver) registry, err := storage.NewRegistry(ctx, driver)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "failed to construct registry: %v", err) fmt.Fprintf(os.Stderr, "failed to construct registry: %v", err)

View file

@ -46,16 +46,16 @@ func init() {
type azureDriverFactory struct{} type azureDriverFactory struct{}
func (factory *azureDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func (factory *azureDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
params, err := NewParameters(parameters) params, err := NewParameters(parameters)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return New(params) return New(ctx, params)
} }
// New constructs a new Driver from parameters // New constructs a new Driver from parameters
func New(params *Parameters) (*Driver, error) { func New(ctx context.Context, params *Parameters) (*Driver, error) {
azClient, err := newAzureClient(params) azClient, err := newAzureClient(params)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -68,7 +68,7 @@ func init() {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return New(params) return New(context.Background(), params)
} }
// Skip Azure storage driver tests if environment variable parameters are not provided // Skip Azure storage driver tests if environment variable parameters are not provided

View file

@ -1,6 +1,7 @@
package factory package factory
import ( import (
"context"
"fmt" "fmt"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
@ -23,7 +24,7 @@ type StorageDriverFactory interface {
// Create returns a new storagedriver.StorageDriver with the given parameters // Create returns a new storagedriver.StorageDriver with the given parameters
// Parameters will vary by driver and may be ignored // Parameters will vary by driver and may be ignored
// Each parameter key must only consist of lowercase letters and numbers // Each parameter key must only consist of lowercase letters and numbers
Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error)
} }
// Register makes a storage driver available by the provided name. // Register makes a storage driver available by the provided name.
@ -46,12 +47,12 @@ func Register(name string, factory StorageDriverFactory) {
// parameters. To use a driver, the StorageDriverFactory must first be // parameters. To use a driver, the StorageDriverFactory must first be
// registered with the given name. If no drivers are found, an // registered with the given name. If no drivers are found, an
// InvalidStorageDriverError is returned // InvalidStorageDriverError is returned
func Create(name string, parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func Create(ctx context.Context, name string, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
driverFactory, ok := driverFactories[name] driverFactory, ok := driverFactories[name]
if !ok { if !ok {
return nil, InvalidStorageDriverError{name} return nil, InvalidStorageDriverError{name}
} }
return driverFactory.Create(parameters) return driverFactory.Create(ctx, parameters)
} }
// InvalidStorageDriverError records an attempt to construct an unregistered storage driver // InvalidStorageDriverError records an attempt to construct an unregistered storage driver

View file

@ -40,7 +40,7 @@ func init() {
// filesystemDriverFactory implements the factory.StorageDriverFactory interface // filesystemDriverFactory implements the factory.StorageDriverFactory interface
type filesystemDriverFactory struct{} type filesystemDriverFactory struct{}
func (factory *filesystemDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func (factory *filesystemDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters) return FromParameters(parameters)
} }

View file

@ -87,8 +87,8 @@ func init() {
type gcsDriverFactory struct{} type gcsDriverFactory struct{}
// Create StorageDriver from parameters // Create StorageDriver from parameters
func (factory *gcsDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func (factory *gcsDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters) return FromParameters(ctx, parameters)
} }
var _ storagedriver.StorageDriver = &driver{} var _ storagedriver.StorageDriver = &driver{}
@ -118,8 +118,7 @@ type baseEmbed struct {
// FromParameters constructs a new Driver with a given parameters map // FromParameters constructs a new Driver with a given parameters map
// Required parameters: // Required parameters:
// - bucket // - bucket
func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func FromParameters(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
ctx := context.TODO()
bucket, ok := parameters["bucket"] bucket, ok := parameters["bucket"]
if !ok || fmt.Sprint(bucket) == "" { if !ok || fmt.Sprint(bucket) == "" {
return nil, fmt.Errorf("No bucket parameter provided") return nil, fmt.Errorf("No bucket parameter provided")
@ -229,11 +228,11 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri
gcs: gcs, gcs: gcs,
} }
return New(params) return New(ctx, params)
} }
// New constructs a new driver // New constructs a new driver
func New(params driverParameters) (storagedriver.StorageDriver, error) { func New(ctx context.Context, params driverParameters) (storagedriver.StorageDriver, error) {
rootDirectory := strings.Trim(params.rootDirectory, "/") rootDirectory := strings.Trim(params.rootDirectory, "/")
if rootDirectory != "" { if rootDirectory != "" {
rootDirectory += "/" rootDirectory += "/"

View file

@ -21,7 +21,7 @@ func init() {
// inMemoryDriverFacotry implements the factory.StorageDriverFactory interface. // inMemoryDriverFacotry implements the factory.StorageDriverFactory interface.
type inMemoryDriverFactory struct{} type inMemoryDriverFactory struct{}
func (factory *inMemoryDriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func (factory *inMemoryDriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return New(), nil return New(), nil
} }

View file

@ -48,7 +48,7 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
// default value. "aws", only aws IP goes to S3 directly. "awsregion", only // default value. "aws", only aws IP goes to S3 directly. "awsregion", only
// regions listed in awsregion options goes to S3 directly // regions listed in awsregion options goes to S3 directly
// - awsregion: a comma separated string of AWS regions. // - awsregion: a comma separated string of AWS regions.
func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { func newCloudFrontStorageMiddleware(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
// parse baseurl // parse baseurl
base, ok := options["baseurl"] base, ok := options["baseurl"]
if !ok { if !ok {
@ -157,7 +157,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
case "", "none": case "", "none":
awsIPs = nil awsIPs = nil
case "aws": case "aws":
awsIPs = newAWSIPs(ipRangesURL, updateFrequency, nil) awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, nil)
if err != nil {
return nil, err
}
case "awsregion": case "awsregion":
var awsRegion []string var awsRegion []string
if i, ok := options["awsregion"]; ok { if i, ok := options["awsregion"]; ok {
@ -165,7 +168,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
for _, awsRegions := range strings.Split(regions, ",") { for _, awsRegions := range strings.Split(regions, ",") {
awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions))) awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
} }
awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion) awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, awsRegion)
if err != nil {
return nil, err
}
} else { } else {
return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions") return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
} }

View file

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"os" "os"
"testing" "testing"
@ -15,7 +16,7 @@ var _ = check.Suite(&MiddlewareSuite{})
func (s *MiddlewareSuite) TestNoConfig(c *check.C) { func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
options := make(map[string]interface{}) options := make(map[string]interface{})
_, err := newCloudFrontStorageMiddleware(nil, options) _, err := newCloudFrontStorageMiddleware(context.Background(), nil, options)
c.Assert(err, check.ErrorMatches, "no baseurl provided") c.Assert(err, check.ErrorMatches, "no baseurl provided")
} }
@ -48,7 +49,7 @@ pZeMRablbPQdp8/1NyIwimq1VlG0ohQ4P6qhW7E09ZMC
defer os.Remove(file.Name()) defer os.Remove(file.Name())
options["privatekey"] = file.Name() options["privatekey"] = file.Name()
options["keypairid"] = "test" options["keypairid"] = "test"
storageDriver, err := newCloudFrontStorageMiddleware(nil, options) storageDriver, err := newCloudFrontStorageMiddleware(context.Background(), nil, options)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -3,6 +3,7 @@ package middleware
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -23,18 +24,21 @@ const (
// newAWSIPs returns a New awsIP object. // newAWSIPs returns a New awsIP object.
// If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified // If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs { func newAWSIPs(ctx context.Context, host string, updateFrequency time.Duration, awsRegion []string) (*awsIPs, error) {
ips := &awsIPs{ ips := &awsIPs{
host: host, host: host,
updateFrequency: updateFrequency, updateFrequency: updateFrequency,
awsRegion: awsRegion, awsRegion: awsRegion,
updaterStopChan: make(chan bool), updaterStopChan: make(chan bool),
} }
if err := ips.tryUpdate(); err != nil { if err := ips.tryUpdate(ctx); err != nil {
dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP") if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
dcontext.GetLogger(ctx).WithError(err).Warn("failed to update AWS IP")
} }
go ips.updater() go ips.updater()
return ips return ips, nil
} }
// awsIPs tracks a list of AWS ips, filtered by awsRegion // awsIPs tracks a list of AWS ips, filtered by awsRegion
@ -61,9 +65,13 @@ type prefixEntry struct {
Service string `json:"service"` Service string `json:"service"`
} }
func fetchAWSIPs(url string) (awsIPResponse, error) { func fetchAWSIPs(ctx context.Context, url string) (awsIPResponse, error) {
var response awsIPResponse var response awsIPResponse
resp, err := http.Get(url) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return response, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return response, err return response, err
} }
@ -83,8 +91,8 @@ func fetchAWSIPs(url string) (awsIPResponse, error) {
// tryUpdate attempts to download the new set of ip addresses. // tryUpdate attempts to download the new set of ip addresses.
// tryUpdate must be thread safe with contains // tryUpdate must be thread safe with contains
func (s *awsIPs) tryUpdate() error { func (s *awsIPs) tryUpdate(ctx context.Context) error {
response, err := fetchAWSIPs(s.host) response, err := fetchAWSIPs(ctx, s.host)
if err != nil { if err != nil {
return err return err
} }
@ -135,17 +143,18 @@ func (s *awsIPs) tryUpdate() error {
// This function is meant to be run in a background goroutine. // This function is meant to be run in a background goroutine.
// It will periodically update the ips from aws. // It will periodically update the ips from aws.
func (s *awsIPs) updater() { func (s *awsIPs) updater() {
ctx := context.TODO()
defer close(s.updaterStopChan) defer close(s.updaterStopChan)
for { for {
time.Sleep(s.updateFrequency) time.Sleep(s.updateFrequency)
select { select {
case <-s.updaterStopChan: case <-s.updaterStopChan:
dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal") dcontext.GetLogger(ctx).Info("aws ip updater received stop signal")
return return
default: default:
err := s.tryUpdate() err := s.tryUpdate(ctx)
if err != nil { if err != nil {
dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP") dcontext.GetLogger(ctx).WithError(err).Error("git AWS IP")
} }
} }
} }

View file

@ -62,7 +62,7 @@ func TestS3TryUpdate(t *testing.T) {
}) })
defer server.Close() defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil)
assertEqual(t, 1, len(ips.ipv4)) assertEqual(t, 1, len(ips.ipv4))
assertEqual(t, 0, len(ips.ipv6)) assertEqual(t, 0, len(ips.ipv6))
@ -77,8 +77,9 @@ func TestMatchIPV6(t *testing.T) {
}) })
defer server.Close() defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ctx := context.Background()
ips.tryUpdate() ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("ff00::"))) assertEqual(t, true, ips.contains(net.ParseIP("ff00::")))
assertEqual(t, 1, len(ips.ipv6)) assertEqual(t, 1, len(ips.ipv6))
assertEqual(t, 0, len(ips.ipv4)) assertEqual(t, 0, len(ips.ipv4))
@ -93,8 +94,9 @@ func TestMatchIPV4(t *testing.T) {
}) })
defer server.Close() defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ctx := context.Background()
ips.tryUpdate() ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@ -112,8 +114,9 @@ func TestMatchIPV4_2(t *testing.T) {
}) })
defer server.Close() defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ctx := context.Background()
ips.tryUpdate() ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@ -131,8 +134,9 @@ func TestMatchIPV4WithRegionMatched(t *testing.T) {
}) })
defer server.Close() defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"}) ctx := context.Background()
ips.tryUpdate() ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-east-1"})
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@ -150,8 +154,9 @@ func TestMatchIPV4WithRegionMatch_2(t *testing.T) {
}) })
defer server.Close() defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"}) ctx := context.Background()
ips.tryUpdate() ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@ -169,8 +174,9 @@ func TestMatchIPV4WithRegionNotMatched(t *testing.T) {
}) })
defer server.Close() defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"}) ctx := context.Background()
ips.tryUpdate() ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2"})
ips.tryUpdate(ctx)
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0"))) assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1"))) assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0"))) assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@ -187,8 +193,9 @@ func TestInvalidData(t *testing.T) {
}) })
defer server.Close() defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ctx := context.Background()
ips.tryUpdate() ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
ips.tryUpdate(ctx)
assertEqual(t, 1, len(ips.ipv4)) assertEqual(t, 1, len(ips.ipv4))
} }
@ -205,7 +212,7 @@ func TestInvalidNetworkType(t *testing.T) {
}) })
defer server.Close() defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil) ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil)
assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks
assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
@ -226,7 +233,7 @@ func TestParsing(t *testing.T) {
t.Parallel() t.Parallel()
server := httptest.NewServer(rawMockHandler) server := httptest.NewServer(rawMockHandler)
defer server.Close() defer server.Close()
schema, err := fetchAWSIPs(server.URL) schema, err := fetchAWSIPs(context.Background(), server.URL)
assertEqual(t, nil, err) assertEqual(t, nil, err)
assertEqual(t, 1, len(schema.Prefixes)) assertEqual(t, 1, len(schema.Prefixes))
@ -253,7 +260,7 @@ func TestUpdateCalledRegularly(t *testing.T) {
rw.Write([]byte("ok")) rw.Write([]byte("ok"))
})) }))
defer server.Close() defer server.Close()
newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil) newAWSIPs(context.Background(), fmt.Sprintf("%s/", server.URL), time.Second, nil)
time.Sleep(time.Second*4 + time.Millisecond*500) time.Sleep(time.Second*4 + time.Millisecond*500)
if updateCount < 4 { if updateCount < 4 {
t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount) t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount)
@ -384,7 +391,7 @@ func BenchmarkContainsRandom(b *testing.B) {
} }
func BenchmarkContainsProd(b *testing.B) { func BenchmarkContainsProd(b *testing.B) {
ips := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil) ips, _ := newAWSIPs(context.Background(), defaultIPRangesURL, defaultUpdateFrequency, nil)
ipv4 := make([][]byte, b.N) ipv4 := make([][]byte, b.N)
ipv6 := make([][]byte, b.N) ipv6 := make([][]byte, b.N)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View file

@ -19,7 +19,7 @@ type redirectStorageMiddleware struct {
var _ storagedriver.StorageDriver = &redirectStorageMiddleware{} var _ storagedriver.StorageDriver = &redirectStorageMiddleware{}
func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) { func newRedirectStorageMiddleware(ctx context.Context, sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
o, ok := options["baseurl"] o, ok := options["baseurl"]
if !ok { if !ok {
return nil, fmt.Errorf("no baseurl provided") return nil, fmt.Errorf("no baseurl provided")

View file

@ -15,21 +15,21 @@ var _ = check.Suite(&MiddlewareSuite{})
func (s *MiddlewareSuite) TestNoConfig(c *check.C) { func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
options := make(map[string]interface{}) options := make(map[string]interface{})
_, err := newRedirectStorageMiddleware(nil, options) _, err := newRedirectStorageMiddleware(context.Background(), nil, options)
c.Assert(err, check.ErrorMatches, "no baseurl provided") c.Assert(err, check.ErrorMatches, "no baseurl provided")
} }
func (s *MiddlewareSuite) TestMissingScheme(c *check.C) { func (s *MiddlewareSuite) TestMissingScheme(c *check.C) {
options := make(map[string]interface{}) options := make(map[string]interface{})
options["baseurl"] = "example.com" options["baseurl"] = "example.com"
_, err := newRedirectStorageMiddleware(nil, options) _, err := newRedirectStorageMiddleware(context.Background(), nil, options)
c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl") c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl")
} }
func (s *MiddlewareSuite) TestHttpsPort(c *check.C) { func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
options := make(map[string]interface{}) options := make(map[string]interface{})
options["baseurl"] = "https://example.com:5443" options["baseurl"] = "https://example.com:5443"
middleware, err := newRedirectStorageMiddleware(nil, options) middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options)
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
m, ok := middleware.(*redirectStorageMiddleware) m, ok := middleware.(*redirectStorageMiddleware)
@ -45,7 +45,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
func (s *MiddlewareSuite) TestHTTP(c *check.C) { func (s *MiddlewareSuite) TestHTTP(c *check.C) {
options := make(map[string]interface{}) options := make(map[string]interface{})
options["baseurl"] = "http://example.com" options["baseurl"] = "http://example.com"
middleware, err := newRedirectStorageMiddleware(nil, options) middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options)
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
m, ok := middleware.(*redirectStorageMiddleware) m, ok := middleware.(*redirectStorageMiddleware)
@ -62,7 +62,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
// basePath: end with no slash // basePath: end with no slash
options := make(map[string]interface{}) options := make(map[string]interface{})
options["baseurl"] = "https://example.com/path" options["baseurl"] = "https://example.com/path"
middleware, err := newRedirectStorageMiddleware(nil, options) middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options)
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
m, ok := middleware.(*redirectStorageMiddleware) m, ok := middleware.(*redirectStorageMiddleware)
@ -82,7 +82,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
// basePath: end with slash // basePath: end with slash
options["baseurl"] = "https://example.com/path/" options["baseurl"] = "https://example.com/path/"
middleware, err = newRedirectStorageMiddleware(nil, options) middleware, err = newRedirectStorageMiddleware(context.Background(), nil, options)
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
m, ok = middleware.(*redirectStorageMiddleware) m, ok = middleware.(*redirectStorageMiddleware)

View file

@ -1,6 +1,7 @@
package storagemiddleware package storagemiddleware
import ( import (
"context"
"fmt" "fmt"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
@ -8,7 +9,7 @@ import (
// InitFunc is the type of a StorageMiddleware factory function and is // InitFunc is the type of a StorageMiddleware factory function and is
// used to register the constructor for different StorageMiddleware backends. // used to register the constructor for different StorageMiddleware backends.
type InitFunc func(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) type InitFunc func(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error)
var storageMiddlewares map[string]InitFunc var storageMiddlewares map[string]InitFunc
@ -28,10 +29,10 @@ func Register(name string, initFunc InitFunc) error {
} }
// Get constructs a StorageMiddleware with the given options using the named backend. // Get constructs a StorageMiddleware with the given options using the named backend.
func Get(name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) { func Get(ctx context.Context, name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) {
if storageMiddlewares != nil { if storageMiddlewares != nil {
if initFunc, exists := storageMiddlewares[name]; exists { if initFunc, exists := storageMiddlewares[name]; exists {
return initFunc(storageDriver, options) return initFunc(ctx, storageDriver, options)
} }
} }

View file

@ -150,8 +150,8 @@ func init() {
// s3DriverFactory implements the factory.StorageDriverFactory interface // s3DriverFactory implements the factory.StorageDriverFactory interface
type s3DriverFactory struct{} type s3DriverFactory struct{}
func (factory *s3DriverFactory) Create(parameters map[string]interface{}) (storagedriver.StorageDriver, error) { func (factory *s3DriverFactory) Create(ctx context.Context, parameters map[string]interface{}) (storagedriver.StorageDriver, error) {
return FromParameters(parameters) return FromParameters(ctx, parameters)
} }
var _ storagedriver.StorageDriver = &driver{} var _ storagedriver.StorageDriver = &driver{}
@ -189,7 +189,7 @@ type Driver struct {
// - region // - region
// - bucket // - bucket
// - encrypt // - encrypt
func FromParameters(parameters map[string]interface{}) (*Driver, error) { func FromParameters(ctx context.Context, parameters map[string]interface{}) (*Driver, error) {
// Providing no values for these is valid in case the user is authenticating // Providing no values for these is valid in case the user is authenticating
// with an IAM on an ec2 instance (in which case the instance credentials will // with an IAM on an ec2 instance (in which case the instance credentials will
// be summoned when GetAuth is called) // be summoned when GetAuth is called)
@ -468,7 +468,7 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) {
getS3LogLevelFromParam(parameters["loglevel"]), getS3LogLevelFromParam(parameters["loglevel"]),
} }
return New(params) return New(ctx, params)
} }
func getS3LogLevelFromParam(param interface{}) aws.LogLevelType { func getS3LogLevelFromParam(param interface{}) aws.LogLevelType {
@ -529,7 +529,7 @@ func getParameterAsInt64(parameters map[string]interface{}, name string, default
// New constructs a new Driver with the given AWS credentials, region, encryption flag, and // New constructs a new Driver with the given AWS credentials, region, encryption flag, and
// bucketName // bucketName
func New(params DriverParameters) (*Driver, error) { func New(ctx context.Context, params DriverParameters) (*Driver, error) {
if !params.V4Auth && if !params.V4Auth &&
(params.RegionEndpoint == "" || (params.RegionEndpoint == "" ||
strings.Contains(params.RegionEndpoint, "s3.amazonaws.com")) { strings.Contains(params.RegionEndpoint, "s3.amazonaws.com")) {

View file

@ -141,7 +141,7 @@ func init() {
getS3LogLevelFromParam(logLevel), getS3LogLevelFromParam(logLevel),
} }
return New(parameters) return New(context.Background(), parameters)
} }
// Skip S3 storage driver tests if environment variable parameters are not provided // Skip S3 storage driver tests if environment variable parameters are not provided