diff --git a/health/health.go b/health/health.go index 06961f35..3e21f731 100644 --- a/health/health.go +++ b/health/health.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" ) @@ -279,7 +279,7 @@ func Handler(handler http.Handler) http.Handler { func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks map[string]string) { p, err := json.Marshal(checks) if err != nil { - context.GetLogger(context.Background()).Errorf("error serializing health status: %v", err) + dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status: %v", err) p, err = json.Marshal(struct { ServerError string `json:"server_error"` }{ @@ -288,7 +288,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m status = http.StatusInternalServerError if err != nil { - context.GetLogger(context.Background()).Errorf("error serializing health status failure message: %v", err) + dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status failure message: %v", err) return } } @@ -297,7 +297,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m w.Header().Set("Content-Length", fmt.Sprint(len(p))) w.WriteHeader(status) if _, err := w.Write(p); err != nil { - context.GetLogger(context.Background()).Errorf("error writing health status response body: %v", err) + dcontext.GetLogger(dcontext.Background()).Errorf("error writing health status response body: %v", err) } } diff --git a/internal/client/repository_test.go b/internal/client/repository_test.go index b6f4d224..a98fe289 100644 --- a/internal/client/repository_test.go +++ b/internal/client/repository_test.go @@ -17,7 +17,7 @@ import ( "time" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/registry/api/errcode" @@ -108,7 +108,7 @@ func TestBlobServeBlob(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo1") r, err := NewRepository(repo, e, nil) if err != nil { @@ -157,7 +157,7 @@ func TestBlobServeBlobHEAD(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo1") r, err := NewRepository(repo, e, nil) if err != nil { @@ -250,7 +250,7 @@ func TestBlobResume(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -307,7 +307,7 @@ func TestBlobDelete(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -327,7 +327,7 @@ func TestBlobFetch(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo1") r, err := NewRepository(repo, e, nil) if err != nil { @@ -382,7 +382,7 @@ func TestBlobExistsNoContentLength(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -406,7 +406,7 @@ func TestBlobExists(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo1") r, err := NewRepository(repo, e, nil) if err != nil { @@ -512,7 +512,7 @@ func TestBlobUploadChunked(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -622,7 +622,7 @@ func TestBlobUploadMonolithic(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -728,7 +728,7 @@ func TestBlobUploadMonolithicDockerUploadUUIDFromURL(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -833,7 +833,7 @@ func TestBlobUploadMonolithicNoDockerUploadUUID(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -891,7 +891,7 @@ func TestBlobMount(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -1066,7 +1066,7 @@ func checkEqualManifest(m1, m2 *ocischema.DeserializedManifest) error { } func TestOCIManifestFetch(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo") m1, dgst, pl := newRandomOCIManifest(t, 6) var m testutil.RequestResponseMap @@ -1149,7 +1149,7 @@ func TestManifestFetchWithEtag(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -1171,7 +1171,7 @@ func TestManifestFetchWithEtag(t *testing.T) { } func TestManifestFetchWithAccept(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() repo, _ := reference.WithName("test.example.com/repo") _, dgst, _ := newRandomOCIManifest(t, 6) headers := make(chan []string, 1) @@ -1258,7 +1258,7 @@ func TestManifestDelete(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() ms, err := r.Manifests(ctx) if err != nil { t.Fatal(err) @@ -1315,7 +1315,7 @@ func TestManifestPut(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() ms, err := r.Manifests(ctx) if err != nil { t.Fatal(err) @@ -1372,7 +1372,7 @@ func TestManifestTags(t *testing.T) { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() tagService := r.Tags(ctx) tags, err := tagService.All(ctx) @@ -1423,7 +1423,7 @@ func TestTagDelete(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() ts := r.Tags(ctx) if err := ts.Untag(ctx, tag); err != nil { @@ -1460,7 +1460,7 @@ func TestObtainsErrorForMissingTag(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -1487,7 +1487,7 @@ func TestObtainsManifestForTagWithoutHeaders(t *testing.T) { e, c := testServer(m) defer c() - ctx := context.Background() + ctx := dcontext.Background() r, err := NewRepository(repo, e, nil) if err != nil { t.Fatal(err) @@ -1566,7 +1566,7 @@ func TestManifestTagsPaginated(t *testing.T) { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() tagService := r.Tags(ctx) tags, err := tagService.All(ctx) @@ -1614,7 +1614,7 @@ func TestManifestUnauthorized(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() ms, err := r.Manifests(ctx) if err != nil { t.Fatal(err) @@ -1652,7 +1652,7 @@ func TestCatalog(t *testing.T) { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() numFilled, err := r.Repositories(ctx, entries, "") if err != io.EOF { t.Fatal(err) @@ -1684,7 +1684,7 @@ func TestCatalogInParts(t *testing.T) { t.Fatal(err) } - ctx := context.Background() + ctx := dcontext.Background() numFilled, err := r.Repositories(ctx, entries, "") if err != nil { t.Fatal(err) diff --git a/context/context.go b/internal/dcontext/context.go similarity index 99% rename from context/context.go rename to internal/dcontext/context.go index 23f93e0b..fe837980 100644 --- a/context/context.go +++ b/internal/dcontext/context.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "context" diff --git a/context/doc.go b/internal/dcontext/doc.go similarity index 93% rename from context/doc.go rename to internal/dcontext/doc.go index 51376dd6..a8d9a740 100644 --- a/context/doc.go +++ b/internal/dcontext/doc.go @@ -1,16 +1,16 @@ -// Package context provides several utilities for working with +// Package dcontext provides several utilities for working with // Go's context in http requests. Primarily, the focus is on logging relevant // request information but this package is not limited to that purpose. // // The easiest way to get started is to get the background context: // -// ctx := context.Background() +// ctx := dcontext.Background() // // The returned context should be passed around your application and be the // root of all other context instances. If the application has a version, this // line should be called before anything else: // -// ctx := context.WithVersion(context.Background(), version) +// ctx := dcontext.WithVersion(dcontext.Background(), version) // // The above will store the version in the context and will be available to // the logger. @@ -27,7 +27,7 @@ // the context and reported with the logger. The following example would // return a logger that prints the version with each log message: // -// ctx := context.Context(context.Background(), "version", version) +// ctx := context.WithValue(dcontext.Background(), "version", version) // GetLogger(ctx, "version").Infof("this log message has a version field") // // The above would print out a log message like this: @@ -85,4 +85,4 @@ // can be traced in log messages. Using the fields like "http.request.id", one // can analyze call flow for a particular request with a simple grep of the // logs. -package context +package dcontext diff --git a/context/http.go b/internal/dcontext/http.go similarity index 83% rename from context/http.go rename to internal/dcontext/http.go index bcdf2965..84d5b474 100644 --- a/context/http.go +++ b/internal/dcontext/http.go @@ -1,17 +1,16 @@ -package context +package dcontext import ( "context" "errors" - "net" "net/http" "strings" "sync" "time" + "github.com/distribution/distribution/v3/internal/requestutil" "github.com/google/uuid" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" ) // Common errors used with this package. @@ -20,48 +19,6 @@ var ( ErrNoResponseWriterContext = errors.New("no http response in context") ) -func parseIP(ipStr string) net.IP { - ip := net.ParseIP(ipStr) - if ip == nil { - log.Warnf("invalid remote IP address: %q", ipStr) - } - return ip -} - -// RemoteAddr extracts the remote address of the request, taking into -// account proxy headers. -func RemoteAddr(r *http.Request) string { - if prior := r.Header.Get("X-Forwarded-For"); prior != "" { - remoteAddr, _, _ := strings.Cut(prior, ",") - remoteAddr = strings.Trim(remoteAddr, " ") - if parseIP(remoteAddr) != nil { - return remoteAddr - } - } - // X-Real-Ip is less supported, but worth checking in the - // absence of X-Forwarded-For - if realIP := r.Header.Get("X-Real-Ip"); realIP != "" { - if parseIP(realIP) != nil { - return realIP - } - } - - return r.RemoteAddr -} - -// RemoteIP extracts the remote IP of the request, taking into -// account proxy headers. -func RemoteIP(r *http.Request) string { - addr := RemoteAddr(r) - - // Try parsing it as "IP:port" - if ip, _, err := net.SplitHostPort(addr); err == nil { - return ip - } - - return addr -} - // WithRequest places the request on the context. The context of the request // is assigned a unique id, available at "http.request.id". The request itself // is available at "http.request". Other common attributes are available under @@ -83,16 +40,6 @@ func WithRequest(ctx context.Context, r *http.Request) context.Context { } } -// GetRequest returns the http request in the given context. Returns -// ErrNoRequestContext if the context does not have an http request associated -// with it. -func GetRequest(ctx context.Context) (*http.Request, error) { - if r, ok := ctx.Value("http.request").(*http.Request); r != nil && ok { - return r, nil - } - return nil, ErrNoRequestContext -} - // GetRequestID attempts to resolve the current request id, if possible. An // error is return if it is not available on the context. func GetRequestID(ctx context.Context) string { @@ -193,7 +140,7 @@ func (ctx *httpRequestContext) Value(key interface{}) interface{} { case "http.request.uri": return ctx.r.RequestURI case "http.request.remoteaddr": - return RemoteAddr(ctx.r) + return requestutil.RemoteAddr(ctx.r) case "http.request.method": return ctx.r.Method case "http.request.host": diff --git a/context/http_test.go b/internal/dcontext/http_test.go similarity index 70% rename from context/http_test.go rename to internal/dcontext/http_test.go index d9e67231..99c47bcd 100644 --- a/context/http_test.go +++ b/internal/dcontext/http_test.go @@ -1,10 +1,7 @@ -package context +package dcontext import ( "net/http" - "net/http/httptest" - "net/http/httputil" - "net/url" "reflect" "testing" "time" @@ -219,70 +216,3 @@ func TestWithVars(t *testing.T) { } } } - -// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test -// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten -// at the transport layer to 127.0.0.1: . However, as the X-Forwarded-For header -// just contains the IP address, it is different enough for testing. -func TestRemoteAddr(t *testing.T) { - var expectedRemote string - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - - if r.RemoteAddr == expectedRemote { - t.Errorf("Unexpected matching remote addresses") - } - - actualRemote := RemoteAddr(r) - if expectedRemote != actualRemote { - t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote) - } - - w.WriteHeader(200) - })) - - defer backend.Close() - backendURL, err := url.Parse(backend.URL) - if err != nil { - t.Fatal(err) - } - - proxy := httputil.NewSingleHostReverseProxy(backendURL) - frontend := httptest.NewServer(proxy) - defer frontend.Close() - - // X-Forwarded-For set by proxy - expectedRemote = "127.0.0.1" - proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil) - if err != nil { - t.Fatal(err) - } - - resp, err := http.DefaultClient.Do(proxyReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - // RemoteAddr in X-Real-Ip - getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil) - if err != nil { - t.Fatal(err) - } - - expectedRemote = "1.2.3.4" - getReq.Header["X-Real-ip"] = []string{expectedRemote} - resp, err = http.DefaultClient.Do(getReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - // Valid X-Real-Ip and invalid X-Forwarded-For - getReq.Header["X-forwarded-for"] = []string{"1.2.3"} - resp, err = http.DefaultClient.Do(getReq) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() -} diff --git a/context/logger.go b/internal/dcontext/logger.go similarity index 99% rename from context/logger.go rename to internal/dcontext/logger.go index f956a228..058fc831 100644 --- a/context/logger.go +++ b/internal/dcontext/logger.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "context" diff --git a/context/trace.go b/internal/dcontext/trace.go similarity index 99% rename from context/trace.go rename to internal/dcontext/trace.go index 2e169f0c..ba248105 100644 --- a/context/trace.go +++ b/internal/dcontext/trace.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "context" diff --git a/context/trace_test.go b/internal/dcontext/trace_test.go similarity index 99% rename from context/trace_test.go rename to internal/dcontext/trace_test.go index 4ee530bf..bb6d1779 100644 --- a/context/trace_test.go +++ b/internal/dcontext/trace_test.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "runtime" diff --git a/context/util.go b/internal/dcontext/util.go similarity index 97% rename from context/util.go rename to internal/dcontext/util.go index c462e756..5b32ba16 100644 --- a/context/util.go +++ b/internal/dcontext/util.go @@ -1,4 +1,4 @@ -package context +package dcontext import ( "context" diff --git a/context/version.go b/internal/dcontext/version.go similarity index 97% rename from context/version.go rename to internal/dcontext/version.go index 97cf9d66..0a0e9cbb 100644 --- a/context/version.go +++ b/internal/dcontext/version.go @@ -1,4 +1,4 @@ -package context +package dcontext import "context" diff --git a/context/version_test.go b/internal/dcontext/version_test.go similarity index 95% rename from context/version_test.go rename to internal/dcontext/version_test.go index b8165269..9829fe95 100644 --- a/context/version_test.go +++ b/internal/dcontext/version_test.go @@ -1,4 +1,4 @@ -package context +package dcontext import "testing" diff --git a/internal/requestutil/util.go b/internal/requestutil/util.go new file mode 100644 index 00000000..099e3454 --- /dev/null +++ b/internal/requestutil/util.go @@ -0,0 +1,51 @@ +package requestutil + +import ( + "net" + "net/http" + "strings" + + log "github.com/sirupsen/logrus" +) + +func parseIP(ipStr string) net.IP { + ip := net.ParseIP(ipStr) + if ip == nil { + log.Warnf("invalid remote IP address: %q", ipStr) + } + return ip +} + +// RemoteAddr extracts the remote address of the request, taking into +// account proxy headers. +func RemoteAddr(r *http.Request) string { + if prior := r.Header.Get("X-Forwarded-For"); prior != "" { + remoteAddr, _, _ := strings.Cut(prior, ",") + remoteAddr = strings.Trim(remoteAddr, " ") + if parseIP(remoteAddr) != nil { + return remoteAddr + } + } + // X-Real-Ip is less supported, but worth checking in the + // absence of X-Forwarded-For + if realIP := r.Header.Get("X-Real-Ip"); realIP != "" { + if parseIP(realIP) != nil { + return realIP + } + } + + return r.RemoteAddr +} + +// RemoteIP extracts the remote IP of the request, taking into +// account proxy headers. +func RemoteIP(r *http.Request) string { + addr := RemoteAddr(r) + + // Try parsing it as "IP:port" + if ip, _, err := net.SplitHostPort(addr); err == nil { + return ip + } + + return addr +} diff --git a/internal/requestutil/util_test.go b/internal/requestutil/util_test.go new file mode 100644 index 00000000..fc33527f --- /dev/null +++ b/internal/requestutil/util_test.go @@ -0,0 +1,76 @@ +package requestutil + +import ( + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "testing" +) + +// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test +// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten +// at the transport layer to 127.0.0.1: . However, as the X-Forwarded-For header +// just contains the IP address, it is different enough for testing. +func TestRemoteAddr(t *testing.T) { + var expectedRemote string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + if r.RemoteAddr == expectedRemote { + t.Errorf("Unexpected matching remote addresses") + } + + actualRemote := RemoteAddr(r) + if expectedRemote != actualRemote { + t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote) + } + + w.WriteHeader(200) + })) + + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxy := httputil.NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxy) + defer frontend.Close() + + // X-Forwarded-For set by proxy + expectedRemote = "127.0.0.1" + proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil) + if err != nil { + t.Fatal(err) + } + + resp, err := http.DefaultClient.Do(proxyReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // RemoteAddr in X-Real-Ip + getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil) + if err != nil { + t.Fatal(err) + } + + expectedRemote = "1.2.3.4" + getReq.Header["X-Real-ip"] = []string{expectedRemote} + resp, err = http.DefaultClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Valid X-Real-Ip and invalid X-Forwarded-For + getReq.Header["X-forwarded-for"] = []string{"1.2.3"} + resp, err = http.DefaultClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() +} diff --git a/notifications/bridge.go b/notifications/bridge.go index e63aab65..8b594774 100644 --- a/notifications/bridge.go +++ b/notifications/bridge.go @@ -5,7 +5,7 @@ import ( "time" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/requestutil" "github.com/distribution/reference" events "github.com/docker/go-events" "github.com/google/uuid" @@ -49,7 +49,7 @@ func NewBridge(ub URLBuilder, source SourceRecord, actor ActorRecord, request Re func NewRequestRecord(id string, r *http.Request) RequestRecord { return RequestRecord{ ID: id, - Addr: context.RemoteAddr(r), + Addr: requestutil.RemoteAddr(r), Host: r.Host, Method: r.Method, UserAgent: r.UserAgent(), diff --git a/notifications/listener.go b/notifications/listener.go index fad29341..563c1392 100644 --- a/notifications/listener.go +++ b/notifications/listener.go @@ -7,7 +7,7 @@ import ( "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/reference" "github.com/opencontainers/go-digest" ) diff --git a/notifications/listener_test.go b/notifications/listener_test.go index e33109be..5781d453 100644 --- a/notifications/listener_test.go +++ b/notifications/listener_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage/cache/memory" diff --git a/registry/auth/auth.go b/registry/auth/auth.go index 9cb036f1..6266d1e5 100644 --- a/registry/auth/auth.go +++ b/registry/auth/auth.go @@ -18,7 +18,7 @@ // resource := auth.Resource{Type: "customerOrder", Name: orderNumber} // access := auth.Access{Resource: resource, Action: "update"} // -// if ctx, err := accessController.Authorized(ctx, access); err != nil { +// if ctx, err := accessController.Authorized(r, access); err != nil { // if challenge, ok := err.(auth.Challenge) { // // Let the challenge write the response. // challenge.SetHeaders(r, w) @@ -32,22 +32,11 @@ package auth import ( - "context" "errors" "fmt" "net/http" ) -const ( - // UserKey is used to get the user object from - // a user context - UserKey = "auth.user" - - // UserNameKey is used to get the user name from - // a user context - UserNameKey = "auth.user.name" -) - var ( // ErrInvalidCredential is returned when the auth token does not authenticate correctly. ErrInvalidCredential = errors.New("invalid authorization credential") @@ -76,6 +65,12 @@ type Access struct { Action string } +// Grant describes the permitted level of access for an authorized request. +type Grant struct { + User UserInfo // The authenticated user for the request. + Resources []Resource // The list of resources which have been authorized for the request. +} + // Challenge is a special error type which is used for HTTP 401 Unauthorized // responses and is able to write the response with WWW-Authenticate challenge // header values based on the error. @@ -93,16 +88,15 @@ type Challenge interface { // and required access levels for a request. Implementations can support both // complete denial and http authorization challenges. type AccessController interface { - // Authorized returns a non-nil error if the context is granted access and - // returns a new authorized context. If one or more Access structs are - // provided, the requested access will be compared with what is available - // to the context. The given context will contain a "http.request" key with - // a `*http.Request` value. If the error is non-nil, access should always - // be denied. The error may be of type Challenge, in which case the caller - // may have the Challenge handle the request or choose what action to take - // based on the Challenge header or response status. The returned context - // object should have a "auth.user" value set to a UserInfo struct. - Authorized(ctx context.Context, access ...Access) (context.Context, error) + // Authorized determines if the request is granted access. If one or more + // Access structs are provided, the requested access will be compared with + // what is available to the request. + // + // Return a Grant to grant the request access. Return an error to deny + // access. The error may be of type Challenge, in which case the caller may + // have the Challenge handle the request or choose what action to take based + // on the Challenge header or response status. + Authorized(r *http.Request, access ...Access) (*Grant, error) } // CredentialAuthenticator is an object which is able to authenticate credentials @@ -110,63 +104,6 @@ type CredentialAuthenticator interface { AuthenticateUser(username, password string) error } -// WithUser returns a context with the authorized user info. -func WithUser(ctx context.Context, user UserInfo) context.Context { - return userInfoContext{ - Context: ctx, - user: user, - } -} - -type userInfoContext struct { - context.Context - user UserInfo -} - -func (uic userInfoContext) Value(key interface{}) interface{} { - switch key { - case UserKey: - return uic.user - case UserNameKey: - return uic.user.Name - } - - return uic.Context.Value(key) -} - -// WithResources returns a context with the authorized resources. -func WithResources(ctx context.Context, resources []Resource) context.Context { - return resourceContext{ - Context: ctx, - resources: resources, - } -} - -type resourceContext struct { - context.Context - resources []Resource -} - -type resourceKey struct{} - -func (rc resourceContext) Value(key interface{}) interface{} { - if key == (resourceKey{}) { - return rc.resources - } - - return rc.Context.Value(key) -} - -// AuthorizedResources returns the list of resources which have -// been authorized for this request. -func AuthorizedResources(ctx context.Context) []Resource { - if resources, ok := ctx.Value(resourceKey{}).([]Resource); ok { - return resources - } - - return nil -} - // InitFunc is the type of an AccessController factory function and is used // to register the constructor for different AccesController backends. type InitFunc func(options map[string]interface{}) (AccessController, error) diff --git a/registry/auth/htpasswd/access.go b/registry/auth/htpasswd/access.go index 8d8a0419..c8c43265 100644 --- a/registry/auth/htpasswd/access.go +++ b/registry/auth/htpasswd/access.go @@ -18,7 +18,7 @@ import ( "golang.org/x/crypto/bcrypt" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) @@ -49,12 +49,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, return &accessController{realm: realm.(string), path: path}, nil } -func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { - req, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - +func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (*auth.Grant, error) { username, password, ok := req.BasicAuth() if !ok { return nil, &challenge{ @@ -92,14 +87,14 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut ac.mu.Unlock() if err := localHTPasswd.authenticateUser(username, password); err != nil { - dcontext.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err) + dcontext.GetLogger(req.Context()).Errorf("error authenticating user %q: %v", username, err) return nil, &challenge{ realm: ac.realm, err: auth.ErrAuthenticationFailure, } } - return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil + return &auth.Grant{User: auth.UserInfo{Name: username}}, nil } // challenge implements the auth.Challenge interface. diff --git a/registry/auth/htpasswd/access_test.go b/registry/auth/htpasswd/access_test.go index 25947757..01f2ac9d 100644 --- a/registry/auth/htpasswd/access_test.go +++ b/registry/auth/htpasswd/access_test.go @@ -8,7 +8,6 @@ import ( "os" "testing" - "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/registry/auth" ) @@ -33,7 +32,6 @@ func TestBasicAccessController(t *testing.T) { "realm": testRealm, "path": tempFile.Name(), } - ctx := context.Background() accessController, err := newAccessController(options) if err != nil { @@ -45,8 +43,7 @@ func TestBasicAccessController(t *testing.T) { userNumber := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithRequest(ctx, r) - authCtx, err := accessController.Authorized(ctx) + grant, err := accessController.Authorized(r) if err != nil { switch err := err.(type) { case auth.Challenge: @@ -58,13 +55,12 @@ func TestBasicAccessController(t *testing.T) { } } - userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo) - if !ok { - t.Fatal("basic accessController did not set auth.user context") + if grant == nil { + t.Fatal("basic accessController did not return auth grant") } - if userInfo.Name != testUsers[userNumber] { - t.Fatalf("expected user name %q, got %q", testUsers[userNumber], userInfo.Name) + if grant.User.Name != testUsers[userNumber] { + t.Fatalf("expected user name %q, got %q", testUsers[userNumber], grant.User.Name) } w.WriteHeader(http.StatusNoContent) diff --git a/registry/auth/silly/access.go b/registry/auth/silly/access.go index b09373f3..1984ba20 100644 --- a/registry/auth/silly/access.go +++ b/registry/auth/silly/access.go @@ -8,12 +8,10 @@ package silly import ( - "context" "fmt" "net/http" "strings" - dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/registry/auth" ) @@ -43,12 +41,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, // Authorized simply checks for the existence of the authorization header, // responding with a bearer challenge if it doesn't exist. -func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { - req, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - +func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (*auth.Grant, error) { if req.Header.Get("Authorization") == "" { challenge := challenge{ realm: ac.realm, @@ -66,10 +59,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut return nil, &challenge } - ctx = auth.WithUser(ctx, auth.UserInfo{Name: "silly"}) - ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey)) - - return ctx, nil + return &auth.Grant{User: auth.UserInfo{Name: "silly"}}, nil } type challenge struct { diff --git a/registry/auth/silly/access_test.go b/registry/auth/silly/access_test.go index 482fb2f7..506af0bd 100644 --- a/registry/auth/silly/access_test.go +++ b/registry/auth/silly/access_test.go @@ -5,7 +5,6 @@ import ( "net/http/httptest" "testing" - "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/registry/auth" ) @@ -16,8 +15,7 @@ func TestSillyAccessController(t *testing.T) { } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithRequest(context.Background(), r) - authCtx, err := ac.Authorized(ctx) + grant, err := ac.Authorized(r) if err != nil { switch err := err.(type) { case auth.Challenge: @@ -29,13 +27,12 @@ func TestSillyAccessController(t *testing.T) { } } - userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo) - if !ok { - t.Fatal("silly accessController did not set auth.user context") + if grant == nil { + t.Fatal("silly accessController did not return auth grant") } - if userInfo.Name != "silly" { - t.Fatalf("expected user name %q, got %q", "silly", userInfo.Name) + if grant.User.Name != "silly" { + t.Fatalf("expected user name %q, got %q", "silly", grant.User.Name) } w.WriteHeader(http.StatusNoContent) diff --git a/registry/auth/token/accesscontroller.go b/registry/auth/token/accesscontroller.go index 24902019..bed4c827 100644 --- a/registry/auth/token/accesscontroller.go +++ b/registry/auth/token/accesscontroller.go @@ -1,7 +1,6 @@ package token import ( - "context" "crypto" "crypto/x509" "encoding/json" @@ -13,7 +12,6 @@ import ( "os" "strings" - dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/registry/auth" "github.com/go-jose/go-jose/v3" ) @@ -292,7 +290,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, // Authorized handles checking whether the given request is authorized // for actions on resources described by the given access items. -func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) { +func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (*auth.Grant, error) { challenge := &authChallenge{ realm: ac.realm, autoRedirect: ac.autoRedirect, @@ -300,11 +298,6 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth. accessSet: newAccessSet(accessItems...), } - req, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ") if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") { challenge.err = ErrTokenRequired @@ -338,9 +331,10 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth. } } - ctx = auth.WithResources(ctx, claims.resources()) - - return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil + return &auth.Grant{ + User: auth.UserInfo{Name: claims.Subject}, + Resources: claims.resources(), + }, nil } // init handles registering the token auth backend. diff --git a/registry/auth/token/token_test.go b/registry/auth/token/token_test.go index 417d0d9b..a96546af 100644 --- a/registry/auth/token/token_test.go +++ b/registry/auth/token/token_test.go @@ -18,7 +18,6 @@ import ( "testing" "time" - "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/registry/auth" "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" @@ -466,8 +465,7 @@ func TestAccessController(t *testing.T) { Action: "baz", } - ctx := context.WithRequest(context.Background(), req) - authCtx, err := accessController.Authorized(ctx, testAccess) + grant, err := accessController.Authorized(req, testAccess) challenge, ok := err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -477,8 +475,8 @@ func TestAccessController(t *testing.T) { t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired) } - if authCtx != nil { - t.Fatalf("expected nil auth context but got %s", authCtx) + if grant != nil { + t.Fatalf("expected nil auth grant but got %#v", grant) } // 2. Supply an invalid token. @@ -502,7 +500,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(ctx, testAccess) + grant, err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -512,8 +510,8 @@ func TestAccessController(t *testing.T) { t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired) } - if authCtx != nil { - t.Fatalf("expected nil auth context but got %s", authCtx) + if grant != nil { + t.Fatalf("expected nil auth grant but got %#v", grant) } // create a valid jwk @@ -534,7 +532,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(ctx, testAccess) + grant, err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -544,8 +542,8 @@ func TestAccessController(t *testing.T) { t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrInsufficientScope) } - if authCtx != nil { - t.Fatalf("expected nil auth context but got %s", authCtx) + if grant != nil { + t.Fatalf("expected nil auth grant but got %#v", grant) } // 4. Supply the token we need, or deserve, or whatever. @@ -564,18 +562,13 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(ctx, testAccess) + grant, err = accessController.Authorized(req, testAccess) if err != nil { t.Fatalf("accessController returned unexpected error: %s", err) } - userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo) - if !ok { - t.Fatal("token accessController did not set auth.user context") - } - - if userInfo.Name != "foo" { - t.Fatalf("expected user name %q, got %q", "foo", userInfo.Name) + if grant.User.Name != "foo" { + t.Fatalf("expected user name %q, got %q", "foo", grant.User.Name) } // 5. Supply a token with full admin rights, which is represented as "*". @@ -594,7 +587,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - _, err = accessController.Authorized(ctx, testAccess) + _, err = accessController.Authorized(req, testAccess) if err != nil { t.Fatalf("accessController returned unexpected error: %s", err) } diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 5f017b70..fb8e9dd2 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -19,9 +19,9 @@ import ( "github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3/configuration" - dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/health" "github.com/distribution/distribution/v3/health/checks" + "github.com/distribution/distribution/v3/internal/dcontext" prometheus "github.com/distribution/distribution/v3/metrics" "github.com/distribution/distribution/v3/notifications" "github.com/distribution/distribution/v3/registry/api/errcode" @@ -635,7 +635,7 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler { } // Add username to request logging - context.Context = dcontext.WithLogger(context.Context, dcontext.GetLogger(context.Context, auth.UserNameKey)) + context.Context = dcontext.WithLogger(context.Context, dcontext.GetLogger(context.Context, userNameKey)) // sync up context on the request. r = r.WithContext(context) @@ -797,7 +797,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont accessRecords = appendCatalogAccessRecord(accessRecords, r) } - ctx, err := app.accessController.Authorized(context.Context, accessRecords...) + grant, err := app.accessController.Authorized(r.WithContext(context.Context), accessRecords...) if err != nil { switch err := err.(type) { case auth.Challenge: @@ -818,8 +818,14 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont return err } + if grant == nil { + return fmt.Errorf("access controller returned neither an access grant nor an error") + } - dcontext.GetLogger(ctx, auth.UserNameKey).Info("authorized request") + ctx := withUser(context.Context, grant.User) + ctx = withResources(ctx, grant.Resources) + + dcontext.GetLogger(ctx, userNameKey).Info("authorized request") // TODO(stevvooe): This pattern needs to be cleaned up a bit. One context // should be replaced by another, rather than replacing the context on a // mutable object. diff --git a/registry/handlers/app_test.go b/registry/handlers/app_test.go index 391deea6..d1862592 100644 --- a/registry/handlers/app_test.go +++ b/registry/handlers/app_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/distribution/distribution/v3/configuration" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" v2 "github.com/distribution/distribution/v3/registry/api/v2" "github.com/distribution/distribution/v3/registry/auth" @@ -25,7 +25,7 @@ import ( // tested individually. func TestAppDispatcher(t *testing.T) { driver := inmemory.New() - ctx := context.Background() + ctx := dcontext.Background() registry, err := storage.NewRegistry(ctx, driver, storage.BlobDescriptorCacheProvider(memorycache.NewInMemoryBlobDescriptorCacheProvider(0)), storage.EnableDelete, storage.EnableRedirect) if err != nil { t.Fatalf("error creating registry: %v", err) @@ -139,7 +139,7 @@ func TestAppDispatcher(t *testing.T) { // TestNewApp covers the creation of an application via NewApp with a // configuration. func TestNewApp(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() config := configuration.Configuration{ Storage: configuration.Storage{ "inmemory": nil, diff --git a/registry/handlers/blob.go b/registry/handlers/blob.go index 4d50754d..b4979782 100644 --- a/registry/handlers/blob.go +++ b/registry/handlers/blob.go @@ -4,7 +4,7 @@ import ( "net/http" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" "github.com/gorilla/handlers" "github.com/opencontainers/go-digest" @@ -53,7 +53,7 @@ type blobHandler struct { // GetBlob fetches the binary data from backend storage returns it in the // response. func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) { - context.GetLogger(bh).Debug("GetBlob") + dcontext.GetLogger(bh).Debug("GetBlob") blobs := bh.Repository.Blobs(bh) desc, err := blobs.Stat(bh, bh.Digest) if err != nil { @@ -66,7 +66,7 @@ func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) { } if err := blobs.ServeBlob(bh, w, r, desc.Digest); err != nil { - context.GetLogger(bh).Debugf("unexpected error getting blob HTTP handler: %v", err) + dcontext.GetLogger(bh).Debugf("unexpected error getting blob HTTP handler: %v", err) bh.Errors = append(bh.Errors, errcode.ErrorCodeUnknown.WithDetail(err)) return } @@ -74,7 +74,7 @@ func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) { // DeleteBlob deletes a layer blob func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) { - context.GetLogger(bh).Debug("DeleteBlob") + dcontext.GetLogger(bh).Debug("DeleteBlob") blobs := bh.Repository.Blobs(bh) err := blobs.Delete(bh, bh.Digest) @@ -88,7 +88,7 @@ func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) { return default: bh.Errors = append(bh.Errors, err) - context.GetLogger(bh).Errorf("Unknown error deleting blob: %s", err.Error()) + dcontext.GetLogger(bh).Errorf("Unknown error deleting blob: %s", err.Error()) return } } diff --git a/registry/handlers/blobupload.go b/registry/handlers/blobupload.go index 09c00bb8..6b6c640d 100644 --- a/registry/handlers/blobupload.go +++ b/registry/handlers/blobupload.go @@ -7,7 +7,7 @@ import ( "strconv" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/reference" diff --git a/registry/handlers/context.go b/registry/handlers/context.go index cac97a04..c272095c 100644 --- a/registry/handlers/context.go +++ b/registry/handlers/context.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/api/errcode" v2 "github.com/distribution/distribution/v3/registry/api/v2" "github.com/distribution/distribution/v3/registry/auth" @@ -77,10 +77,20 @@ func getUploadUUID(ctx context.Context) (uuid string) { return dcontext.GetStringValue(ctx, "vars.uuid") } +const ( + // userKey is used to get the user object from + // a user context + userKey = "auth.user" + + // userNameKey is used to get the user name from + // a user context + userNameKey = "auth.user.name" +) + // getUserName attempts to resolve a username from the context and request. If // a username cannot be resolved, the empty string is returned. func getUserName(ctx context.Context, r *http.Request) string { - username := dcontext.GetStringValue(ctx, auth.UserNameKey) + username := dcontext.GetStringValue(ctx, userNameKey) // Fallback to request user with basic auth if username == "" { @@ -93,3 +103,60 @@ func getUserName(ctx context.Context, r *http.Request) string { return username } + +// withUser returns a context with the authorized user info. +func withUser(ctx context.Context, user auth.UserInfo) context.Context { + return userInfoContext{ + Context: ctx, + user: user, + } +} + +type userInfoContext struct { + context.Context + user auth.UserInfo +} + +func (uic userInfoContext) Value(key interface{}) interface{} { + switch key { + case userKey: + return uic.user + case userNameKey: + return uic.user.Name + } + + return uic.Context.Value(key) +} + +// withResources returns a context with the authorized resources. +func withResources(ctx context.Context, resources []auth.Resource) context.Context { + return resourceContext{ + Context: ctx, + resources: resources, + } +} + +type resourceContext struct { + context.Context + resources []auth.Resource +} + +type resourceKey struct{} + +func (rc resourceContext) Value(key interface{}) interface{} { + if key == (resourceKey{}) { + return rc.resources + } + + return rc.Context.Value(key) +} + +// authorizedResources returns the list of resources which have +// been authorized for this request. +func authorizedResources(ctx context.Context) []auth.Resource { + if resources, ok := ctx.Value(resourceKey{}).([]auth.Resource); ok { + return resources + } + + return nil +} diff --git a/registry/handlers/health_test.go b/registry/handlers/health_test.go index e0d61555..dc706b87 100644 --- a/registry/handlers/health_test.go +++ b/registry/handlers/health_test.go @@ -9,8 +9,8 @@ import ( "time" "github.com/distribution/distribution/v3/configuration" - "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/health" + "github.com/distribution/distribution/v3/internal/dcontext" ) func TestFileHealthCheck(t *testing.T) { @@ -39,7 +39,7 @@ func TestFileHealthCheck(t *testing.T) { }, } - ctx := context.Background() + ctx := dcontext.Background() app := NewApp(ctx, config) healthRegistry := health.NewRegistry() @@ -103,7 +103,7 @@ func TestTCPHealthCheck(t *testing.T) { }, } - ctx := context.Background() + ctx := dcontext.Background() app := NewApp(ctx, config) healthRegistry := health.NewRegistry() @@ -165,7 +165,7 @@ func TestHTTPHealthCheck(t *testing.T) { }, } - ctx := context.Background() + ctx := dcontext.Background() app := NewApp(ctx, config) healthRegistry := health.NewRegistry() diff --git a/registry/handlers/helpers.go b/registry/handlers/helpers.go index d70306fd..3ccba555 100644 --- a/registry/handlers/helpers.go +++ b/registry/handlers/helpers.go @@ -9,7 +9,7 @@ import ( "strconv" "strings" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" ) // closeResources closes all the provided resources after running the target diff --git a/registry/handlers/manifests.go b/registry/handlers/manifests.go index f571afd9..1eea810f 100644 --- a/registry/handlers/manifests.go +++ b/registry/handlers/manifests.go @@ -8,12 +8,11 @@ import ( "strings" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/registry/api/errcode" - "github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/reference" "github.com/gorilla/handlers" @@ -394,7 +393,7 @@ func (imh *manifestHandler) applyResourcePolicy(manifest distribution.Manifest) return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("registry does not allow %s manifest", class)) } - resources := auth.AuthorizedResources(imh) + resources := authorizedResources(imh) n := imh.Repository.Named().Name() var foundResource bool diff --git a/registry/proxy/proxyauth.go b/registry/proxy/proxyauth.go index ee79ed7e..8cdc3ebf 100644 --- a/registry/proxy/proxyauth.go +++ b/registry/proxy/proxyauth.go @@ -5,9 +5,9 @@ import ( "net/url" "strings" - "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/client/auth" "github.com/distribution/distribution/v3/internal/client/auth/challenge" + "github.com/distribution/distribution/v3/internal/dcontext" ) const challengeHeader = "Docker-Distribution-Api-Version" @@ -44,7 +44,7 @@ func configureAuth(username, password, remoteURL string) (auth.CredentialStore, } for _, url := range authURLs { - context.GetLogger(context.Background()).Infof("Discovered token authentication URL: %s", url) + dcontext.GetLogger(dcontext.Background()).Infof("Discovered token authentication URL: %s", url) creds[url] = userpass{ username: username, password: password, diff --git a/registry/proxy/proxyblobstore.go b/registry/proxy/proxyblobstore.go index f83ef329..bf8ca22f 100644 --- a/registry/proxy/proxyblobstore.go +++ b/registry/proxy/proxyblobstore.go @@ -11,7 +11,7 @@ import ( "github.com/opencontainers/go-digest" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/proxy/scheduler" "github.com/distribution/reference" ) diff --git a/registry/proxy/proxymanifeststore.go b/registry/proxy/proxymanifeststore.go index fa60a07e..1b0e5a31 100644 --- a/registry/proxy/proxymanifeststore.go +++ b/registry/proxy/proxymanifeststore.go @@ -7,7 +7,7 @@ import ( "github.com/opencontainers/go-digest" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/proxy/scheduler" "github.com/distribution/reference" ) diff --git a/registry/proxy/proxyregistry.go b/registry/proxy/proxyregistry.go index 15dc783b..33dcc4af 100644 --- a/registry/proxy/proxyregistry.go +++ b/registry/proxy/proxyregistry.go @@ -10,11 +10,11 @@ import ( "github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3/configuration" - dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/client" "github.com/distribution/distribution/v3/internal/client/auth" "github.com/distribution/distribution/v3/internal/client/auth/challenge" "github.com/distribution/distribution/v3/internal/client/transport" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/proxy/scheduler" "github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage/driver" diff --git a/registry/proxy/scheduler/scheduler.go b/registry/proxy/scheduler/scheduler.go index e492cf71..ed1d9d41 100644 --- a/registry/proxy/scheduler/scheduler.go +++ b/registry/proxy/scheduler/scheduler.go @@ -7,7 +7,7 @@ import ( "sync" "time" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/reference" ) diff --git a/registry/proxy/scheduler/scheduler_test.go b/registry/proxy/scheduler/scheduler_test.go index 1309a1b0..38fa0f58 100644 --- a/registry/proxy/scheduler/scheduler_test.go +++ b/registry/proxy/scheduler/scheduler_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/distribution/reference" ) @@ -40,7 +40,7 @@ func TestSchedule(t *testing.T) { } var mu sync.Mutex - s := New(context.Background(), inmemory.New(), "/ttl") + s := New(dcontext.Background(), inmemory.New(), "/ttl") deleteFunc := func(repoName reference.Reference) error { if len(remainingRepos) == 0 { t.Fatalf("Incorrect expiry count") @@ -123,14 +123,14 @@ func TestRestoreOld(t *testing.T) { t.Fatalf("Error serializing test data: %s", err.Error()) } - ctx := context.Background() + ctx := dcontext.Background() pathToStatFile := "/ttl" fs := inmemory.New() err = fs.PutContent(ctx, pathToStatFile, serialized) if err != nil { t.Fatal("Unable to write serialized data to fs") } - s := New(context.Background(), fs, "/ttl") + s := New(dcontext.Background(), fs, "/ttl") s.OnBlobExpire(deleteFunc) err = s.Start() if err != nil { @@ -165,7 +165,7 @@ func TestStopRestore(t *testing.T) { fs := inmemory.New() pathToStateFile := "/ttl" - s := New(context.Background(), fs, pathToStateFile) + s := New(dcontext.Background(), fs, pathToStateFile) s.onBlobExpire = deleteFunc err := s.Start() @@ -181,7 +181,7 @@ func TestStopRestore(t *testing.T) { time.Sleep(10 * time.Millisecond) // v2 will restore state from fs - s2 := New(context.Background(), fs, pathToStateFile) + s2 := New(dcontext.Background(), fs, pathToStateFile) s2.onBlobExpire = deleteFunc err = s2.Start() if err != nil { @@ -197,7 +197,7 @@ func TestStopRestore(t *testing.T) { } func TestDoubleStart(t *testing.T) { - s := New(context.Background(), inmemory.New(), "/ttl") + s := New(dcontext.Background(), inmemory.New(), "/ttl") err := s.Start() if err != nil { t.Fatalf("Unable to start scheduler") diff --git a/registry/registry.go b/registry/registry.go index 98a29eed..933013c5 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -21,8 +21,8 @@ import ( "golang.org/x/crypto/acme/autocert" "github.com/distribution/distribution/v3/configuration" - dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/health" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/handlers" "github.com/distribution/distribution/v3/registry/listener" "github.com/distribution/distribution/v3/version" diff --git a/registry/registry_test.go b/registry/registry_test.go index 3e919e31..194da97a 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -25,7 +25,7 @@ import ( "time" "github.com/distribution/distribution/v3/configuration" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" _ "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" diff --git a/registry/root.go b/registry/root.go index fb370ca8..3119ea85 100644 --- a/registry/root.go +++ b/registry/root.go @@ -4,7 +4,7 @@ import ( "fmt" "os" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage/driver/factory" "github.com/distribution/distribution/v3/version" diff --git a/registry/storage/blobserver.go b/registry/storage/blobserver.go index 6392e355..6beef7e3 100644 --- a/registry/storage/blobserver.go +++ b/registry/storage/blobserver.go @@ -20,7 +20,7 @@ type blobServer struct { driver driver.StorageDriver statter distribution.BlobStatter pathFn func(dgst digest.Digest) (string, error) - redirect bool // allows disabling URLFor redirects + redirect bool // allows disabling RedirectURL redirects } func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error { @@ -35,19 +35,16 @@ func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *h } if bs.redirect { - redirectURL, err := bs.driver.URLFor(ctx, path, map[string]interface{}{"method": r.Method}) - switch err.(type) { - case nil: - // Redirect to storage URL. - http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect) - return err - - case driver.ErrUnsupportedMethod: - // Fallback to serving the content directly. - default: - // Some unexpected error. + redirectURL, err := bs.driver.RedirectURL(r, path) + if err != nil { return err } + if redirectURL != "" { + // Redirect to storage URL. + http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect) + return nil + } + // Fallback to serving the content directly. } br, err := newFileReader(ctx, bs.driver, path, desc.Size) diff --git a/registry/storage/blobstore.go b/registry/storage/blobstore.go index f54e901d..c03736ea 100644 --- a/registry/storage/blobstore.go +++ b/registry/storage/blobstore.go @@ -6,7 +6,7 @@ import ( "path" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/opencontainers/go-digest" ) diff --git a/registry/storage/blobwriter.go b/registry/storage/blobwriter.go index 70c8b2af..9de3b0c2 100644 --- a/registry/storage/blobwriter.go +++ b/registry/storage/blobwriter.go @@ -9,7 +9,7 @@ import ( "time" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/opencontainers/go-digest" "github.com/sirupsen/logrus" diff --git a/registry/storage/blobwriter_nonresumable.go b/registry/storage/blobwriter_nonresumable.go index a7df9a22..b3b3f6ab 100644 --- a/registry/storage/blobwriter_nonresumable.go +++ b/registry/storage/blobwriter_nonresumable.go @@ -4,7 +4,7 @@ package storage import ( - "github.com/distribution/distribution/v3/context" + "context" ) // resumeHashAt is a noop when resumable digest support is disabled. diff --git a/registry/storage/cache/cachedblobdescriptorstore.go b/registry/storage/cache/cachedblobdescriptorstore.go index b4dc828c..38dd1dcf 100644 --- a/registry/storage/cache/cachedblobdescriptorstore.go +++ b/registry/storage/cache/cachedblobdescriptorstore.go @@ -4,7 +4,7 @@ import ( "context" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" prometheus "github.com/distribution/distribution/v3/metrics" "github.com/opencontainers/go-digest" ) diff --git a/registry/storage/driver/azure/azure.go b/registry/storage/driver/azure/azure.go index 7e853c62..45fd85c9 100644 --- a/registry/storage/driver/azure/azure.go +++ b/registry/storage/driver/azure/azure.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "io" + "net/http" "strings" "time" @@ -302,7 +303,7 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) { // Move moves an object stored at sourcePath to destPath, removing the original // object. func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error { - sourceBlobURL, err := d.URLFor(ctx, sourcePath, nil) + sourceBlobURL, err := d.signBlobURL(ctx, sourcePath) if err != nil { return err } @@ -382,18 +383,15 @@ func (d *driver) Delete(ctx context.Context, path string) error { return nil } -// URLFor returns a publicly accessible URL for the blob stored at given path +// RedirectURL returns a publicly accessible URL for the blob stored at given path // for specified duration by making use of Azure Storage Shared Access Signatures (SAS). // See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx for more info. -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { +func (d *driver) RedirectURL(req *http.Request, path string) (string, error) { + return d.signBlobURL(req.Context(), path) +} + +func (d *driver) signBlobURL(ctx context.Context, path string) (string, error) { expiresTime := time.Now().UTC().Add(20 * time.Minute) // default expiration - expires, ok := options["expiry"] - if ok { - t, ok := expires.(time.Time) - if ok { - expiresTime = t - } - } blobName := d.blobName(path) blobRef := d.client.NewBlobClient(blobName) return d.azClient.SignBlobURL(ctx, blobRef.URL(), expiresTime) diff --git a/registry/storage/driver/base/base.go b/registry/storage/driver/base/base.go index 69371712..32c9037b 100644 --- a/registry/storage/driver/base/base.go +++ b/registry/storage/driver/base/base.go @@ -40,9 +40,10 @@ package base import ( "context" "io" + "net/http" "time" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" prometheus "github.com/distribution/distribution/v3/metrics" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/docker/go-metrics" @@ -208,18 +209,18 @@ func (base *Base) Delete(ctx context.Context, path string) error { return err } -// URLFor wraps URLFor of underlying storage driver. -func (base *Base) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - ctx, done := dcontext.WithTrace(ctx) - defer done("%s.URLFor(%q)", base.Name(), path) +// RedirectURL wraps RedirectURL of the underlying storage driver. +func (base *Base) RedirectURL(r *http.Request, path string) (string, error) { + ctx, done := dcontext.WithTrace(r.Context()) + defer done("%s.RedirectURL(%q)", base.Name(), path) if !storagedriver.PathRegexp.MatchString(path) { return "", storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()} } start := time.Now() - str, e := base.StorageDriver.URLFor(ctx, path, options) - storageAction.WithValues(base.Name(), "URLFor").UpdateSince(start) + str, e := base.StorageDriver.RedirectURL(r.WithContext(ctx), path) + storageAction.WithValues(base.Name(), "RedirectURL").UpdateSince(start) return str, base.setDriverName(e) } diff --git a/registry/storage/driver/base/regulator.go b/registry/storage/driver/base/regulator.go index 09a258e3..2cf7a3ec 100644 --- a/registry/storage/driver/base/regulator.go +++ b/registry/storage/driver/base/regulator.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/http" "reflect" "strconv" "sync" @@ -172,13 +173,11 @@ func (r *regulator) Delete(ctx context.Context, path string) error { return r.StorageDriver.Delete(ctx, path) } -// URLFor returns a URL which may be used to retrieve the content stored at -// the given path, possibly using the given options. -// May return an ErrUnsupportedMethod in certain StorageDriver -// implementations. -func (r *regulator) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { +// RedirectURL returns a URL which may be used to retrieve the content stored at +// the given path. +func (r *regulator) RedirectURL(req *http.Request, path string) (string, error) { r.enter() defer r.exit() - return r.StorageDriver.URLFor(ctx, path, options) + return r.StorageDriver.RedirectURL(req, path) } diff --git a/registry/storage/driver/filesystem/driver.go b/registry/storage/driver/filesystem/driver.go index d5514f0b..c8746bdd 100644 --- a/registry/storage/driver/filesystem/driver.go +++ b/registry/storage/driver/filesystem/driver.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "io" + "net/http" "os" "path" "time" @@ -282,10 +283,9 @@ func (d *driver) Delete(ctx context.Context, subPath string) error { return err } -// URLFor returns a URL which may be used to retrieve the content stored at the given path. -// May return an UnsupportedMethodErr in certain StorageDriver implementations. -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - return "", storagedriver.ErrUnsupportedMethod{} +// RedirectURL returns a URL which may be used to retrieve the content stored at the given path. +func (d *driver) RedirectURL(*http.Request, string) (string, error) { + return "", nil } // Walk traverses a filesystem defined within driver, starting diff --git a/registry/storage/driver/gcs/gcs.go b/registry/storage/driver/gcs/gcs.go index 79c03c70..4cbe2fc3 100644 --- a/registry/storage/driver/gcs/gcs.go +++ b/registry/storage/driver/gcs/gcs.go @@ -809,40 +809,24 @@ func storageCopyObject(ctx context.Context, srcBucket, srcName string, destBucke return attrs, err } -// URLFor returns a URL which may be used to retrieve the content stored at +// RedirectURL returns a URL which may be used to retrieve the content stored at // the given path, possibly using the given options. -// Returns ErrUnsupportedMethod if this driver has no privateKey -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { +func (d *driver) RedirectURL(r *http.Request, path string) (string, error) { if d.privateKey == nil { - return "", storagedriver.ErrUnsupportedMethod{} + return "", nil } - name := d.pathToKey(path) - methodString := http.MethodGet - method, ok := options["method"] - if ok { - methodString, ok = method.(string) - if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) { - return "", storagedriver.ErrUnsupportedMethod{} - } - } - - expiresTime := time.Now().Add(20 * time.Minute) - expires, ok := options["expiry"] - if ok { - et, ok := expires.(time.Time) - if ok { - expiresTime = et - } + if r.Method != http.MethodGet && r.Method != http.MethodHead { + return "", nil } opts := &storage.SignedURLOptions{ GoogleAccessID: d.email, PrivateKey: d.privateKey, - Method: methodString, - Expires: expiresTime, + Method: r.Method, + Expires: time.Now().Add(20 * time.Minute), } - return storage.SignedURL(d.bucket, name, opts) + return storage.SignedURL(d.bucket, d.pathToKey(path), opts) } // Walk traverses a filesystem defined within driver, starting diff --git a/registry/storage/driver/gcs/gcs_test.go b/registry/storage/driver/gcs/gcs_test.go index 65998e7b..a76015d0 100644 --- a/registry/storage/driver/gcs/gcs_test.go +++ b/registry/storage/driver/gcs/gcs_test.go @@ -10,7 +10,7 @@ import ( "testing" "cloud.google.com/go/storage" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/testsuites" "golang.org/x/oauth2" diff --git a/registry/storage/driver/inmemory/driver.go b/registry/storage/driver/inmemory/driver.go index 97c3bcde..1915c011 100644 --- a/registry/storage/driver/inmemory/driver.go +++ b/registry/storage/driver/inmemory/driver.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/http" "sync" "time" @@ -236,10 +237,9 @@ func (d *driver) Delete(ctx context.Context, path string) error { } } -// URLFor returns a URL which may be used to retrieve the content stored at the given path. -// May return an UnsupportedMethodErr in certain StorageDriver implementations. -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - return "", storagedriver.ErrUnsupportedMethod{} +// RedirectURL returns a URL which may be used to retrieve the content stored at the given path. +func (d *driver) RedirectURL(*http.Request, string) (string, error) { + return "", nil } // Walk traverses a filesystem defined within driver, starting diff --git a/registry/storage/driver/middleware/cloudfront/middleware.go b/registry/storage/driver/middleware/cloudfront/middleware.go index 741d618e..63cd2bf9 100644 --- a/registry/storage/driver/middleware/cloudfront/middleware.go +++ b/registry/storage/driver/middleware/cloudfront/middleware.go @@ -7,13 +7,14 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "net/http" "net/url" "os" "strings" "time" "github.com/aws/aws-sdk-go/service/cloudfront/sign" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagemiddleware "github.com/distribution/distribution/v3/registry/storage/driver/middleware" ) @@ -201,18 +202,18 @@ type S3BucketKeyer interface { S3BucketKey(path string) string } -// URLFor attempts to find a url which may be used to retrieve the file at the given path. +// RedirectURL attempts to find a url which may be used to retrieve the file at the given path. // Returns an error if the file cannot be found. -func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { +func (lh *cloudFrontStorageMiddleware) RedirectURL(r *http.Request, path string) (string, error) { // TODO(endophage): currently only supports S3 keyer, ok := lh.StorageDriver.(S3BucketKeyer) if !ok { - dcontext.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver") - return lh.StorageDriver.URLFor(ctx, path, options) + dcontext.GetLogger(r.Context()).Warn("the CloudFront middleware does not support this backend storage driver") + return lh.StorageDriver.RedirectURL(r, path) } - if eligibleForS3(ctx, lh.awsIPs) { - return lh.StorageDriver.URLFor(ctx, path, options) + if eligibleForS3(r, lh.awsIPs) { + return lh.StorageDriver.RedirectURL(r, path) } // Get signed cloudfront url. diff --git a/registry/storage/driver/middleware/cloudfront/s3filter.go b/registry/storage/driver/middleware/cloudfront/s3filter.go index 75158c91..190bc6d9 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter.go @@ -12,7 +12,8 @@ import ( "sync" "time" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" + "github.com/distribution/distribution/v3/internal/requestutil" ) const ( @@ -192,12 +193,8 @@ func (s *awsIPs) contains(ip net.IP) bool { // parseIPFromRequest attempts to extract the ip address of the // client that made the request -func parseIPFromRequest(ctx context.Context) (net.IP, error) { - request, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - ipStr := dcontext.RemoteIP(request) +func parseIPFromRequest(request *http.Request) (net.IP, error) { + ipStr := requestutil.RemoteIP(request) ip := net.ParseIP(ipStr) if ip == nil { return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr) @@ -208,25 +205,20 @@ func parseIPFromRequest(ctx context.Context) (net.IP, error) { // eligibleForS3 checks if a request is eligible for using S3 directly // Return true only when the IP belongs to a specific aws region and user-agent is docker -func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool { +func eligibleForS3(request *http.Request, awsIPs *awsIPs) bool { if awsIPs != nil && awsIPs.initialized { - if addr, err := parseIPFromRequest(ctx); err == nil { - request, err := dcontext.GetRequest(ctx) - if err != nil { - dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err) - } else { - loggerField := map[interface{}]interface{}{ - "user-client": request.UserAgent(), - "ip": dcontext.RemoteIP(request), - } - if awsIPs.contains(addr) { - dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront") - return true - } - dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront") + if addr, err := parseIPFromRequest(request); err == nil { + loggerField := map[interface{}]interface{}{ + "user-client": request.UserAgent(), + "ip": requestutil.RemoteIP(request), } + if awsIPs.contains(addr) { + dcontext.GetLoggerWithFields(request.Context(), loggerField).Info("request from the allowed AWS region, skipping CloudFront") + return true + } + dcontext.GetLoggerWithFields(request.Context(), loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront") } else { - dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront") + dcontext.GetLogger(request.Context()).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront") } } return false diff --git a/registry/storage/driver/middleware/cloudfront/s3filter_test.go b/registry/storage/driver/middleware/cloudfront/s3filter_test.go index c347d035..7c55c635 100644 --- a/registry/storage/driver/middleware/cloudfront/s3filter_test.go +++ b/registry/storage/driver/middleware/cloudfront/s3filter_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "crypto/rand" "encoding/json" "fmt" @@ -11,8 +10,6 @@ import ( "reflect" // used as a replacement for testify "testing" "time" - - dcontext "github.com/distribution/distribution/v3/context" ) // Rather than pull in all of testify @@ -276,29 +273,22 @@ func TestEligibleForS3(t *testing.T) { }}, initialized: true, } - empty := context.TODO() - makeContext := func(ip string) context.Context { - req := &http.Request{ - RemoteAddr: ip, - } - - return dcontext.WithRequest(empty, req) - } tests := []struct { - Context context.Context - Expected bool + RemoteAddr string + Expected bool }{ - {Context: empty, Expected: false}, - {Context: makeContext("192.168.1.2"), Expected: true}, - {Context: makeContext("192.168.0.2"), Expected: false}, + {RemoteAddr: "", Expected: false}, + {RemoteAddr: "192.168.1.2", Expected: true}, + {RemoteAddr: "192.168.0.2", Expected: false}, } for _, tc := range tests { tc := tc - t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) { + t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) { t.Parallel() - assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips)) + req := &http.Request{RemoteAddr: tc.RemoteAddr} + assertEqual(t, tc.Expected, eligibleForS3(req, ips)) }) } } @@ -312,29 +302,22 @@ func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) { }}, initialized: false, } - empty := context.TODO() - makeContext := func(ip string) context.Context { - req := &http.Request{ - RemoteAddr: ip, - } - - return dcontext.WithRequest(empty, req) - } tests := []struct { - Context context.Context - Expected bool + RemoteAddr string + Expected bool }{ - {Context: empty, Expected: false}, - {Context: makeContext("192.168.1.2"), Expected: false}, - {Context: makeContext("192.168.0.2"), Expected: false}, + {RemoteAddr: "", Expected: false}, + {RemoteAddr: "192.168.1.2", Expected: false}, + {RemoteAddr: "192.168.0.2", Expected: false}, } for _, tc := range tests { tc := tc - t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) { + t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) { t.Parallel() - assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips)) + req := &http.Request{RemoteAddr: tc.RemoteAddr} + assertEqual(t, tc.Expected, eligibleForS3(req, ips)) }) } } diff --git a/registry/storage/driver/middleware/redirect/middleware.go b/registry/storage/driver/middleware/redirect/middleware.go index 8976d868..e4b0663b 100644 --- a/registry/storage/driver/middleware/redirect/middleware.go +++ b/registry/storage/driver/middleware/redirect/middleware.go @@ -1,8 +1,8 @@ package middleware import ( - "context" "fmt" + "net/http" "net/url" "path" @@ -42,7 +42,7 @@ func newRedirectStorageMiddleware(ctx context.Context, sd storagedriver.StorageD return &redirectStorageMiddleware{StorageDriver: sd, scheme: u.Scheme, host: u.Host, basePath: u.Path}, nil } -func (r *redirectStorageMiddleware) URLFor(ctx context.Context, urlPath string, options map[string]interface{}) (string, error) { +func (r *redirectStorageMiddleware) RedirectURL(_ *http.Request, urlPath string) (string, error) { if r.basePath != "" { urlPath = path.Join(r.basePath, urlPath) } diff --git a/registry/storage/driver/middleware/redirect/middleware_test.go b/registry/storage/driver/middleware/redirect/middleware_test.go index 2c22dec3..7be383f5 100644 --- a/registry/storage/driver/middleware/redirect/middleware_test.go +++ b/registry/storage/driver/middleware/redirect/middleware_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "testing" "gopkg.in/check.v1" @@ -37,7 +36,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) { c.Assert(m.scheme, check.Equals, "https") c.Assert(m.host, check.Equals, "example.com:5443") - url, err := middleware.URLFor(context.TODO(), "/rick/data", nil) + url, err := middleware.RedirectURL(nil, "/rick/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com:5443/rick/data") } @@ -53,7 +52,7 @@ func (s *MiddlewareSuite) TestHTTP(c *check.C) { c.Assert(m.scheme, check.Equals, "http") c.Assert(m.host, check.Equals, "example.com") - url, err := middleware.URLFor(context.TODO(), "morty/data", nil) + url, err := middleware.RedirectURL(nil, "morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "http://example.com/morty/data") } @@ -71,12 +70,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) { c.Assert(m.host, check.Equals, "example.com") c.Assert(m.basePath, check.Equals, "/path") - // call URLFor() with no leading slash - url, err := middleware.URLFor(context.TODO(), "morty/data", nil) + // call RedirectURL() with no leading slash + url, err := middleware.RedirectURL(nil, "morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com/path/morty/data") - // call URLFor() with leading slash - url, err = middleware.URLFor(context.TODO(), "/morty/data", nil) + // call RedirectURL() with leading slash + url, err = middleware.RedirectURL(nil, "/morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com/path/morty/data") @@ -91,12 +90,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) { c.Assert(m.host, check.Equals, "example.com") c.Assert(m.basePath, check.Equals, "/path/") - // call URLFor() with no leading slash - url, err = middleware.URLFor(context.TODO(), "morty/data", nil) + // call RedirectURL() with no leading slash + url, err = middleware.RedirectURL(nil, "morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com/path/morty/data") - // call URLFor() with leading slash - url, err = middleware.URLFor(context.TODO(), "/morty/data", nil) + // call RedirectURL() with leading slash + url, err = middleware.RedirectURL(nil, "/morty/data") c.Assert(err, check.Equals, nil) c.Assert(url, check.Equals, "https://example.com/path/morty/data") } diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go index 08e39118..8092bb7e 100644 --- a/registry/storage/driver/s3-aws/s3.go +++ b/registry/storage/driver/s3-aws/s3.go @@ -36,7 +36,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/base" "github.com/distribution/distribution/v3/registry/storage/driver/factory" @@ -1036,30 +1036,13 @@ func (d *driver) Delete(ctx context.Context, path string) error { return nil } -// URLFor returns a URL which may be used to retrieve the content stored at the given path. -// May return an UnsupportedMethodErr in certain StorageDriver implementations. -func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { - methodString := http.MethodGet - method, ok := options["method"] - if ok { - methodString, ok = method.(string) - if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) { - return "", storagedriver.ErrUnsupportedMethod{} - } - } - +// RedirectURL returns a URL which may be used to retrieve the content stored at the given path. +func (d *driver) RedirectURL(r *http.Request, path string) (string, error) { expiresIn := 20 * time.Minute - expires, ok := options["expiry"] - if ok { - et, ok := expires.(time.Time) - if ok { - expiresIn = time.Until(et) - } - } var req *request.Request - switch methodString { + switch r.Method { case http.MethodGet: req, _ = d.S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: aws.String(d.Bucket), @@ -1071,7 +1054,7 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int Key: aws.String(d.s3Path(path)), }) default: - panic("unreachable") + return "", nil } return req.Presign(expiresIn) diff --git a/registry/storage/driver/s3-aws/s3_test.go b/registry/storage/driver/s3-aws/s3_test.go index 7bfae2aa..35af9249 100644 --- a/registry/storage/driver/s3-aws/s3_test.go +++ b/registry/storage/driver/s3-aws/s3_test.go @@ -16,7 +16,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/s3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/testsuites" ) @@ -180,7 +180,7 @@ func TestEmptyRootList(t *testing.T) { filename := "/test" contents := []byte("contents") - ctx := context.Background() + ctx := dcontext.Background() err = rootedDriver.PutContent(ctx, filename, contents) if err != nil { t.Fatalf("unexpected error creating content: %v", err) @@ -209,7 +209,7 @@ func TestStorageClass(t *testing.T) { rootDir := t.TempDir() contents := []byte("contents") - ctx := context.Background() + ctx := dcontext.Background() // We don't need to test all the storage classes, just that its selectable. // The first 3 are common to AWS and MinIO, so use those. @@ -377,7 +377,7 @@ func TestDelete(t *testing.T) { // init file structure matching objs var created []string for _, p := range objs { - err := drvr.PutContent(context.Background(), p, []byte("content "+p)) + err := drvr.PutContent(dcontext.Background(), p, []byte("content "+p)) if err != nil { fmt.Printf("unable to init file %s: %s\n", p, err) continue @@ -390,7 +390,7 @@ func TestDelete(t *testing.T) { cleanup := func(objs []string) { var lastErr error for _, p := range objs { - err := drvr.Delete(context.Background(), p) + err := drvr.Delete(dcontext.Background(), p) if err != nil { switch err.(type) { case storagedriver.PathNotFoundError: @@ -409,7 +409,7 @@ func TestDelete(t *testing.T) { t.Run(tc.name, func(t *testing.T) { objs := init() - err := drvr.Delete(context.Background(), tc.delete) + err := drvr.Delete(dcontext.Background(), tc.delete) if tc.err != nil { if err == nil { @@ -437,7 +437,7 @@ func TestDelete(t *testing.T) { return false } for _, path := range objs { - stat, err := drvr.Stat(context.Background(), path) + stat, err := drvr.Stat(dcontext.Background(), path) if err != nil { switch err.(type) { case storagedriver.PathNotFoundError: @@ -491,7 +491,7 @@ func TestWalk(t *testing.T) { // create file structure matching fileset above created := make([]string, 0, len(fileset)) for _, p := range fileset { - err := drvr.PutContent(context.Background(), p, []byte("content "+p)) + err := drvr.PutContent(dcontext.Background(), p, []byte("content "+p)) if err != nil { fmt.Printf("unable to create file %s: %s\n", p, err) continue @@ -503,7 +503,7 @@ func TestWalk(t *testing.T) { defer func() { var lastErr error for _, p := range created { - err := drvr.Delete(context.Background(), p) + err := drvr.Delete(dcontext.Background(), p) if err != nil { _ = fmt.Errorf("cleanup failed for path %s: %s", p, err) lastErr = err @@ -692,7 +692,7 @@ func TestWalk(t *testing.T) { tc.from = "/" } t.Run(tc.name, func(t *testing.T) { - err := drvr.Walk(context.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error { + err := drvr.Walk(dcontext.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error { walked = append(walked, fileInfo.Path()) return tc.fn(fileInfo) }, tc.options...) @@ -718,7 +718,7 @@ func TestOverThousandBlobs(t *testing.T) { t.Fatalf("unexpected error creating driver with standard storage: %v", err) } - ctx := context.Background() + ctx := dcontext.Background() for i := 0; i < 1005; i++ { filename := "/thousandfiletest/file" + strconv.Itoa(i) contents := []byte("contents") @@ -746,7 +746,7 @@ func TestMoveWithMultipartCopy(t *testing.T) { t.Fatalf("unexpected error creating driver: %v", err) } - ctx := context.Background() + ctx := dcontext.Background() sourcePath := "/source" destPath := "/dest" @@ -795,7 +795,7 @@ func TestListObjectsV2(t *testing.T) { t.Fatalf("unexpected error creating driver: %v", err) } - ctx := context.Background() + ctx := dcontext.Background() n := 6 prefix := "/test-list-objects-v2" var filePaths []string diff --git a/registry/storage/driver/storagedriver.go b/registry/storage/driver/storagedriver.go index f521b4a4..d8464576 100644 --- a/registry/storage/driver/storagedriver.go +++ b/registry/storage/driver/storagedriver.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "regexp" "strconv" "strings" @@ -92,11 +93,10 @@ type StorageDriver interface { // Delete recursively deletes all objects stored at "path" and its subpaths. Delete(ctx context.Context, path string) error - // URLFor returns a URL which may be used to retrieve the content stored at - // the given path, possibly using the given options. - // May return an ErrUnsupportedMethod in certain StorageDriver - // implementations. - URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) + // RedirectURL returns a URL which the client of the request r may use + // to retrieve the content stored at path. Returning the empty string + // signals that the request may not be redirected. + RedirectURL(r *http.Request, path string) (string, error) // Walk traverses a filesystem defined within driver, starting // from the given path, calling f on each file. diff --git a/registry/storage/driver/testsuites/testsuites.go b/registry/storage/driver/testsuites/testsuites.go index 6ad827fd..72c6ca45 100644 --- a/registry/storage/driver/testsuites/testsuites.go +++ b/registry/storage/driver/testsuites/testsuites.go @@ -8,6 +8,7 @@ import ( "io" "math/rand" "net/http" + "net/http/httptest" "os" "path" "sort" @@ -733,9 +734,9 @@ func (suite *DriverSuite) TestDelete(c *check.C) { c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true) } -// TestURLFor checks that the URLFor method functions properly, but only if it -// is implemented -func (suite *DriverSuite) TestURLFor(c *check.C) { +// TestRedirectURL checks that the RedirectURL method functions properly, +// but only if it is implemented +func (suite *DriverSuite) TestRedirectURL(c *check.C) { filename := randomPath(32) contents := randomContents(32) @@ -744,8 +745,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) { err := suite.StorageDriver.PutContent(suite.ctx, filename, contents) c.Assert(err, check.IsNil) - url, err := suite.StorageDriver.URLFor(suite.ctx, filename, nil) - if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok { + url, err := suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodGet, filename, nil), filename) + if url == "" && err == nil { return } c.Assert(err, check.IsNil) @@ -758,8 +759,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) { c.Assert(err, check.IsNil) c.Assert(read, check.DeepEquals, contents) - url, err = suite.StorageDriver.URLFor(suite.ctx, filename, map[string]interface{}{"method": http.MethodHead}) - if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok { + url, err = suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodHead, filename, nil), filename) + if url == "" && err == nil { return } c.Assert(err, check.IsNil) diff --git a/registry/storage/filereader_test.go b/registry/storage/filereader_test.go index 3f05582e..5dfa2947 100644 --- a/registry/storage/filereader_test.go +++ b/registry/storage/filereader_test.go @@ -7,13 +7,13 @@ import ( mrand "math/rand" "testing" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/opencontainers/go-digest" ) func TestSimpleRead(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() content := make([]byte, 1<<20) n, err := crand.Read(content) if err != nil { @@ -55,7 +55,7 @@ func TestFileReaderSeek(t *testing.T) { repititions := 1024 path := "/patterned" content := bytes.Repeat([]byte(pattern), repititions) - ctx := context.Background() + ctx := dcontext.Background() if err := driver.PutContent(ctx, path, content); err != nil { t.Fatalf("error putting patterned content: %v", err) @@ -156,7 +156,7 @@ func TestFileReaderSeek(t *testing.T) { // read method, with an io.EOF error. func TestFileReaderNonExistentFile(t *testing.T) { driver := inmemory.New() - fr, err := newFileReader(context.Background(), driver, "/doesnotexist", 10) + fr, err := newFileReader(dcontext.Background(), driver, "/doesnotexist", 10) if err != nil { t.Fatalf("unexpected error initializing reader: %v", err) } diff --git a/registry/storage/garbagecollect_test.go b/registry/storage/garbagecollect_test.go index edf7f1c3..5e69a178 100644 --- a/registry/storage/garbagecollect_test.go +++ b/registry/storage/garbagecollect_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/distribution/distribution/v3/testutil" @@ -21,7 +21,7 @@ type image struct { } func createRegistry(t *testing.T, driver driver.StorageDriver, options ...RegistryOption) distribution.Namespace { - ctx := context.Background() + ctx := dcontext.Background() options = append(options, EnableDelete) registry, err := NewRegistry(ctx, driver, options...) if err != nil { @@ -31,7 +31,7 @@ func createRegistry(t *testing.T, driver driver.StorageDriver, options ...Regist } func makeRepository(t *testing.T, registry distribution.Namespace, name string) distribution.Repository { - ctx := context.Background() + ctx := dcontext.Background() // Initialize a dummy repository named, err := reference.WithName(name) @@ -47,7 +47,7 @@ func makeRepository(t *testing.T, registry distribution.Namespace, name string) } func makeManifestService(t *testing.T, repository distribution.Repository) distribution.ManifestService { - ctx := context.Background() + ctx := dcontext.Background() manifestService, err := repository.Manifests(ctx) if err != nil { @@ -57,7 +57,7 @@ func makeManifestService(t *testing.T, repository distribution.Repository) distr } func allManifests(t *testing.T, manifestService distribution.ManifestService) map[digest.Digest]struct{} { - ctx := context.Background() + ctx := dcontext.Background() allManMap := make(map[digest.Digest]struct{}) manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator) if !ok { @@ -74,7 +74,7 @@ func allManifests(t *testing.T, manifestService distribution.ManifestService) ma } func allBlobs(t *testing.T, registry distribution.Namespace) map[digest.Digest]struct{} { - ctx := context.Background() + ctx := dcontext.Background() blobService := registry.Blobs() allBlobsMap := make(map[digest.Digest]struct{}) err := blobService.Enumerate(ctx, func(dgst digest.Digest) error { @@ -95,7 +95,7 @@ func uploadImage(t *testing.T, repository distribution.Repository, im image) dig } // upload manifest - ctx := context.Background() + ctx := dcontext.Background() manifestService := makeManifestService(t, repository) manifestDigest, err := manifestService.Put(ctx, im.manifest) if err != nil { @@ -130,7 +130,7 @@ func uploadRandomSchema2Image(t *testing.T, repository distribution.Repository) } func TestNoDeletionNoEffect(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver) @@ -158,7 +158,7 @@ func TestNoDeletionNoEffect(t *testing.T) { before := allBlobs(t, registry) // Run GC - err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: false, RemoveUntagged: false, }) @@ -173,7 +173,7 @@ func TestNoDeletionNoEffect(t *testing.T) { } func TestDeleteManifestIfTagNotFound(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver) @@ -233,7 +233,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) { before2 := allManifests(t, manifestService) // run GC with dry-run (should not remove anything) - err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: true, RemoveUntagged: true, }) @@ -250,7 +250,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) { } // Run GC (removes everything because no manifests with tags exist) - err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: false, RemoveUntagged: true, }) @@ -269,7 +269,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) { } func TestGCWithMissingManifests(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() d := inmemory.New() registry := createRegistry(t, d) @@ -288,7 +288,7 @@ func TestGCWithMissingManifests(t *testing.T) { t.Fatal(err) } - err = MarkAndSweep(context.Background(), d, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), d, registry, GCOpts{ DryRun: false, RemoveUntagged: false, }) @@ -303,7 +303,7 @@ func TestGCWithMissingManifests(t *testing.T) { } func TestDeletionHasEffect(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver) @@ -318,7 +318,7 @@ func TestDeletionHasEffect(t *testing.T) { manifests.Delete(ctx, image3.manifestDigest) // Run GC - err := MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err := MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: false, RemoveUntagged: false, }) @@ -368,7 +368,7 @@ func getKeys(digests map[digest.Digest]io.ReadSeeker) (ds []digest.Digest) { } func TestDeletionWithSharedLayer(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver) @@ -455,7 +455,7 @@ func TestOrphanBlobDeleted(t *testing.T) { uploadRandomSchema2Image(t, repo) // Run GC - err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ + err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{ DryRun: false, RemoveUntagged: false, }) diff --git a/registry/storage/linkedblobstore.go b/registry/storage/linkedblobstore.go index 392d39b0..18d40d27 100644 --- a/registry/storage/linkedblobstore.go +++ b/registry/storage/linkedblobstore.go @@ -9,7 +9,7 @@ import ( "time" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/reference" "github.com/google/uuid" diff --git a/registry/storage/manifestlisthandler.go b/registry/storage/manifestlisthandler.go index 1fc7aac7..caaf76bb 100644 --- a/registry/storage/manifestlisthandler.go +++ b/registry/storage/manifestlisthandler.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/opencontainers/go-digest" diff --git a/registry/storage/manifeststore.go b/registry/storage/manifeststore.go index 37c86495..027ab65c 100644 --- a/registry/storage/manifeststore.go +++ b/registry/storage/manifeststore.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest" "github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/ocischema" diff --git a/registry/storage/ociindexhandler.go b/registry/storage/ociindexhandler.go index 01864f06..b4d73413 100644 --- a/registry/storage/ociindexhandler.go +++ b/registry/storage/ociindexhandler.go @@ -4,7 +4,7 @@ import ( "context" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/opencontainers/go-digest" ) diff --git a/registry/storage/ocimanifesthandler.go b/registry/storage/ocimanifesthandler.go index 97216d2a..f69c2a2b 100644 --- a/registry/storage/ocimanifesthandler.go +++ b/registry/storage/ocimanifesthandler.go @@ -6,7 +6,7 @@ import ( "net/url" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/ocischema" "github.com/opencontainers/go-digest" v1 "github.com/opencontainers/image-spec/specs-go/v1" diff --git a/registry/storage/registry.go b/registry/storage/registry.go index 49b604f2..ecf483bf 100644 --- a/registry/storage/registry.go +++ b/registry/storage/registry.go @@ -34,7 +34,7 @@ type manifestURLs struct { type RegistryOption func(*registry) error // EnableRedirect is a functional option for NewRegistry. It causes the backend -// blob server to attempt using (StorageDriver).URLFor to serve all blobs. +// blob server to attempt using (StorageDriver).RedirectURL to serve all blobs. func EnableRedirect(registry *registry) error { registry.blobServer.redirect = true return nil @@ -102,7 +102,7 @@ func BlobDescriptorCacheProvider(blobDescriptorCacheProvider cache.BlobDescripto // NewRegistry creates a new registry instance from the provided driver. The // resulting registry may be shared by multiple goroutines but is cheap to // allocate. If the Redirect option is specified, the backend blob server will -// attempt to use (StorageDriver).URLFor to serve all blobs. +// attempt to use (StorageDriver).RedirectURL to serve all blobs. func NewRegistry(ctx context.Context, driver storagedriver.StorageDriver, options ...RegistryOption) (distribution.Namespace, error) { // create global statter statter := &blobStatter{ diff --git a/registry/storage/schema2manifesthandler.go b/registry/storage/schema2manifesthandler.go index aa8980ec..adf0677e 100644 --- a/registry/storage/schema2manifesthandler.go +++ b/registry/storage/schema2manifesthandler.go @@ -7,7 +7,7 @@ import ( "net/url" "github.com/distribution/distribution/v3" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/opencontainers/go-digest" ) diff --git a/registry/storage/schema2manifesthandler_test.go b/registry/storage/schema2manifesthandler_test.go index 908f8c13..26c19ec1 100644 --- a/registry/storage/schema2manifesthandler_test.go +++ b/registry/storage/schema2manifesthandler_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" @@ -14,7 +14,7 @@ import ( ) func TestVerifyManifestForeignLayer(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver, ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")), @@ -152,7 +152,7 @@ func TestVerifyManifestForeignLayer(t *testing.T) { } func TestVerifyManifestBlobLayerAndConfig(t *testing.T) { - ctx := context.Background() + ctx := dcontext.Background() inmemoryDriver := inmemory.New() registry := createRegistry(t, inmemoryDriver, ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")), diff --git a/registry/storage/vacuum.go b/registry/storage/vacuum.go index 749fb319..38ebbd67 100644 --- a/registry/storage/vacuum.go +++ b/registry/storage/vacuum.go @@ -4,7 +4,7 @@ import ( "context" "path" - dcontext "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/storage/driver" "github.com/opencontainers/go-digest" ) diff --git a/testutil/manifests.go b/testutil/manifests.go index 96af5703..dd1cc3a3 100644 --- a/testutil/manifests.go +++ b/testutil/manifests.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/schema2" "github.com/opencontainers/go-digest" @@ -12,7 +12,7 @@ import ( // MakeManifestList constructs a manifest list out of a list of manifest digests func MakeManifestList(blobstatter distribution.BlobStatter, manifestDigests []digest.Digest) (*manifestlist.DeserializedManifestList, error) { - ctx := context.Background() + ctx := dcontext.Background() manifestDescriptors := make([]manifestlist.ManifestDescriptor, 0, len(manifestDigests)) for _, manifestDigest := range manifestDigests { @@ -39,7 +39,7 @@ func MakeManifestList(blobstatter distribution.BlobStatter, manifestDigests []di // MakeSchema2Manifest constructs a schema 2 manifest from a given list of digests and returns // the digest of the manifest func MakeSchema2Manifest(repository distribution.Repository, digests []digest.Digest) (distribution.Manifest, error) { - ctx := context.Background() + ctx := dcontext.Background() blobStore := repository.Blobs(ctx) var configJSON []byte diff --git a/testutil/tarfile.go b/testutil/tarfile.go index 2b30a24d..d4bc7e2a 100644 --- a/testutil/tarfile.go +++ b/testutil/tarfile.go @@ -10,7 +10,7 @@ import ( "time" "github.com/distribution/distribution/v3" - "github.com/distribution/distribution/v3/context" + "github.com/distribution/distribution/v3/internal/dcontext" "github.com/opencontainers/go-digest" ) @@ -96,7 +96,7 @@ func CreateRandomLayers(n int) (map[digest.Digest]io.ReadSeeker, error) { // UploadBlobs lets you upload blobs to a repository func UploadBlobs(repository distribution.Repository, layers map[digest.Digest]io.ReadSeeker) error { - ctx := context.Background() + ctx := dcontext.Background() for dgst, rs := range layers { wr, err := repository.Blobs(ctx).Create(ctx) if err != nil {