Hide our misuses of contexts from the public interface (#4128)

This commit is contained in:
Milos Gajdos 2023-11-03 05:05:19 +00:00 committed by GitHub
commit bd0e476910
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
75 changed files with 510 additions and 594 deletions

View file

@ -7,7 +7,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/api/errcode"
) )
@ -279,7 +279,7 @@ func Handler(handler http.Handler) http.Handler {
func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks map[string]string) { func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks map[string]string) {
p, err := json.Marshal(checks) p, err := json.Marshal(checks)
if err != nil { if err != nil {
context.GetLogger(context.Background()).Errorf("error serializing health status: %v", err) dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status: %v", err)
p, err = json.Marshal(struct { p, err = json.Marshal(struct {
ServerError string `json:"server_error"` ServerError string `json:"server_error"`
}{ }{
@ -288,7 +288,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m
status = http.StatusInternalServerError status = http.StatusInternalServerError
if err != nil { if err != nil {
context.GetLogger(context.Background()).Errorf("error serializing health status failure message: %v", err) dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status failure message: %v", err)
return return
} }
} }
@ -297,7 +297,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m
w.Header().Set("Content-Length", fmt.Sprint(len(p))) w.Header().Set("Content-Length", fmt.Sprint(len(p)))
w.WriteHeader(status) w.WriteHeader(status)
if _, err := w.Write(p); err != nil { if _, err := w.Write(p); err != nil {
context.GetLogger(context.Background()).Errorf("error writing health status response body: %v", err) dcontext.GetLogger(dcontext.Background()).Errorf("error writing health status response body: %v", err)
} }
} }

View file

@ -17,7 +17,7 @@ import (
"time" "time"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest" "github.com/distribution/distribution/v3/manifest"
"github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/api/errcode"
@ -108,7 +108,7 @@ func TestBlobServeBlob(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1") repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
@ -157,7 +157,7 @@ func TestBlobServeBlobHEAD(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1") repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
@ -250,7 +250,7 @@ func TestBlobResume(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -307,7 +307,7 @@ func TestBlobDelete(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -327,7 +327,7 @@ func TestBlobFetch(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1") repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
@ -382,7 +382,7 @@ func TestBlobExistsNoContentLength(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -406,7 +406,7 @@ func TestBlobExists(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1") repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
@ -512,7 +512,7 @@ func TestBlobUploadChunked(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -622,7 +622,7 @@ func TestBlobUploadMonolithic(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -728,7 +728,7 @@ func TestBlobUploadMonolithicDockerUploadUUIDFromURL(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -833,7 +833,7 @@ func TestBlobUploadMonolithicNoDockerUploadUUID(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -891,7 +891,7 @@ func TestBlobMount(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -1066,7 +1066,7 @@ func checkEqualManifest(m1, m2 *ocischema.DeserializedManifest) error {
} }
func TestOCIManifestFetch(t *testing.T) { func TestOCIManifestFetch(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo") repo, _ := reference.WithName("test.example.com/repo")
m1, dgst, pl := newRandomOCIManifest(t, 6) m1, dgst, pl := newRandomOCIManifest(t, 6)
var m testutil.RequestResponseMap var m testutil.RequestResponseMap
@ -1149,7 +1149,7 @@ func TestManifestFetchWithEtag(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -1171,7 +1171,7 @@ func TestManifestFetchWithEtag(t *testing.T) {
} }
func TestManifestFetchWithAccept(t *testing.T) { func TestManifestFetchWithAccept(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo") repo, _ := reference.WithName("test.example.com/repo")
_, dgst, _ := newRandomOCIManifest(t, 6) _, dgst, _ := newRandomOCIManifest(t, 6)
headers := make(chan []string, 1) headers := make(chan []string, 1)
@ -1258,7 +1258,7 @@ func TestManifestDelete(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ctx := context.Background() ctx := dcontext.Background()
ms, err := r.Manifests(ctx) ms, err := r.Manifests(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -1315,7 +1315,7 @@ func TestManifestPut(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ctx := context.Background() ctx := dcontext.Background()
ms, err := r.Manifests(ctx) ms, err := r.Manifests(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -1372,7 +1372,7 @@ func TestManifestTags(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ctx := context.Background() ctx := dcontext.Background()
tagService := r.Tags(ctx) tagService := r.Tags(ctx)
tags, err := tagService.All(ctx) tags, err := tagService.All(ctx)
@ -1423,7 +1423,7 @@ func TestTagDelete(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ctx := context.Background() ctx := dcontext.Background()
ts := r.Tags(ctx) ts := r.Tags(ctx)
if err := ts.Untag(ctx, tag); err != nil { if err := ts.Untag(ctx, tag); err != nil {
@ -1460,7 +1460,7 @@ func TestObtainsErrorForMissingTag(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -1487,7 +1487,7 @@ func TestObtainsManifestForTagWithoutHeaders(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
ctx := context.Background() ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil) r, err := NewRepository(repo, e, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -1566,7 +1566,7 @@ func TestManifestTagsPaginated(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ctx := context.Background() ctx := dcontext.Background()
tagService := r.Tags(ctx) tagService := r.Tags(ctx)
tags, err := tagService.All(ctx) tags, err := tagService.All(ctx)
@ -1614,7 +1614,7 @@ func TestManifestUnauthorized(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ctx := context.Background() ctx := dcontext.Background()
ms, err := r.Manifests(ctx) ms, err := r.Manifests(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -1652,7 +1652,7 @@ func TestCatalog(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ctx := context.Background() ctx := dcontext.Background()
numFilled, err := r.Repositories(ctx, entries, "") numFilled, err := r.Repositories(ctx, entries, "")
if err != io.EOF { if err != io.EOF {
t.Fatal(err) t.Fatal(err)
@ -1684,7 +1684,7 @@ func TestCatalogInParts(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ctx := context.Background() ctx := dcontext.Background()
numFilled, err := r.Repositories(ctx, entries, "") numFilled, err := r.Repositories(ctx, entries, "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View file

@ -1,4 +1,4 @@
package context package dcontext
import ( import (
"context" "context"

View file

@ -1,16 +1,16 @@
// Package context provides several utilities for working with // Package dcontext provides several utilities for working with
// Go's context in http requests. Primarily, the focus is on logging relevant // Go's context in http requests. Primarily, the focus is on logging relevant
// request information but this package is not limited to that purpose. // request information but this package is not limited to that purpose.
// //
// The easiest way to get started is to get the background context: // The easiest way to get started is to get the background context:
// //
// ctx := context.Background() // ctx := dcontext.Background()
// //
// The returned context should be passed around your application and be the // The returned context should be passed around your application and be the
// root of all other context instances. If the application has a version, this // root of all other context instances. If the application has a version, this
// line should be called before anything else: // line should be called before anything else:
// //
// ctx := context.WithVersion(context.Background(), version) // ctx := dcontext.WithVersion(dcontext.Background(), version)
// //
// The above will store the version in the context and will be available to // The above will store the version in the context and will be available to
// the logger. // the logger.
@ -27,7 +27,7 @@
// the context and reported with the logger. The following example would // the context and reported with the logger. The following example would
// return a logger that prints the version with each log message: // return a logger that prints the version with each log message:
// //
// ctx := context.Context(context.Background(), "version", version) // ctx := context.WithValue(dcontext.Background(), "version", version)
// GetLogger(ctx, "version").Infof("this log message has a version field") // GetLogger(ctx, "version").Infof("this log message has a version field")
// //
// The above would print out a log message like this: // The above would print out a log message like this:
@ -85,4 +85,4 @@
// can be traced in log messages. Using the fields like "http.request.id", one // can be traced in log messages. Using the fields like "http.request.id", one
// can analyze call flow for a particular request with a simple grep of the // can analyze call flow for a particular request with a simple grep of the
// logs. // logs.
package context package dcontext

View file

@ -1,17 +1,16 @@
package context package dcontext
import ( import (
"context" "context"
"errors" "errors"
"net"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/distribution/distribution/v3/internal/requestutil"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
) )
// Common errors used with this package. // Common errors used with this package.
@ -20,48 +19,6 @@ var (
ErrNoResponseWriterContext = errors.New("no http response in context") ErrNoResponseWriterContext = errors.New("no http response in context")
) )
func parseIP(ipStr string) net.IP {
ip := net.ParseIP(ipStr)
if ip == nil {
log.Warnf("invalid remote IP address: %q", ipStr)
}
return ip
}
// RemoteAddr extracts the remote address of the request, taking into
// account proxy headers.
func RemoteAddr(r *http.Request) string {
if prior := r.Header.Get("X-Forwarded-For"); prior != "" {
remoteAddr, _, _ := strings.Cut(prior, ",")
remoteAddr = strings.Trim(remoteAddr, " ")
if parseIP(remoteAddr) != nil {
return remoteAddr
}
}
// X-Real-Ip is less supported, but worth checking in the
// absence of X-Forwarded-For
if realIP := r.Header.Get("X-Real-Ip"); realIP != "" {
if parseIP(realIP) != nil {
return realIP
}
}
return r.RemoteAddr
}
// RemoteIP extracts the remote IP of the request, taking into
// account proxy headers.
func RemoteIP(r *http.Request) string {
addr := RemoteAddr(r)
// Try parsing it as "IP:port"
if ip, _, err := net.SplitHostPort(addr); err == nil {
return ip
}
return addr
}
// WithRequest places the request on the context. The context of the request // WithRequest places the request on the context. The context of the request
// is assigned a unique id, available at "http.request.id". The request itself // is assigned a unique id, available at "http.request.id". The request itself
// is available at "http.request". Other common attributes are available under // is available at "http.request". Other common attributes are available under
@ -83,16 +40,6 @@ func WithRequest(ctx context.Context, r *http.Request) context.Context {
} }
} }
// GetRequest returns the http request in the given context. Returns
// ErrNoRequestContext if the context does not have an http request associated
// with it.
func GetRequest(ctx context.Context) (*http.Request, error) {
if r, ok := ctx.Value("http.request").(*http.Request); r != nil && ok {
return r, nil
}
return nil, ErrNoRequestContext
}
// GetRequestID attempts to resolve the current request id, if possible. An // GetRequestID attempts to resolve the current request id, if possible. An
// error is return if it is not available on the context. // error is return if it is not available on the context.
func GetRequestID(ctx context.Context) string { func GetRequestID(ctx context.Context) string {
@ -193,7 +140,7 @@ func (ctx *httpRequestContext) Value(key interface{}) interface{} {
case "http.request.uri": case "http.request.uri":
return ctx.r.RequestURI return ctx.r.RequestURI
case "http.request.remoteaddr": case "http.request.remoteaddr":
return RemoteAddr(ctx.r) return requestutil.RemoteAddr(ctx.r)
case "http.request.method": case "http.request.method":
return ctx.r.Method return ctx.r.Method
case "http.request.host": case "http.request.host":

View file

@ -1,10 +1,7 @@
package context package dcontext
import ( import (
"net/http" "net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -219,70 +216,3 @@ func TestWithVars(t *testing.T) {
} }
} }
} }
// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
// at the transport layer to 127.0.0.1:<port> . However, as the X-Forwarded-For header
// just contains the IP address, it is different enough for testing.
func TestRemoteAddr(t *testing.T) {
var expectedRemote string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if r.RemoteAddr == expectedRemote {
t.Errorf("Unexpected matching remote addresses")
}
actualRemote := RemoteAddr(r)
if expectedRemote != actualRemote {
t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
}
w.WriteHeader(200)
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxy := httputil.NewSingleHostReverseProxy(backendURL)
frontend := httptest.NewServer(proxy)
defer frontend.Close()
// X-Forwarded-For set by proxy
expectedRemote = "127.0.0.1"
proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultClient.Do(proxyReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// RemoteAddr in X-Real-Ip
getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil)
if err != nil {
t.Fatal(err)
}
expectedRemote = "1.2.3.4"
getReq.Header["X-Real-ip"] = []string{expectedRemote}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// Valid X-Real-Ip and invalid X-Forwarded-For
getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
}

View file

@ -1,4 +1,4 @@
package context package dcontext
import ( import (
"context" "context"

View file

@ -1,4 +1,4 @@
package context package dcontext
import ( import (
"context" "context"

View file

@ -1,4 +1,4 @@
package context package dcontext
import ( import (
"runtime" "runtime"

View file

@ -1,4 +1,4 @@
package context package dcontext
import ( import (
"context" "context"

View file

@ -1,4 +1,4 @@
package context package dcontext
import "context" import "context"

View file

@ -1,4 +1,4 @@
package context package dcontext
import "testing" import "testing"

View file

@ -0,0 +1,51 @@
package requestutil
import (
"net"
"net/http"
"strings"
log "github.com/sirupsen/logrus"
)
func parseIP(ipStr string) net.IP {
ip := net.ParseIP(ipStr)
if ip == nil {
log.Warnf("invalid remote IP address: %q", ipStr)
}
return ip
}
// RemoteAddr extracts the remote address of the request, taking into
// account proxy headers.
func RemoteAddr(r *http.Request) string {
if prior := r.Header.Get("X-Forwarded-For"); prior != "" {
remoteAddr, _, _ := strings.Cut(prior, ",")
remoteAddr = strings.Trim(remoteAddr, " ")
if parseIP(remoteAddr) != nil {
return remoteAddr
}
}
// X-Real-Ip is less supported, but worth checking in the
// absence of X-Forwarded-For
if realIP := r.Header.Get("X-Real-Ip"); realIP != "" {
if parseIP(realIP) != nil {
return realIP
}
}
return r.RemoteAddr
}
// RemoteIP extracts the remote IP of the request, taking into
// account proxy headers.
func RemoteIP(r *http.Request) string {
addr := RemoteAddr(r)
// Try parsing it as "IP:port"
if ip, _, err := net.SplitHostPort(addr); err == nil {
return ip
}
return addr
}

View file

@ -0,0 +1,76 @@
package requestutil
import (
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"testing"
)
// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
// at the transport layer to 127.0.0.1:<port> . However, as the X-Forwarded-For header
// just contains the IP address, it is different enough for testing.
func TestRemoteAddr(t *testing.T) {
var expectedRemote string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if r.RemoteAddr == expectedRemote {
t.Errorf("Unexpected matching remote addresses")
}
actualRemote := RemoteAddr(r)
if expectedRemote != actualRemote {
t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
}
w.WriteHeader(200)
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxy := httputil.NewSingleHostReverseProxy(backendURL)
frontend := httptest.NewServer(proxy)
defer frontend.Close()
// X-Forwarded-For set by proxy
expectedRemote = "127.0.0.1"
proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultClient.Do(proxyReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// RemoteAddr in X-Real-Ip
getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil)
if err != nil {
t.Fatal(err)
}
expectedRemote = "1.2.3.4"
getReq.Header["X-Real-ip"] = []string{expectedRemote}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// Valid X-Real-Ip and invalid X-Forwarded-For
getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
}

View file

@ -5,7 +5,7 @@ import (
"time" "time"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/requestutil"
"github.com/distribution/reference" "github.com/distribution/reference"
events "github.com/docker/go-events" events "github.com/docker/go-events"
"github.com/google/uuid" "github.com/google/uuid"
@ -49,7 +49,7 @@ func NewBridge(ub URLBuilder, source SourceRecord, actor ActorRecord, request Re
func NewRequestRecord(id string, r *http.Request) RequestRecord { func NewRequestRecord(id string, r *http.Request) RequestRecord {
return RequestRecord{ return RequestRecord{
ID: id, ID: id,
Addr: context.RemoteAddr(r), Addr: requestutil.RemoteAddr(r),
Host: r.Host, Host: r.Host,
Method: r.Method, Method: r.Method,
UserAgent: r.UserAgent(), UserAgent: r.UserAgent(),

View file

@ -7,7 +7,7 @@ import (
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/reference" "github.com/distribution/reference"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
) )

View file

@ -6,7 +6,7 @@ import (
"testing" "testing"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/manifest/schema2"
"github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage"
"github.com/distribution/distribution/v3/registry/storage/cache/memory" "github.com/distribution/distribution/v3/registry/storage/cache/memory"

View file

@ -18,7 +18,7 @@
// resource := auth.Resource{Type: "customerOrder", Name: orderNumber} // resource := auth.Resource{Type: "customerOrder", Name: orderNumber}
// access := auth.Access{Resource: resource, Action: "update"} // access := auth.Access{Resource: resource, Action: "update"}
// //
// if ctx, err := accessController.Authorized(ctx, access); err != nil { // if ctx, err := accessController.Authorized(r, access); err != nil {
// if challenge, ok := err.(auth.Challenge) { // if challenge, ok := err.(auth.Challenge) {
// // Let the challenge write the response. // // Let the challenge write the response.
// challenge.SetHeaders(r, w) // challenge.SetHeaders(r, w)
@ -32,22 +32,11 @@
package auth package auth
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
) )
const (
// UserKey is used to get the user object from
// a user context
UserKey = "auth.user"
// UserNameKey is used to get the user name from
// a user context
UserNameKey = "auth.user.name"
)
var ( var (
// ErrInvalidCredential is returned when the auth token does not authenticate correctly. // ErrInvalidCredential is returned when the auth token does not authenticate correctly.
ErrInvalidCredential = errors.New("invalid authorization credential") ErrInvalidCredential = errors.New("invalid authorization credential")
@ -76,6 +65,12 @@ type Access struct {
Action string Action string
} }
// Grant describes the permitted level of access for an authorized request.
type Grant struct {
User UserInfo // The authenticated user for the request.
Resources []Resource // The list of resources which have been authorized for the request.
}
// Challenge is a special error type which is used for HTTP 401 Unauthorized // Challenge is a special error type which is used for HTTP 401 Unauthorized
// responses and is able to write the response with WWW-Authenticate challenge // responses and is able to write the response with WWW-Authenticate challenge
// header values based on the error. // header values based on the error.
@ -93,16 +88,15 @@ type Challenge interface {
// and required access levels for a request. Implementations can support both // and required access levels for a request. Implementations can support both
// complete denial and http authorization challenges. // complete denial and http authorization challenges.
type AccessController interface { type AccessController interface {
// Authorized returns a non-nil error if the context is granted access and // Authorized determines if the request is granted access. If one or more
// returns a new authorized context. If one or more Access structs are // Access structs are provided, the requested access will be compared with
// provided, the requested access will be compared with what is available // what is available to the request.
// to the context. The given context will contain a "http.request" key with //
// a `*http.Request` value. If the error is non-nil, access should always // Return a Grant to grant the request access. Return an error to deny
// be denied. The error may be of type Challenge, in which case the caller // access. The error may be of type Challenge, in which case the caller may
// may have the Challenge handle the request or choose what action to take // have the Challenge handle the request or choose what action to take based
// based on the Challenge header or response status. The returned context // on the Challenge header or response status.
// object should have a "auth.user" value set to a UserInfo struct. Authorized(r *http.Request, access ...Access) (*Grant, error)
Authorized(ctx context.Context, access ...Access) (context.Context, error)
} }
// CredentialAuthenticator is an object which is able to authenticate credentials // CredentialAuthenticator is an object which is able to authenticate credentials
@ -110,63 +104,6 @@ type CredentialAuthenticator interface {
AuthenticateUser(username, password string) error AuthenticateUser(username, password string) error
} }
// WithUser returns a context with the authorized user info.
func WithUser(ctx context.Context, user UserInfo) context.Context {
return userInfoContext{
Context: ctx,
user: user,
}
}
type userInfoContext struct {
context.Context
user UserInfo
}
func (uic userInfoContext) Value(key interface{}) interface{} {
switch key {
case UserKey:
return uic.user
case UserNameKey:
return uic.user.Name
}
return uic.Context.Value(key)
}
// WithResources returns a context with the authorized resources.
func WithResources(ctx context.Context, resources []Resource) context.Context {
return resourceContext{
Context: ctx,
resources: resources,
}
}
type resourceContext struct {
context.Context
resources []Resource
}
type resourceKey struct{}
func (rc resourceContext) Value(key interface{}) interface{} {
if key == (resourceKey{}) {
return rc.resources
}
return rc.Context.Value(key)
}
// AuthorizedResources returns the list of resources which have
// been authorized for this request.
func AuthorizedResources(ctx context.Context) []Resource {
if resources, ok := ctx.Value(resourceKey{}).([]Resource); ok {
return resources
}
return nil
}
// InitFunc is the type of an AccessController factory function and is used // InitFunc is the type of an AccessController factory function and is used
// to register the constructor for different AccesController backends. // to register the constructor for different AccesController backends.
type InitFunc func(options map[string]interface{}) (AccessController, error) type InitFunc func(options map[string]interface{}) (AccessController, error)

View file

@ -18,7 +18,7 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
) )
@ -49,12 +49,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
return &accessController{realm: realm.(string), path: path}, nil return &accessController{realm: realm.(string), path: path}, nil
} }
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (*auth.Grant, error) {
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
username, password, ok := req.BasicAuth() username, password, ok := req.BasicAuth()
if !ok { if !ok {
return nil, &challenge{ return nil, &challenge{
@ -92,14 +87,14 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
ac.mu.Unlock() ac.mu.Unlock()
if err := localHTPasswd.authenticateUser(username, password); err != nil { if err := localHTPasswd.authenticateUser(username, password); err != nil {
dcontext.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err) dcontext.GetLogger(req.Context()).Errorf("error authenticating user %q: %v", username, err)
return nil, &challenge{ return nil, &challenge{
realm: ac.realm, realm: ac.realm,
err: auth.ErrAuthenticationFailure, err: auth.ErrAuthenticationFailure,
} }
} }
return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil return &auth.Grant{User: auth.UserInfo{Name: username}}, nil
} }
// challenge implements the auth.Challenge interface. // challenge implements the auth.Challenge interface.

View file

@ -8,7 +8,6 @@ import (
"os" "os"
"testing" "testing"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
) )
@ -33,7 +32,6 @@ func TestBasicAccessController(t *testing.T) {
"realm": testRealm, "realm": testRealm,
"path": tempFile.Name(), "path": tempFile.Name(),
} }
ctx := context.Background()
accessController, err := newAccessController(options) accessController, err := newAccessController(options)
if err != nil { if err != nil {
@ -45,8 +43,7 @@ func TestBasicAccessController(t *testing.T) {
userNumber := 0 userNumber := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithRequest(ctx, r) grant, err := accessController.Authorized(r)
authCtx, err := accessController.Authorized(ctx)
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case auth.Challenge: case auth.Challenge:
@ -58,13 +55,12 @@ func TestBasicAccessController(t *testing.T) {
} }
} }
userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo) if grant == nil {
if !ok { t.Fatal("basic accessController did not return auth grant")
t.Fatal("basic accessController did not set auth.user context")
} }
if userInfo.Name != testUsers[userNumber] { if grant.User.Name != testUsers[userNumber] {
t.Fatalf("expected user name %q, got %q", testUsers[userNumber], userInfo.Name) t.Fatalf("expected user name %q, got %q", testUsers[userNumber], grant.User.Name)
} }
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)

View file

@ -8,12 +8,10 @@
package silly package silly
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
) )
@ -43,12 +41,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized simply checks for the existence of the authorization header, // Authorized simply checks for the existence of the authorization header,
// responding with a bearer challenge if it doesn't exist. // responding with a bearer challenge if it doesn't exist.
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (*auth.Grant, error) {
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
if req.Header.Get("Authorization") == "" { if req.Header.Get("Authorization") == "" {
challenge := challenge{ challenge := challenge{
realm: ac.realm, realm: ac.realm,
@ -66,10 +59,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
return nil, &challenge return nil, &challenge
} }
ctx = auth.WithUser(ctx, auth.UserInfo{Name: "silly"}) return &auth.Grant{User: auth.UserInfo{Name: "silly"}}, nil
ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey))
return ctx, nil
} }
type challenge struct { type challenge struct {

View file

@ -5,7 +5,6 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
) )
@ -16,8 +15,7 @@ func TestSillyAccessController(t *testing.T) {
} }
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithRequest(context.Background(), r) grant, err := ac.Authorized(r)
authCtx, err := ac.Authorized(ctx)
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case auth.Challenge: case auth.Challenge:
@ -29,13 +27,12 @@ func TestSillyAccessController(t *testing.T) {
} }
} }
userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo) if grant == nil {
if !ok { t.Fatal("silly accessController did not return auth grant")
t.Fatal("silly accessController did not set auth.user context")
} }
if userInfo.Name != "silly" { if grant.User.Name != "silly" {
t.Fatalf("expected user name %q, got %q", "silly", userInfo.Name) t.Fatalf("expected user name %q, got %q", "silly", grant.User.Name)
} }
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)

View file

@ -1,7 +1,6 @@
package token package token
import ( import (
"context"
"crypto" "crypto"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
@ -13,7 +12,6 @@ import (
"os" "os"
"strings" "strings"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
) )
@ -292,7 +290,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized handles checking whether the given request is authorized // Authorized handles checking whether the given request is authorized
// for actions on resources described by the given access items. // for actions on resources described by the given access items.
func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) { func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (*auth.Grant, error) {
challenge := &authChallenge{ challenge := &authChallenge{
realm: ac.realm, realm: ac.realm,
autoRedirect: ac.autoRedirect, autoRedirect: ac.autoRedirect,
@ -300,11 +298,6 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
accessSet: newAccessSet(accessItems...), accessSet: newAccessSet(accessItems...),
} }
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ") prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ")
if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") { if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") {
challenge.err = ErrTokenRequired challenge.err = ErrTokenRequired
@ -338,9 +331,10 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
} }
} }
ctx = auth.WithResources(ctx, claims.resources()) return &auth.Grant{
User: auth.UserInfo{Name: claims.Subject},
return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil Resources: claims.resources(),
}, nil
} }
// init handles registering the token auth backend. // init handles registering the token auth backend.

View file

@ -18,7 +18,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt" "github.com/go-jose/go-jose/v3/jwt"
@ -466,8 +465,7 @@ func TestAccessController(t *testing.T) {
Action: "baz", Action: "baz",
} }
ctx := context.WithRequest(context.Background(), req) grant, err := accessController.Authorized(req, testAccess)
authCtx, err := accessController.Authorized(ctx, testAccess)
challenge, ok := err.(auth.Challenge) challenge, ok := err.(auth.Challenge)
if !ok { if !ok {
t.Fatal("accessController did not return a challenge") t.Fatal("accessController did not return a challenge")
@ -477,8 +475,8 @@ func TestAccessController(t *testing.T) {
t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired) t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
} }
if authCtx != nil { if grant != nil {
t.Fatalf("expected nil auth context but got %s", authCtx) t.Fatalf("expected nil auth grant but got %#v", grant)
} }
// 2. Supply an invalid token. // 2. Supply an invalid token.
@ -502,7 +500,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess) grant, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge) challenge, ok = err.(auth.Challenge)
if !ok { if !ok {
t.Fatal("accessController did not return a challenge") t.Fatal("accessController did not return a challenge")
@ -512,8 +510,8 @@ func TestAccessController(t *testing.T) {
t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired) t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
} }
if authCtx != nil { if grant != nil {
t.Fatalf("expected nil auth context but got %s", authCtx) t.Fatalf("expected nil auth grant but got %#v", grant)
} }
// create a valid jwk // create a valid jwk
@ -534,7 +532,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess) grant, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge) challenge, ok = err.(auth.Challenge)
if !ok { if !ok {
t.Fatal("accessController did not return a challenge") t.Fatal("accessController did not return a challenge")
@ -544,8 +542,8 @@ func TestAccessController(t *testing.T) {
t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrInsufficientScope) t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrInsufficientScope)
} }
if authCtx != nil { if grant != nil {
t.Fatalf("expected nil auth context but got %s", authCtx) t.Fatalf("expected nil auth grant but got %#v", grant)
} }
// 4. Supply the token we need, or deserve, or whatever. // 4. Supply the token we need, or deserve, or whatever.
@ -564,18 +562,13 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess) grant, err = accessController.Authorized(req, testAccess)
if err != nil { if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err) t.Fatalf("accessController returned unexpected error: %s", err)
} }
userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo) if grant.User.Name != "foo" {
if !ok { t.Fatalf("expected user name %q, got %q", "foo", grant.User.Name)
t.Fatal("token accessController did not set auth.user context")
}
if userInfo.Name != "foo" {
t.Fatalf("expected user name %q, got %q", "foo", userInfo.Name)
} }
// 5. Supply a token with full admin rights, which is represented as "*". // 5. Supply a token with full admin rights, which is represented as "*".
@ -594,7 +587,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
_, err = accessController.Authorized(ctx, testAccess) _, err = accessController.Authorized(req, testAccess)
if err != nil { if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err) t.Fatalf("accessController returned unexpected error: %s", err)
} }

View file

@ -19,9 +19,9 @@ import (
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/configuration" "github.com/distribution/distribution/v3/configuration"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/health" "github.com/distribution/distribution/v3/health"
"github.com/distribution/distribution/v3/health/checks" "github.com/distribution/distribution/v3/health/checks"
"github.com/distribution/distribution/v3/internal/dcontext"
prometheus "github.com/distribution/distribution/v3/metrics" prometheus "github.com/distribution/distribution/v3/metrics"
"github.com/distribution/distribution/v3/notifications" "github.com/distribution/distribution/v3/notifications"
"github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/api/errcode"
@ -635,7 +635,7 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
} }
// Add username to request logging // Add username to request logging
context.Context = dcontext.WithLogger(context.Context, dcontext.GetLogger(context.Context, auth.UserNameKey)) context.Context = dcontext.WithLogger(context.Context, dcontext.GetLogger(context.Context, userNameKey))
// sync up context on the request. // sync up context on the request.
r = r.WithContext(context) r = r.WithContext(context)
@ -797,7 +797,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont
accessRecords = appendCatalogAccessRecord(accessRecords, r) accessRecords = appendCatalogAccessRecord(accessRecords, r)
} }
ctx, err := app.accessController.Authorized(context.Context, accessRecords...) grant, err := app.accessController.Authorized(r.WithContext(context.Context), accessRecords...)
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case auth.Challenge: case auth.Challenge:
@ -818,8 +818,14 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont
return err return err
} }
if grant == nil {
return fmt.Errorf("access controller returned neither an access grant nor an error")
}
dcontext.GetLogger(ctx, auth.UserNameKey).Info("authorized request") ctx := withUser(context.Context, grant.User)
ctx = withResources(ctx, grant.Resources)
dcontext.GetLogger(ctx, userNameKey).Info("authorized request")
// TODO(stevvooe): This pattern needs to be cleaned up a bit. One context // TODO(stevvooe): This pattern needs to be cleaned up a bit. One context
// should be replaced by another, rather than replacing the context on a // should be replaced by another, rather than replacing the context on a
// mutable object. // mutable object.

View file

@ -9,7 +9,7 @@ import (
"testing" "testing"
"github.com/distribution/distribution/v3/configuration" "github.com/distribution/distribution/v3/configuration"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/api/errcode"
v2 "github.com/distribution/distribution/v3/registry/api/v2" v2 "github.com/distribution/distribution/v3/registry/api/v2"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
@ -25,7 +25,7 @@ import (
// tested individually. // tested individually.
func TestAppDispatcher(t *testing.T) { func TestAppDispatcher(t *testing.T) {
driver := inmemory.New() driver := inmemory.New()
ctx := context.Background() ctx := dcontext.Background()
registry, err := storage.NewRegistry(ctx, driver, storage.BlobDescriptorCacheProvider(memorycache.NewInMemoryBlobDescriptorCacheProvider(0)), storage.EnableDelete, storage.EnableRedirect) registry, err := storage.NewRegistry(ctx, driver, storage.BlobDescriptorCacheProvider(memorycache.NewInMemoryBlobDescriptorCacheProvider(0)), storage.EnableDelete, storage.EnableRedirect)
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
@ -139,7 +139,7 @@ func TestAppDispatcher(t *testing.T) {
// TestNewApp covers the creation of an application via NewApp with a // TestNewApp covers the creation of an application via NewApp with a
// configuration. // configuration.
func TestNewApp(t *testing.T) { func TestNewApp(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
config := configuration.Configuration{ config := configuration.Configuration{
Storage: configuration.Storage{ Storage: configuration.Storage{
"inmemory": nil, "inmemory": nil,

View file

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/api/errcode"
"github.com/gorilla/handlers" "github.com/gorilla/handlers"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
@ -53,7 +53,7 @@ type blobHandler struct {
// GetBlob fetches the binary data from backend storage returns it in the // GetBlob fetches the binary data from backend storage returns it in the
// response. // response.
func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) { func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) {
context.GetLogger(bh).Debug("GetBlob") dcontext.GetLogger(bh).Debug("GetBlob")
blobs := bh.Repository.Blobs(bh) blobs := bh.Repository.Blobs(bh)
desc, err := blobs.Stat(bh, bh.Digest) desc, err := blobs.Stat(bh, bh.Digest)
if err != nil { if err != nil {
@ -66,7 +66,7 @@ func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) {
} }
if err := blobs.ServeBlob(bh, w, r, desc.Digest); err != nil { if err := blobs.ServeBlob(bh, w, r, desc.Digest); err != nil {
context.GetLogger(bh).Debugf("unexpected error getting blob HTTP handler: %v", err) dcontext.GetLogger(bh).Debugf("unexpected error getting blob HTTP handler: %v", err)
bh.Errors = append(bh.Errors, errcode.ErrorCodeUnknown.WithDetail(err)) bh.Errors = append(bh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return return
} }
@ -74,7 +74,7 @@ func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) {
// DeleteBlob deletes a layer blob // DeleteBlob deletes a layer blob
func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) { func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) {
context.GetLogger(bh).Debug("DeleteBlob") dcontext.GetLogger(bh).Debug("DeleteBlob")
blobs := bh.Repository.Blobs(bh) blobs := bh.Repository.Blobs(bh)
err := blobs.Delete(bh, bh.Digest) err := blobs.Delete(bh, bh.Digest)
@ -88,7 +88,7 @@ func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) {
return return
default: default:
bh.Errors = append(bh.Errors, err) bh.Errors = append(bh.Errors, err)
context.GetLogger(bh).Errorf("Unknown error deleting blob: %s", err.Error()) dcontext.GetLogger(bh).Errorf("Unknown error deleting blob: %s", err.Error())
return return
} }
} }

View file

@ -7,7 +7,7 @@ import (
"strconv" "strconv"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/api/errcode"
"github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage"
"github.com/distribution/reference" "github.com/distribution/reference"

View file

@ -6,7 +6,7 @@ import (
"net/http" "net/http"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/api/errcode"
v2 "github.com/distribution/distribution/v3/registry/api/v2" v2 "github.com/distribution/distribution/v3/registry/api/v2"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
@ -77,10 +77,20 @@ func getUploadUUID(ctx context.Context) (uuid string) {
return dcontext.GetStringValue(ctx, "vars.uuid") return dcontext.GetStringValue(ctx, "vars.uuid")
} }
const (
// userKey is used to get the user object from
// a user context
userKey = "auth.user"
// userNameKey is used to get the user name from
// a user context
userNameKey = "auth.user.name"
)
// getUserName attempts to resolve a username from the context and request. If // getUserName attempts to resolve a username from the context and request. If
// a username cannot be resolved, the empty string is returned. // a username cannot be resolved, the empty string is returned.
func getUserName(ctx context.Context, r *http.Request) string { func getUserName(ctx context.Context, r *http.Request) string {
username := dcontext.GetStringValue(ctx, auth.UserNameKey) username := dcontext.GetStringValue(ctx, userNameKey)
// Fallback to request user with basic auth // Fallback to request user with basic auth
if username == "" { if username == "" {
@ -93,3 +103,60 @@ func getUserName(ctx context.Context, r *http.Request) string {
return username return username
} }
// withUser returns a context with the authorized user info.
func withUser(ctx context.Context, user auth.UserInfo) context.Context {
return userInfoContext{
Context: ctx,
user: user,
}
}
type userInfoContext struct {
context.Context
user auth.UserInfo
}
func (uic userInfoContext) Value(key interface{}) interface{} {
switch key {
case userKey:
return uic.user
case userNameKey:
return uic.user.Name
}
return uic.Context.Value(key)
}
// withResources returns a context with the authorized resources.
func withResources(ctx context.Context, resources []auth.Resource) context.Context {
return resourceContext{
Context: ctx,
resources: resources,
}
}
type resourceContext struct {
context.Context
resources []auth.Resource
}
type resourceKey struct{}
func (rc resourceContext) Value(key interface{}) interface{} {
if key == (resourceKey{}) {
return rc.resources
}
return rc.Context.Value(key)
}
// authorizedResources returns the list of resources which have
// been authorized for this request.
func authorizedResources(ctx context.Context) []auth.Resource {
if resources, ok := ctx.Value(resourceKey{}).([]auth.Resource); ok {
return resources
}
return nil
}

View file

@ -9,8 +9,8 @@ import (
"time" "time"
"github.com/distribution/distribution/v3/configuration" "github.com/distribution/distribution/v3/configuration"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/health" "github.com/distribution/distribution/v3/health"
"github.com/distribution/distribution/v3/internal/dcontext"
) )
func TestFileHealthCheck(t *testing.T) { func TestFileHealthCheck(t *testing.T) {
@ -39,7 +39,7 @@ func TestFileHealthCheck(t *testing.T) {
}, },
} }
ctx := context.Background() ctx := dcontext.Background()
app := NewApp(ctx, config) app := NewApp(ctx, config)
healthRegistry := health.NewRegistry() healthRegistry := health.NewRegistry()
@ -103,7 +103,7 @@ func TestTCPHealthCheck(t *testing.T) {
}, },
} }
ctx := context.Background() ctx := dcontext.Background()
app := NewApp(ctx, config) app := NewApp(ctx, config)
healthRegistry := health.NewRegistry() healthRegistry := health.NewRegistry()
@ -165,7 +165,7 @@ func TestHTTPHealthCheck(t *testing.T) {
}, },
} }
ctx := context.Background() ctx := dcontext.Background()
app := NewApp(ctx, config) app := NewApp(ctx, config)
healthRegistry := health.NewRegistry() healthRegistry := health.NewRegistry()

View file

@ -9,7 +9,7 @@ import (
"strconv" "strconv"
"strings" "strings"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
) )
// closeResources closes all the provided resources after running the target // closeResources closes all the provided resources after running the target

View file

@ -8,12 +8,11 @@ import (
"strings" "strings"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/manifestlist"
"github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/manifest/schema2"
"github.com/distribution/distribution/v3/registry/api/errcode" "github.com/distribution/distribution/v3/registry/api/errcode"
"github.com/distribution/distribution/v3/registry/auth"
"github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/reference" "github.com/distribution/reference"
"github.com/gorilla/handlers" "github.com/gorilla/handlers"
@ -394,7 +393,7 @@ func (imh *manifestHandler) applyResourcePolicy(manifest distribution.Manifest)
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("registry does not allow %s manifest", class)) return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("registry does not allow %s manifest", class))
} }
resources := auth.AuthorizedResources(imh) resources := authorizedResources(imh)
n := imh.Repository.Named().Name() n := imh.Repository.Named().Name()
var foundResource bool var foundResource bool

View file

@ -5,9 +5,9 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/client/auth" "github.com/distribution/distribution/v3/internal/client/auth"
"github.com/distribution/distribution/v3/internal/client/auth/challenge" "github.com/distribution/distribution/v3/internal/client/auth/challenge"
"github.com/distribution/distribution/v3/internal/dcontext"
) )
const challengeHeader = "Docker-Distribution-Api-Version" const challengeHeader = "Docker-Distribution-Api-Version"
@ -44,7 +44,7 @@ func configureAuth(username, password, remoteURL string) (auth.CredentialStore,
} }
for _, url := range authURLs { for _, url := range authURLs {
context.GetLogger(context.Background()).Infof("Discovered token authentication URL: %s", url) dcontext.GetLogger(dcontext.Background()).Infof("Discovered token authentication URL: %s", url)
creds[url] = userpass{ creds[url] = userpass{
username: username, username: username,
password: password, password: password,

View file

@ -11,7 +11,7 @@ import (
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/proxy/scheduler" "github.com/distribution/distribution/v3/registry/proxy/scheduler"
"github.com/distribution/reference" "github.com/distribution/reference"
) )

View file

@ -7,7 +7,7 @@ import (
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/proxy/scheduler" "github.com/distribution/distribution/v3/registry/proxy/scheduler"
"github.com/distribution/reference" "github.com/distribution/reference"
) )

View file

@ -10,11 +10,11 @@ import (
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/configuration" "github.com/distribution/distribution/v3/configuration"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/client" "github.com/distribution/distribution/v3/internal/client"
"github.com/distribution/distribution/v3/internal/client/auth" "github.com/distribution/distribution/v3/internal/client/auth"
"github.com/distribution/distribution/v3/internal/client/auth/challenge" "github.com/distribution/distribution/v3/internal/client/auth/challenge"
"github.com/distribution/distribution/v3/internal/client/transport" "github.com/distribution/distribution/v3/internal/client/transport"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/proxy/scheduler" "github.com/distribution/distribution/v3/registry/proxy/scheduler"
"github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage"
"github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver"

View file

@ -7,7 +7,7 @@ import (
"sync" "sync"
"time" "time"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/reference" "github.com/distribution/reference"
) )

View file

@ -6,7 +6,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
"github.com/distribution/reference" "github.com/distribution/reference"
) )
@ -40,7 +40,7 @@ func TestSchedule(t *testing.T) {
} }
var mu sync.Mutex var mu sync.Mutex
s := New(context.Background(), inmemory.New(), "/ttl") s := New(dcontext.Background(), inmemory.New(), "/ttl")
deleteFunc := func(repoName reference.Reference) error { deleteFunc := func(repoName reference.Reference) error {
if len(remainingRepos) == 0 { if len(remainingRepos) == 0 {
t.Fatalf("Incorrect expiry count") t.Fatalf("Incorrect expiry count")
@ -123,14 +123,14 @@ func TestRestoreOld(t *testing.T) {
t.Fatalf("Error serializing test data: %s", err.Error()) t.Fatalf("Error serializing test data: %s", err.Error())
} }
ctx := context.Background() ctx := dcontext.Background()
pathToStatFile := "/ttl" pathToStatFile := "/ttl"
fs := inmemory.New() fs := inmemory.New()
err = fs.PutContent(ctx, pathToStatFile, serialized) err = fs.PutContent(ctx, pathToStatFile, serialized)
if err != nil { if err != nil {
t.Fatal("Unable to write serialized data to fs") t.Fatal("Unable to write serialized data to fs")
} }
s := New(context.Background(), fs, "/ttl") s := New(dcontext.Background(), fs, "/ttl")
s.OnBlobExpire(deleteFunc) s.OnBlobExpire(deleteFunc)
err = s.Start() err = s.Start()
if err != nil { if err != nil {
@ -165,7 +165,7 @@ func TestStopRestore(t *testing.T) {
fs := inmemory.New() fs := inmemory.New()
pathToStateFile := "/ttl" pathToStateFile := "/ttl"
s := New(context.Background(), fs, pathToStateFile) s := New(dcontext.Background(), fs, pathToStateFile)
s.onBlobExpire = deleteFunc s.onBlobExpire = deleteFunc
err := s.Start() err := s.Start()
@ -181,7 +181,7 @@ func TestStopRestore(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
// v2 will restore state from fs // v2 will restore state from fs
s2 := New(context.Background(), fs, pathToStateFile) s2 := New(dcontext.Background(), fs, pathToStateFile)
s2.onBlobExpire = deleteFunc s2.onBlobExpire = deleteFunc
err = s2.Start() err = s2.Start()
if err != nil { if err != nil {
@ -197,7 +197,7 @@ func TestStopRestore(t *testing.T) {
} }
func TestDoubleStart(t *testing.T) { func TestDoubleStart(t *testing.T) {
s := New(context.Background(), inmemory.New(), "/ttl") s := New(dcontext.Background(), inmemory.New(), "/ttl")
err := s.Start() err := s.Start()
if err != nil { if err != nil {
t.Fatalf("Unable to start scheduler") t.Fatalf("Unable to start scheduler")

View file

@ -21,8 +21,8 @@ import (
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"github.com/distribution/distribution/v3/configuration" "github.com/distribution/distribution/v3/configuration"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/health" "github.com/distribution/distribution/v3/health"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/handlers" "github.com/distribution/distribution/v3/registry/handlers"
"github.com/distribution/distribution/v3/registry/listener" "github.com/distribution/distribution/v3/registry/listener"
"github.com/distribution/distribution/v3/version" "github.com/distribution/distribution/v3/version"

View file

@ -25,7 +25,7 @@ import (
"time" "time"
"github.com/distribution/distribution/v3/configuration" "github.com/distribution/distribution/v3/configuration"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
_ "github.com/distribution/distribution/v3/registry/storage/driver/inmemory" _ "github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"

View file

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage" "github.com/distribution/distribution/v3/registry/storage"
"github.com/distribution/distribution/v3/registry/storage/driver/factory" "github.com/distribution/distribution/v3/registry/storage/driver/factory"
"github.com/distribution/distribution/v3/version" "github.com/distribution/distribution/v3/version"

View file

@ -20,7 +20,7 @@ type blobServer struct {
driver driver.StorageDriver driver driver.StorageDriver
statter distribution.BlobStatter statter distribution.BlobStatter
pathFn func(dgst digest.Digest) (string, error) pathFn func(dgst digest.Digest) (string, error)
redirect bool // allows disabling URLFor redirects redirect bool // allows disabling RedirectURL redirects
} }
func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error { func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
@ -35,19 +35,16 @@ func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *h
} }
if bs.redirect { if bs.redirect {
redirectURL, err := bs.driver.URLFor(ctx, path, map[string]interface{}{"method": r.Method}) redirectURL, err := bs.driver.RedirectURL(r, path)
switch err.(type) { if err != nil {
case nil:
// Redirect to storage URL.
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
return err
case driver.ErrUnsupportedMethod:
// Fallback to serving the content directly.
default:
// Some unexpected error.
return err return err
} }
if redirectURL != "" {
// Redirect to storage URL.
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
return nil
}
// Fallback to serving the content directly.
} }
br, err := newFileReader(ctx, bs.driver, path, desc.Size) br, err := newFileReader(ctx, bs.driver, path, desc.Size)

View file

@ -6,7 +6,7 @@ import (
"path" "path"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
) )

View file

@ -9,7 +9,7 @@ import (
"time" "time"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"

View file

@ -4,7 +4,7 @@
package storage package storage
import ( import (
"github.com/distribution/distribution/v3/context" "context"
) )
// resumeHashAt is a noop when resumable digest support is disabled. // resumeHashAt is a noop when resumable digest support is disabled.

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
prometheus "github.com/distribution/distribution/v3/metrics" prometheus "github.com/distribution/distribution/v3/metrics"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
) )

View file

@ -8,6 +8,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"net/http"
"strings" "strings"
"time" "time"
@ -302,7 +303,7 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
// Move moves an object stored at sourcePath to destPath, removing the original // Move moves an object stored at sourcePath to destPath, removing the original
// object. // object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error { func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
sourceBlobURL, err := d.URLFor(ctx, sourcePath, nil) sourceBlobURL, err := d.signBlobURL(ctx, sourcePath)
if err != nil { if err != nil {
return err return err
} }
@ -382,18 +383,15 @@ func (d *driver) Delete(ctx context.Context, path string) error {
return nil return nil
} }
// URLFor returns a publicly accessible URL for the blob stored at given path // RedirectURL returns a publicly accessible URL for the blob stored at given path
// for specified duration by making use of Azure Storage Shared Access Signatures (SAS). // for specified duration by making use of Azure Storage Shared Access Signatures (SAS).
// See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx for more info. // See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx for more info.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { func (d *driver) RedirectURL(req *http.Request, path string) (string, error) {
return d.signBlobURL(req.Context(), path)
}
func (d *driver) signBlobURL(ctx context.Context, path string) (string, error) {
expiresTime := time.Now().UTC().Add(20 * time.Minute) // default expiration expiresTime := time.Now().UTC().Add(20 * time.Minute) // default expiration
expires, ok := options["expiry"]
if ok {
t, ok := expires.(time.Time)
if ok {
expiresTime = t
}
}
blobName := d.blobName(path) blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName) blobRef := d.client.NewBlobClient(blobName)
return d.azClient.SignBlobURL(ctx, blobRef.URL(), expiresTime) return d.azClient.SignBlobURL(ctx, blobRef.URL(), expiresTime)

View file

@ -40,9 +40,10 @@ package base
import ( import (
"context" "context"
"io" "io"
"net/http"
"time" "time"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
prometheus "github.com/distribution/distribution/v3/metrics" prometheus "github.com/distribution/distribution/v3/metrics"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/docker/go-metrics" "github.com/docker/go-metrics"
@ -208,18 +209,18 @@ func (base *Base) Delete(ctx context.Context, path string) error {
return err return err
} }
// URLFor wraps URLFor of underlying storage driver. // RedirectURL wraps RedirectURL of the underlying storage driver.
func (base *Base) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { func (base *Base) RedirectURL(r *http.Request, path string) (string, error) {
ctx, done := dcontext.WithTrace(ctx) ctx, done := dcontext.WithTrace(r.Context())
defer done("%s.URLFor(%q)", base.Name(), path) defer done("%s.RedirectURL(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) { if !storagedriver.PathRegexp.MatchString(path) {
return "", storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()} return "", storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
} }
start := time.Now() start := time.Now()
str, e := base.StorageDriver.URLFor(ctx, path, options) str, e := base.StorageDriver.RedirectURL(r.WithContext(ctx), path)
storageAction.WithValues(base.Name(), "URLFor").UpdateSince(start) storageAction.WithValues(base.Name(), "RedirectURL").UpdateSince(start)
return str, base.setDriverName(e) return str, base.setDriverName(e)
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"net/http"
"reflect" "reflect"
"strconv" "strconv"
"sync" "sync"
@ -172,13 +173,11 @@ func (r *regulator) Delete(ctx context.Context, path string) error {
return r.StorageDriver.Delete(ctx, path) return r.StorageDriver.Delete(ctx, path)
} }
// URLFor returns a URL which may be used to retrieve the content stored at // RedirectURL returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options. // the given path.
// May return an ErrUnsupportedMethod in certain StorageDriver func (r *regulator) RedirectURL(req *http.Request, path string) (string, error) {
// implementations.
func (r *regulator) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
r.enter() r.enter()
defer r.exit() defer r.exit()
return r.StorageDriver.URLFor(ctx, path, options) return r.StorageDriver.RedirectURL(req, path)
} }

View file

@ -6,6 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"path" "path"
"time" "time"
@ -282,10 +283,9 @@ func (d *driver) Delete(ctx context.Context, subPath string) error {
return err return err
} }
// URLFor returns a URL which may be used to retrieve the content stored at the given path. // RedirectURL returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations. func (d *driver) RedirectURL(*http.Request, string) (string, error) {
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { return "", nil
return "", storagedriver.ErrUnsupportedMethod{}
} }
// Walk traverses a filesystem defined within driver, starting // Walk traverses a filesystem defined within driver, starting

View file

@ -809,40 +809,24 @@ func storageCopyObject(ctx context.Context, srcBucket, srcName string, destBucke
return attrs, err return attrs, err
} }
// URLFor returns a URL which may be used to retrieve the content stored at // RedirectURL returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options. // the given path, possibly using the given options.
// Returns ErrUnsupportedMethod if this driver has no privateKey func (d *driver) RedirectURL(r *http.Request, path string) (string, error) {
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
if d.privateKey == nil { if d.privateKey == nil {
return "", storagedriver.ErrUnsupportedMethod{} return "", nil
} }
name := d.pathToKey(path) if r.Method != http.MethodGet && r.Method != http.MethodHead {
methodString := http.MethodGet return "", nil
method, ok := options["method"]
if ok {
methodString, ok = method.(string)
if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) {
return "", storagedriver.ErrUnsupportedMethod{}
}
}
expiresTime := time.Now().Add(20 * time.Minute)
expires, ok := options["expiry"]
if ok {
et, ok := expires.(time.Time)
if ok {
expiresTime = et
}
} }
opts := &storage.SignedURLOptions{ opts := &storage.SignedURLOptions{
GoogleAccessID: d.email, GoogleAccessID: d.email,
PrivateKey: d.privateKey, PrivateKey: d.privateKey,
Method: methodString, Method: r.Method,
Expires: expiresTime, Expires: time.Now().Add(20 * time.Minute),
} }
return storage.SignedURL(d.bucket, name, opts) return storage.SignedURL(d.bucket, d.pathToKey(path), opts)
} }
// Walk traverses a filesystem defined within driver, starting // Walk traverses a filesystem defined within driver, starting

View file

@ -10,7 +10,7 @@ import (
"testing" "testing"
"cloud.google.com/go/storage" "cloud.google.com/go/storage"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/distribution/v3/registry/storage/driver/testsuites" "github.com/distribution/distribution/v3/registry/storage/driver/testsuites"
"golang.org/x/oauth2" "golang.org/x/oauth2"

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"net/http"
"sync" "sync"
"time" "time"
@ -236,10 +237,9 @@ func (d *driver) Delete(ctx context.Context, path string) error {
} }
} }
// URLFor returns a URL which may be used to retrieve the content stored at the given path. // RedirectURL returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations. func (d *driver) RedirectURL(*http.Request, string) (string, error) {
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { return "", nil
return "", storagedriver.ErrUnsupportedMethod{}
} }
// Walk traverses a filesystem defined within driver, starting // Walk traverses a filesystem defined within driver, starting

View file

@ -7,13 +7,14 @@ import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"os" "os"
"strings" "strings"
"time" "time"
"github.com/aws/aws-sdk-go/service/cloudfront/sign" "github.com/aws/aws-sdk-go/service/cloudfront/sign"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
storagemiddleware "github.com/distribution/distribution/v3/registry/storage/driver/middleware" storagemiddleware "github.com/distribution/distribution/v3/registry/storage/driver/middleware"
) )
@ -201,18 +202,18 @@ type S3BucketKeyer interface {
S3BucketKey(path string) string S3BucketKey(path string) string
} }
// URLFor attempts to find a url which may be used to retrieve the file at the given path. // RedirectURL attempts to find a url which may be used to retrieve the file at the given path.
// Returns an error if the file cannot be found. // Returns an error if the file cannot be found.
func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) { func (lh *cloudFrontStorageMiddleware) RedirectURL(r *http.Request, path string) (string, error) {
// TODO(endophage): currently only supports S3 // TODO(endophage): currently only supports S3
keyer, ok := lh.StorageDriver.(S3BucketKeyer) keyer, ok := lh.StorageDriver.(S3BucketKeyer)
if !ok { if !ok {
dcontext.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver") dcontext.GetLogger(r.Context()).Warn("the CloudFront middleware does not support this backend storage driver")
return lh.StorageDriver.URLFor(ctx, path, options) return lh.StorageDriver.RedirectURL(r, path)
} }
if eligibleForS3(ctx, lh.awsIPs) { if eligibleForS3(r, lh.awsIPs) {
return lh.StorageDriver.URLFor(ctx, path, options) return lh.StorageDriver.RedirectURL(r, path)
} }
// Get signed cloudfront url. // Get signed cloudfront url.

View file

@ -12,7 +12,8 @@ import (
"sync" "sync"
"time" "time"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/internal/requestutil"
) )
const ( const (
@ -192,12 +193,8 @@ func (s *awsIPs) contains(ip net.IP) bool {
// parseIPFromRequest attempts to extract the ip address of the // parseIPFromRequest attempts to extract the ip address of the
// client that made the request // client that made the request
func parseIPFromRequest(ctx context.Context) (net.IP, error) { func parseIPFromRequest(request *http.Request) (net.IP, error) {
request, err := dcontext.GetRequest(ctx) ipStr := requestutil.RemoteIP(request)
if err != nil {
return nil, err
}
ipStr := dcontext.RemoteIP(request)
ip := net.ParseIP(ipStr) ip := net.ParseIP(ipStr)
if ip == nil { if ip == nil {
return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr) return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr)
@ -208,25 +205,20 @@ func parseIPFromRequest(ctx context.Context) (net.IP, error) {
// eligibleForS3 checks if a request is eligible for using S3 directly // eligibleForS3 checks if a request is eligible for using S3 directly
// Return true only when the IP belongs to a specific aws region and user-agent is docker // Return true only when the IP belongs to a specific aws region and user-agent is docker
func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool { func eligibleForS3(request *http.Request, awsIPs *awsIPs) bool {
if awsIPs != nil && awsIPs.initialized { if awsIPs != nil && awsIPs.initialized {
if addr, err := parseIPFromRequest(ctx); err == nil { if addr, err := parseIPFromRequest(request); err == nil {
request, err := dcontext.GetRequest(ctx) loggerField := map[interface{}]interface{}{
if err != nil { "user-client": request.UserAgent(),
dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err) "ip": requestutil.RemoteIP(request),
} else {
loggerField := map[interface{}]interface{}{
"user-client": request.UserAgent(),
"ip": dcontext.RemoteIP(request),
}
if awsIPs.contains(addr) {
dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront")
return true
}
dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
} }
if awsIPs.contains(addr) {
dcontext.GetLoggerWithFields(request.Context(), loggerField).Info("request from the allowed AWS region, skipping CloudFront")
return true
}
dcontext.GetLoggerWithFields(request.Context(), loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
} else { } else {
dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront") dcontext.GetLogger(request.Context()).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront")
} }
} }
return false return false

View file

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"context"
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -11,8 +10,6 @@ import (
"reflect" // used as a replacement for testify "reflect" // used as a replacement for testify
"testing" "testing"
"time" "time"
dcontext "github.com/distribution/distribution/v3/context"
) )
// Rather than pull in all of testify // Rather than pull in all of testify
@ -276,29 +273,22 @@ func TestEligibleForS3(t *testing.T) {
}}, }},
initialized: true, initialized: true,
} }
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
tests := []struct { tests := []struct {
Context context.Context RemoteAddr string
Expected bool Expected bool
}{ }{
{Context: empty, Expected: false}, {RemoteAddr: "", Expected: false},
{Context: makeContext("192.168.1.2"), Expected: true}, {RemoteAddr: "192.168.1.2", Expected: true},
{Context: makeContext("192.168.0.2"), Expected: false}, {RemoteAddr: "192.168.0.2", Expected: false},
} }
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) { t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) {
t.Parallel() t.Parallel()
assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips)) req := &http.Request{RemoteAddr: tc.RemoteAddr}
assertEqual(t, tc.Expected, eligibleForS3(req, ips))
}) })
} }
} }
@ -312,29 +302,22 @@ func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) {
}}, }},
initialized: false, initialized: false,
} }
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
tests := []struct { tests := []struct {
Context context.Context RemoteAddr string
Expected bool Expected bool
}{ }{
{Context: empty, Expected: false}, {RemoteAddr: "", Expected: false},
{Context: makeContext("192.168.1.2"), Expected: false}, {RemoteAddr: "192.168.1.2", Expected: false},
{Context: makeContext("192.168.0.2"), Expected: false}, {RemoteAddr: "192.168.0.2", Expected: false},
} }
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) { t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) {
t.Parallel() t.Parallel()
assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips)) req := &http.Request{RemoteAddr: tc.RemoteAddr}
assertEqual(t, tc.Expected, eligibleForS3(req, ips))
}) })
} }
} }

View file

@ -1,8 +1,8 @@
package middleware package middleware
import ( import (
"context"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"path" "path"
@ -42,7 +42,7 @@ func newRedirectStorageMiddleware(ctx context.Context, sd storagedriver.StorageD
return &redirectStorageMiddleware{StorageDriver: sd, scheme: u.Scheme, host: u.Host, basePath: u.Path}, nil return &redirectStorageMiddleware{StorageDriver: sd, scheme: u.Scheme, host: u.Host, basePath: u.Path}, nil
} }
func (r *redirectStorageMiddleware) URLFor(ctx context.Context, urlPath string, options map[string]interface{}) (string, error) { func (r *redirectStorageMiddleware) RedirectURL(_ *http.Request, urlPath string) (string, error) {
if r.basePath != "" { if r.basePath != "" {
urlPath = path.Join(r.basePath, urlPath) urlPath = path.Join(r.basePath, urlPath)
} }

View file

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"context"
"testing" "testing"
"gopkg.in/check.v1" "gopkg.in/check.v1"
@ -37,7 +36,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
c.Assert(m.scheme, check.Equals, "https") c.Assert(m.scheme, check.Equals, "https")
c.Assert(m.host, check.Equals, "example.com:5443") c.Assert(m.host, check.Equals, "example.com:5443")
url, err := middleware.URLFor(context.TODO(), "/rick/data", nil) url, err := middleware.RedirectURL(nil, "/rick/data")
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com:5443/rick/data") c.Assert(url, check.Equals, "https://example.com:5443/rick/data")
} }
@ -53,7 +52,7 @@ func (s *MiddlewareSuite) TestHTTP(c *check.C) {
c.Assert(m.scheme, check.Equals, "http") c.Assert(m.scheme, check.Equals, "http")
c.Assert(m.host, check.Equals, "example.com") c.Assert(m.host, check.Equals, "example.com")
url, err := middleware.URLFor(context.TODO(), "morty/data", nil) url, err := middleware.RedirectURL(nil, "morty/data")
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "http://example.com/morty/data") c.Assert(url, check.Equals, "http://example.com/morty/data")
} }
@ -71,12 +70,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
c.Assert(m.host, check.Equals, "example.com") c.Assert(m.host, check.Equals, "example.com")
c.Assert(m.basePath, check.Equals, "/path") c.Assert(m.basePath, check.Equals, "/path")
// call URLFor() with no leading slash // call RedirectURL() with no leading slash
url, err := middleware.URLFor(context.TODO(), "morty/data", nil) url, err := middleware.RedirectURL(nil, "morty/data")
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data") c.Assert(url, check.Equals, "https://example.com/path/morty/data")
// call URLFor() with leading slash // call RedirectURL() with leading slash
url, err = middleware.URLFor(context.TODO(), "/morty/data", nil) url, err = middleware.RedirectURL(nil, "/morty/data")
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data") c.Assert(url, check.Equals, "https://example.com/path/morty/data")
@ -91,12 +90,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
c.Assert(m.host, check.Equals, "example.com") c.Assert(m.host, check.Equals, "example.com")
c.Assert(m.basePath, check.Equals, "/path/") c.Assert(m.basePath, check.Equals, "/path/")
// call URLFor() with no leading slash // call RedirectURL() with no leading slash
url, err = middleware.URLFor(context.TODO(), "morty/data", nil) url, err = middleware.RedirectURL(nil, "morty/data")
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data") c.Assert(url, check.Equals, "https://example.com/path/morty/data")
// call URLFor() with leading slash // call RedirectURL() with leading slash
url, err = middleware.URLFor(context.TODO(), "/morty/data", nil) url, err = middleware.RedirectURL(nil, "/morty/data")
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data") c.Assert(url, check.Equals, "https://example.com/path/morty/data")
} }

View file

@ -36,7 +36,7 @@ import (
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/distribution/v3/registry/storage/driver/base" "github.com/distribution/distribution/v3/registry/storage/driver/base"
"github.com/distribution/distribution/v3/registry/storage/driver/factory" "github.com/distribution/distribution/v3/registry/storage/driver/factory"
@ -1036,30 +1036,13 @@ func (d *driver) Delete(ctx context.Context, path string) error {
return nil return nil
} }
// URLFor returns a URL which may be used to retrieve the content stored at the given path. // RedirectURL returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations. func (d *driver) RedirectURL(r *http.Request, path string) (string, error) {
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
methodString := http.MethodGet
method, ok := options["method"]
if ok {
methodString, ok = method.(string)
if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) {
return "", storagedriver.ErrUnsupportedMethod{}
}
}
expiresIn := 20 * time.Minute expiresIn := 20 * time.Minute
expires, ok := options["expiry"]
if ok {
et, ok := expires.(time.Time)
if ok {
expiresIn = time.Until(et)
}
}
var req *request.Request var req *request.Request
switch methodString { switch r.Method {
case http.MethodGet: case http.MethodGet:
req, _ = d.S3.GetObjectRequest(&s3.GetObjectInput{ req, _ = d.S3.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String(d.Bucket), Bucket: aws.String(d.Bucket),
@ -1071,7 +1054,7 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
Key: aws.String(d.s3Path(path)), Key: aws.String(d.s3Path(path)),
}) })
default: default:
panic("unreachable") return "", nil
} }
return req.Presign(expiresIn) return req.Presign(expiresIn)

View file

@ -16,7 +16,7 @@ import (
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver" storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/distribution/v3/registry/storage/driver/testsuites" "github.com/distribution/distribution/v3/registry/storage/driver/testsuites"
) )
@ -180,7 +180,7 @@ func TestEmptyRootList(t *testing.T) {
filename := "/test" filename := "/test"
contents := []byte("contents") contents := []byte("contents")
ctx := context.Background() ctx := dcontext.Background()
err = rootedDriver.PutContent(ctx, filename, contents) err = rootedDriver.PutContent(ctx, filename, contents)
if err != nil { if err != nil {
t.Fatalf("unexpected error creating content: %v", err) t.Fatalf("unexpected error creating content: %v", err)
@ -209,7 +209,7 @@ func TestStorageClass(t *testing.T) {
rootDir := t.TempDir() rootDir := t.TempDir()
contents := []byte("contents") contents := []byte("contents")
ctx := context.Background() ctx := dcontext.Background()
// We don't need to test all the storage classes, just that its selectable. // We don't need to test all the storage classes, just that its selectable.
// The first 3 are common to AWS and MinIO, so use those. // The first 3 are common to AWS and MinIO, so use those.
@ -377,7 +377,7 @@ func TestDelete(t *testing.T) {
// init file structure matching objs // init file structure matching objs
var created []string var created []string
for _, p := range objs { for _, p := range objs {
err := drvr.PutContent(context.Background(), p, []byte("content "+p)) err := drvr.PutContent(dcontext.Background(), p, []byte("content "+p))
if err != nil { if err != nil {
fmt.Printf("unable to init file %s: %s\n", p, err) fmt.Printf("unable to init file %s: %s\n", p, err)
continue continue
@ -390,7 +390,7 @@ func TestDelete(t *testing.T) {
cleanup := func(objs []string) { cleanup := func(objs []string) {
var lastErr error var lastErr error
for _, p := range objs { for _, p := range objs {
err := drvr.Delete(context.Background(), p) err := drvr.Delete(dcontext.Background(), p)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case storagedriver.PathNotFoundError: case storagedriver.PathNotFoundError:
@ -409,7 +409,7 @@ func TestDelete(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
objs := init() objs := init()
err := drvr.Delete(context.Background(), tc.delete) err := drvr.Delete(dcontext.Background(), tc.delete)
if tc.err != nil { if tc.err != nil {
if err == nil { if err == nil {
@ -437,7 +437,7 @@ func TestDelete(t *testing.T) {
return false return false
} }
for _, path := range objs { for _, path := range objs {
stat, err := drvr.Stat(context.Background(), path) stat, err := drvr.Stat(dcontext.Background(), path)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case storagedriver.PathNotFoundError: case storagedriver.PathNotFoundError:
@ -491,7 +491,7 @@ func TestWalk(t *testing.T) {
// create file structure matching fileset above // create file structure matching fileset above
created := make([]string, 0, len(fileset)) created := make([]string, 0, len(fileset))
for _, p := range fileset { for _, p := range fileset {
err := drvr.PutContent(context.Background(), p, []byte("content "+p)) err := drvr.PutContent(dcontext.Background(), p, []byte("content "+p))
if err != nil { if err != nil {
fmt.Printf("unable to create file %s: %s\n", p, err) fmt.Printf("unable to create file %s: %s\n", p, err)
continue continue
@ -503,7 +503,7 @@ func TestWalk(t *testing.T) {
defer func() { defer func() {
var lastErr error var lastErr error
for _, p := range created { for _, p := range created {
err := drvr.Delete(context.Background(), p) err := drvr.Delete(dcontext.Background(), p)
if err != nil { if err != nil {
_ = fmt.Errorf("cleanup failed for path %s: %s", p, err) _ = fmt.Errorf("cleanup failed for path %s: %s", p, err)
lastErr = err lastErr = err
@ -692,7 +692,7 @@ func TestWalk(t *testing.T) {
tc.from = "/" tc.from = "/"
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err := drvr.Walk(context.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error { err := drvr.Walk(dcontext.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error {
walked = append(walked, fileInfo.Path()) walked = append(walked, fileInfo.Path())
return tc.fn(fileInfo) return tc.fn(fileInfo)
}, tc.options...) }, tc.options...)
@ -718,7 +718,7 @@ func TestOverThousandBlobs(t *testing.T) {
t.Fatalf("unexpected error creating driver with standard storage: %v", err) t.Fatalf("unexpected error creating driver with standard storage: %v", err)
} }
ctx := context.Background() ctx := dcontext.Background()
for i := 0; i < 1005; i++ { for i := 0; i < 1005; i++ {
filename := "/thousandfiletest/file" + strconv.Itoa(i) filename := "/thousandfiletest/file" + strconv.Itoa(i)
contents := []byte("contents") contents := []byte("contents")
@ -746,7 +746,7 @@ func TestMoveWithMultipartCopy(t *testing.T) {
t.Fatalf("unexpected error creating driver: %v", err) t.Fatalf("unexpected error creating driver: %v", err)
} }
ctx := context.Background() ctx := dcontext.Background()
sourcePath := "/source" sourcePath := "/source"
destPath := "/dest" destPath := "/dest"
@ -795,7 +795,7 @@ func TestListObjectsV2(t *testing.T) {
t.Fatalf("unexpected error creating driver: %v", err) t.Fatalf("unexpected error creating driver: %v", err)
} }
ctx := context.Background() ctx := dcontext.Background()
n := 6 n := 6
prefix := "/test-list-objects-v2" prefix := "/test-list-objects-v2"
var filePaths []string var filePaths []string

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@ -92,11 +93,10 @@ type StorageDriver interface {
// Delete recursively deletes all objects stored at "path" and its subpaths. // Delete recursively deletes all objects stored at "path" and its subpaths.
Delete(ctx context.Context, path string) error Delete(ctx context.Context, path string) error
// URLFor returns a URL which may be used to retrieve the content stored at // RedirectURL returns a URL which the client of the request r may use
// the given path, possibly using the given options. // to retrieve the content stored at path. Returning the empty string
// May return an ErrUnsupportedMethod in certain StorageDriver // signals that the request may not be redirected.
// implementations. RedirectURL(r *http.Request, path string) (string, error)
URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error)
// Walk traverses a filesystem defined within driver, starting // Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file. // from the given path, calling f on each file.

View file

@ -8,6 +8,7 @@ import (
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest"
"os" "os"
"path" "path"
"sort" "sort"
@ -733,9 +734,9 @@ func (suite *DriverSuite) TestDelete(c *check.C) {
c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true) c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true)
} }
// TestURLFor checks that the URLFor method functions properly, but only if it // TestRedirectURL checks that the RedirectURL method functions properly,
// is implemented // but only if it is implemented
func (suite *DriverSuite) TestURLFor(c *check.C) { func (suite *DriverSuite) TestRedirectURL(c *check.C) {
filename := randomPath(32) filename := randomPath(32)
contents := randomContents(32) contents := randomContents(32)
@ -744,8 +745,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) {
err := suite.StorageDriver.PutContent(suite.ctx, filename, contents) err := suite.StorageDriver.PutContent(suite.ctx, filename, contents)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
url, err := suite.StorageDriver.URLFor(suite.ctx, filename, nil) url, err := suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodGet, filename, nil), filename)
if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok { if url == "" && err == nil {
return return
} }
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -758,8 +759,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(read, check.DeepEquals, contents) c.Assert(read, check.DeepEquals, contents)
url, err = suite.StorageDriver.URLFor(suite.ctx, filename, map[string]interface{}{"method": http.MethodHead}) url, err = suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodHead, filename, nil), filename)
if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok { if url == "" && err == nil {
return return
} }
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

View file

@ -7,13 +7,13 @@ import (
mrand "math/rand" mrand "math/rand"
"testing" "testing"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
) )
func TestSimpleRead(t *testing.T) { func TestSimpleRead(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
content := make([]byte, 1<<20) content := make([]byte, 1<<20)
n, err := crand.Read(content) n, err := crand.Read(content)
if err != nil { if err != nil {
@ -55,7 +55,7 @@ func TestFileReaderSeek(t *testing.T) {
repititions := 1024 repititions := 1024
path := "/patterned" path := "/patterned"
content := bytes.Repeat([]byte(pattern), repititions) content := bytes.Repeat([]byte(pattern), repititions)
ctx := context.Background() ctx := dcontext.Background()
if err := driver.PutContent(ctx, path, content); err != nil { if err := driver.PutContent(ctx, path, content); err != nil {
t.Fatalf("error putting patterned content: %v", err) t.Fatalf("error putting patterned content: %v", err)
@ -156,7 +156,7 @@ func TestFileReaderSeek(t *testing.T) {
// read method, with an io.EOF error. // read method, with an io.EOF error.
func TestFileReaderNonExistentFile(t *testing.T) { func TestFileReaderNonExistentFile(t *testing.T) {
driver := inmemory.New() driver := inmemory.New()
fr, err := newFileReader(context.Background(), driver, "/doesnotexist", 10) fr, err := newFileReader(dcontext.Background(), driver, "/doesnotexist", 10)
if err != nil { if err != nil {
t.Fatalf("unexpected error initializing reader: %v", err) t.Fatalf("unexpected error initializing reader: %v", err)
} }

View file

@ -6,7 +6,7 @@ import (
"testing" "testing"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
"github.com/distribution/distribution/v3/testutil" "github.com/distribution/distribution/v3/testutil"
@ -21,7 +21,7 @@ type image struct {
} }
func createRegistry(t *testing.T, driver driver.StorageDriver, options ...RegistryOption) distribution.Namespace { func createRegistry(t *testing.T, driver driver.StorageDriver, options ...RegistryOption) distribution.Namespace {
ctx := context.Background() ctx := dcontext.Background()
options = append(options, EnableDelete) options = append(options, EnableDelete)
registry, err := NewRegistry(ctx, driver, options...) registry, err := NewRegistry(ctx, driver, options...)
if err != nil { if err != nil {
@ -31,7 +31,7 @@ func createRegistry(t *testing.T, driver driver.StorageDriver, options ...Regist
} }
func makeRepository(t *testing.T, registry distribution.Namespace, name string) distribution.Repository { func makeRepository(t *testing.T, registry distribution.Namespace, name string) distribution.Repository {
ctx := context.Background() ctx := dcontext.Background()
// Initialize a dummy repository // Initialize a dummy repository
named, err := reference.WithName(name) named, err := reference.WithName(name)
@ -47,7 +47,7 @@ func makeRepository(t *testing.T, registry distribution.Namespace, name string)
} }
func makeManifestService(t *testing.T, repository distribution.Repository) distribution.ManifestService { func makeManifestService(t *testing.T, repository distribution.Repository) distribution.ManifestService {
ctx := context.Background() ctx := dcontext.Background()
manifestService, err := repository.Manifests(ctx) manifestService, err := repository.Manifests(ctx)
if err != nil { if err != nil {
@ -57,7 +57,7 @@ func makeManifestService(t *testing.T, repository distribution.Repository) distr
} }
func allManifests(t *testing.T, manifestService distribution.ManifestService) map[digest.Digest]struct{} { func allManifests(t *testing.T, manifestService distribution.ManifestService) map[digest.Digest]struct{} {
ctx := context.Background() ctx := dcontext.Background()
allManMap := make(map[digest.Digest]struct{}) allManMap := make(map[digest.Digest]struct{})
manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator) manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator)
if !ok { if !ok {
@ -74,7 +74,7 @@ func allManifests(t *testing.T, manifestService distribution.ManifestService) ma
} }
func allBlobs(t *testing.T, registry distribution.Namespace) map[digest.Digest]struct{} { func allBlobs(t *testing.T, registry distribution.Namespace) map[digest.Digest]struct{} {
ctx := context.Background() ctx := dcontext.Background()
blobService := registry.Blobs() blobService := registry.Blobs()
allBlobsMap := make(map[digest.Digest]struct{}) allBlobsMap := make(map[digest.Digest]struct{})
err := blobService.Enumerate(ctx, func(dgst digest.Digest) error { err := blobService.Enumerate(ctx, func(dgst digest.Digest) error {
@ -95,7 +95,7 @@ func uploadImage(t *testing.T, repository distribution.Repository, im image) dig
} }
// upload manifest // upload manifest
ctx := context.Background() ctx := dcontext.Background()
manifestService := makeManifestService(t, repository) manifestService := makeManifestService(t, repository)
manifestDigest, err := manifestService.Put(ctx, im.manifest) manifestDigest, err := manifestService.Put(ctx, im.manifest)
if err != nil { if err != nil {
@ -130,7 +130,7 @@ func uploadRandomSchema2Image(t *testing.T, repository distribution.Repository)
} }
func TestNoDeletionNoEffect(t *testing.T) { func TestNoDeletionNoEffect(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
inmemoryDriver := inmemory.New() inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver) registry := createRegistry(t, inmemoryDriver)
@ -158,7 +158,7 @@ func TestNoDeletionNoEffect(t *testing.T) {
before := allBlobs(t, registry) before := allBlobs(t, registry)
// Run GC // Run GC
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false, DryRun: false,
RemoveUntagged: false, RemoveUntagged: false,
}) })
@ -173,7 +173,7 @@ func TestNoDeletionNoEffect(t *testing.T) {
} }
func TestDeleteManifestIfTagNotFound(t *testing.T) { func TestDeleteManifestIfTagNotFound(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
inmemoryDriver := inmemory.New() inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver) registry := createRegistry(t, inmemoryDriver)
@ -233,7 +233,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) {
before2 := allManifests(t, manifestService) before2 := allManifests(t, manifestService)
// run GC with dry-run (should not remove anything) // run GC with dry-run (should not remove anything)
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: true, DryRun: true,
RemoveUntagged: true, RemoveUntagged: true,
}) })
@ -250,7 +250,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) {
} }
// Run GC (removes everything because no manifests with tags exist) // Run GC (removes everything because no manifests with tags exist)
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false, DryRun: false,
RemoveUntagged: true, RemoveUntagged: true,
}) })
@ -269,7 +269,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) {
} }
func TestGCWithMissingManifests(t *testing.T) { func TestGCWithMissingManifests(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
d := inmemory.New() d := inmemory.New()
registry := createRegistry(t, d) registry := createRegistry(t, d)
@ -288,7 +288,7 @@ func TestGCWithMissingManifests(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = MarkAndSweep(context.Background(), d, registry, GCOpts{ err = MarkAndSweep(dcontext.Background(), d, registry, GCOpts{
DryRun: false, DryRun: false,
RemoveUntagged: false, RemoveUntagged: false,
}) })
@ -303,7 +303,7 @@ func TestGCWithMissingManifests(t *testing.T) {
} }
func TestDeletionHasEffect(t *testing.T) { func TestDeletionHasEffect(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
inmemoryDriver := inmemory.New() inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver) registry := createRegistry(t, inmemoryDriver)
@ -318,7 +318,7 @@ func TestDeletionHasEffect(t *testing.T) {
manifests.Delete(ctx, image3.manifestDigest) manifests.Delete(ctx, image3.manifestDigest)
// Run GC // Run GC
err := MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ err := MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false, DryRun: false,
RemoveUntagged: false, RemoveUntagged: false,
}) })
@ -368,7 +368,7 @@ func getKeys(digests map[digest.Digest]io.ReadSeeker) (ds []digest.Digest) {
} }
func TestDeletionWithSharedLayer(t *testing.T) { func TestDeletionWithSharedLayer(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
inmemoryDriver := inmemory.New() inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver) registry := createRegistry(t, inmemoryDriver)
@ -455,7 +455,7 @@ func TestOrphanBlobDeleted(t *testing.T) {
uploadRandomSchema2Image(t, repo) uploadRandomSchema2Image(t, repo)
// Run GC // Run GC
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{ err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false, DryRun: false,
RemoveUntagged: false, RemoveUntagged: false,
}) })

View file

@ -9,7 +9,7 @@ import (
"time" "time"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/reference" "github.com/distribution/reference"
"github.com/google/uuid" "github.com/google/uuid"

View file

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/manifestlist"
"github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"

View file

@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest" "github.com/distribution/distribution/v3/manifest"
"github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/manifestlist"
"github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/manifest/ocischema"

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
) )

View file

@ -6,7 +6,7 @@ import (
"net/url" "net/url"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/ocischema" "github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
v1 "github.com/opencontainers/image-spec/specs-go/v1" v1 "github.com/opencontainers/image-spec/specs-go/v1"

View file

@ -34,7 +34,7 @@ type manifestURLs struct {
type RegistryOption func(*registry) error type RegistryOption func(*registry) error
// EnableRedirect is a functional option for NewRegistry. It causes the backend // EnableRedirect is a functional option for NewRegistry. It causes the backend
// blob server to attempt using (StorageDriver).URLFor to serve all blobs. // blob server to attempt using (StorageDriver).RedirectURL to serve all blobs.
func EnableRedirect(registry *registry) error { func EnableRedirect(registry *registry) error {
registry.blobServer.redirect = true registry.blobServer.redirect = true
return nil return nil
@ -102,7 +102,7 @@ func BlobDescriptorCacheProvider(blobDescriptorCacheProvider cache.BlobDescripto
// NewRegistry creates a new registry instance from the provided driver. The // NewRegistry creates a new registry instance from the provided driver. The
// resulting registry may be shared by multiple goroutines but is cheap to // resulting registry may be shared by multiple goroutines but is cheap to
// allocate. If the Redirect option is specified, the backend blob server will // allocate. If the Redirect option is specified, the backend blob server will
// attempt to use (StorageDriver).URLFor to serve all blobs. // attempt to use (StorageDriver).RedirectURL to serve all blobs.
func NewRegistry(ctx context.Context, driver storagedriver.StorageDriver, options ...RegistryOption) (distribution.Namespace, error) { func NewRegistry(ctx context.Context, driver storagedriver.StorageDriver, options ...RegistryOption) (distribution.Namespace, error) {
// create global statter // create global statter
statter := &blobStatter{ statter := &blobStatter{

View file

@ -7,7 +7,7 @@ import (
"net/url" "net/url"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/manifest/schema2"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
) )

View file

@ -6,7 +6,7 @@ import (
"testing" "testing"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest" "github.com/distribution/distribution/v3/manifest"
"github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/manifest/schema2"
"github.com/distribution/distribution/v3/registry/storage/driver/inmemory" "github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
@ -14,7 +14,7 @@ import (
) )
func TestVerifyManifestForeignLayer(t *testing.T) { func TestVerifyManifestForeignLayer(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
inmemoryDriver := inmemory.New() inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver, registry := createRegistry(t, inmemoryDriver,
ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")), ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")),
@ -152,7 +152,7 @@ func TestVerifyManifestForeignLayer(t *testing.T) {
} }
func TestVerifyManifestBlobLayerAndConfig(t *testing.T) { func TestVerifyManifestBlobLayerAndConfig(t *testing.T) {
ctx := context.Background() ctx := dcontext.Background()
inmemoryDriver := inmemory.New() inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver, registry := createRegistry(t, inmemoryDriver,
ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")), ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")),

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"path" "path"
dcontext "github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver" "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
) )

View file

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/manifestlist" "github.com/distribution/distribution/v3/manifest/manifestlist"
"github.com/distribution/distribution/v3/manifest/schema2" "github.com/distribution/distribution/v3/manifest/schema2"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
@ -12,7 +12,7 @@ import (
// MakeManifestList constructs a manifest list out of a list of manifest digests // MakeManifestList constructs a manifest list out of a list of manifest digests
func MakeManifestList(blobstatter distribution.BlobStatter, manifestDigests []digest.Digest) (*manifestlist.DeserializedManifestList, error) { func MakeManifestList(blobstatter distribution.BlobStatter, manifestDigests []digest.Digest) (*manifestlist.DeserializedManifestList, error) {
ctx := context.Background() ctx := dcontext.Background()
manifestDescriptors := make([]manifestlist.ManifestDescriptor, 0, len(manifestDigests)) manifestDescriptors := make([]manifestlist.ManifestDescriptor, 0, len(manifestDigests))
for _, manifestDigest := range manifestDigests { for _, manifestDigest := range manifestDigests {
@ -39,7 +39,7 @@ func MakeManifestList(blobstatter distribution.BlobStatter, manifestDigests []di
// MakeSchema2Manifest constructs a schema 2 manifest from a given list of digests and returns // MakeSchema2Manifest constructs a schema 2 manifest from a given list of digests and returns
// the digest of the manifest // the digest of the manifest
func MakeSchema2Manifest(repository distribution.Repository, digests []digest.Digest) (distribution.Manifest, error) { func MakeSchema2Manifest(repository distribution.Repository, digests []digest.Digest) (distribution.Manifest, error) {
ctx := context.Background() ctx := dcontext.Background()
blobStore := repository.Blobs(ctx) blobStore := repository.Blobs(ctx)
var configJSON []byte var configJSON []byte

View file

@ -10,7 +10,7 @@ import (
"time" "time"
"github.com/distribution/distribution/v3" "github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context" "github.com/distribution/distribution/v3/internal/dcontext"
"github.com/opencontainers/go-digest" "github.com/opencontainers/go-digest"
) )
@ -96,7 +96,7 @@ func CreateRandomLayers(n int) (map[digest.Digest]io.ReadSeeker, error) {
// UploadBlobs lets you upload blobs to a repository // UploadBlobs lets you upload blobs to a repository
func UploadBlobs(repository distribution.Repository, layers map[digest.Digest]io.ReadSeeker) error { func UploadBlobs(repository distribution.Repository, layers map[digest.Digest]io.ReadSeeker) error {
ctx := context.Background() ctx := dcontext.Background()
for dgst, rs := range layers { for dgst, rs := range layers {
wr, err := repository.Blobs(ctx).Create(ctx) wr, err := repository.Blobs(ctx).Create(ctx)
if err != nil { if err != nil {