forked from TrueCloudLab/distribution
storage/driver: plumb contexts into middlewares
Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
parent
b45b6d18b8
commit
b4dc4f3474
8 changed files with 72 additions and 48 deletions
|
@ -148,7 +148,7 @@ func NewApp(ctx context.Context, config *configuration.Configuration) *App {
|
||||||
|
|
||||||
startUploadPurger(app, app.driver, dcontext.GetLogger(app), purgeConfig)
|
startUploadPurger(app, app.driver, dcontext.GetLogger(app), purgeConfig)
|
||||||
|
|
||||||
app.driver, err = applyStorageMiddleware(app.driver, config.Middleware["storage"])
|
app.driver, err = applyStorageMiddleware(app, app.driver, config.Middleware["storage"])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -938,9 +938,9 @@ func applyRepoMiddleware(ctx context.Context, repository distribution.Repository
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyStorageMiddleware wraps a storage driver with the configured middlewares
|
// applyStorageMiddleware wraps a storage driver with the configured middlewares
|
||||||
func applyStorageMiddleware(driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) {
|
func applyStorageMiddleware(ctx context.Context, driver storagedriver.StorageDriver, middlewares []configuration.Middleware) (storagedriver.StorageDriver, error) {
|
||||||
for _, mw := range middlewares {
|
for _, mw := range middlewares {
|
||||||
smw, err := storagemiddleware.Get(mw.Name, mw.Options, driver)
|
smw, err := storagemiddleware.Get(ctx, mw.Name, mw.Options, driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to configure storage middleware (%s): %v", mw.Name, err)
|
return nil, fmt.Errorf("unable to configure storage middleware (%s): %v", mw.Name, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
|
||||||
// default value. "aws", only aws IP goes to S3 directly. "awsregion", only
|
// default value. "aws", only aws IP goes to S3 directly. "awsregion", only
|
||||||
// regions listed in awsregion options goes to S3 directly
|
// regions listed in awsregion options goes to S3 directly
|
||||||
// - awsregion: a comma separated string of AWS regions.
|
// - awsregion: a comma separated string of AWS regions.
|
||||||
func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
|
func newCloudFrontStorageMiddleware(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
|
||||||
// parse baseurl
|
// parse baseurl
|
||||||
base, ok := options["baseurl"]
|
base, ok := options["baseurl"]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -157,7 +157,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
||||||
case "", "none":
|
case "", "none":
|
||||||
awsIPs = nil
|
awsIPs = nil
|
||||||
case "aws":
|
case "aws":
|
||||||
awsIPs = newAWSIPs(ipRangesURL, updateFrequency, nil)
|
awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
case "awsregion":
|
case "awsregion":
|
||||||
var awsRegion []string
|
var awsRegion []string
|
||||||
if i, ok := options["awsregion"]; ok {
|
if i, ok := options["awsregion"]; ok {
|
||||||
|
@ -165,7 +168,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
||||||
for _, awsRegions := range strings.Split(regions, ",") {
|
for _, awsRegions := range strings.Split(regions, ",") {
|
||||||
awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
|
awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
|
||||||
}
|
}
|
||||||
awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion)
|
awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, awsRegion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
|
return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -15,7 +16,7 @@ var _ = check.Suite(&MiddlewareSuite{})
|
||||||
|
|
||||||
func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
|
func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
_, err := newCloudFrontStorageMiddleware(nil, options)
|
_, err := newCloudFrontStorageMiddleware(context.Background(), nil, options)
|
||||||
c.Assert(err, check.ErrorMatches, "no baseurl provided")
|
c.Assert(err, check.ErrorMatches, "no baseurl provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,7 +49,7 @@ pZeMRablbPQdp8/1NyIwimq1VlG0ohQ4P6qhW7E09ZMC
|
||||||
defer os.Remove(file.Name())
|
defer os.Remove(file.Name())
|
||||||
options["privatekey"] = file.Name()
|
options["privatekey"] = file.Name()
|
||||||
options["keypairid"] = "test"
|
options["keypairid"] = "test"
|
||||||
storageDriver, err := newCloudFrontStorageMiddleware(nil, options)
|
storageDriver, err := newCloudFrontStorageMiddleware(context.Background(), nil, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package middleware
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -23,18 +24,21 @@ const (
|
||||||
|
|
||||||
// newAWSIPs returns a New awsIP object.
|
// newAWSIPs returns a New awsIP object.
|
||||||
// If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
|
// If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
|
||||||
func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs {
|
func newAWSIPs(ctx context.Context, host string, updateFrequency time.Duration, awsRegion []string) (*awsIPs, error) {
|
||||||
ips := &awsIPs{
|
ips := &awsIPs{
|
||||||
host: host,
|
host: host,
|
||||||
updateFrequency: updateFrequency,
|
updateFrequency: updateFrequency,
|
||||||
awsRegion: awsRegion,
|
awsRegion: awsRegion,
|
||||||
updaterStopChan: make(chan bool),
|
updaterStopChan: make(chan bool),
|
||||||
}
|
}
|
||||||
if err := ips.tryUpdate(); err != nil {
|
if err := ips.tryUpdate(ctx); err != nil {
|
||||||
dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP")
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dcontext.GetLogger(ctx).WithError(err).Warn("failed to update AWS IP")
|
||||||
}
|
}
|
||||||
go ips.updater()
|
go ips.updater()
|
||||||
return ips
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// awsIPs tracks a list of AWS ips, filtered by awsRegion
|
// awsIPs tracks a list of AWS ips, filtered by awsRegion
|
||||||
|
@ -61,9 +65,13 @@ type prefixEntry struct {
|
||||||
Service string `json:"service"`
|
Service string `json:"service"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchAWSIPs(url string) (awsIPResponse, error) {
|
func fetchAWSIPs(ctx context.Context, url string) (awsIPResponse, error) {
|
||||||
var response awsIPResponse
|
var response awsIPResponse
|
||||||
resp, err := http.Get(url)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return response, err
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return response, err
|
return response, err
|
||||||
}
|
}
|
||||||
|
@ -83,8 +91,8 @@ func fetchAWSIPs(url string) (awsIPResponse, error) {
|
||||||
|
|
||||||
// tryUpdate attempts to download the new set of ip addresses.
|
// tryUpdate attempts to download the new set of ip addresses.
|
||||||
// tryUpdate must be thread safe with contains
|
// tryUpdate must be thread safe with contains
|
||||||
func (s *awsIPs) tryUpdate() error {
|
func (s *awsIPs) tryUpdate(ctx context.Context) error {
|
||||||
response, err := fetchAWSIPs(s.host)
|
response, err := fetchAWSIPs(ctx, s.host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -135,17 +143,18 @@ func (s *awsIPs) tryUpdate() error {
|
||||||
// This function is meant to be run in a background goroutine.
|
// This function is meant to be run in a background goroutine.
|
||||||
// It will periodically update the ips from aws.
|
// It will periodically update the ips from aws.
|
||||||
func (s *awsIPs) updater() {
|
func (s *awsIPs) updater() {
|
||||||
|
ctx := context.TODO()
|
||||||
defer close(s.updaterStopChan)
|
defer close(s.updaterStopChan)
|
||||||
for {
|
for {
|
||||||
time.Sleep(s.updateFrequency)
|
time.Sleep(s.updateFrequency)
|
||||||
select {
|
select {
|
||||||
case <-s.updaterStopChan:
|
case <-s.updaterStopChan:
|
||||||
dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal")
|
dcontext.GetLogger(ctx).Info("aws ip updater received stop signal")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
err := s.tryUpdate()
|
err := s.tryUpdate(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP")
|
dcontext.GetLogger(ctx).WithError(err).Error("git AWS IP")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ func TestS3TryUpdate(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil)
|
||||||
|
|
||||||
assertEqual(t, 1, len(ips.ipv4))
|
assertEqual(t, 1, len(ips.ipv4))
|
||||||
assertEqual(t, 0, len(ips.ipv6))
|
assertEqual(t, 0, len(ips.ipv6))
|
||||||
|
@ -77,8 +77,9 @@ func TestMatchIPV6(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
ctx := context.Background()
|
||||||
ips.tryUpdate()
|
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
|
||||||
|
ips.tryUpdate(ctx)
|
||||||
assertEqual(t, true, ips.contains(net.ParseIP("ff00::")))
|
assertEqual(t, true, ips.contains(net.ParseIP("ff00::")))
|
||||||
assertEqual(t, 1, len(ips.ipv6))
|
assertEqual(t, 1, len(ips.ipv6))
|
||||||
assertEqual(t, 0, len(ips.ipv4))
|
assertEqual(t, 0, len(ips.ipv4))
|
||||||
|
@ -93,8 +94,9 @@ func TestMatchIPV4(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
ctx := context.Background()
|
||||||
ips.tryUpdate()
|
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
|
||||||
|
ips.tryUpdate(ctx)
|
||||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
@ -112,8 +114,9 @@ func TestMatchIPV4_2(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
ctx := context.Background()
|
||||||
ips.tryUpdate()
|
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
|
||||||
|
ips.tryUpdate(ctx)
|
||||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
@ -131,8 +134,9 @@ func TestMatchIPV4WithRegionMatched(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"})
|
ctx := context.Background()
|
||||||
ips.tryUpdate()
|
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-east-1"})
|
||||||
|
ips.tryUpdate(ctx)
|
||||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
@ -150,8 +154,9 @@ func TestMatchIPV4WithRegionMatch_2(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
|
ctx := context.Background()
|
||||||
ips.tryUpdate()
|
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
|
||||||
|
ips.tryUpdate(ctx)
|
||||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
@ -169,8 +174,9 @@ func TestMatchIPV4WithRegionNotMatched(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"})
|
ctx := context.Background()
|
||||||
ips.tryUpdate()
|
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2"})
|
||||||
|
ips.tryUpdate(ctx)
|
||||||
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0")))
|
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1")))
|
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
@ -187,8 +193,9 @@ func TestInvalidData(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
ctx := context.Background()
|
||||||
ips.tryUpdate()
|
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
|
||||||
|
ips.tryUpdate(ctx)
|
||||||
assertEqual(t, 1, len(ips.ipv4))
|
assertEqual(t, 1, len(ips.ipv4))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,7 +212,7 @@ func TestInvalidNetworkType(t *testing.T) {
|
||||||
})
|
})
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil)
|
||||||
assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
|
assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
|
||||||
assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks
|
assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks
|
||||||
assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
|
assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
|
||||||
|
@ -226,7 +233,7 @@ func TestParsing(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
server := httptest.NewServer(rawMockHandler)
|
server := httptest.NewServer(rawMockHandler)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
schema, err := fetchAWSIPs(server.URL)
|
schema, err := fetchAWSIPs(context.Background(), server.URL)
|
||||||
|
|
||||||
assertEqual(t, nil, err)
|
assertEqual(t, nil, err)
|
||||||
assertEqual(t, 1, len(schema.Prefixes))
|
assertEqual(t, 1, len(schema.Prefixes))
|
||||||
|
@ -253,7 +260,7 @@ func TestUpdateCalledRegularly(t *testing.T) {
|
||||||
rw.Write([]byte("ok"))
|
rw.Write([]byte("ok"))
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil)
|
newAWSIPs(context.Background(), fmt.Sprintf("%s/", server.URL), time.Second, nil)
|
||||||
time.Sleep(time.Second*4 + time.Millisecond*500)
|
time.Sleep(time.Second*4 + time.Millisecond*500)
|
||||||
if updateCount < 4 {
|
if updateCount < 4 {
|
||||||
t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount)
|
t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount)
|
||||||
|
@ -384,7 +391,7 @@ func BenchmarkContainsRandom(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkContainsProd(b *testing.B) {
|
func BenchmarkContainsProd(b *testing.B) {
|
||||||
ips := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil)
|
ips, _ := newAWSIPs(context.Background(), defaultIPRangesURL, defaultUpdateFrequency, nil)
|
||||||
ipv4 := make([][]byte, b.N)
|
ipv4 := make([][]byte, b.N)
|
||||||
ipv6 := make([][]byte, b.N)
|
ipv6 := make([][]byte, b.N)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
|
|
@ -19,7 +19,7 @@ type redirectStorageMiddleware struct {
|
||||||
|
|
||||||
var _ storagedriver.StorageDriver = &redirectStorageMiddleware{}
|
var _ storagedriver.StorageDriver = &redirectStorageMiddleware{}
|
||||||
|
|
||||||
func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
|
func newRedirectStorageMiddleware(ctx context.Context, sd storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
|
||||||
o, ok := options["baseurl"]
|
o, ok := options["baseurl"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("no baseurl provided")
|
return nil, fmt.Errorf("no baseurl provided")
|
||||||
|
|
|
@ -15,21 +15,21 @@ var _ = check.Suite(&MiddlewareSuite{})
|
||||||
|
|
||||||
func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
|
func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
_, err := newRedirectStorageMiddleware(nil, options)
|
_, err := newRedirectStorageMiddleware(context.Background(), nil, options)
|
||||||
c.Assert(err, check.ErrorMatches, "no baseurl provided")
|
c.Assert(err, check.ErrorMatches, "no baseurl provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MiddlewareSuite) TestMissingScheme(c *check.C) {
|
func (s *MiddlewareSuite) TestMissingScheme(c *check.C) {
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
options["baseurl"] = "example.com"
|
options["baseurl"] = "example.com"
|
||||||
_, err := newRedirectStorageMiddleware(nil, options)
|
_, err := newRedirectStorageMiddleware(context.Background(), nil, options)
|
||||||
c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl")
|
c.Assert(err, check.ErrorMatches, "no scheme specified for redirect baseurl")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
|
func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
options["baseurl"] = "https://example.com:5443"
|
options["baseurl"] = "https://example.com:5443"
|
||||||
middleware, err := newRedirectStorageMiddleware(nil, options)
|
middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options)
|
||||||
c.Assert(err, check.Equals, nil)
|
c.Assert(err, check.Equals, nil)
|
||||||
|
|
||||||
m, ok := middleware.(*redirectStorageMiddleware)
|
m, ok := middleware.(*redirectStorageMiddleware)
|
||||||
|
@ -45,7 +45,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
|
||||||
func (s *MiddlewareSuite) TestHTTP(c *check.C) {
|
func (s *MiddlewareSuite) TestHTTP(c *check.C) {
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
options["baseurl"] = "http://example.com"
|
options["baseurl"] = "http://example.com"
|
||||||
middleware, err := newRedirectStorageMiddleware(nil, options)
|
middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options)
|
||||||
c.Assert(err, check.Equals, nil)
|
c.Assert(err, check.Equals, nil)
|
||||||
|
|
||||||
m, ok := middleware.(*redirectStorageMiddleware)
|
m, ok := middleware.(*redirectStorageMiddleware)
|
||||||
|
@ -62,7 +62,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
|
||||||
// basePath: end with no slash
|
// basePath: end with no slash
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
options["baseurl"] = "https://example.com/path"
|
options["baseurl"] = "https://example.com/path"
|
||||||
middleware, err := newRedirectStorageMiddleware(nil, options)
|
middleware, err := newRedirectStorageMiddleware(context.Background(), nil, options)
|
||||||
c.Assert(err, check.Equals, nil)
|
c.Assert(err, check.Equals, nil)
|
||||||
|
|
||||||
m, ok := middleware.(*redirectStorageMiddleware)
|
m, ok := middleware.(*redirectStorageMiddleware)
|
||||||
|
@ -82,7 +82,7 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
|
||||||
|
|
||||||
// basePath: end with slash
|
// basePath: end with slash
|
||||||
options["baseurl"] = "https://example.com/path/"
|
options["baseurl"] = "https://example.com/path/"
|
||||||
middleware, err = newRedirectStorageMiddleware(nil, options)
|
middleware, err = newRedirectStorageMiddleware(context.Background(), nil, options)
|
||||||
c.Assert(err, check.Equals, nil)
|
c.Assert(err, check.Equals, nil)
|
||||||
|
|
||||||
m, ok = middleware.(*redirectStorageMiddleware)
|
m, ok = middleware.(*redirectStorageMiddleware)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package storagemiddleware
|
package storagemiddleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
|
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
|
||||||
|
@ -8,7 +9,7 @@ import (
|
||||||
|
|
||||||
// InitFunc is the type of a StorageMiddleware factory function and is
|
// InitFunc is the type of a StorageMiddleware factory function and is
|
||||||
// used to register the constructor for different StorageMiddleware backends.
|
// used to register the constructor for different StorageMiddleware backends.
|
||||||
type InitFunc func(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error)
|
type InitFunc func(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error)
|
||||||
|
|
||||||
var storageMiddlewares map[string]InitFunc
|
var storageMiddlewares map[string]InitFunc
|
||||||
|
|
||||||
|
@ -28,10 +29,10 @@ func Register(name string, initFunc InitFunc) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get constructs a StorageMiddleware with the given options using the named backend.
|
// Get constructs a StorageMiddleware with the given options using the named backend.
|
||||||
func Get(name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) {
|
func Get(ctx context.Context, name string, options map[string]interface{}, storageDriver storagedriver.StorageDriver) (storagedriver.StorageDriver, error) {
|
||||||
if storageMiddlewares != nil {
|
if storageMiddlewares != nil {
|
||||||
if initFunc, exists := storageMiddlewares[name]; exists {
|
if initFunc, exists := storageMiddlewares[name]; exists {
|
||||||
return initFunc(storageDriver, options)
|
return initFunc(ctx, storageDriver, options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue