From d0f5aa670becaa83b757239e05e0224c248c135b Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 24 Oct 2023 13:16:58 -0400 Subject: [PATCH 1/7] Move context package internal Our context package predates the establishment of current best practices regarding context usage and it shows. It encourages bad practices such as using contexts to propagate non-request-scoped values like the application version and using string-typed keys for context values. Move the package internal to remove it from the API surface of distribution/v3@v3.0.0 so we are free to iterate on it without being constrained by compatibility. Signed-off-by: Cory Snider --- health/health.go | 8 +-- internal/client/repository_test.go | 52 +++++++++---------- {context => internal/dcontext}/context.go | 2 +- {context => internal/dcontext}/doc.go | 10 ++-- {context => internal/dcontext}/http.go | 2 +- {context => internal/dcontext}/http_test.go | 2 +- {context => internal/dcontext}/logger.go | 2 +- {context => internal/dcontext}/trace.go | 2 +- {context => internal/dcontext}/trace_test.go | 2 +- {context => internal/dcontext}/util.go | 2 +- {context => internal/dcontext}/version.go | 2 +- .../dcontext}/version_test.go | 2 +- notifications/bridge.go | 4 +- notifications/listener.go | 2 +- notifications/listener_test.go | 2 +- registry/auth/htpasswd/access.go | 2 +- registry/auth/htpasswd/access_test.go | 6 +-- registry/auth/silly/access.go | 2 +- registry/auth/silly/access_test.go | 4 +- registry/auth/token/accesscontroller.go | 2 +- registry/auth/token/token_test.go | 4 +- registry/handlers/app.go | 2 +- registry/handlers/app_test.go | 6 +-- registry/handlers/blob.go | 10 ++-- registry/handlers/blobupload.go | 2 +- registry/handlers/context.go | 2 +- registry/handlers/health_test.go | 8 +-- registry/handlers/helpers.go | 2 +- registry/handlers/manifests.go | 2 +- registry/proxy/proxyauth.go | 4 +- registry/proxy/proxyblobstore.go | 2 +- registry/proxy/proxymanifeststore.go | 2 +- registry/proxy/proxyregistry.go | 2 +- registry/proxy/scheduler/scheduler.go | 2 +- registry/proxy/scheduler/scheduler_test.go | 14 ++--- registry/registry.go | 2 +- registry/registry_test.go | 2 +- registry/root.go | 2 +- registry/storage/blobstore.go | 2 +- registry/storage/blobwriter.go | 2 +- registry/storage/blobwriter_nonresumable.go | 2 +- .../cache/cachedblobdescriptorstore.go | 2 +- registry/storage/driver/base/base.go | 2 +- registry/storage/driver/gcs/gcs_test.go | 2 +- .../middleware/cloudfront/middleware.go | 2 +- .../driver/middleware/cloudfront/s3filter.go | 2 +- .../middleware/cloudfront/s3filter_test.go | 2 +- registry/storage/driver/s3-aws/s3.go | 2 +- registry/storage/driver/s3-aws/s3_test.go | 26 +++++----- registry/storage/filereader_test.go | 8 +-- registry/storage/garbagecollect_test.go | 36 ++++++------- registry/storage/linkedblobstore.go | 2 +- registry/storage/manifestlisthandler.go | 2 +- registry/storage/manifeststore.go | 2 +- registry/storage/ociindexhandler.go | 2 +- registry/storage/ocimanifesthandler.go | 2 +- registry/storage/schema2manifesthandler.go | 2 +- .../storage/schema2manifesthandler_test.go | 6 +-- registry/storage/vacuum.go | 2 +- testutil/manifests.go | 6 +-- testutil/tarfile.go | 4 +- 61 files changed, 151 insertions(+), 151 deletions(-) rename {context => internal/dcontext}/context.go (99%) rename {context => internal/dcontext}/doc.go (93%) rename {context => internal/dcontext}/http.go (99%) rename {context => internal/dcontext}/http_test.go (99%) rename {context => internal/dcontext}/logger.go (99%) rename {context => internal/dcontext}/trace.go (99%) rename {context => internal/dcontext}/trace_test.go (99%) rename {context => internal/dcontext}/util.go (97%) rename {context => internal/dcontext}/version.go (97%) rename {context => internal/dcontext}/version_test.go (95%) diff --git a/health/health.go b/health/health.go index 06961f353..3e21f731f 100644 --- a/health/health.go +++ b/health/health.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" ) @@ -279,7 +279,7 @@ func Handler(handler http.Handler) http.Handler { func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks map[string]string) { p, err := json.Marshal(checks) if err != nil { - context.GetLogger(context.Background()).Errorf("error serializing health status: %v", err) + dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status: %v", err) p, err = json.Marshal(struct { ServerError string `json:"server_error"` }{ @@ -288,7 +288,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m status = http.StatusInternalServerError if err != nil { - context.GetLogger(context.Background()).Errorf("error serializing health status failure message: %v", err) + dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status failure message: %v", err) return } } @@ -297,7 +297,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m w.Header().Set("Content-Length", fmt.Sprint(len(p))) w.WriteHeader(status) if _, err := w.Write(p); err != nil { - context.GetLogger(context.Background()).Errorf("error writing health status response body: %v", err) + dcontext.GetLogger(dcontext.Background()).Errorf("error writing health status response body: %v", err) } } diff --git a/internal/client/repository_test.go b/internal/client/repository_test.go index b6f4d224c..a98fe2899 100644 --- a/internal/client/repository_test.go +++ b/internal/client/repository_test.go @@ -17,7 +17,7 @@ import ( "time" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/registry/api/errcode" @@ -108,7 +108,7 @@ func TestBlobServeBlob(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo1") r, err := NewRepository(repo, e, nil) if err != nil { @@ -157,7 +157,7 @@ func TestBlobServeBlobHEAD(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo1") r, err := NewRepository(repo, e, nil) if err != nil { @@ -250,7 +250,7 @@ func TestBlobResume(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -307,7 +307,7 @@ func TestBlobDelete(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -327,7 +327,7 @@ func TestBlobFetch(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo1") r, err := NewRepository(repo, e, nil) if err != nil { @@ -382,7 +382,7 @@ func TestBlobExistsNoContentLength(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -406,7 +406,7 @@ func TestBlobExists(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo1") r, err := NewRepository(repo, e, nil) if err != nil { @@ -512,7 +512,7 @@ func TestBlobUploadChunked(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -622,7 +622,7 @@ func TestBlobUploadMonolithic(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -728,7 +728,7 @@ func TestBlobUploadMonolithicDockerUploadUUIDFromURL(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -833,7 +833,7 @@ func TestBlobUploadMonolithicNoDockerUploadUUID(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -891,7 +891,7 @@ func TestBlobMount(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -1066,7 +1066,7 @@ func checkEqualManifest(m1, m2 *ocischema.DeserializedManifest) error { } func TestOCIManifestFetch(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo") m1, dgst, pl := newRandomOCIManifest(t, 6) var m testutil.RequestResponseMap @@ -1149,7 +1149,7 @@ func TestManifestFetchWithEtag(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -1171,7 +1171,7 @@ func TestManifestFetchWithEtag(t *testing.T) { } func TestManifestFetchWithAccept(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo") _, dgst, _ := newRandomOCIManifest(t, 6) headers := make(chan []string, 1) @@ -1258,7 +1258,7 @@ func TestManifestDelete(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() ms, err := r.Manifests(ctx) if err != nil { t.Fatal(err) @@ -1315,7 +1315,7 @@ func TestManifestPut(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() ms, err := r.Manifests(ctx) if err != nil { t.Fatal(err) @@ -1372,7 +1372,7 @@ func TestManifestTags(t *testing.T) { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() tagService := r.Tags(ctx) tags, err := tagService.All(ctx) @@ -1423,7 +1423,7 @@ func TestTagDelete(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() ts := r.Tags(ctx) if err := ts.Untag(ctx, tag); err != nil { @@ -1460,7 +1460,7 @@ func TestObtainsErrorForMissingTag(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -1487,7 +1487,7 @@ func TestObtainsManifestForTagWithoutHeaders(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -1566,7 +1566,7 @@ func TestManifestTagsPaginated(t *testing.T) { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() tagService := r.Tags(ctx) tags, err := tagService.All(ctx) @@ -1614,7 +1614,7 @@ func TestManifestUnauthorized(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() ms, err := r.Manifests(ctx) if err != nil { t.Fatal(err) @@ -1652,7 +1652,7 @@ func TestCatalog(t *testing.T) { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() numFilled, err := r.Repositories(ctx, entries, "") if err != io.EOF { t.Fatal(err) @@ -1684,7 +1684,7 @@ func TestCatalogInParts(t *testing.T) { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() numFilled, err := r.Repositories(ctx, entries, "") if err != nil { t.Fatal(err) diff --git a/context/context.go b/internal/dcontext/context.go similarity index 99% rename from context/context.go rename to internal/dcontext/context.go index 23f93e0bc..fe8379800 100644 --- a/context/context.go +++ b/internal/dcontext/context.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "context" diff --git a/context/doc.go b/internal/dcontext/doc.go similarity index 93% rename from context/doc.go rename to internal/dcontext/doc.go index 51376dd69..a8d9a7402 100644 --- a/context/doc.go +++ b/internal/dcontext/doc.go @@ -1,16 +1,16 @@ -// Package context provides several utilities for working with +// Package dcontext provides several utilities for working with // Go's context in http requests. Primarily, the focus is on logging relevant // request information but this package is not limited to that purpose. // // The easiest way to get started is to get the background context: // -// ctx := context.Background() +// ctx := dcontext.Background() // // The returned context should be passed around your application and be the // root of all other context instances. If the application has a version, this // line should be called before anything else: // -// ctx := context.WithVersion(context.Background(), version) +// ctx := dcontext.WithVersion(dcontext.Background(), version) // // The above will store the version in the context and will be available to // the logger. @@ -27,7 +27,7 @@ // the context and reported with the logger. The following example would // return a logger that prints the version with each log message: // -// ctx := context.Context(context.Background(), "version", version) +// ctx := context.WithValue(dcontext.Background(), "version", version) // GetLogger(ctx, "version").Infof("this log message has a version field") // // The above would print out a log message like this: @@ -85,4 +85,4 @@ // can be traced in log messages. Using the fields like "http.request.id", one // can analyze call flow for a particular request with a simple grep of the // logs. -package context +package dcontext diff --git a/context/http.go b/internal/dcontext/http.go similarity index 99% rename from context/http.go rename to internal/dcontext/http.go index bcdf29658..69c29b74a 100644 --- a/context/http.go +++ b/internal/dcontext/http.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "context" diff --git a/context/http_test.go b/internal/dcontext/http_test.go similarity index 99% rename from context/http_test.go rename to internal/dcontext/http_test.go index d9e672316..9d1069d28 100644 --- a/context/http_test.go +++ b/internal/dcontext/http_test.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "net/http" diff --git a/context/logger.go b/internal/dcontext/logger.go similarity index 99% rename from context/logger.go rename to internal/dcontext/logger.go index f956a2282..058fc8310 100644 --- a/context/logger.go +++ b/internal/dcontext/logger.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "context" diff --git a/context/trace.go b/internal/dcontext/trace.go similarity index 99% rename from context/trace.go rename to internal/dcontext/trace.go index 2e169f0c5..ba2481053 100644 --- a/context/trace.go +++ b/internal/dcontext/trace.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "context" diff --git a/context/trace_test.go b/internal/dcontext/trace_test.go similarity index 99% rename from context/trace_test.go rename to internal/dcontext/trace_test.go index 4ee530bfa..bb6d1779f 100644 --- a/context/trace_test.go +++ b/internal/dcontext/trace_test.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "runtime" diff --git a/context/util.go b/internal/dcontext/util.go similarity index 97% rename from context/util.go rename to internal/dcontext/util.go index c462e7563..5b32ba16f 100644 --- a/context/util.go +++ b/internal/dcontext/util.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "context" diff --git a/context/version.go b/internal/dcontext/version.go similarity index 97% rename from context/version.go rename to internal/dcontext/version.go index 97cf9d665..0a0e9cbb6 100644 --- a/context/version.go +++ b/internal/dcontext/version.go @@ -1,4 +1,4 @@ -package context +package dcontext import "context" diff --git a/context/version_test.go b/internal/dcontext/version_test.go similarity index 95% rename from context/version_test.go rename to internal/dcontext/version_test.go index b81652691..9829fe959 100644 --- a/context/version_test.go +++ b/internal/dcontext/version_test.go @@ -1,4 +1,4 @@ -package context +package dcontext import "testing" diff --git a/notifications/bridge.go b/notifications/bridge.go index e63aab65b..133153deb 100644 --- a/notifications/bridge.go +++ b/notifications/bridge.go @@ -5,7 +5,7 @@ import ( "time" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/reference" events "github.com/docker/go-events" "github.com/google/uuid" @@ -49,7 +49,7 @@ func NewBridge(ub URLBuilder, source SourceRecord, actor ActorRecord, request Re func NewRequestRecord(id string, r *http.Request) RequestRecord { return RequestRecord{ ID: id, - Addr: context.RemoteAddr(r), + Addr: dcontext.RemoteAddr(r), Host: r.Host, Method: r.Method, UserAgent: r.UserAgent(), diff --git a/notifications/listener.go b/notifications/listener.go index fad293415..563c13928 100644 --- a/notifications/listener.go +++ b/notifications/listener.go @@ -7,7 +7,7 @@ import ( "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/reference" "github.com/opencontainers/go-digest" ) diff --git a/notifications/listener_test.go b/notifications/listener_test.go index e33109beb..5781d4537 100644 --- a/notifications/listener_test.go +++ b/notifications/listener_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage/cache/memory" diff --git a/registry/auth/htpasswd/access.go b/registry/auth/htpasswd/access.go index 8d8a04198..0a1d0c1ce 100644 --- a/registry/auth/htpasswd/access.go +++ b/registry/auth/htpasswd/access.go @@ -18,7 +18,7 @@ import ( "golang.org/x/crypto/bcrypt" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) diff --git a/registry/auth/htpasswd/access_test.go b/registry/auth/htpasswd/access_test.go index 25947757a..0871ef411 100644 --- a/registry/auth/htpasswd/access_test.go +++ b/registry/auth/htpasswd/access_test.go @@ -8,7 +8,7 @@ import ( "os" "testing" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) @@ -33,7 +33,7 @@ func TestBasicAccessController(t *testing.T) { "realm": testRealm, "path": tempFile.Name(), } - ctx := context.Background() + ctx := dcontext.Background() accessController, err := newAccessController(options) if err != nil { @@ -45,7 +45,7 @@ func TestBasicAccessController(t *testing.T) { userNumber := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithRequest(ctx, r) + ctx := dcontext.WithRequest(ctx, r) authCtx, err := accessController.Authorized(ctx) if err != nil { switch err := err.(type) { diff --git a/registry/auth/silly/access.go b/registry/auth/silly/access.go index b09373f3f..685cf6a62 100644 --- a/registry/auth/silly/access.go +++ b/registry/auth/silly/access.go @@ -13,7 +13,7 @@ import ( "net/http" "strings" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) diff --git a/registry/auth/silly/access_test.go b/registry/auth/silly/access_test.go index 482fb2f7a..f463e98c4 100644 --- a/registry/auth/silly/access_test.go +++ b/registry/auth/silly/access_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) @@ -16,7 +16,7 @@ func TestSillyAccessController(t *testing.T) { } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithRequest(context.Background(), r) + ctx := dcontext.WithRequest(dcontext.Background(), r) authCtx, err := ac.Authorized(ctx) if err != nil { switch err := err.(type) { diff --git a/registry/auth/token/accesscontroller.go b/registry/auth/token/accesscontroller.go index 249020192..b2e4e4b27 100644 --- a/registry/auth/token/accesscontroller.go +++ b/registry/auth/token/accesscontroller.go @@ -13,7 +13,7 @@ import ( "os" "strings" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" "github.com/go-jose/go-jose/v3" ) diff --git a/registry/auth/token/token_test.go b/registry/auth/token/token_test.go index 417d0d9bc..a331a93bf 100644 --- a/registry/auth/token/token_test.go +++ b/registry/auth/token/token_test.go @@ -18,7 +18,7 @@ import ( "testing" "time" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" @@ -466,7 +466,7 @@ func TestAccessController(t *testing.T) { Action: "baz", } - ctx := context.WithRequest(context.Background(), req) + ctx := dcontext.WithRequest(dcontext.Background(), req) authCtx, err := accessController.Authorized(ctx, testAccess) challenge, ok := err.(auth.Challenge) if !ok { diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 8efdaf85e..7ce27d6dc 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -19,9 +19,9 @@ import ( "github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3/configuration" - dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/health" "github.com/distribution/distribution/v3/health/checks" + "github.com/distribution/distribution/v3/internal/dcontext" prometheus "github.com/distribution/distribution/v3/metrics" "github.com/distribution/distribution/v3/notifications" "github.com/distribution/distribution/v3/registry/api/errcode" diff --git a/registry/handlers/app_test.go b/registry/handlers/app_test.go index 391deea62..d18625926 100644 --- a/registry/handlers/app_test.go +++ b/registry/handlers/app_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/distribution/distribution/v3/configuration" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" v2 "github.com/distribution/distribution/v3/registry/api/v2" "github.com/distribution/distribution/v3/registry/auth" @@ -25,7 +25,7 @@ import ( // tested individually. func TestAppDispatcher(t *testing.T) { driver := inmemory.New() - ctx := context.Background() + ctx := dcontext.Background() registry, err := storage.NewRegistry(ctx, driver, storage.BlobDescriptorCacheProvider(memorycache.NewInMemoryBlobDescriptorCacheProvider(0)), storage.EnableDelete, storage.EnableRedirect) if err != nil { t.Fatalf("error creating registry: %v", err) @@ -139,7 +139,7 @@ func TestAppDispatcher(t *testing.T) { // TestNewApp covers the creation of an application via NewApp with a // configuration. func TestNewApp(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() config := configuration.Configuration{ Storage: configuration.Storage{ "inmemory": nil, diff --git a/registry/handlers/blob.go b/registry/handlers/blob.go index 4d50754d7..b4979782f 100644 --- a/registry/handlers/blob.go +++ b/registry/handlers/blob.go @@ -4,7 +4,7 @@ import ( "net/http" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" "github.com/gorilla/handlers" "github.com/opencontainers/go-digest" @@ -53,7 +53,7 @@ type blobHandler struct { // GetBlob fetches the binary data from backend storage returns it in the // response. func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) { - context.GetLogger(bh).Debug("GetBlob") + dcontext.GetLogger(bh).Debug("GetBlob") blobs := bh.Repository.Blobs(bh) desc, err := blobs.Stat(bh, bh.Digest) if err != nil { @@ -66,7 +66,7 @@ func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) { } if err := blobs.ServeBlob(bh, w, r, desc.Digest); err != nil { - context.GetLogger(bh).Debugf("unexpected error getting blob HTTP handler: %v", err) + dcontext.GetLogger(bh).Debugf("unexpected error getting blob HTTP handler: %v", err) bh.Errors = append(bh.Errors, errcode.ErrorCodeUnknown.WithDetail(err)) return } @@ -74,7 +74,7 @@ func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) { // DeleteBlob deletes a layer blob func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) { - context.GetLogger(bh).Debug("DeleteBlob") + dcontext.GetLogger(bh).Debug("DeleteBlob") blobs := bh.Repository.Blobs(bh) err := blobs.Delete(bh, bh.Digest) @@ -88,7 +88,7 @@ func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) { return default: bh.Errors = append(bh.Errors, err) - context.GetLogger(bh).Errorf("Unknown error deleting blob: %s", err.Error()) + dcontext.GetLogger(bh).Errorf("Unknown error deleting blob: %s", err.Error()) return } } diff --git a/registry/handlers/blobupload.go b/registry/handlers/blobupload.go index 09c00bb8c..6b6c640da 100644 --- a/registry/handlers/blobupload.go +++ b/registry/handlers/blobupload.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/reference" diff --git a/registry/handlers/context.go b/registry/handlers/context.go index cac97a04c..cb3540237 100644 --- a/registry/handlers/context.go +++ b/registry/handlers/context.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" v2 "github.com/distribution/distribution/v3/registry/api/v2" "github.com/distribution/distribution/v3/registry/auth" diff --git a/registry/handlers/health_test.go b/registry/handlers/health_test.go index e0d61555a..dc706b870 100644 --- a/registry/handlers/health_test.go +++ b/registry/handlers/health_test.go @@ -9,8 +9,8 @@ import ( "time" "github.com/distribution/distribution/v3/configuration" - "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/health" + "github.com/distribution/distribution/v3/internal/dcontext" ) func TestFileHealthCheck(t *testing.T) { @@ -39,7 +39,7 @@ func TestFileHealthCheck(t *testing.T) { }, } - ctx := context.Background() + ctx := dcontext.Background() app := NewApp(ctx, config) healthRegistry := health.NewRegistry() @@ -103,7 +103,7 @@ func TestTCPHealthCheck(t *testing.T) { }, } - ctx := context.Background() + ctx := dcontext.Background() app := NewApp(ctx, config) healthRegistry := health.NewRegistry() @@ -165,7 +165,7 @@ func TestHTTPHealthCheck(t *testing.T) { }, } - ctx := context.Background() + ctx := dcontext.Background() app := NewApp(ctx, config) healthRegistry := health.NewRegistry() diff --git a/registry/handlers/helpers.go b/registry/handlers/helpers.go index d70306fd7..3ccba5558 100644 --- a/registry/handlers/helpers.go +++ b/registry/handlers/helpers.go @@ -9,7 +9,7 @@ import ( "strconv" "strings" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" ) // closeResources closes all the provided resources after running the target diff --git a/registry/handlers/manifests.go b/registry/handlers/manifests.go index d7bbab926..06b7c0c75 100644 --- a/registry/handlers/manifests.go +++ b/registry/handlers/manifests.go @@ -8,7 +8,7 @@ import ( "strings" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/manifest/schema2" diff --git a/registry/proxy/proxyauth.go b/registry/proxy/proxyauth.go index ee79ed7e6..8cdc3ebff 100644 --- a/registry/proxy/proxyauth.go +++ b/registry/proxy/proxyauth.go @@ -5,9 +5,9 @@ import ( "net/url" "strings" - "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/client/auth" "github.com/distribution/distribution/v3/internal/client/auth/challenge" + "github.com/distribution/distribution/v3/internal/dcontext" ) const challengeHeader = "Docker-Distribution-Api-Version" @@ -44,7 +44,7 @@ func configureAuth(username, password, remoteURL string) (auth.CredentialStore, } for _, url := range authURLs { - context.GetLogger(context.Background()).Infof("Discovered token authentication URL: %s", url) + dcontext.GetLogger(dcontext.Background()).Infof("Discovered token authentication URL: %s", url) creds[url] = userpass{ username: username, password: password, diff --git a/registry/proxy/proxyblobstore.go b/registry/proxy/proxyblobstore.go index f83ef329a..bf8ca22fd 100644 --- a/registry/proxy/proxyblobstore.go +++ b/registry/proxy/proxyblobstore.go @@ -11,7 +11,7 @@ import ( "github.com/opencontainers/go-digest" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/proxy/scheduler" "github.com/distribution/reference" ) diff --git a/registry/proxy/proxymanifeststore.go b/registry/proxy/proxymanifeststore.go index fa60a07e1..1b0e5a314 100644 --- a/registry/proxy/proxymanifeststore.go +++ b/registry/proxy/proxymanifeststore.go @@ -7,7 +7,7 @@ import ( "github.com/opencontainers/go-digest" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/proxy/scheduler" "github.com/distribution/reference" ) diff --git a/registry/proxy/proxyregistry.go b/registry/proxy/proxyregistry.go index 15dc783b7..33dcc4afa 100644 --- a/registry/proxy/proxyregistry.go +++ b/registry/proxy/proxyregistry.go @@ -10,11 +10,11 @@ import ( "github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3/configuration" - dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/client" "github.com/distribution/distribution/v3/internal/client/auth" "github.com/distribution/distribution/v3/internal/client/auth/challenge" "github.com/distribution/distribution/v3/internal/client/transport" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/proxy/scheduler" "github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage/driver" diff --git a/registry/proxy/scheduler/scheduler.go b/registry/proxy/scheduler/scheduler.go index e492cf716..ed1d9d419 100644 --- a/registry/proxy/scheduler/scheduler.go +++ b/registry/proxy/scheduler/scheduler.go @@ -7,7 +7,7 @@ import ( "sync" "time" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/reference" ) diff --git a/registry/proxy/scheduler/scheduler_test.go b/registry/proxy/scheduler/scheduler_test.go index 1309a1b01..38fa0f580 100644 --- a/registry/proxy/scheduler/scheduler_test.go +++ b/registry/proxy/scheduler/scheduler_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/distribution/reference" ) @@ -40,7 +40,7 @@ func TestSchedule(t *testing.T) { } var mu sync.Mutex - s := New(context.Background(), inmemory.New(), "/ttl") + s := New(dcontext.Background(), inmemory.New(), "/ttl") deleteFunc := func(repoName reference.Reference) error { if len(remainingRepos) == 0 { t.Fatalf("Incorrect expiry count") @@ -123,14 +123,14 @@ func TestRestoreOld(t *testing.T) { t.Fatalf("Error serializing test data: %s", err.Error()) } - ctx := context.Background() + ctx := dcontext.Background() pathToStatFile := "/ttl" fs := inmemory.New() err = fs.PutContent(ctx, pathToStatFile, serialized) if err != nil { t.Fatal("Unable to write serialized data to fs") } - s := New(context.Background(), fs, "/ttl") + s := New(dcontext.Background(), fs, "/ttl") s.OnBlobExpire(deleteFunc) err = s.Start() if err != nil { @@ -165,7 +165,7 @@ func TestStopRestore(t *testing.T) { fs := inmemory.New() pathToStateFile := "/ttl" - s := New(context.Background(), fs, pathToStateFile) + s := New(dcontext.Background(), fs, pathToStateFile) s.onBlobExpire = deleteFunc err := s.Start() @@ -181,7 +181,7 @@ func TestStopRestore(t *testing.T) { time.Sleep(10 * time.Millisecond) // v2 will restore state from fs - s2 := New(context.Background(), fs, pathToStateFile) + s2 := New(dcontext.Background(), fs, pathToStateFile) s2.onBlobExpire = deleteFunc err = s2.Start() if err != nil { @@ -197,7 +197,7 @@ func TestStopRestore(t *testing.T) { } func TestDoubleStart(t *testing.T) { - s := New(context.Background(), inmemory.New(), "/ttl") + s := New(dcontext.Background(), inmemory.New(), "/ttl") err := s.Start() if err != nil { t.Fatalf("Unable to start scheduler") diff --git a/registry/registry.go b/registry/registry.go index 98a29eeda..933013c58 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -21,8 +21,8 @@ import ( "golang.org/x/crypto/acme/autocert" "github.com/distribution/distribution/v3/configuration" - dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/health" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/handlers" "github.com/distribution/distribution/v3/registry/listener" "github.com/distribution/distribution/v3/version" diff --git a/registry/registry_test.go b/registry/registry_test.go index 3e919e315..194da97ac 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -25,7 +25,7 @@ import ( "time" "github.com/distribution/distribution/v3/configuration" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" _ "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" diff --git a/registry/root.go b/registry/root.go index 51c8c7ffb..272027e67 100644 --- a/registry/root.go +++ b/registry/root.go @@ -4,7 +4,7 @@ import ( "fmt" "os" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage/driver/factory" "github.com/distribution/distribution/v3/version" diff --git a/registry/storage/blobstore.go b/registry/storage/blobstore.go index f54e901d5..c03736ea3 100644 --- a/registry/storage/blobstore.go +++ b/registry/storage/blobstore.go @@ -6,7 +6,7 @@ import ( "path" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/opencontainers/go-digest" ) diff --git a/registry/storage/blobwriter.go b/registry/storage/blobwriter.go index 70c8b2afe..9de3b0c2f 100644 --- a/registry/storage/blobwriter.go +++ b/registry/storage/blobwriter.go @@ -9,7 +9,7 @@ import ( "time" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/opencontainers/go-digest" "github.com/sirupsen/logrus" diff --git a/registry/storage/blobwriter_nonresumable.go b/registry/storage/blobwriter_nonresumable.go index a7df9a22a..b3b3f6abe 100644 --- a/registry/storage/blobwriter_nonresumable.go +++ b/registry/storage/blobwriter_nonresumable.go @@ -4,7 +4,7 @@ package storage import ( - "github.com/distribution/distribution/v3/context" + "context" ) // resumeHashAt is a noop when resumable digest support is disabled. diff --git a/registry/storage/cache/cachedblobdescriptorstore.go b/registry/storage/cache/cachedblobdescriptorstore.go index b4dc828cc..38dd1dcff 100644 --- a/registry/storage/cache/cachedblobdescriptorstore.go +++ b/registry/storage/cache/cachedblobdescriptorstore.go @@ -4,7 +4,7 @@ import ( "context" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" prometheus "github.com/distribution/distribution/v3/metrics" "github.com/opencontainers/go-digest" ) diff --git a/registry/storage/driver/base/base.go b/registry/storage/driver/base/base.go index 693717129..756a0d4c9 100644 --- a/registry/storage/driver/base/base.go +++ b/registry/storage/driver/base/base.go @@ -42,7 +42,7 @@ import ( "io" "time" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" prometheus "github.com/distribution/distribution/v3/metrics" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/docker/go-metrics" diff --git a/registry/storage/driver/gcs/gcs_test.go b/registry/storage/driver/gcs/gcs_test.go index 65998e7bc..a76015d08 100644 --- a/registry/storage/driver/gcs/gcs_test.go +++ b/registry/storage/driver/gcs/gcs_test.go @@ -10,7 +10,7 @@ import ( "testing" "cloud.google.com/go/storage" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/testsuites" "golang.org/x/oauth2" diff --git a/registry/storage/driver/middleware/cloudfront/middleware.go b/registry/storage/driver/middleware/cloudfront/middleware.go index 5c2c09955..32474cbd0 100644 --- a/registry/storage/driver/middleware/cloudfront/middleware.go +++ b/registry/storage/driver/middleware/cloudfront/middleware.go @@ -13,7 +13,7 @@ import ( "time" "github.com/aws/aws-sdk-go/service/cloudfront/sign" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagemiddleware "github.com/distribution/distribution/v3/registry/storage/driver/middleware" ) diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go index 25aafd043..7a23bcbd0 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter.go @@ -11,7 +11,7 @@ import ( "sync" "time" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" ) const ( diff --git a/registry/storage/driver/middleware/cloudfront/s3filter_test.go b/registry/storage/driver/middleware/cloudfront/s3filter_test.go index 81ef6aa8a..0d1055601 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter_test.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" ) // Rather than pull in all of testify diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go index a6428a918..a85624d44 100644 --- a/registry/storage/driver/s3-aws/s3.go +++ b/registry/storage/driver/s3-aws/s3.go @@ -36,7 +36,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/base" "github.com/distribution/distribution/v3/registry/storage/driver/factory" diff --git a/registry/storage/driver/s3-aws/s3_test.go b/registry/storage/driver/s3-aws/s3_test.go index ebd9d2a2b..de6dd12ba 100644 --- a/registry/storage/driver/s3-aws/s3_test.go +++ b/registry/storage/driver/s3-aws/s3_test.go @@ -16,7 +16,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/s3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/testsuites" ) @@ -180,7 +180,7 @@ func TestEmptyRootList(t *testing.T) { filename := "/test" contents := []byte("contents") - ctx := context.Background() + ctx := dcontext.Background() err = rootedDriver.PutContent(ctx, filename, contents) if err != nil { t.Fatalf("unexpected error creating content: %v", err) @@ -209,7 +209,7 @@ func TestStorageClass(t *testing.T) { rootDir := t.TempDir() contents := []byte("contents") - ctx := context.Background() + ctx := dcontext.Background() // We don't need to test all the storage classes, just that its selectable. // The first 3 are common to AWS and MinIO, so use those. @@ -377,7 +377,7 @@ func TestDelete(t *testing.T) { // init file structure matching objs var created []string for _, p := range objs { - err := drvr.PutContent(context.Background(), p, []byte("content "+p)) + err := drvr.PutContent(dcontext.Background(), p, []byte("content "+p)) if err != nil { fmt.Printf("unable to init file %s: %s\n", p, err) continue @@ -390,7 +390,7 @@ func TestDelete(t *testing.T) { cleanup := func(objs []string) { var lastErr error for _, p := range objs { - err := drvr.Delete(context.Background(), p) + err := drvr.Delete(dcontext.Background(), p) if err != nil { switch err.(type) { case storagedriver.PathNotFoundError: @@ -409,7 +409,7 @@ func TestDelete(t *testing.T) { t.Run(tc.name, func(t *testing.T) { objs := init() - err := drvr.Delete(context.Background(), tc.delete) + err := drvr.Delete(dcontext.Background(), tc.delete) if tc.err != nil { if err == nil { @@ -437,7 +437,7 @@ func TestDelete(t *testing.T) { return false } for _, path := range objs { - stat, err := drvr.Stat(context.Background(), path) + stat, err := drvr.Stat(dcontext.Background(), path) if err != nil { switch err.(type) { case storagedriver.PathNotFoundError: @@ -491,7 +491,7 @@ func TestWalk(t *testing.T) { // create file structure matching fileset above created := make([]string, 0, len(fileset)) for _, p := range fileset { - err := drvr.PutContent(context.Background(), p, []byte("content "+p)) + err := drvr.PutContent(dcontext.Background(), p, []byte("content "+p)) if err != nil { fmt.Printf("unable to create file %s: %s\n", p, err) continue @@ -503,7 +503,7 @@ func TestWalk(t *testing.T) { defer func() { var lastErr error for _, p := range created { - err := drvr.Delete(context.Background(), p) + err := drvr.Delete(dcontext.Background(), p) if err != nil { _ = fmt.Errorf("cleanup failed for path %s: %s", p, err) lastErr = err @@ -692,7 +692,7 @@ func TestWalk(t *testing.T) { tc.from = "/" } t.Run(tc.name, func(t *testing.T) { - err := drvr.Walk(context.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error { + err := drvr.Walk(dcontext.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error { walked = append(walked, fileInfo.Path()) return tc.fn(fileInfo) }, tc.options...) @@ -718,7 +718,7 @@ func TestOverThousandBlobs(t *testing.T) { t.Fatalf("unexpected error creating driver with standard storage: %v", err) } - ctx := context.Background() + ctx := dcontext.Background() for i := 0; i < 1005; i++ { filename := "/thousandfiletest/file" + strconv.Itoa(i) contents := []byte("contents") @@ -746,7 +746,7 @@ func TestMoveWithMultipartCopy(t *testing.T) { t.Fatalf("unexpected error creating driver: %v", err) } - ctx := context.Background() + ctx := dcontext.Background() sourcePath := "/source" destPath := "/dest" @@ -795,7 +795,7 @@ func TestListObjectsV2(t *testing.T) { t.Fatalf("unexpected error creating driver: %v", err) } - ctx := context.Background() + ctx := dcontext.Background() n := 6 prefix := "/test-list-objects-v2" var filePaths []string diff --git a/registry/storage/filereader_test.go b/registry/storage/filereader_test.go index 3f05582e0..5dfa29473 100644 --- a/registry/storage/filereader_test.go +++ b/registry/storage/filereader_test.go @@ -7,13 +7,13 @@ import ( mrand "math/rand" "testing" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/opencontainers/go-digest" ) func TestSimpleRead(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() content := make([]byte, 1<<20) n, err := crand.Read(content) if err != nil { @@ -55,7 +55,7 @@ func TestFileReaderSeek(t *testing.T) { repititions := 1024 path := "/patterned" content := bytes.Repeat([]byte(pattern), repititions) - ctx := context.Background() + ctx := dcontext.Background() if err := driver.PutContent(ctx, path, content); err != nil { t.Fatalf("error putting patterned content: %v", err) @@ -156,7 +156,7 @@ func TestFileReaderSeek(t *testing.T) { // read method, with an io.EOF error. func TestFileReaderNonExistentFile(t *testing.T) { driver := inmemory.New() - fr, err := newFileReader(context.Background(), driver, "/doesnotexist", 10) + fr, err := newFileReader(dcontext.Background(), driver, "/doesnotexist", 10) if err != nil { t.Fatalf("unexpected error initializing reader: %v", err) } diff --git a/registry/storage/garbagecollect_test.go b/registry/storage/garbagecollect_test.go index edf7f1c37..5e69a1786 100644 --- a/registry/storage/garbagecollect_test.go +++ b/registry/storage/garbagecollect_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/distribution/distribution/v3/testutil" @@ -21,7 +21,7 @@ type image struct { } func createRegistry(t *testing.T, driver driver.StorageDriver, options ...RegistryOption) distribution.Namespace { - ctx := context.Background() + ctx := dcontext.Background() options = append(options, EnableDelete) registry, err := NewRegistry(ctx, driver, options...) if err != nil { @@ -31,7 +31,7 @@ func createRegistry(t *testing.T, driver driver.StorageDriver, options ...Regist } func makeRepository(t *testing.T, registry distribution.Namespace, name string) distribution.Repository { - ctx := context.Background() + ctx := dcontext.Background() // Initialize a dummy repository named, err := reference.WithName(name) @@ -47,7 +47,7 @@ func makeRepository(t *testing.T, registry distribution.Namespace, name string) } func makeManifestService(t *testing.T, repository distribution.Repository) distribution.ManifestService { - ctx := context.Background() + ctx := dcontext.Background() manifestService, err := repository.Manifests(ctx) if err != nil { @@ -57,7 +57,7 @@ func makeManifestService(t *testing.T, repository distribution.Repository) distr } func allManifests(t *testing.T, manifestService distribution.ManifestService) map[digest.Digest]struct{} { - ctx := context.Background() + ctx := dcontext.Background() allManMap := make(map[digest.Digest]struct{}) manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator) if !ok { @@ -74,7 +74,7 @@ func allManifests(t *testing.T, manifestService distribution.ManifestService) ma } func allBlobs(t *testing.T, registry distribution.Namespace) map[digest.Digest]struct{} { - ctx := context.Background() + ctx := dcontext.Background() blobService := registry.Blobs() allBlobsMap := make(map[digest.Digest]struct{}) err := blobService.Enumerate(ctx, func(dgst digest.Digest) error { @@ -95,7 +95,7 @@ func uploadImage(t *testing.T, repository distribution.Repository, im image) dig } // upload manifest - ctx := context.Background() + ctx := dcontext.Background() manifestService := makeManifestService(t, repository) manifestDigest, err := manifestService.Put(ctx, im.manifest) if err != nil { @@ -130,7 +130,7 @@ func uploadRandomSchema2Image(t *testing.T, repository distribution.Repository) } func TestNoDeletionNoEffect(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver) @@ -158,7 +158,7 @@ func TestNoDeletionNoEffect(t *testing.T) { before := allBlobs(t, registry) // Run GC - err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: false, RemoveUntagged: false, }) @@ -173,7 +173,7 @@ func TestNoDeletionNoEffect(t *testing.T) { } func TestDeleteManifestIfTagNotFound(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver) @@ -233,7 +233,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) { before2 := allManifests(t, manifestService) // run GC with dry-run (should not remove anything) - err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: true, RemoveUntagged: true, }) @@ -250,7 +250,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) { } // Run GC (removes everything because no manifests with tags exist) - err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: false, RemoveUntagged: true, }) @@ -269,7 +269,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) { } func TestGCWithMissingManifests(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() d := inmemory.New() registry := createRegistry(t, d) @@ -288,7 +288,7 @@ func TestGCWithMissingManifests(t *testing.T) { t.Fatal(err) } - err = MarkAndSweep(context.Background(), d, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), d, registry, GCOpts{ DryRun: false, RemoveUntagged: false, }) @@ -303,7 +303,7 @@ func TestGCWithMissingManifests(t *testing.T) { } func TestDeletionHasEffect(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver) @@ -318,7 +318,7 @@ func TestDeletionHasEffect(t *testing.T) { manifests.Delete(ctx, image3.manifestDigest) // Run GC - err := MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err := MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: false, RemoveUntagged: false, }) @@ -368,7 +368,7 @@ func getKeys(digests map[digest.Digest]io.ReadSeeker) (ds []digest.Digest) { } func TestDeletionWithSharedLayer(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver) @@ -455,7 +455,7 @@ func TestOrphanBlobDeleted(t *testing.T) { uploadRandomSchema2Image(t, repo) // Run GC - err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: false, RemoveUntagged: false, }) diff --git a/registry/storage/linkedblobstore.go b/registry/storage/linkedblobstore.go index 392d39b06..18d40d278 100644 --- a/registry/storage/linkedblobstore.go +++ b/registry/storage/linkedblobstore.go @@ -9,7 +9,7 @@ import ( "time" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/reference" "github.com/google/uuid" diff --git a/registry/storage/manifestlisthandler.go b/registry/storage/manifestlisthandler.go index 1fc7aac7a..caaf76bb0 100644 --- a/registry/storage/manifestlisthandler.go +++ b/registry/storage/manifestlisthandler.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/opencontainers/go-digest" diff --git a/registry/storage/manifeststore.go b/registry/storage/manifeststore.go index 37c86495f..027ab65ce 100644 --- a/registry/storage/manifeststore.go +++ b/registry/storage/manifeststore.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest" "github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/ocischema" diff --git a/registry/storage/ociindexhandler.go b/registry/storage/ociindexhandler.go index 01864f061..b4d73413d 100644 --- a/registry/storage/ociindexhandler.go +++ b/registry/storage/ociindexhandler.go @@ -4,7 +4,7 @@ import ( "context" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/opencontainers/go-digest" ) diff --git a/registry/storage/ocimanifesthandler.go b/registry/storage/ocimanifesthandler.go index 97216d2a4..f69c2a2b3 100644 --- a/registry/storage/ocimanifesthandler.go +++ b/registry/storage/ocimanifesthandler.go @@ -6,7 +6,7 @@ import ( "net/url" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/opencontainers/go-digest" v1 "github.com/opencontainers/image-spec/specs-go/v1" diff --git a/registry/storage/schema2manifesthandler.go b/registry/storage/schema2manifesthandler.go index aa8980ec8..adf0677ee 100644 --- a/registry/storage/schema2manifesthandler.go +++ b/registry/storage/schema2manifesthandler.go @@ -7,7 +7,7 @@ import ( "net/url" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/opencontainers/go-digest" ) diff --git a/registry/storage/schema2manifesthandler_test.go b/registry/storage/schema2manifesthandler_test.go index 908f8c137..26c19ec12 100644 --- a/registry/storage/schema2manifesthandler_test.go +++ b/registry/storage/schema2manifesthandler_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" @@ -14,7 +14,7 @@ import ( ) func TestVerifyManifestForeignLayer(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver, ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")), @@ -152,7 +152,7 @@ func TestVerifyManifestForeignLayer(t *testing.T) { } func TestVerifyManifestBlobLayerAndConfig(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver, ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")), diff --git a/registry/storage/vacuum.go b/registry/storage/vacuum.go index 749fb3190..38ebbd67f 100644 --- a/registry/storage/vacuum.go +++ b/registry/storage/vacuum.go @@ -4,7 +4,7 @@ import ( "context" "path" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/opencontainers/go-digest" ) diff --git a/testutil/manifests.go b/testutil/manifests.go index 96af5703b..dd1cc3a3d 100644 --- a/testutil/manifests.go +++ b/testutil/manifests.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/opencontainers/go-digest" @@ -12,7 +12,7 @@ import ( // MakeManifestList constructs a manifest list out of a list of manifest digests func MakeManifestList(blobstatter distribution.BlobStatter, manifestDigests []digest.Digest) (*manifestlist.DeserializedManifestList, error) { - ctx := context.Background() + ctx := dcontext.Background() manifestDescriptors := make([]manifestlist.ManifestDescriptor, 0, len(manifestDigests)) for _, manifestDigest := range manifestDigests { @@ -39,7 +39,7 @@ func MakeManifestList(blobstatter distribution.BlobStatter, manifestDigests []di // MakeSchema2Manifest constructs a schema 2 manifest from a given list of digests and returns // the digest of the manifest func MakeSchema2Manifest(repository distribution.Repository, digests []digest.Digest) (distribution.Manifest, error) { - ctx := context.Background() + ctx := dcontext.Background() blobStore := repository.Blobs(ctx) var configJSON []byte diff --git a/testutil/tarfile.go b/testutil/tarfile.go index 2b30a24dc..d4bc7e2a1 100644 --- a/testutil/tarfile.go +++ b/testutil/tarfile.go @@ -10,7 +10,7 @@ import ( "time" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/opencontainers/go-digest" ) @@ -96,7 +96,7 @@ func CreateRandomLayers(n int) (map[digest.Digest]io.ReadSeeker, error) { // UploadBlobs lets you upload blobs to a repository func UploadBlobs(repository distribution.Repository, layers map[digest.Digest]io.ReadSeeker) error { - ctx := context.Background() + ctx := dcontext.Background() for dgst, rs := range layers { wr, err := repository.Blobs(ctx).Create(ctx) if err != nil { From 9157226e7bf3dbe440fa1563ff662cd3b9b39e34 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 24 Oct 2023 13:39:55 -0400 Subject: [PATCH 2/7] Extract request utilities into its own package The RemoteAddr and RemoteIP functions operate on *http.Request values, not contexts. They have very low cohesion with the rest of the package. Signed-off-by: Cory Snider --- internal/dcontext/http.go | 47 +----------- internal/dcontext/http_test.go | 70 ----------------- internal/requestutil/util.go | 51 +++++++++++++ internal/requestutil/util_test.go | 76 +++++++++++++++++++ notifications/bridge.go | 4 +- .../driver/middleware/cloudfront/s3filter.go | 5 +- 6 files changed, 134 insertions(+), 119 deletions(-) create mode 100644 internal/requestutil/util.go create mode 100644 internal/requestutil/util_test.go diff --git a/internal/dcontext/http.go b/internal/dcontext/http.go index 69c29b74a..df068f13e 100644 --- a/internal/dcontext/http.go +++ b/internal/dcontext/http.go @@ -3,15 +3,14 @@ package dcontext import ( "context" "errors" - "net" "net/http" "strings" "sync" "time" + "github.com/distribution/distribution/v3/internal/requestutil" "github.com/google/uuid" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" ) // Common errors used with this package. @@ -20,48 +19,6 @@ var ( ErrNoResponseWriterContext = errors.New("no http response in context") ) -func parseIP(ipStr string) net.IP { - ip := net.ParseIP(ipStr) - if ip == nil { - log.Warnf("invalid remote IP address: %q", ipStr) - } - return ip -} - -// RemoteAddr extracts the remote address of the request, taking into -// account proxy headers. -func RemoteAddr(r *http.Request) string { - if prior := r.Header.Get("X-Forwarded-For"); prior != "" { - remoteAddr, _, _ := strings.Cut(prior, ",") - remoteAddr = strings.Trim(remoteAddr, " ") - if parseIP(remoteAddr) != nil { - return remoteAddr - } - } - // X-Real-Ip is less supported, but worth checking in the - // absence of X-Forwarded-For - if realIP := r.Header.Get("X-Real-Ip"); realIP != "" { - if parseIP(realIP) != nil { - return realIP - } - } - - return r.RemoteAddr -} - -// RemoteIP extracts the remote IP of the request, taking into -// account proxy headers. -func RemoteIP(r *http.Request) string { - addr := RemoteAddr(r) - - // Try parsing it as "IP:port" - if ip, _, err := net.SplitHostPort(addr); err == nil { - return ip - } - - return addr -} - // WithRequest places the request on the context. The context of the request // is assigned a unique id, available at "http.request.id". The request itself // is available at "http.request". Other common attributes are available under @@ -193,7 +150,7 @@ func (ctx *httpRequestContext) Value(key interface{}) interface{} { case "http.request.uri": return ctx.r.RequestURI case "http.request.remoteaddr": - return RemoteAddr(ctx.r) + return requestutil.RemoteAddr(ctx.r) case "http.request.method": return ctx.r.Method case "http.request.host": diff --git a/internal/dcontext/http_test.go b/internal/dcontext/http_test.go index 9d1069d28..99c47bcdd 100644 --- a/internal/dcontext/http_test.go +++ b/internal/dcontext/http_test.go @@ -2,9 +2,6 @@ package dcontext import ( "net/http" - "net/http/httptest" - "net/http/httputil" - "net/url" "reflect" "testing" "time" @@ -219,70 +216,3 @@ func TestWithVars(t *testing.T) { } } } - -// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test -// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten -// at the transport layer to 127.0.0.1: . However, as the X-Forwarded-For header -// just contains the IP address, it is different enough for testing. -func TestRemoteAddr(t *testing.T) { - var expectedRemote string - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - - if r.RemoteAddr == expectedRemote { - t.Errorf("Unexpected matching remote addresses") - } - - actualRemote := RemoteAddr(r) - if expectedRemote != actualRemote { - t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote) - } - - w.WriteHeader(200) - })) - - defer backend.Close() - backendURL, err := url.Parse(backend.URL) - if err != nil { - t.Fatal(err) - } - - proxy := httputil.NewSingleHostReverseProxy(backendURL) - frontend := httptest.NewServer(proxy) - defer frontend.Close() - - // X-Forwarded-For set by proxy - expectedRemote = "127.0.0.1" - proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil) - if err != nil { - t.Fatal(err) - } - - resp, err := http.DefaultClient.Do(proxyReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - // RemoteAddr in X-Real-Ip - getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil) - if err != nil { - t.Fatal(err) - } - - expectedRemote = "1.2.3.4" - getReq.Header["X-Real-ip"] = []string{expectedRemote} - resp, err = http.DefaultClient.Do(getReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - // Valid X-Real-Ip and invalid X-Forwarded-For - getReq.Header["X-forwarded-for"] = []string{"1.2.3"} - resp, err = http.DefaultClient.Do(getReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() -} diff --git a/internal/requestutil/util.go b/internal/requestutil/util.go new file mode 100644 index 000000000..099e3454c --- /dev/null +++ b/internal/requestutil/util.go @@ -0,0 +1,51 @@ +package requestutil + +import ( + "net" + "net/http" + "strings" + + log "github.com/sirupsen/logrus" +) + +func parseIP(ipStr string) net.IP { + ip := net.ParseIP(ipStr) + if ip == nil { + log.Warnf("invalid remote IP address: %q", ipStr) + } + return ip +} + +// RemoteAddr extracts the remote address of the request, taking into +// account proxy headers. +func RemoteAddr(r *http.Request) string { + if prior := r.Header.Get("X-Forwarded-For"); prior != "" { + remoteAddr, _, _ := strings.Cut(prior, ",") + remoteAddr = strings.Trim(remoteAddr, " ") + if parseIP(remoteAddr) != nil { + return remoteAddr + } + } + // X-Real-Ip is less supported, but worth checking in the + // absence of X-Forwarded-For + if realIP := r.Header.Get("X-Real-Ip"); realIP != "" { + if parseIP(realIP) != nil { + return realIP + } + } + + return r.RemoteAddr +} + +// RemoteIP extracts the remote IP of the request, taking into +// account proxy headers. +func RemoteIP(r *http.Request) string { + addr := RemoteAddr(r) + + // Try parsing it as "IP:port" + if ip, _, err := net.SplitHostPort(addr); err == nil { + return ip + } + + return addr +} diff --git a/internal/requestutil/util_test.go b/internal/requestutil/util_test.go new file mode 100644 index 000000000..fc33527f6 --- /dev/null +++ b/internal/requestutil/util_test.go @@ -0,0 +1,76 @@ +package requestutil + +import ( + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "testing" +) + +// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test +// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten +// at the transport layer to 127.0.0.1: . However, as the X-Forwarded-For header +// just contains the IP address, it is different enough for testing. +func TestRemoteAddr(t *testing.T) { + var expectedRemote string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + if r.RemoteAddr == expectedRemote { + t.Errorf("Unexpected matching remote addresses") + } + + actualRemote := RemoteAddr(r) + if expectedRemote != actualRemote { + t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote) + } + + w.WriteHeader(200) + })) + + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxy := httputil.NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxy) + defer frontend.Close() + + // X-Forwarded-For set by proxy + expectedRemote = "127.0.0.1" + proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil) + if err != nil { + t.Fatal(err) + } + + resp, err := http.DefaultClient.Do(proxyReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // RemoteAddr in X-Real-Ip + getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil) + if err != nil { + t.Fatal(err) + } + + expectedRemote = "1.2.3.4" + getReq.Header["X-Real-ip"] = []string{expectedRemote} + resp, err = http.DefaultClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Valid X-Real-Ip and invalid X-Forwarded-For + getReq.Header["X-forwarded-for"] = []string{"1.2.3"} + resp, err = http.DefaultClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() +} diff --git a/notifications/bridge.go b/notifications/bridge.go index 133153deb..8b594774b 100644 --- a/notifications/bridge.go +++ b/notifications/bridge.go @@ -5,7 +5,7 @@ import ( "time" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/internal/dcontext" + "github.com/distribution/distribution/v3/internal/requestutil" "github.com/distribution/reference" events "github.com/docker/go-events" "github.com/google/uuid" @@ -49,7 +49,7 @@ func NewBridge(ub URLBuilder, source SourceRecord, actor ActorRecord, request Re func NewRequestRecord(id string, r *http.Request) RequestRecord { return RequestRecord{ ID: id, - Addr: dcontext.RemoteAddr(r), + Addr: requestutil.RemoteAddr(r), Host: r.Host, Method: r.Method, UserAgent: r.UserAgent(), diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go index 7a23bcbd0..c7ddd6f55 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter.go @@ -12,6 +12,7 @@ import ( "time" "github.com/distribution/distribution/v3/internal/dcontext" + "github.com/distribution/distribution/v3/internal/requestutil" ) const ( @@ -188,7 +189,7 @@ func parseIPFromRequest(ctx context.Context) (net.IP, error) { if err != nil { return nil, err } - ipStr := dcontext.RemoteIP(request) + ipStr := requestutil.RemoteIP(request) ip := net.ParseIP(ipStr) if ip == nil { return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr) @@ -208,7 +209,7 @@ func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool { } else { loggerField := map[interface{}]interface{}{ "user-client": request.UserAgent(), - "ip": dcontext.RemoteIP(request), + "ip": requestutil.RemoteIP(request), } if awsIPs.contains(addr) { dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront") From 49e22cbf3e29b09a0c816fbfdead278ee8bed2b4 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 24 Oct 2023 14:08:04 -0400 Subject: [PATCH 3/7] registry/auth: pass request to AccessController Signed-off-by: Cory Snider --- registry/auth/auth.go | 17 ++++++++--------- registry/auth/htpasswd/access.go | 11 +++-------- registry/auth/htpasswd/access_test.go | 5 +---- registry/auth/silly/access.go | 9 ++------- registry/auth/silly/access_test.go | 4 +--- registry/auth/token/accesscontroller.go | 10 ++-------- registry/auth/token/token_test.go | 12 +++++------- registry/handlers/app.go | 2 +- 8 files changed, 23 insertions(+), 47 deletions(-) diff --git a/registry/auth/auth.go b/registry/auth/auth.go index 9cb036f1f..1f28ea85e 100644 --- a/registry/auth/auth.go +++ b/registry/auth/auth.go @@ -18,7 +18,7 @@ // resource := auth.Resource{Type: "customerOrder", Name: orderNumber} // access := auth.Access{Resource: resource, Action: "update"} // -// if ctx, err := accessController.Authorized(ctx, access); err != nil { +// if ctx, err := accessController.Authorized(r, access); err != nil { // if challenge, ok := err.(auth.Challenge) { // // Let the challenge write the response. // challenge.SetHeaders(r, w) @@ -93,16 +93,15 @@ type Challenge interface { // and required access levels for a request. Implementations can support both // complete denial and http authorization challenges. type AccessController interface { - // Authorized returns a non-nil error if the context is granted access and + // Authorized returns a nil error if the request is granted access and // returns a new authorized context. If one or more Access structs are // provided, the requested access will be compared with what is available - // to the context. The given context will contain a "http.request" key with - // a `*http.Request` value. If the error is non-nil, access should always - // be denied. The error may be of type Challenge, in which case the caller - // may have the Challenge handle the request or choose what action to take - // based on the Challenge header or response status. The returned context - // object should have a "auth.user" value set to a UserInfo struct. - Authorized(ctx context.Context, access ...Access) (context.Context, error) + // to the request. Access is denied if the error is non-nil. The error may + // be of type Challenge, in which case the caller may have the Challenge + // handle the request or choose what action to take based on the Challenge + // header or response status. The returned context object should be derived + // from r.Context() and have a "auth.user" value set to a UserInfo struct. + Authorized(r *http.Request, access ...Access) (context.Context, error) } // CredentialAuthenticator is an object which is able to authenticate credentials diff --git a/registry/auth/htpasswd/access.go b/registry/auth/htpasswd/access.go index 0a1d0c1ce..a5c89a42d 100644 --- a/registry/auth/htpasswd/access.go +++ b/registry/auth/htpasswd/access.go @@ -49,12 +49,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, return &accessController{realm: realm.(string), path: path}, nil } -func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { - req, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - +func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) { username, password, ok := req.BasicAuth() if !ok { return nil, &challenge{ @@ -92,14 +87,14 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut ac.mu.Unlock() if err := localHTPasswd.authenticateUser(username, password); err != nil { - dcontext.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err) + dcontext.GetLogger(req.Context()).Errorf("error authenticating user %q: %v", username, err) return nil, &challenge{ realm: ac.realm, err: auth.ErrAuthenticationFailure, } } - return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil + return auth.WithUser(req.Context(), auth.UserInfo{Name: username}), nil } // challenge implements the auth.Challenge interface. diff --git a/registry/auth/htpasswd/access_test.go b/registry/auth/htpasswd/access_test.go index 0871ef411..ad5e7f70c 100644 --- a/registry/auth/htpasswd/access_test.go +++ b/registry/auth/htpasswd/access_test.go @@ -8,7 +8,6 @@ import ( "os" "testing" - "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) @@ -33,7 +32,6 @@ func TestBasicAccessController(t *testing.T) { "realm": testRealm, "path": tempFile.Name(), } - ctx := dcontext.Background() accessController, err := newAccessController(options) if err != nil { @@ -45,8 +43,7 @@ func TestBasicAccessController(t *testing.T) { userNumber := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := dcontext.WithRequest(ctx, r) - authCtx, err := accessController.Authorized(ctx) + authCtx, err := accessController.Authorized(r) if err != nil { switch err := err.(type) { case auth.Challenge: diff --git a/registry/auth/silly/access.go b/registry/auth/silly/access.go index 685cf6a62..c8f383e23 100644 --- a/registry/auth/silly/access.go +++ b/registry/auth/silly/access.go @@ -43,12 +43,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, // Authorized simply checks for the existence of the authorization header, // responding with a bearer challenge if it doesn't exist. -func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { - req, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - +func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) { if req.Header.Get("Authorization") == "" { challenge := challenge{ realm: ac.realm, @@ -66,7 +61,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut return nil, &challenge } - ctx = auth.WithUser(ctx, auth.UserInfo{Name: "silly"}) + ctx := auth.WithUser(req.Context(), auth.UserInfo{Name: "silly"}) ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey)) return ctx, nil diff --git a/registry/auth/silly/access_test.go b/registry/auth/silly/access_test.go index f463e98c4..1a137c715 100644 --- a/registry/auth/silly/access_test.go +++ b/registry/auth/silly/access_test.go @@ -5,7 +5,6 @@ import ( "net/http/httptest" "testing" - "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) @@ -16,8 +15,7 @@ func TestSillyAccessController(t *testing.T) { } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := dcontext.WithRequest(dcontext.Background(), r) - authCtx, err := ac.Authorized(ctx) + authCtx, err := ac.Authorized(r) if err != nil { switch err := err.(type) { case auth.Challenge: diff --git a/registry/auth/token/accesscontroller.go b/registry/auth/token/accesscontroller.go index b2e4e4b27..e019d0f51 100644 --- a/registry/auth/token/accesscontroller.go +++ b/registry/auth/token/accesscontroller.go @@ -13,7 +13,6 @@ import ( "os" "strings" - "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" "github.com/go-jose/go-jose/v3" ) @@ -292,7 +291,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, // Authorized handles checking whether the given request is authorized // for actions on resources described by the given access items. -func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) { +func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (context.Context, error) { challenge := &authChallenge{ realm: ac.realm, autoRedirect: ac.autoRedirect, @@ -300,11 +299,6 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth. accessSet: newAccessSet(accessItems...), } - req, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ") if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") { challenge.err = ErrTokenRequired @@ -338,7 +332,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth. } } - ctx = auth.WithResources(ctx, claims.resources()) + ctx := auth.WithResources(req.Context(), claims.resources()) return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil } diff --git a/registry/auth/token/token_test.go b/registry/auth/token/token_test.go index a331a93bf..52d34a70f 100644 --- a/registry/auth/token/token_test.go +++ b/registry/auth/token/token_test.go @@ -18,7 +18,6 @@ import ( "testing" "time" - "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" @@ -466,8 +465,7 @@ func TestAccessController(t *testing.T) { Action: "baz", } - ctx := dcontext.WithRequest(dcontext.Background(), req) - authCtx, err := accessController.Authorized(ctx, testAccess) + authCtx, err := accessController.Authorized(req, testAccess) challenge, ok := err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -502,7 +500,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(ctx, testAccess) + authCtx, err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -534,7 +532,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(ctx, testAccess) + authCtx, err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -564,7 +562,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(ctx, testAccess) + authCtx, err = accessController.Authorized(req, testAccess) if err != nil { t.Fatalf("accessController returned unexpected error: %s", err) } @@ -594,7 +592,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - _, err = accessController.Authorized(ctx, testAccess) + _, err = accessController.Authorized(req, testAccess) if err != nil { t.Fatalf("accessController returned unexpected error: %s", err) } diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 7ce27d6dc..8bb5bbbcd 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -797,7 +797,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont accessRecords = appendCatalogAccessRecord(accessRecords, r) } - ctx, err := app.accessController.Authorized(context.Context, accessRecords...) + ctx, err := app.accessController.Authorized(r.WithContext(context.Context), accessRecords...) if err != nil { switch err := err.(type) { case auth.Challenge: From bd80d7590d1ca49ddb169dd54d655114b8de45a7 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 24 Oct 2023 16:41:54 -0400 Subject: [PATCH 4/7] reg/auth: remove contexts from Authorized method The details of how request-scoped information is propagated through the registry server app should be left as private implementation details so they can be changed without fear of breaking compatibility with third-party code which imports the distribution module. The AccessController interface unnecessarily bakes into the public API details of how authorization grants are propagated through request contexts. In practice the only values the in-tree authorizers attach to the request contexts are the UserInfo and Resources for the request. Change the AccessController interface to return the UserInfo and Resources directly to allow us to change how request contexts are used within the app without altering the AccessController interface contract. Signed-off-by: Cory Snider --- registry/auth/auth.go | 24 ++++++++++++-------- registry/auth/htpasswd/access.go | 4 ++-- registry/auth/htpasswd/access_test.go | 11 +++++----- registry/auth/silly/access.go | 9 ++------ registry/auth/silly/access_test.go | 11 +++++----- registry/auth/token/accesscontroller.go | 10 ++++----- registry/auth/token/token_test.go | 29 ++++++++++--------------- registry/handlers/app.go | 8 ++++++- 8 files changed, 53 insertions(+), 53 deletions(-) diff --git a/registry/auth/auth.go b/registry/auth/auth.go index 1f28ea85e..0bda67f6c 100644 --- a/registry/auth/auth.go +++ b/registry/auth/auth.go @@ -76,6 +76,12 @@ type Access struct { Action string } +// Grant describes the permitted level of access for an authorized request. +type Grant struct { + User UserInfo // The authenticated user for the request. + Resources []Resource // The list of resources which have been authorized for the request. +} + // Challenge is a special error type which is used for HTTP 401 Unauthorized // responses and is able to write the response with WWW-Authenticate challenge // header values based on the error. @@ -93,15 +99,15 @@ type Challenge interface { // and required access levels for a request. Implementations can support both // complete denial and http authorization challenges. type AccessController interface { - // Authorized returns a nil error if the request is granted access and - // returns a new authorized context. If one or more Access structs are - // provided, the requested access will be compared with what is available - // to the request. Access is denied if the error is non-nil. The error may - // be of type Challenge, in which case the caller may have the Challenge - // handle the request or choose what action to take based on the Challenge - // header or response status. The returned context object should be derived - // from r.Context() and have a "auth.user" value set to a UserInfo struct. - Authorized(r *http.Request, access ...Access) (context.Context, error) + // Authorized determines if the request is granted access. If one or more + // Access structs are provided, the requested access will be compared with + // what is available to the request. + // + // Return a Grant to grant the request access. Return an error to deny + // access. The error may be of type Challenge, in which case the caller may + // have the Challenge handle the request or choose what action to take based + // on the Challenge header or response status. + Authorized(r *http.Request, access ...Access) (*Grant, error) } // CredentialAuthenticator is an object which is able to authenticate credentials diff --git a/registry/auth/htpasswd/access.go b/registry/auth/htpasswd/access.go index a5c89a42d..c8c432653 100644 --- a/registry/auth/htpasswd/access.go +++ b/registry/auth/htpasswd/access.go @@ -49,7 +49,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, return &accessController{realm: realm.(string), path: path}, nil } -func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) { +func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (*auth.Grant, error) { username, password, ok := req.BasicAuth() if !ok { return nil, &challenge{ @@ -94,7 +94,7 @@ func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth. } } - return auth.WithUser(req.Context(), auth.UserInfo{Name: username}), nil + return &auth.Grant{User: auth.UserInfo{Name: username}}, nil } // challenge implements the auth.Challenge interface. diff --git a/registry/auth/htpasswd/access_test.go b/registry/auth/htpasswd/access_test.go index ad5e7f70c..01f2ac9d7 100644 --- a/registry/auth/htpasswd/access_test.go +++ b/registry/auth/htpasswd/access_test.go @@ -43,7 +43,7 @@ func TestBasicAccessController(t *testing.T) { userNumber := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authCtx, err := accessController.Authorized(r) + grant, err := accessController.Authorized(r) if err != nil { switch err := err.(type) { case auth.Challenge: @@ -55,13 +55,12 @@ func TestBasicAccessController(t *testing.T) { } } - userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo) - if !ok { - t.Fatal("basic accessController did not set auth.user context") + if grant == nil { + t.Fatal("basic accessController did not return auth grant") } - if userInfo.Name != testUsers[userNumber] { - t.Fatalf("expected user name %q, got %q", testUsers[userNumber], userInfo.Name) + if grant.User.Name != testUsers[userNumber] { + t.Fatalf("expected user name %q, got %q", testUsers[userNumber], grant.User.Name) } w.WriteHeader(http.StatusNoContent) diff --git a/registry/auth/silly/access.go b/registry/auth/silly/access.go index c8f383e23..1984ba20d 100644 --- a/registry/auth/silly/access.go +++ b/registry/auth/silly/access.go @@ -8,12 +8,10 @@ package silly import ( - "context" "fmt" "net/http" "strings" - "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) @@ -43,7 +41,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, // Authorized simply checks for the existence of the authorization header, // responding with a bearer challenge if it doesn't exist. -func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) { +func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (*auth.Grant, error) { if req.Header.Get("Authorization") == "" { challenge := challenge{ realm: ac.realm, @@ -61,10 +59,7 @@ func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth. return nil, &challenge } - ctx := auth.WithUser(req.Context(), auth.UserInfo{Name: "silly"}) - ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey)) - - return ctx, nil + return &auth.Grant{User: auth.UserInfo{Name: "silly"}}, nil } type challenge struct { diff --git a/registry/auth/silly/access_test.go b/registry/auth/silly/access_test.go index 1a137c715..506af0bde 100644 --- a/registry/auth/silly/access_test.go +++ b/registry/auth/silly/access_test.go @@ -15,7 +15,7 @@ func TestSillyAccessController(t *testing.T) { } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authCtx, err := ac.Authorized(r) + grant, err := ac.Authorized(r) if err != nil { switch err := err.(type) { case auth.Challenge: @@ -27,13 +27,12 @@ func TestSillyAccessController(t *testing.T) { } } - userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo) - if !ok { - t.Fatal("silly accessController did not set auth.user context") + if grant == nil { + t.Fatal("silly accessController did not return auth grant") } - if userInfo.Name != "silly" { - t.Fatalf("expected user name %q, got %q", "silly", userInfo.Name) + if grant.User.Name != "silly" { + t.Fatalf("expected user name %q, got %q", "silly", grant.User.Name) } w.WriteHeader(http.StatusNoContent) diff --git a/registry/auth/token/accesscontroller.go b/registry/auth/token/accesscontroller.go index e019d0f51..bed4c827d 100644 --- a/registry/auth/token/accesscontroller.go +++ b/registry/auth/token/accesscontroller.go @@ -1,7 +1,6 @@ package token import ( - "context" "crypto" "crypto/x509" "encoding/json" @@ -291,7 +290,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, // Authorized handles checking whether the given request is authorized // for actions on resources described by the given access items. -func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (context.Context, error) { +func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (*auth.Grant, error) { challenge := &authChallenge{ realm: ac.realm, autoRedirect: ac.autoRedirect, @@ -332,9 +331,10 @@ func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Ac } } - ctx := auth.WithResources(req.Context(), claims.resources()) - - return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil + return &auth.Grant{ + User: auth.UserInfo{Name: claims.Subject}, + Resources: claims.resources(), + }, nil } // init handles registering the token auth backend. diff --git a/registry/auth/token/token_test.go b/registry/auth/token/token_test.go index 52d34a70f..a96546af3 100644 --- a/registry/auth/token/token_test.go +++ b/registry/auth/token/token_test.go @@ -465,7 +465,7 @@ func TestAccessController(t *testing.T) { Action: "baz", } - authCtx, err := accessController.Authorized(req, testAccess) + grant, err := accessController.Authorized(req, testAccess) challenge, ok := err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -475,8 +475,8 @@ func TestAccessController(t *testing.T) { t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired) } - if authCtx != nil { - t.Fatalf("expected nil auth context but got %s", authCtx) + if grant != nil { + t.Fatalf("expected nil auth grant but got %#v", grant) } // 2. Supply an invalid token. @@ -500,7 +500,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(req, testAccess) + grant, err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -510,8 +510,8 @@ func TestAccessController(t *testing.T) { t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired) } - if authCtx != nil { - t.Fatalf("expected nil auth context but got %s", authCtx) + if grant != nil { + t.Fatalf("expected nil auth grant but got %#v", grant) } // create a valid jwk @@ -532,7 +532,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(req, testAccess) + grant, err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -542,8 +542,8 @@ func TestAccessController(t *testing.T) { t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrInsufficientScope) } - if authCtx != nil { - t.Fatalf("expected nil auth context but got %s", authCtx) + if grant != nil { + t.Fatalf("expected nil auth grant but got %#v", grant) } // 4. Supply the token we need, or deserve, or whatever. @@ -562,18 +562,13 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(req, testAccess) + grant, err = accessController.Authorized(req, testAccess) if err != nil { t.Fatalf("accessController returned unexpected error: %s", err) } - userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo) - if !ok { - t.Fatal("token accessController did not set auth.user context") - } - - if userInfo.Name != "foo" { - t.Fatalf("expected user name %q, got %q", "foo", userInfo.Name) + if grant.User.Name != "foo" { + t.Fatalf("expected user name %q, got %q", "foo", grant.User.Name) } // 5. Supply a token with full admin rights, which is represented as "*". diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 8bb5bbbcd..471f7c169 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -797,7 +797,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont accessRecords = appendCatalogAccessRecord(accessRecords, r) } - ctx, err := app.accessController.Authorized(r.WithContext(context.Context), accessRecords...) + grant, err := app.accessController.Authorized(r.WithContext(context.Context), accessRecords...) if err != nil { switch err := err.(type) { case auth.Challenge: @@ -818,6 +818,12 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont return err } + if grant == nil { + return fmt.Errorf("access controller returned neither an access grant nor an error") + } + + ctx := auth.WithUser(context.Context, grant.User) + ctx = auth.WithResources(ctx, grant.Resources) dcontext.GetLogger(ctx, auth.UserNameKey).Info("authorized request") // TODO(stevvooe): This pattern needs to be cleaned up a bit. One context From 868faeec6761a3b20c3d62cd0be57087ce04d528 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 24 Oct 2023 17:04:16 -0400 Subject: [PATCH 5/7] registry: unexport auth-related context utilities The specifics of how the authorization for a request is propagated through the registry app are private implementation details. Hide those details from outsiders so they can be changed as needed without fear of breaking third-party code. Move the utilities for attaching a request's authorization status to its context and retrieving it from the context into the registry/handlers package as unexported symbols. Signed-off-by: Cory Snider --- registry/auth/auth.go | 68 --------------------------------- registry/handlers/app.go | 8 ++-- registry/handlers/context.go | 69 +++++++++++++++++++++++++++++++++- registry/handlers/manifests.go | 3 +- 4 files changed, 73 insertions(+), 75 deletions(-) diff --git a/registry/auth/auth.go b/registry/auth/auth.go index 0bda67f6c..6266d1e53 100644 --- a/registry/auth/auth.go +++ b/registry/auth/auth.go @@ -32,22 +32,11 @@ package auth import ( - "context" "errors" "fmt" "net/http" ) -const ( - // UserKey is used to get the user object from - // a user context - UserKey = "auth.user" - - // UserNameKey is used to get the user name from - // a user context - UserNameKey = "auth.user.name" -) - var ( // ErrInvalidCredential is returned when the auth token does not authenticate correctly. ErrInvalidCredential = errors.New("invalid authorization credential") @@ -115,63 +104,6 @@ type CredentialAuthenticator interface { AuthenticateUser(username, password string) error } -// WithUser returns a context with the authorized user info. -func WithUser(ctx context.Context, user UserInfo) context.Context { - return userInfoContext{ - Context: ctx, - user: user, - } -} - -type userInfoContext struct { - context.Context - user UserInfo -} - -func (uic userInfoContext) Value(key interface{}) interface{} { - switch key { - case UserKey: - return uic.user - case UserNameKey: - return uic.user.Name - } - - return uic.Context.Value(key) -} - -// WithResources returns a context with the authorized resources. -func WithResources(ctx context.Context, resources []Resource) context.Context { - return resourceContext{ - Context: ctx, - resources: resources, - } -} - -type resourceContext struct { - context.Context - resources []Resource -} - -type resourceKey struct{} - -func (rc resourceContext) Value(key interface{}) interface{} { - if key == (resourceKey{}) { - return rc.resources - } - - return rc.Context.Value(key) -} - -// AuthorizedResources returns the list of resources which have -// been authorized for this request. -func AuthorizedResources(ctx context.Context) []Resource { - if resources, ok := ctx.Value(resourceKey{}).([]Resource); ok { - return resources - } - - return nil -} - // InitFunc is the type of an AccessController factory function and is used // to register the constructor for different AccesController backends. type InitFunc func(options map[string]interface{}) (AccessController, error) diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 471f7c169..2254a0acf 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -635,7 +635,7 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler { } // Add username to request logging - context.Context = dcontext.WithLogger(context.Context, dcontext.GetLogger(context.Context, auth.UserNameKey)) + context.Context = dcontext.WithLogger(context.Context, dcontext.GetLogger(context.Context, userNameKey)) // sync up context on the request. r = r.WithContext(context) @@ -822,10 +822,10 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont return fmt.Errorf("access controller returned neither an access grant nor an error") } - ctx := auth.WithUser(context.Context, grant.User) - ctx = auth.WithResources(ctx, grant.Resources) + ctx := withUser(context.Context, grant.User) + ctx = withResources(ctx, grant.Resources) - dcontext.GetLogger(ctx, auth.UserNameKey).Info("authorized request") + dcontext.GetLogger(ctx, userNameKey).Info("authorized request") // TODO(stevvooe): This pattern needs to be cleaned up a bit. One context // should be replaced by another, rather than replacing the context on a // mutable object. diff --git a/registry/handlers/context.go b/registry/handlers/context.go index cb3540237..c272095c8 100644 --- a/registry/handlers/context.go +++ b/registry/handlers/context.go @@ -77,10 +77,20 @@ func getUploadUUID(ctx context.Context) (uuid string) { return dcontext.GetStringValue(ctx, "vars.uuid") } +const ( + // userKey is used to get the user object from + // a user context + userKey = "auth.user" + + // userNameKey is used to get the user name from + // a user context + userNameKey = "auth.user.name" +) + // getUserName attempts to resolve a username from the context and request. If // a username cannot be resolved, the empty string is returned. func getUserName(ctx context.Context, r *http.Request) string { - username := dcontext.GetStringValue(ctx, auth.UserNameKey) + username := dcontext.GetStringValue(ctx, userNameKey) // Fallback to request user with basic auth if username == "" { @@ -93,3 +103,60 @@ func getUserName(ctx context.Context, r *http.Request) string { return username } + +// withUser returns a context with the authorized user info. +func withUser(ctx context.Context, user auth.UserInfo) context.Context { + return userInfoContext{ + Context: ctx, + user: user, + } +} + +type userInfoContext struct { + context.Context + user auth.UserInfo +} + +func (uic userInfoContext) Value(key interface{}) interface{} { + switch key { + case userKey: + return uic.user + case userNameKey: + return uic.user.Name + } + + return uic.Context.Value(key) +} + +// withResources returns a context with the authorized resources. +func withResources(ctx context.Context, resources []auth.Resource) context.Context { + return resourceContext{ + Context: ctx, + resources: resources, + } +} + +type resourceContext struct { + context.Context + resources []auth.Resource +} + +type resourceKey struct{} + +func (rc resourceContext) Value(key interface{}) interface{} { + if key == (resourceKey{}) { + return rc.resources + } + + return rc.Context.Value(key) +} + +// authorizedResources returns the list of resources which have +// been authorized for this request. +func authorizedResources(ctx context.Context) []auth.Resource { + if resources, ok := ctx.Value(resourceKey{}).([]auth.Resource); ok { + return resources + } + + return nil +} diff --git a/registry/handlers/manifests.go b/registry/handlers/manifests.go index 06b7c0c75..4c6dbd0af 100644 --- a/registry/handlers/manifests.go +++ b/registry/handlers/manifests.go @@ -13,7 +13,6 @@ import ( "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/registry/api/errcode" - "github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/reference" "github.com/gorilla/handlers" @@ -394,7 +393,7 @@ func (imh *manifestHandler) applyResourcePolicy(manifest distribution.Manifest) return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("registry does not allow %s manifest", class)) } - resources := auth.AuthorizedResources(imh) + resources := authorizedResources(imh) n := imh.Repository.Named().Name() var foundResource bool From f089932de0a69492ce67ab01bb0d59ecac881ef8 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 24 Oct 2023 15:49:47 -0400 Subject: [PATCH 6/7] storage/driver: replace URLFor method Several storage drivers and storage middlewares need to introspect the client HTTP request in order to construct content-redirect URLs. The request is indirectly passed into the driver interface method URLFor() through the context argument, which is bad practice. The request should be passed in as an explicit argument as the method is only called from request handlers. Replace the URLFor() method with a RedirectURL() method which takes an HTTP request as a parameter instead of a context. Drop the options argument from URLFor() as in practice it only ever encoded the request method, which can now be fetched directly from the request. No URLFor() callers ever passed in an "expiry" option, either. Signed-off-by: Cory Snider --- registry/storage/blobserver.go | 21 ++++---- registry/storage/driver/azure/azure.go | 18 +++---- registry/storage/driver/base/base.go | 13 ++--- registry/storage/driver/base/regulator.go | 11 ++--- registry/storage/driver/filesystem/driver.go | 8 +-- registry/storage/driver/gcs/gcs.go | 32 +++--------- registry/storage/driver/inmemory/driver.go | 8 +-- .../middleware/cloudfront/middleware.go | 13 ++--- .../driver/middleware/cloudfront/s3filter.go | 33 +++++-------- .../middleware/cloudfront/s3filter_test.go | 49 ++++++------------- .../driver/middleware/redirect/middleware.go | 4 +- .../middleware/redirect/middleware_test.go | 21 ++++---- registry/storage/driver/s3-aws/s3.go | 25 ++-------- registry/storage/driver/storagedriver.go | 10 ++-- .../storage/driver/testsuites/testsuites.go | 15 +++--- registry/storage/registry.go | 4 +- 16 files changed, 111 insertions(+), 174 deletions(-) diff --git a/registry/storage/blobserver.go b/registry/storage/blobserver.go index 6392e3554..6beef7e3e 100644 --- a/registry/storage/blobserver.go +++ b/registry/storage/blobserver.go @@ -20,7 +20,7 @@ type blobServer struct { driver driver.StorageDriver statter distribution.BlobStatter pathFn func(dgst digest.Digest) (string, error) - redirect bool // allows disabling URLFor redirects + redirect bool // allows disabling RedirectURL redirects } func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error { @@ -35,19 +35,16 @@ func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *h } if bs.redirect { - redirectURL, err := bs.driver.URLFor(ctx, path, map[string]interface{}{"method": r.Method}) - switch err.(type) { - case nil: - // Redirect to storage URL. - http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect) - return err - - case driver.ErrUnsupportedMethod: - // Fallback to serving the content directly. - default: - // Some unexpected error. + redirectURL, err := bs.driver.RedirectURL(r, path) + if err != nil { return err } + if redirectURL != "" { + // Redirect to storage URL. + http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect) + return nil + } + // Fallback to serving the content directly. } br, err := newFileReader(ctx, bs.driver, path, desc.Size) diff --git a/registry/storage/driver/azure/azure.go b/registry/storage/driver/azure/azure.go index 585c8b432..57b13dd15 100644 --- a/registry/storage/driver/azure/azure.go +++ b/registry/storage/driver/azure/azure.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "io" + "net/http" "strings" "time" @@ -286,7 +287,7 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) { // Move moves an object stored at sourcePath to destPath, removing the original // object. func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error { - sourceBlobURL, err := d.URLFor(ctx, sourcePath, nil) + sourceBlobURL, err := d.signBlobURL(ctx, sourcePath) if err != nil { return err } @@ -366,18 +367,15 @@ func (d *driver) Delete(ctx context.Context, path string) error { return nil } -// URLFor returns a publicly accessible URL for the blob stored at given path +// RedirectURL returns a publicly accessible URL for the blob stored at given path // for specified duration by making use of Azure Storage Shared Access Signatures (SAS). // See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx for more info. -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { +func (d *driver) RedirectURL(req *http.Request, path string) (string, error) { + return d.signBlobURL(req.Context(), path) +} + +func (d *driver) signBlobURL(ctx context.Context, path string) (string, error) { expiresTime := time.Now().UTC().Add(20 * time.Minute) // default expiration - expires, ok := options["expiry"] - if ok { - t, ok := expires.(time.Time) - if ok { - expiresTime = t - } - } blobName := d.blobName(path) blobRef := d.client.NewBlobClient(blobName) return d.azClient.SignBlobURL(ctx, blobRef.URL(), expiresTime) diff --git a/registry/storage/driver/base/base.go b/registry/storage/driver/base/base.go index 756a0d4c9..32c9037ba 100644 --- a/registry/storage/driver/base/base.go +++ b/registry/storage/driver/base/base.go @@ -40,6 +40,7 @@ package base import ( "context" "io" + "net/http" "time" "github.com/distribution/distribution/v3/internal/dcontext" @@ -208,18 +209,18 @@ func (base *Base) Delete(ctx context.Context, path string) error { return err } -// URLFor wraps URLFor of underlying storage driver. -func (base *Base) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - ctx, done := dcontext.WithTrace(ctx) - defer done("%s.URLFor(%q)", base.Name(), path) +// RedirectURL wraps RedirectURL of the underlying storage driver. +func (base *Base) RedirectURL(r *http.Request, path string) (string, error) { + ctx, done := dcontext.WithTrace(r.Context()) + defer done("%s.RedirectURL(%q)", base.Name(), path) if !storagedriver.PathRegexp.MatchString(path) { return "", storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()} } start := time.Now() - str, e := base.StorageDriver.URLFor(ctx, path, options) - storageAction.WithValues(base.Name(), "URLFor").UpdateSince(start) + str, e := base.StorageDriver.RedirectURL(r.WithContext(ctx), path) + storageAction.WithValues(base.Name(), "RedirectURL").UpdateSince(start) return str, base.setDriverName(e) } diff --git a/registry/storage/driver/base/regulator.go b/registry/storage/driver/base/regulator.go index 09a258e35..2cf7a3ece 100644 --- a/registry/storage/driver/base/regulator.go +++ b/registry/storage/driver/base/regulator.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/http" "reflect" "strconv" "sync" @@ -172,13 +173,11 @@ func (r *regulator) Delete(ctx context.Context, path string) error { return r.StorageDriver.Delete(ctx, path) } -// URLFor returns a URL which may be used to retrieve the content stored at -// the given path, possibly using the given options. -// May return an ErrUnsupportedMethod in certain StorageDriver -// implementations. -func (r *regulator) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { +// RedirectURL returns a URL which may be used to retrieve the content stored at +// the given path. +func (r *regulator) RedirectURL(req *http.Request, path string) (string, error) { r.enter() defer r.exit() - return r.StorageDriver.URLFor(ctx, path, options) + return r.StorageDriver.RedirectURL(req, path) } diff --git a/registry/storage/driver/filesystem/driver.go b/registry/storage/driver/filesystem/driver.go index 23033268f..9ef262a2b 100644 --- a/registry/storage/driver/filesystem/driver.go +++ b/registry/storage/driver/filesystem/driver.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "io" + "net/http" "os" "path" "time" @@ -282,10 +283,9 @@ func (d *driver) Delete(ctx context.Context, subPath string) error { return err } -// URLFor returns a URL which may be used to retrieve the content stored at the given path. -// May return an UnsupportedMethodErr in certain StorageDriver implementations. -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - return "", storagedriver.ErrUnsupportedMethod{} +// RedirectURL returns a URL which may be used to retrieve the content stored at the given path. +func (d *driver) RedirectURL(*http.Request, string) (string, error) { + return "", nil } // Walk traverses a filesystem defined within driver, starting diff --git a/registry/storage/driver/gcs/gcs.go b/registry/storage/driver/gcs/gcs.go index 5b276d65f..6a0597c53 100644 --- a/registry/storage/driver/gcs/gcs.go +++ b/registry/storage/driver/gcs/gcs.go @@ -810,40 +810,24 @@ func storageCopyObject(ctx context.Context, srcBucket, srcName string, destBucke return attrs, err } -// URLFor returns a URL which may be used to retrieve the content stored at +// RedirectURL returns a URL which may be used to retrieve the content stored at // the given path, possibly using the given options. -// Returns ErrUnsupportedMethod if this driver has no privateKey -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { +func (d *driver) RedirectURL(r *http.Request, path string) (string, error) { if d.privateKey == nil { - return "", storagedriver.ErrUnsupportedMethod{} + return "", nil } - name := d.pathToKey(path) - methodString := http.MethodGet - method, ok := options["method"] - if ok { - methodString, ok = method.(string) - if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) { - return "", storagedriver.ErrUnsupportedMethod{} - } - } - - expiresTime := time.Now().Add(20 * time.Minute) - expires, ok := options["expiry"] - if ok { - et, ok := expires.(time.Time) - if ok { - expiresTime = et - } + if r.Method != http.MethodGet && r.Method != http.MethodHead { + return "", nil } opts := &storage.SignedURLOptions{ GoogleAccessID: d.email, PrivateKey: d.privateKey, - Method: methodString, - Expires: expiresTime, + Method: r.Method, + Expires: time.Now().Add(20 * time.Minute), } - return storage.SignedURL(d.bucket, name, opts) + return storage.SignedURL(d.bucket, d.pathToKey(path), opts) } // Walk traverses a filesystem defined within driver, starting diff --git a/registry/storage/driver/inmemory/driver.go b/registry/storage/driver/inmemory/driver.go index 4c00ca404..fde516f9e 100644 --- a/registry/storage/driver/inmemory/driver.go +++ b/registry/storage/driver/inmemory/driver.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/http" "sync" "time" @@ -236,10 +237,9 @@ func (d *driver) Delete(ctx context.Context, path string) error { } } -// URLFor returns a URL which may be used to retrieve the content stored at the given path. -// May return an UnsupportedMethodErr in certain StorageDriver implementations. -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - return "", storagedriver.ErrUnsupportedMethod{} +// RedirectURL returns a URL which may be used to retrieve the content stored at the given path. +func (d *driver) RedirectURL(*http.Request, string) (string, error) { + return "", nil } // Walk traverses a filesystem defined within driver, starting diff --git a/registry/storage/driver/middleware/cloudfront/middleware.go b/registry/storage/driver/middleware/cloudfront/middleware.go index 32474cbd0..4c2fc7dda 100644 --- a/registry/storage/driver/middleware/cloudfront/middleware.go +++ b/registry/storage/driver/middleware/cloudfront/middleware.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "net/http" "net/url" "os" "strings" @@ -195,18 +196,18 @@ type S3BucketKeyer interface { S3BucketKey(path string) string } -// URLFor attempts to find a url which may be used to retrieve the file at the given path. +// RedirectURL attempts to find a url which may be used to retrieve the file at the given path. // Returns an error if the file cannot be found. -func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { +func (lh *cloudFrontStorageMiddleware) RedirectURL(r *http.Request, path string) (string, error) { // TODO(endophage): currently only supports S3 keyer, ok := lh.StorageDriver.(S3BucketKeyer) if !ok { - dcontext.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver") - return lh.StorageDriver.URLFor(ctx, path, options) + dcontext.GetLogger(r.Context()).Warn("the CloudFront middleware does not support this backend storage driver") + return lh.StorageDriver.RedirectURL(r, path) } - if eligibleForS3(ctx, lh.awsIPs) { - return lh.StorageDriver.URLFor(ctx, path, options) + if eligibleForS3(r, lh.awsIPs) { + return lh.StorageDriver.RedirectURL(r, path) } // Get signed cloudfront url. diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go index c7ddd6f55..a0e332d3b 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter.go @@ -184,11 +184,7 @@ func (s *awsIPs) contains(ip net.IP) bool { // parseIPFromRequest attempts to extract the ip address of the // client that made the request -func parseIPFromRequest(ctx context.Context) (net.IP, error) { - request, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } +func parseIPFromRequest(request *http.Request) (net.IP, error) { ipStr := requestutil.RemoteIP(request) ip := net.ParseIP(ipStr) if ip == nil { @@ -200,25 +196,20 @@ func parseIPFromRequest(ctx context.Context) (net.IP, error) { // eligibleForS3 checks if a request is eligible for using S3 directly // Return true only when the IP belongs to a specific aws region and user-agent is docker -func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool { +func eligibleForS3(request *http.Request, awsIPs *awsIPs) bool { if awsIPs != nil && awsIPs.initialized { - if addr, err := parseIPFromRequest(ctx); err == nil { - request, err := dcontext.GetRequest(ctx) - if err != nil { - dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err) - } else { - loggerField := map[interface{}]interface{}{ - "user-client": request.UserAgent(), - "ip": requestutil.RemoteIP(request), - } - if awsIPs.contains(addr) { - dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront") - return true - } - dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront") + if addr, err := parseIPFromRequest(request); err == nil { + loggerField := map[interface{}]interface{}{ + "user-client": request.UserAgent(), + "ip": requestutil.RemoteIP(request), } + if awsIPs.contains(addr) { + dcontext.GetLoggerWithFields(request.Context(), loggerField).Info("request from the allowed AWS region, skipping CloudFront") + return true + } + dcontext.GetLoggerWithFields(request.Context(), loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront") } else { - dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront") + dcontext.GetLogger(request.Context()).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront") } } return false diff --git a/registry/storage/driver/middleware/cloudfront/s3filter_test.go b/registry/storage/driver/middleware/cloudfront/s3filter_test.go index 0d1055601..3d7356108 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter_test.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "crypto/rand" "encoding/json" "fmt" @@ -11,8 +10,6 @@ import ( "reflect" // used as a replacement for testify "testing" "time" - - "github.com/distribution/distribution/v3/internal/dcontext" ) // Rather than pull in all of testify @@ -269,29 +266,22 @@ func TestEligibleForS3(t *testing.T) { }}, initialized: true, } - empty := context.TODO() - makeContext := func(ip string) context.Context { - req := &http.Request{ - RemoteAddr: ip, - } - - return dcontext.WithRequest(empty, req) - } tests := []struct { - Context context.Context - Expected bool + RemoteAddr string + Expected bool }{ - {Context: empty, Expected: false}, - {Context: makeContext("192.168.1.2"), Expected: true}, - {Context: makeContext("192.168.0.2"), Expected: false}, + {RemoteAddr: "", Expected: false}, + {RemoteAddr: "192.168.1.2", Expected: true}, + {RemoteAddr: "192.168.0.2", Expected: false}, } for _, tc := range tests { tc := tc - t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) { + t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) { t.Parallel() - assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips)) + req := &http.Request{RemoteAddr: tc.RemoteAddr} + assertEqual(t, tc.Expected, eligibleForS3(req, ips)) }) } } @@ -305,29 +295,22 @@ func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) { }}, initialized: false, } - empty := context.TODO() - makeContext := func(ip string) context.Context { - req := &http.Request{ - RemoteAddr: ip, - } - - return dcontext.WithRequest(empty, req) - } tests := []struct { - Context context.Context - Expected bool + RemoteAddr string + Expected bool }{ - {Context: empty, Expected: false}, - {Context: makeContext("192.168.1.2"), Expected: false}, - {Context: makeContext("192.168.0.2"), Expected: false}, + {RemoteAddr: "", Expected: false}, + {RemoteAddr: "192.168.1.2", Expected: false}, + {RemoteAddr: "192.168.0.2", Expected: false}, } for _, tc := range tests { tc := tc - t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) { + t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) { t.Parallel() - assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips)) + req := &http.Request{RemoteAddr: tc.RemoteAddr} + assertEqual(t, tc.Expected, eligibleForS3(req, ips)) }) } } diff --git a/registry/storage/driver/middleware/redirect/middleware.go b/registry/storage/driver/middleware/redirect/middleware.go index 9e1b303ea..9ec59345e 100644 --- a/registry/storage/driver/middleware/redirect/middleware.go +++ b/registry/storage/driver/middleware/redirect/middleware.go @@ -1,8 +1,8 @@ package middleware import ( - "context" "fmt" + "net/http" "net/url" "path" @@ -42,7 +42,7 @@ func newRedirectStorageMiddleware(sd storagedriver.StorageDriver, options map[st return &redirectStorageMiddleware{StorageDriver: sd, scheme: u.Scheme, host: u.Host, basePath: u.Path}, nil } -func (r *redirectStorageMiddleware) URLFor(ctx context.Context, urlPath string, options map[string]interface{}) (string, error) { +func (r *redirectStorageMiddleware) RedirectURL(_ *http.Request, urlPath string) (string, error) { if r.basePath != "" { urlPath = path.Join(r.basePath, urlPath) } diff --git a/registry/storage/driver/middleware/redirect/middleware_test.go b/registry/storage/driver/middleware/redirect/middleware_test.go index 30d0bb192..bf4bb6f47 100644 --- a/registry/storage/driver/middleware/redirect/middleware_test.go +++ b/registry/storage/driver/middleware/redirect/middleware_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "testing" "gopkg.in/check.v1" @@ -37,7 +36,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) { c.Assert(m.scheme, check.Equals, "https") c.Assert(m.host, check.Equals, "example.com:5443") - url, err := middleware.URLFor(context.TODO(), "/rick/data", nil) + url, err := middleware.RedirectURL(nil, "/rick/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com:5443/rick/data") } @@ -53,7 +52,7 @@ func (s *MiddlewareSuite) TestHTTP(c *check.C) { c.Assert(m.scheme, check.Equals, "http") c.Assert(m.host, check.Equals, "example.com") - url, err := middleware.URLFor(context.TODO(), "morty/data", nil) + url, err := middleware.RedirectURL(nil, "morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "http://example.com/morty/data") } @@ -71,12 +70,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) { c.Assert(m.host, check.Equals, "example.com") c.Assert(m.basePath, check.Equals, "/path") - // call URLFor() with no leading slash - url, err := middleware.URLFor(context.TODO(), "morty/data", nil) + // call RedirectURL() with no leading slash + url, err := middleware.RedirectURL(nil, "morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com/path/morty/data") - // call URLFor() with leading slash - url, err = middleware.URLFor(context.TODO(), "/morty/data", nil) + // call RedirectURL() with leading slash + url, err = middleware.RedirectURL(nil, "/morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com/path/morty/data") @@ -91,12 +90,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) { c.Assert(m.host, check.Equals, "example.com") c.Assert(m.basePath, check.Equals, "/path/") - // call URLFor() with no leading slash - url, err = middleware.URLFor(context.TODO(), "morty/data", nil) + // call RedirectURL() with no leading slash + url, err = middleware.RedirectURL(nil, "morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com/path/morty/data") - // call URLFor() with leading slash - url, err = middleware.URLFor(context.TODO(), "/morty/data", nil) + // call RedirectURL() with leading slash + url, err = middleware.RedirectURL(nil, "/morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com/path/morty/data") } diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go index a85624d44..dd4666fe5 100644 --- a/registry/storage/driver/s3-aws/s3.go +++ b/registry/storage/driver/s3-aws/s3.go @@ -1036,30 +1036,13 @@ func (d *driver) Delete(ctx context.Context, path string) error { return nil } -// URLFor returns a URL which may be used to retrieve the content stored at the given path. -// May return an UnsupportedMethodErr in certain StorageDriver implementations. -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - methodString := http.MethodGet - method, ok := options["method"] - if ok { - methodString, ok = method.(string) - if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) { - return "", storagedriver.ErrUnsupportedMethod{} - } - } - +// RedirectURL returns a URL which may be used to retrieve the content stored at the given path. +func (d *driver) RedirectURL(r *http.Request, path string) (string, error) { expiresIn := 20 * time.Minute - expires, ok := options["expiry"] - if ok { - et, ok := expires.(time.Time) - if ok { - expiresIn = time.Until(et) - } - } var req *request.Request - switch methodString { + switch r.Method { case http.MethodGet: req, _ = d.S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: aws.String(d.Bucket), @@ -1071,7 +1054,7 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int Key: aws.String(d.s3Path(path)), }) default: - panic("unreachable") + return "", nil } return req.Presign(expiresIn) diff --git a/registry/storage/driver/storagedriver.go b/registry/storage/driver/storagedriver.go index f521b4a45..d84645765 100644 --- a/registry/storage/driver/storagedriver.go +++ b/registry/storage/driver/storagedriver.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "regexp" "strconv" "strings" @@ -92,11 +93,10 @@ type StorageDriver interface { // Delete recursively deletes all objects stored at "path" and its subpaths. Delete(ctx context.Context, path string) error - // URLFor returns a URL which may be used to retrieve the content stored at - // the given path, possibly using the given options. - // May return an ErrUnsupportedMethod in certain StorageDriver - // implementations. - URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) + // RedirectURL returns a URL which the client of the request r may use + // to retrieve the content stored at path. Returning the empty string + // signals that the request may not be redirected. + RedirectURL(r *http.Request, path string) (string, error) // Walk traverses a filesystem defined within driver, starting // from the given path, calling f on each file. diff --git a/registry/storage/driver/testsuites/testsuites.go b/registry/storage/driver/testsuites/testsuites.go index 6ad827fd0..72c6ca45e 100644 --- a/registry/storage/driver/testsuites/testsuites.go +++ b/registry/storage/driver/testsuites/testsuites.go @@ -8,6 +8,7 @@ import ( "io" "math/rand" "net/http" + "net/http/httptest" "os" "path" "sort" @@ -733,9 +734,9 @@ func (suite *DriverSuite) TestDelete(c *check.C) { c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true) } -// TestURLFor checks that the URLFor method functions properly, but only if it -// is implemented -func (suite *DriverSuite) TestURLFor(c *check.C) { +// TestRedirectURL checks that the RedirectURL method functions properly, +// but only if it is implemented +func (suite *DriverSuite) TestRedirectURL(c *check.C) { filename := randomPath(32) contents := randomContents(32) @@ -744,8 +745,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) { err := suite.StorageDriver.PutContent(suite.ctx, filename, contents) c.Assert(err, check.IsNil) - url, err := suite.StorageDriver.URLFor(suite.ctx, filename, nil) - if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok { + url, err := suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodGet, filename, nil), filename) + if url == "" && err == nil { return } c.Assert(err, check.IsNil) @@ -758,8 +759,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) { c.Assert(err, check.IsNil) c.Assert(read, check.DeepEquals, contents) - url, err = suite.StorageDriver.URLFor(suite.ctx, filename, map[string]interface{}{"method": http.MethodHead}) - if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok { + url, err = suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodHead, filename, nil), filename) + if url == "" && err == nil { return } c.Assert(err, check.IsNil) diff --git a/registry/storage/registry.go b/registry/storage/registry.go index 49b604f21..ecf483bf9 100644 --- a/registry/storage/registry.go +++ b/registry/storage/registry.go @@ -34,7 +34,7 @@ type manifestURLs struct { type RegistryOption func(*registry) error // EnableRedirect is a functional option for NewRegistry. It causes the backend -// blob server to attempt using (StorageDriver).URLFor to serve all blobs. +// blob server to attempt using (StorageDriver).RedirectURL to serve all blobs. func EnableRedirect(registry *registry) error { registry.blobServer.redirect = true return nil @@ -102,7 +102,7 @@ func BlobDescriptorCacheProvider(blobDescriptorCacheProvider cache.BlobDescripto // NewRegistry creates a new registry instance from the provided driver. The // resulting registry may be shared by multiple goroutines but is cheap to // allocate. If the Redirect option is specified, the backend blob server will -// attempt to use (StorageDriver).URLFor to serve all blobs. +// attempt to use (StorageDriver).RedirectURL to serve all blobs. func NewRegistry(ctx context.Context, driver storagedriver.StorageDriver, options ...RegistryOption) (distribution.Namespace, error) { // create global statter statter := &blobStatter{ From f7e5eaae702413ccaa59bb5aae09ea370fb1a894 Mon Sep 17 00:00:00 2001 From: Cory Snider Date: Tue, 24 Oct 2023 16:11:10 -0400 Subject: [PATCH 7/7] internal/dcontext: drop GetRequest() function It is no longer used. Signed-off-by: Cory Snider --- internal/dcontext/http.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/internal/dcontext/http.go b/internal/dcontext/http.go index df068f13e..84d5b4744 100644 --- a/internal/dcontext/http.go +++ b/internal/dcontext/http.go @@ -40,16 +40,6 @@ func WithRequest(ctx context.Context, r *http.Request) context.Context { } } -// GetRequest returns the http request in the given context. Returns -// ErrNoRequestContext if the context does not have an http request associated -// with it. -func GetRequest(ctx context.Context) (*http.Request, error) { - if r, ok := ctx.Value("http.request").(*http.Request); r != nil && ok { - return r, nil - } - return nil, ErrNoRequestContext -} - // GetRequestID attempts to resolve the current request id, if possible. An // error is return if it is not available on the context. func GetRequestID(ctx context.Context) string {