Hide our misuses of contexts from the public interface (#4128)

This commit is contained in:
Milos Gajdos 2023-11-03 05:05:19 +00:00 committed by GitHub
commit bd0e476910
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
75 changed files with 510 additions and 594 deletions

View file

@ -7,7 +7,7 @@ import (
"sync"
"time"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode"
)
@ -279,7 +279,7 @@ func Handler(handler http.Handler) http.Handler {
func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks map[string]string) {
p, err := json.Marshal(checks)
if err != nil {
context.GetLogger(context.Background()).Errorf("error serializing health status: %v", err)
dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status: %v", err)
p, err = json.Marshal(struct {
ServerError string `json:"server_error"`
}{
@ -288,7 +288,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m
status = http.StatusInternalServerError
if err != nil {
context.GetLogger(context.Background()).Errorf("error serializing health status failure message: %v", err)
dcontext.GetLogger(dcontext.Background()).Errorf("error serializing health status failure message: %v", err)
return
}
}
@ -297,7 +297,7 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m
w.Header().Set("Content-Length", fmt.Sprint(len(p)))
w.WriteHeader(status)
if _, err := w.Write(p); err != nil {
context.GetLogger(context.Background()).Errorf("error writing health status response body: %v", err)
dcontext.GetLogger(dcontext.Background()).Errorf("error writing health status response body: %v", err)
}
}

View file

@ -17,7 +17,7 @@ import (
"time"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest"
"github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/distribution/distribution/v3/registry/api/errcode"
@ -108,7 +108,7 @@ func TestBlobServeBlob(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil)
if err != nil {
@ -157,7 +157,7 @@ func TestBlobServeBlobHEAD(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil)
if err != nil {
@ -250,7 +250,7 @@ func TestBlobResume(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -307,7 +307,7 @@ func TestBlobDelete(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -327,7 +327,7 @@ func TestBlobFetch(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil)
if err != nil {
@ -382,7 +382,7 @@ func TestBlobExistsNoContentLength(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -406,7 +406,7 @@ func TestBlobExists(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo1")
r, err := NewRepository(repo, e, nil)
if err != nil {
@ -512,7 +512,7 @@ func TestBlobUploadChunked(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -622,7 +622,7 @@ func TestBlobUploadMonolithic(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -728,7 +728,7 @@ func TestBlobUploadMonolithicDockerUploadUUIDFromURL(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -833,7 +833,7 @@ func TestBlobUploadMonolithicNoDockerUploadUUID(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -891,7 +891,7 @@ func TestBlobMount(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -1066,7 +1066,7 @@ func checkEqualManifest(m1, m2 *ocischema.DeserializedManifest) error {
}
func TestOCIManifestFetch(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo")
m1, dgst, pl := newRandomOCIManifest(t, 6)
var m testutil.RequestResponseMap
@ -1149,7 +1149,7 @@ func TestManifestFetchWithEtag(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -1171,7 +1171,7 @@ func TestManifestFetchWithEtag(t *testing.T) {
}
func TestManifestFetchWithAccept(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
repo, _ := reference.WithName("test.example.com/repo")
_, dgst, _ := newRandomOCIManifest(t, 6)
headers := make(chan []string, 1)
@ -1258,7 +1258,7 @@ func TestManifestDelete(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
ms, err := r.Manifests(ctx)
if err != nil {
t.Fatal(err)
@ -1315,7 +1315,7 @@ func TestManifestPut(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
ms, err := r.Manifests(ctx)
if err != nil {
t.Fatal(err)
@ -1372,7 +1372,7 @@ func TestManifestTags(t *testing.T) {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
tagService := r.Tags(ctx)
tags, err := tagService.All(ctx)
@ -1423,7 +1423,7 @@ func TestTagDelete(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
ts := r.Tags(ctx)
if err := ts.Untag(ctx, tag); err != nil {
@ -1460,7 +1460,7 @@ func TestObtainsErrorForMissingTag(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -1487,7 +1487,7 @@ func TestObtainsManifestForTagWithoutHeaders(t *testing.T) {
e, c := testServer(m)
defer c()
ctx := context.Background()
ctx := dcontext.Background()
r, err := NewRepository(repo, e, nil)
if err != nil {
t.Fatal(err)
@ -1566,7 +1566,7 @@ func TestManifestTagsPaginated(t *testing.T) {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
tagService := r.Tags(ctx)
tags, err := tagService.All(ctx)
@ -1614,7 +1614,7 @@ func TestManifestUnauthorized(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
ms, err := r.Manifests(ctx)
if err != nil {
t.Fatal(err)
@ -1652,7 +1652,7 @@ func TestCatalog(t *testing.T) {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
numFilled, err := r.Repositories(ctx, entries, "")
if err != io.EOF {
t.Fatal(err)
@ -1684,7 +1684,7 @@ func TestCatalogInParts(t *testing.T) {
t.Fatal(err)
}
ctx := context.Background()
ctx := dcontext.Background()
numFilled, err := r.Repositories(ctx, entries, "")
if err != nil {
t.Fatal(err)

View file

@ -1,4 +1,4 @@
package context
package dcontext
import (
"context"

View file

@ -1,16 +1,16 @@
// Package context provides several utilities for working with
// Package dcontext provides several utilities for working with
// Go's context in http requests. Primarily, the focus is on logging relevant
// request information but this package is not limited to that purpose.
//
// The easiest way to get started is to get the background context:
//
// ctx := context.Background()
// ctx := dcontext.Background()
//
// The returned context should be passed around your application and be the
// root of all other context instances. If the application has a version, this
// line should be called before anything else:
//
// ctx := context.WithVersion(context.Background(), version)
// ctx := dcontext.WithVersion(dcontext.Background(), version)
//
// The above will store the version in the context and will be available to
// the logger.
@ -27,7 +27,7 @@
// the context and reported with the logger. The following example would
// return a logger that prints the version with each log message:
//
// ctx := context.Context(context.Background(), "version", version)
// ctx := context.WithValue(dcontext.Background(), "version", version)
// GetLogger(ctx, "version").Infof("this log message has a version field")
//
// The above would print out a log message like this:
@ -85,4 +85,4 @@
// can be traced in log messages. Using the fields like "http.request.id", one
// can analyze call flow for a particular request with a simple grep of the
// logs.
package context
package dcontext

View file

@ -1,17 +1,16 @@
package context
package dcontext
import (
"context"
"errors"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/distribution/distribution/v3/internal/requestutil"
"github.com/google/uuid"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
)
// Common errors used with this package.
@ -20,48 +19,6 @@ var (
ErrNoResponseWriterContext = errors.New("no http response in context")
)
func parseIP(ipStr string) net.IP {
ip := net.ParseIP(ipStr)
if ip == nil {
log.Warnf("invalid remote IP address: %q", ipStr)
}
return ip
}
// RemoteAddr extracts the remote address of the request, taking into
// account proxy headers.
func RemoteAddr(r *http.Request) string {
if prior := r.Header.Get("X-Forwarded-For"); prior != "" {
remoteAddr, _, _ := strings.Cut(prior, ",")
remoteAddr = strings.Trim(remoteAddr, " ")
if parseIP(remoteAddr) != nil {
return remoteAddr
}
}
// X-Real-Ip is less supported, but worth checking in the
// absence of X-Forwarded-For
if realIP := r.Header.Get("X-Real-Ip"); realIP != "" {
if parseIP(realIP) != nil {
return realIP
}
}
return r.RemoteAddr
}
// RemoteIP extracts the remote IP of the request, taking into
// account proxy headers.
func RemoteIP(r *http.Request) string {
addr := RemoteAddr(r)
// Try parsing it as "IP:port"
if ip, _, err := net.SplitHostPort(addr); err == nil {
return ip
}
return addr
}
// WithRequest places the request on the context. The context of the request
// is assigned a unique id, available at "http.request.id". The request itself
// is available at "http.request". Other common attributes are available under
@ -83,16 +40,6 @@ func WithRequest(ctx context.Context, r *http.Request) context.Context {
}
}
// GetRequest returns the http request in the given context. Returns
// ErrNoRequestContext if the context does not have an http request associated
// with it.
func GetRequest(ctx context.Context) (*http.Request, error) {
if r, ok := ctx.Value("http.request").(*http.Request); r != nil && ok {
return r, nil
}
return nil, ErrNoRequestContext
}
// GetRequestID attempts to resolve the current request id, if possible. An
// error is return if it is not available on the context.
func GetRequestID(ctx context.Context) string {
@ -193,7 +140,7 @@ func (ctx *httpRequestContext) Value(key interface{}) interface{} {
case "http.request.uri":
return ctx.r.RequestURI
case "http.request.remoteaddr":
return RemoteAddr(ctx.r)
return requestutil.RemoteAddr(ctx.r)
case "http.request.method":
return ctx.r.Method
case "http.request.host":

View file

@ -1,10 +1,7 @@
package context
package dcontext
import (
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"reflect"
"testing"
"time"
@ -219,70 +216,3 @@ func TestWithVars(t *testing.T) {
}
}
}
// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
// at the transport layer to 127.0.0.1:<port> . However, as the X-Forwarded-For header
// just contains the IP address, it is different enough for testing.
func TestRemoteAddr(t *testing.T) {
var expectedRemote string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if r.RemoteAddr == expectedRemote {
t.Errorf("Unexpected matching remote addresses")
}
actualRemote := RemoteAddr(r)
if expectedRemote != actualRemote {
t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
}
w.WriteHeader(200)
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxy := httputil.NewSingleHostReverseProxy(backendURL)
frontend := httptest.NewServer(proxy)
defer frontend.Close()
// X-Forwarded-For set by proxy
expectedRemote = "127.0.0.1"
proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultClient.Do(proxyReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// RemoteAddr in X-Real-Ip
getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil)
if err != nil {
t.Fatal(err)
}
expectedRemote = "1.2.3.4"
getReq.Header["X-Real-ip"] = []string{expectedRemote}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// Valid X-Real-Ip and invalid X-Forwarded-For
getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
}

View file

@ -1,4 +1,4 @@
package context
package dcontext
import (
"context"

View file

@ -1,4 +1,4 @@
package context
package dcontext
import (
"context"

View file

@ -1,4 +1,4 @@
package context
package dcontext
import (
"runtime"

View file

@ -1,4 +1,4 @@
package context
package dcontext
import (
"context"

View file

@ -1,4 +1,4 @@
package context
package dcontext
import "context"

View file

@ -1,4 +1,4 @@
package context
package dcontext
import "testing"

View file

@ -0,0 +1,51 @@
package requestutil
import (
"net"
"net/http"
"strings"
log "github.com/sirupsen/logrus"
)
func parseIP(ipStr string) net.IP {
ip := net.ParseIP(ipStr)
if ip == nil {
log.Warnf("invalid remote IP address: %q", ipStr)
}
return ip
}
// RemoteAddr extracts the remote address of the request, taking into
// account proxy headers.
func RemoteAddr(r *http.Request) string {
if prior := r.Header.Get("X-Forwarded-For"); prior != "" {
remoteAddr, _, _ := strings.Cut(prior, ",")
remoteAddr = strings.Trim(remoteAddr, " ")
if parseIP(remoteAddr) != nil {
return remoteAddr
}
}
// X-Real-Ip is less supported, but worth checking in the
// absence of X-Forwarded-For
if realIP := r.Header.Get("X-Real-Ip"); realIP != "" {
if parseIP(realIP) != nil {
return realIP
}
}
return r.RemoteAddr
}
// RemoteIP extracts the remote IP of the request, taking into
// account proxy headers.
func RemoteIP(r *http.Request) string {
addr := RemoteAddr(r)
// Try parsing it as "IP:port"
if ip, _, err := net.SplitHostPort(addr); err == nil {
return ip
}
return addr
}

View file

@ -0,0 +1,76 @@
package requestutil
import (
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"testing"
)
// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
// at the transport layer to 127.0.0.1:<port> . However, as the X-Forwarded-For header
// just contains the IP address, it is different enough for testing.
func TestRemoteAddr(t *testing.T) {
var expectedRemote string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if r.RemoteAddr == expectedRemote {
t.Errorf("Unexpected matching remote addresses")
}
actualRemote := RemoteAddr(r)
if expectedRemote != actualRemote {
t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
}
w.WriteHeader(200)
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxy := httputil.NewSingleHostReverseProxy(backendURL)
frontend := httptest.NewServer(proxy)
defer frontend.Close()
// X-Forwarded-For set by proxy
expectedRemote = "127.0.0.1"
proxyReq, err := http.NewRequest(http.MethodGet, frontend.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultClient.Do(proxyReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// RemoteAddr in X-Real-Ip
getReq, err := http.NewRequest(http.MethodGet, backend.URL, nil)
if err != nil {
t.Fatal(err)
}
expectedRemote = "1.2.3.4"
getReq.Header["X-Real-ip"] = []string{expectedRemote}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// Valid X-Real-Ip and invalid X-Forwarded-For
getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
resp, err = http.DefaultClient.Do(getReq)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
}

View file

@ -5,7 +5,7 @@ import (
"time"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/requestutil"
"github.com/distribution/reference"
events "github.com/docker/go-events"
"github.com/google/uuid"
@ -49,7 +49,7 @@ func NewBridge(ub URLBuilder, source SourceRecord, actor ActorRecord, request Re
func NewRequestRecord(id string, r *http.Request) RequestRecord {
return RequestRecord{
ID: id,
Addr: context.RemoteAddr(r),
Addr: requestutil.RemoteAddr(r),
Host: r.Host,
Method: r.Method,
UserAgent: r.UserAgent(),

View file

@ -7,7 +7,7 @@ import (
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/reference"
"github.com/opencontainers/go-digest"
)

View file

@ -6,7 +6,7 @@ import (
"testing"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/schema2"
"github.com/distribution/distribution/v3/registry/storage"
"github.com/distribution/distribution/v3/registry/storage/cache/memory"

View file

@ -18,7 +18,7 @@
// resource := auth.Resource{Type: "customerOrder", Name: orderNumber}
// access := auth.Access{Resource: resource, Action: "update"}
//
// if ctx, err := accessController.Authorized(ctx, access); err != nil {
// if ctx, err := accessController.Authorized(r, access); err != nil {
// if challenge, ok := err.(auth.Challenge) {
// // Let the challenge write the response.
// challenge.SetHeaders(r, w)
@ -32,22 +32,11 @@
package auth
import (
"context"
"errors"
"fmt"
"net/http"
)
const (
// UserKey is used to get the user object from
// a user context
UserKey = "auth.user"
// UserNameKey is used to get the user name from
// a user context
UserNameKey = "auth.user.name"
)
var (
// ErrInvalidCredential is returned when the auth token does not authenticate correctly.
ErrInvalidCredential = errors.New("invalid authorization credential")
@ -76,6 +65,12 @@ type Access struct {
Action string
}
// Grant describes the permitted level of access for an authorized request.
type Grant struct {
User UserInfo // The authenticated user for the request.
Resources []Resource // The list of resources which have been authorized for the request.
}
// Challenge is a special error type which is used for HTTP 401 Unauthorized
// responses and is able to write the response with WWW-Authenticate challenge
// header values based on the error.
@ -93,16 +88,15 @@ type Challenge interface {
// and required access levels for a request. Implementations can support both
// complete denial and http authorization challenges.
type AccessController interface {
// Authorized returns a non-nil error if the context is granted access and
// returns a new authorized context. If one or more Access structs are
// provided, the requested access will be compared with what is available
// to the context. The given context will contain a "http.request" key with
// a `*http.Request` value. If the error is non-nil, access should always
// be denied. The error may be of type Challenge, in which case the caller
// may have the Challenge handle the request or choose what action to take
// based on the Challenge header or response status. The returned context
// object should have a "auth.user" value set to a UserInfo struct.
Authorized(ctx context.Context, access ...Access) (context.Context, error)
// Authorized determines if the request is granted access. If one or more
// Access structs are provided, the requested access will be compared with
// what is available to the request.
//
// Return a Grant to grant the request access. Return an error to deny
// access. The error may be of type Challenge, in which case the caller may
// have the Challenge handle the request or choose what action to take based
// on the Challenge header or response status.
Authorized(r *http.Request, access ...Access) (*Grant, error)
}
// CredentialAuthenticator is an object which is able to authenticate credentials
@ -110,63 +104,6 @@ type CredentialAuthenticator interface {
AuthenticateUser(username, password string) error
}
// WithUser returns a context with the authorized user info.
func WithUser(ctx context.Context, user UserInfo) context.Context {
return userInfoContext{
Context: ctx,
user: user,
}
}
type userInfoContext struct {
context.Context
user UserInfo
}
func (uic userInfoContext) Value(key interface{}) interface{} {
switch key {
case UserKey:
return uic.user
case UserNameKey:
return uic.user.Name
}
return uic.Context.Value(key)
}
// WithResources returns a context with the authorized resources.
func WithResources(ctx context.Context, resources []Resource) context.Context {
return resourceContext{
Context: ctx,
resources: resources,
}
}
type resourceContext struct {
context.Context
resources []Resource
}
type resourceKey struct{}
func (rc resourceContext) Value(key interface{}) interface{} {
if key == (resourceKey{}) {
return rc.resources
}
return rc.Context.Value(key)
}
// AuthorizedResources returns the list of resources which have
// been authorized for this request.
func AuthorizedResources(ctx context.Context) []Resource {
if resources, ok := ctx.Value(resourceKey{}).([]Resource); ok {
return resources
}
return nil
}
// InitFunc is the type of an AccessController factory function and is used
// to register the constructor for different AccesController backends.
type InitFunc func(options map[string]interface{}) (AccessController, error)

View file

@ -18,7 +18,7 @@ import (
"golang.org/x/crypto/bcrypt"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth"
)
@ -49,12 +49,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
return &accessController{realm: realm.(string), path: path}, nil
}
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) {
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (*auth.Grant, error) {
username, password, ok := req.BasicAuth()
if !ok {
return nil, &challenge{
@ -92,14 +87,14 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
ac.mu.Unlock()
if err := localHTPasswd.authenticateUser(username, password); err != nil {
dcontext.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err)
dcontext.GetLogger(req.Context()).Errorf("error authenticating user %q: %v", username, err)
return nil, &challenge{
realm: ac.realm,
err: auth.ErrAuthenticationFailure,
}
}
return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil
return &auth.Grant{User: auth.UserInfo{Name: username}}, nil
}
// challenge implements the auth.Challenge interface.

View file

@ -8,7 +8,6 @@ import (
"os"
"testing"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth"
)
@ -33,7 +32,6 @@ func TestBasicAccessController(t *testing.T) {
"realm": testRealm,
"path": tempFile.Name(),
}
ctx := context.Background()
accessController, err := newAccessController(options)
if err != nil {
@ -45,8 +43,7 @@ func TestBasicAccessController(t *testing.T) {
userNumber := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithRequest(ctx, r)
authCtx, err := accessController.Authorized(ctx)
grant, err := accessController.Authorized(r)
if err != nil {
switch err := err.(type) {
case auth.Challenge:
@ -58,13 +55,12 @@ func TestBasicAccessController(t *testing.T) {
}
}
userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo)
if !ok {
t.Fatal("basic accessController did not set auth.user context")
if grant == nil {
t.Fatal("basic accessController did not return auth grant")
}
if userInfo.Name != testUsers[userNumber] {
t.Fatalf("expected user name %q, got %q", testUsers[userNumber], userInfo.Name)
if grant.User.Name != testUsers[userNumber] {
t.Fatalf("expected user name %q, got %q", testUsers[userNumber], grant.User.Name)
}
w.WriteHeader(http.StatusNoContent)

View file

@ -8,12 +8,10 @@
package silly
import (
"context"
"fmt"
"net/http"
"strings"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth"
)
@ -43,12 +41,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized simply checks for the existence of the authorization header,
// responding with a bearer challenge if it doesn't exist.
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) {
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (*auth.Grant, error) {
if req.Header.Get("Authorization") == "" {
challenge := challenge{
realm: ac.realm,
@ -66,10 +59,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
return nil, &challenge
}
ctx = auth.WithUser(ctx, auth.UserInfo{Name: "silly"})
ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey))
return ctx, nil
return &auth.Grant{User: auth.UserInfo{Name: "silly"}}, nil
}
type challenge struct {

View file

@ -5,7 +5,6 @@ import (
"net/http/httptest"
"testing"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth"
)
@ -16,8 +15,7 @@ func TestSillyAccessController(t *testing.T) {
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithRequest(context.Background(), r)
authCtx, err := ac.Authorized(ctx)
grant, err := ac.Authorized(r)
if err != nil {
switch err := err.(type) {
case auth.Challenge:
@ -29,13 +27,12 @@ func TestSillyAccessController(t *testing.T) {
}
}
userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo)
if !ok {
t.Fatal("silly accessController did not set auth.user context")
if grant == nil {
t.Fatal("silly accessController did not return auth grant")
}
if userInfo.Name != "silly" {
t.Fatalf("expected user name %q, got %q", "silly", userInfo.Name)
if grant.User.Name != "silly" {
t.Fatalf("expected user name %q, got %q", "silly", grant.User.Name)
}
w.WriteHeader(http.StatusNoContent)

View file

@ -1,7 +1,6 @@
package token
import (
"context"
"crypto"
"crypto/x509"
"encoding/json"
@ -13,7 +12,6 @@ import (
"os"
"strings"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3"
)
@ -292,7 +290,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized handles checking whether the given request is authorized
// for actions on resources described by the given access items.
func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) {
func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (*auth.Grant, error) {
challenge := &authChallenge{
realm: ac.realm,
autoRedirect: ac.autoRedirect,
@ -300,11 +298,6 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
accessSet: newAccessSet(accessItems...),
}
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ")
if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") {
challenge.err = ErrTokenRequired
@ -338,9 +331,10 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
}
}
ctx = auth.WithResources(ctx, claims.resources())
return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil
return &auth.Grant{
User: auth.UserInfo{Name: claims.Subject},
Resources: claims.resources(),
}, nil
}
// init handles registering the token auth backend.

View file

@ -18,7 +18,6 @@ import (
"testing"
"time"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
@ -466,8 +465,7 @@ func TestAccessController(t *testing.T) {
Action: "baz",
}
ctx := context.WithRequest(context.Background(), req)
authCtx, err := accessController.Authorized(ctx, testAccess)
grant, err := accessController.Authorized(req, testAccess)
challenge, ok := err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
@ -477,8 +475,8 @@ func TestAccessController(t *testing.T) {
t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
}
if authCtx != nil {
t.Fatalf("expected nil auth context but got %s", authCtx)
if grant != nil {
t.Fatalf("expected nil auth grant but got %#v", grant)
}
// 2. Supply an invalid token.
@ -502,7 +500,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess)
grant, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
@ -512,8 +510,8 @@ func TestAccessController(t *testing.T) {
t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
}
if authCtx != nil {
t.Fatalf("expected nil auth context but got %s", authCtx)
if grant != nil {
t.Fatalf("expected nil auth grant but got %#v", grant)
}
// create a valid jwk
@ -534,7 +532,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess)
grant, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
@ -544,8 +542,8 @@ func TestAccessController(t *testing.T) {
t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrInsufficientScope)
}
if authCtx != nil {
t.Fatalf("expected nil auth context but got %s", authCtx)
if grant != nil {
t.Fatalf("expected nil auth grant but got %#v", grant)
}
// 4. Supply the token we need, or deserve, or whatever.
@ -564,18 +562,13 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess)
grant, err = accessController.Authorized(req, testAccess)
if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err)
}
userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo)
if !ok {
t.Fatal("token accessController did not set auth.user context")
}
if userInfo.Name != "foo" {
t.Fatalf("expected user name %q, got %q", "foo", userInfo.Name)
if grant.User.Name != "foo" {
t.Fatalf("expected user name %q, got %q", "foo", grant.User.Name)
}
// 5. Supply a token with full admin rights, which is represented as "*".
@ -594,7 +587,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
_, err = accessController.Authorized(ctx, testAccess)
_, err = accessController.Authorized(req, testAccess)
if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err)
}

View file

@ -19,9 +19,9 @@ import (
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/configuration"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/health"
"github.com/distribution/distribution/v3/health/checks"
"github.com/distribution/distribution/v3/internal/dcontext"
prometheus "github.com/distribution/distribution/v3/metrics"
"github.com/distribution/distribution/v3/notifications"
"github.com/distribution/distribution/v3/registry/api/errcode"
@ -635,7 +635,7 @@ func (app *App) dispatcher(dispatch dispatchFunc) http.Handler {
}
// Add username to request logging
context.Context = dcontext.WithLogger(context.Context, dcontext.GetLogger(context.Context, auth.UserNameKey))
context.Context = dcontext.WithLogger(context.Context, dcontext.GetLogger(context.Context, userNameKey))
// sync up context on the request.
r = r.WithContext(context)
@ -797,7 +797,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont
accessRecords = appendCatalogAccessRecord(accessRecords, r)
}
ctx, err := app.accessController.Authorized(context.Context, accessRecords...)
grant, err := app.accessController.Authorized(r.WithContext(context.Context), accessRecords...)
if err != nil {
switch err := err.(type) {
case auth.Challenge:
@ -818,8 +818,14 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont
return err
}
if grant == nil {
return fmt.Errorf("access controller returned neither an access grant nor an error")
}
dcontext.GetLogger(ctx, auth.UserNameKey).Info("authorized request")
ctx := withUser(context.Context, grant.User)
ctx = withResources(ctx, grant.Resources)
dcontext.GetLogger(ctx, userNameKey).Info("authorized request")
// TODO(stevvooe): This pattern needs to be cleaned up a bit. One context
// should be replaced by another, rather than replacing the context on a
// mutable object.

View file

@ -9,7 +9,7 @@ import (
"testing"
"github.com/distribution/distribution/v3/configuration"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode"
v2 "github.com/distribution/distribution/v3/registry/api/v2"
"github.com/distribution/distribution/v3/registry/auth"
@ -25,7 +25,7 @@ import (
// tested individually.
func TestAppDispatcher(t *testing.T) {
driver := inmemory.New()
ctx := context.Background()
ctx := dcontext.Background()
registry, err := storage.NewRegistry(ctx, driver, storage.BlobDescriptorCacheProvider(memorycache.NewInMemoryBlobDescriptorCacheProvider(0)), storage.EnableDelete, storage.EnableRedirect)
if err != nil {
t.Fatalf("error creating registry: %v", err)
@ -139,7 +139,7 @@ func TestAppDispatcher(t *testing.T) {
// TestNewApp covers the creation of an application via NewApp with a
// configuration.
func TestNewApp(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
config := configuration.Configuration{
Storage: configuration.Storage{
"inmemory": nil,

View file

@ -4,7 +4,7 @@ import (
"net/http"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode"
"github.com/gorilla/handlers"
"github.com/opencontainers/go-digest"
@ -53,7 +53,7 @@ type blobHandler struct {
// GetBlob fetches the binary data from backend storage returns it in the
// response.
func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) {
context.GetLogger(bh).Debug("GetBlob")
dcontext.GetLogger(bh).Debug("GetBlob")
blobs := bh.Repository.Blobs(bh)
desc, err := blobs.Stat(bh, bh.Digest)
if err != nil {
@ -66,7 +66,7 @@ func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) {
}
if err := blobs.ServeBlob(bh, w, r, desc.Digest); err != nil {
context.GetLogger(bh).Debugf("unexpected error getting blob HTTP handler: %v", err)
dcontext.GetLogger(bh).Debugf("unexpected error getting blob HTTP handler: %v", err)
bh.Errors = append(bh.Errors, errcode.ErrorCodeUnknown.WithDetail(err))
return
}
@ -74,7 +74,7 @@ func (bh *blobHandler) GetBlob(w http.ResponseWriter, r *http.Request) {
// DeleteBlob deletes a layer blob
func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) {
context.GetLogger(bh).Debug("DeleteBlob")
dcontext.GetLogger(bh).Debug("DeleteBlob")
blobs := bh.Repository.Blobs(bh)
err := blobs.Delete(bh, bh.Digest)
@ -88,7 +88,7 @@ func (bh *blobHandler) DeleteBlob(w http.ResponseWriter, r *http.Request) {
return
default:
bh.Errors = append(bh.Errors, err)
context.GetLogger(bh).Errorf("Unknown error deleting blob: %s", err.Error())
dcontext.GetLogger(bh).Errorf("Unknown error deleting blob: %s", err.Error())
return
}
}

View file

@ -7,7 +7,7 @@ import (
"strconv"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode"
"github.com/distribution/distribution/v3/registry/storage"
"github.com/distribution/reference"

View file

@ -6,7 +6,7 @@ import (
"net/http"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/api/errcode"
v2 "github.com/distribution/distribution/v3/registry/api/v2"
"github.com/distribution/distribution/v3/registry/auth"
@ -77,10 +77,20 @@ func getUploadUUID(ctx context.Context) (uuid string) {
return dcontext.GetStringValue(ctx, "vars.uuid")
}
const (
// userKey is used to get the user object from
// a user context
userKey = "auth.user"
// userNameKey is used to get the user name from
// a user context
userNameKey = "auth.user.name"
)
// getUserName attempts to resolve a username from the context and request. If
// a username cannot be resolved, the empty string is returned.
func getUserName(ctx context.Context, r *http.Request) string {
username := dcontext.GetStringValue(ctx, auth.UserNameKey)
username := dcontext.GetStringValue(ctx, userNameKey)
// Fallback to request user with basic auth
if username == "" {
@ -93,3 +103,60 @@ func getUserName(ctx context.Context, r *http.Request) string {
return username
}
// withUser returns a context with the authorized user info.
func withUser(ctx context.Context, user auth.UserInfo) context.Context {
return userInfoContext{
Context: ctx,
user: user,
}
}
type userInfoContext struct {
context.Context
user auth.UserInfo
}
func (uic userInfoContext) Value(key interface{}) interface{} {
switch key {
case userKey:
return uic.user
case userNameKey:
return uic.user.Name
}
return uic.Context.Value(key)
}
// withResources returns a context with the authorized resources.
func withResources(ctx context.Context, resources []auth.Resource) context.Context {
return resourceContext{
Context: ctx,
resources: resources,
}
}
type resourceContext struct {
context.Context
resources []auth.Resource
}
type resourceKey struct{}
func (rc resourceContext) Value(key interface{}) interface{} {
if key == (resourceKey{}) {
return rc.resources
}
return rc.Context.Value(key)
}
// authorizedResources returns the list of resources which have
// been authorized for this request.
func authorizedResources(ctx context.Context) []auth.Resource {
if resources, ok := ctx.Value(resourceKey{}).([]auth.Resource); ok {
return resources
}
return nil
}

View file

@ -9,8 +9,8 @@ import (
"time"
"github.com/distribution/distribution/v3/configuration"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/health"
"github.com/distribution/distribution/v3/internal/dcontext"
)
func TestFileHealthCheck(t *testing.T) {
@ -39,7 +39,7 @@ func TestFileHealthCheck(t *testing.T) {
},
}
ctx := context.Background()
ctx := dcontext.Background()
app := NewApp(ctx, config)
healthRegistry := health.NewRegistry()
@ -103,7 +103,7 @@ func TestTCPHealthCheck(t *testing.T) {
},
}
ctx := context.Background()
ctx := dcontext.Background()
app := NewApp(ctx, config)
healthRegistry := health.NewRegistry()
@ -165,7 +165,7 @@ func TestHTTPHealthCheck(t *testing.T) {
},
}
ctx := context.Background()
ctx := dcontext.Background()
app := NewApp(ctx, config)
healthRegistry := health.NewRegistry()

View file

@ -9,7 +9,7 @@ import (
"strconv"
"strings"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
)
// closeResources closes all the provided resources after running the target

View file

@ -8,12 +8,11 @@ import (
"strings"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/manifestlist"
"github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/distribution/distribution/v3/manifest/schema2"
"github.com/distribution/distribution/v3/registry/api/errcode"
"github.com/distribution/distribution/v3/registry/auth"
"github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/reference"
"github.com/gorilla/handlers"
@ -394,7 +393,7 @@ func (imh *manifestHandler) applyResourcePolicy(manifest distribution.Manifest)
return errcode.ErrorCodeDenied.WithMessage(fmt.Sprintf("registry does not allow %s manifest", class))
}
resources := auth.AuthorizedResources(imh)
resources := authorizedResources(imh)
n := imh.Repository.Named().Name()
var foundResource bool

View file

@ -5,9 +5,9 @@ import (
"net/url"
"strings"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/client/auth"
"github.com/distribution/distribution/v3/internal/client/auth/challenge"
"github.com/distribution/distribution/v3/internal/dcontext"
)
const challengeHeader = "Docker-Distribution-Api-Version"
@ -44,7 +44,7 @@ func configureAuth(username, password, remoteURL string) (auth.CredentialStore,
}
for _, url := range authURLs {
context.GetLogger(context.Background()).Infof("Discovered token authentication URL: %s", url)
dcontext.GetLogger(dcontext.Background()).Infof("Discovered token authentication URL: %s", url)
creds[url] = userpass{
username: username,
password: password,

View file

@ -11,7 +11,7 @@ import (
"github.com/opencontainers/go-digest"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/proxy/scheduler"
"github.com/distribution/reference"
)

View file

@ -7,7 +7,7 @@ import (
"github.com/opencontainers/go-digest"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/proxy/scheduler"
"github.com/distribution/reference"
)

View file

@ -10,11 +10,11 @@ import (
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/configuration"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/client"
"github.com/distribution/distribution/v3/internal/client/auth"
"github.com/distribution/distribution/v3/internal/client/auth/challenge"
"github.com/distribution/distribution/v3/internal/client/transport"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/proxy/scheduler"
"github.com/distribution/distribution/v3/registry/storage"
"github.com/distribution/distribution/v3/registry/storage/driver"

View file

@ -7,7 +7,7 @@ import (
"sync"
"time"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/reference"
)

View file

@ -6,7 +6,7 @@ import (
"testing"
"time"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
"github.com/distribution/reference"
)
@ -40,7 +40,7 @@ func TestSchedule(t *testing.T) {
}
var mu sync.Mutex
s := New(context.Background(), inmemory.New(), "/ttl")
s := New(dcontext.Background(), inmemory.New(), "/ttl")
deleteFunc := func(repoName reference.Reference) error {
if len(remainingRepos) == 0 {
t.Fatalf("Incorrect expiry count")
@ -123,14 +123,14 @@ func TestRestoreOld(t *testing.T) {
t.Fatalf("Error serializing test data: %s", err.Error())
}
ctx := context.Background()
ctx := dcontext.Background()
pathToStatFile := "/ttl"
fs := inmemory.New()
err = fs.PutContent(ctx, pathToStatFile, serialized)
if err != nil {
t.Fatal("Unable to write serialized data to fs")
}
s := New(context.Background(), fs, "/ttl")
s := New(dcontext.Background(), fs, "/ttl")
s.OnBlobExpire(deleteFunc)
err = s.Start()
if err != nil {
@ -165,7 +165,7 @@ func TestStopRestore(t *testing.T) {
fs := inmemory.New()
pathToStateFile := "/ttl"
s := New(context.Background(), fs, pathToStateFile)
s := New(dcontext.Background(), fs, pathToStateFile)
s.onBlobExpire = deleteFunc
err := s.Start()
@ -181,7 +181,7 @@ func TestStopRestore(t *testing.T) {
time.Sleep(10 * time.Millisecond)
// v2 will restore state from fs
s2 := New(context.Background(), fs, pathToStateFile)
s2 := New(dcontext.Background(), fs, pathToStateFile)
s2.onBlobExpire = deleteFunc
err = s2.Start()
if err != nil {
@ -197,7 +197,7 @@ func TestStopRestore(t *testing.T) {
}
func TestDoubleStart(t *testing.T) {
s := New(context.Background(), inmemory.New(), "/ttl")
s := New(dcontext.Background(), inmemory.New(), "/ttl")
err := s.Start()
if err != nil {
t.Fatalf("Unable to start scheduler")

View file

@ -21,8 +21,8 @@ import (
"golang.org/x/crypto/acme/autocert"
"github.com/distribution/distribution/v3/configuration"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/health"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/handlers"
"github.com/distribution/distribution/v3/registry/listener"
"github.com/distribution/distribution/v3/version"

View file

@ -25,7 +25,7 @@ import (
"time"
"github.com/distribution/distribution/v3/configuration"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
_ "github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"

View file

@ -4,7 +4,7 @@ import (
"fmt"
"os"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage"
"github.com/distribution/distribution/v3/registry/storage/driver/factory"
"github.com/distribution/distribution/v3/version"

View file

@ -20,7 +20,7 @@ type blobServer struct {
driver driver.StorageDriver
statter distribution.BlobStatter
pathFn func(dgst digest.Digest) (string, error)
redirect bool // allows disabling URLFor redirects
redirect bool // allows disabling RedirectURL redirects
}
func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
@ -35,19 +35,16 @@ func (bs *blobServer) ServeBlob(ctx context.Context, w http.ResponseWriter, r *h
}
if bs.redirect {
redirectURL, err := bs.driver.URLFor(ctx, path, map[string]interface{}{"method": r.Method})
switch err.(type) {
case nil:
// Redirect to storage URL.
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
return err
case driver.ErrUnsupportedMethod:
// Fallback to serving the content directly.
default:
// Some unexpected error.
redirectURL, err := bs.driver.RedirectURL(r, path)
if err != nil {
return err
}
if redirectURL != "" {
// Redirect to storage URL.
http.Redirect(w, r, redirectURL, http.StatusTemporaryRedirect)
return nil
}
// Fallback to serving the content directly.
}
br, err := newFileReader(ctx, bs.driver, path, desc.Size)

View file

@ -6,7 +6,7 @@ import (
"path"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/opencontainers/go-digest"
)

View file

@ -9,7 +9,7 @@ import (
"time"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/opencontainers/go-digest"
"github.com/sirupsen/logrus"

View file

@ -4,7 +4,7 @@
package storage
import (
"github.com/distribution/distribution/v3/context"
"context"
)
// resumeHashAt is a noop when resumable digest support is disabled.

View file

@ -4,7 +4,7 @@ import (
"context"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
prometheus "github.com/distribution/distribution/v3/metrics"
"github.com/opencontainers/go-digest"
)

View file

@ -8,6 +8,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
@ -302,7 +303,7 @@ func (d *driver) List(ctx context.Context, path string) ([]string, error) {
// Move moves an object stored at sourcePath to destPath, removing the original
// object.
func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) error {
sourceBlobURL, err := d.URLFor(ctx, sourcePath, nil)
sourceBlobURL, err := d.signBlobURL(ctx, sourcePath)
if err != nil {
return err
}
@ -382,18 +383,15 @@ func (d *driver) Delete(ctx context.Context, path string) error {
return nil
}
// URLFor returns a publicly accessible URL for the blob stored at given path
// RedirectURL returns a publicly accessible URL for the blob stored at given path
// for specified duration by making use of Azure Storage Shared Access Signatures (SAS).
// See https://msdn.microsoft.com/en-us/library/azure/ee395415.aspx for more info.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
func (d *driver) RedirectURL(req *http.Request, path string) (string, error) {
return d.signBlobURL(req.Context(), path)
}
func (d *driver) signBlobURL(ctx context.Context, path string) (string, error) {
expiresTime := time.Now().UTC().Add(20 * time.Minute) // default expiration
expires, ok := options["expiry"]
if ok {
t, ok := expires.(time.Time)
if ok {
expiresTime = t
}
}
blobName := d.blobName(path)
blobRef := d.client.NewBlobClient(blobName)
return d.azClient.SignBlobURL(ctx, blobRef.URL(), expiresTime)

View file

@ -40,9 +40,10 @@ package base
import (
"context"
"io"
"net/http"
"time"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
prometheus "github.com/distribution/distribution/v3/metrics"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/docker/go-metrics"
@ -208,18 +209,18 @@ func (base *Base) Delete(ctx context.Context, path string) error {
return err
}
// URLFor wraps URLFor of underlying storage driver.
func (base *Base) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
ctx, done := dcontext.WithTrace(ctx)
defer done("%s.URLFor(%q)", base.Name(), path)
// RedirectURL wraps RedirectURL of the underlying storage driver.
func (base *Base) RedirectURL(r *http.Request, path string) (string, error) {
ctx, done := dcontext.WithTrace(r.Context())
defer done("%s.RedirectURL(%q)", base.Name(), path)
if !storagedriver.PathRegexp.MatchString(path) {
return "", storagedriver.InvalidPathError{Path: path, DriverName: base.StorageDriver.Name()}
}
start := time.Now()
str, e := base.StorageDriver.URLFor(ctx, path, options)
storageAction.WithValues(base.Name(), "URLFor").UpdateSince(start)
str, e := base.StorageDriver.RedirectURL(r.WithContext(ctx), path)
storageAction.WithValues(base.Name(), "RedirectURL").UpdateSince(start)
return str, base.setDriverName(e)
}

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"sync"
@ -172,13 +173,11 @@ func (r *regulator) Delete(ctx context.Context, path string) error {
return r.StorageDriver.Delete(ctx, path)
}
// URLFor returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// May return an ErrUnsupportedMethod in certain StorageDriver
// implementations.
func (r *regulator) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
// RedirectURL returns a URL which may be used to retrieve the content stored at
// the given path.
func (r *regulator) RedirectURL(req *http.Request, path string) (string, error) {
r.enter()
defer r.exit()
return r.StorageDriver.URLFor(ctx, path, options)
return r.StorageDriver.RedirectURL(req, path)
}

View file

@ -6,6 +6,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"os"
"path"
"time"
@ -282,10 +283,9 @@ func (d *driver) Delete(ctx context.Context, subPath string) error {
return err
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{}
// RedirectURL returns a URL which may be used to retrieve the content stored at the given path.
func (d *driver) RedirectURL(*http.Request, string) (string, error) {
return "", nil
}
// Walk traverses a filesystem defined within driver, starting

View file

@ -809,40 +809,24 @@ func storageCopyObject(ctx context.Context, srcBucket, srcName string, destBucke
return attrs, err
}
// URLFor returns a URL which may be used to retrieve the content stored at
// RedirectURL returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// Returns ErrUnsupportedMethod if this driver has no privateKey
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
func (d *driver) RedirectURL(r *http.Request, path string) (string, error) {
if d.privateKey == nil {
return "", storagedriver.ErrUnsupportedMethod{}
return "", nil
}
name := d.pathToKey(path)
methodString := http.MethodGet
method, ok := options["method"]
if ok {
methodString, ok = method.(string)
if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) {
return "", storagedriver.ErrUnsupportedMethod{}
}
}
expiresTime := time.Now().Add(20 * time.Minute)
expires, ok := options["expiry"]
if ok {
et, ok := expires.(time.Time)
if ok {
expiresTime = et
}
if r.Method != http.MethodGet && r.Method != http.MethodHead {
return "", nil
}
opts := &storage.SignedURLOptions{
GoogleAccessID: d.email,
PrivateKey: d.privateKey,
Method: methodString,
Expires: expiresTime,
Method: r.Method,
Expires: time.Now().Add(20 * time.Minute),
}
return storage.SignedURL(d.bucket, name, opts)
return storage.SignedURL(d.bucket, d.pathToKey(path), opts)
}
// Walk traverses a filesystem defined within driver, starting

View file

@ -10,7 +10,7 @@ import (
"testing"
"cloud.google.com/go/storage"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/distribution/v3/registry/storage/driver/testsuites"
"golang.org/x/oauth2"

View file

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net/http"
"sync"
"time"
@ -236,10 +237,9 @@ func (d *driver) Delete(ctx context.Context, path string) error {
}
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
return "", storagedriver.ErrUnsupportedMethod{}
// RedirectURL returns a URL which may be used to retrieve the content stored at the given path.
func (d *driver) RedirectURL(*http.Request, string) (string, error) {
return "", nil
}
// Walk traverses a filesystem defined within driver, starting

View file

@ -7,13 +7,14 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go/service/cloudfront/sign"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
storagemiddleware "github.com/distribution/distribution/v3/registry/storage/driver/middleware"
)
@ -201,18 +202,18 @@ type S3BucketKeyer interface {
S3BucketKey(path string) string
}
// URLFor attempts to find a url which may be used to retrieve the file at the given path.
// RedirectURL attempts to find a url which may be used to retrieve the file at the given path.
// Returns an error if the file cannot be found.
func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
func (lh *cloudFrontStorageMiddleware) RedirectURL(r *http.Request, path string) (string, error) {
// TODO(endophage): currently only supports S3
keyer, ok := lh.StorageDriver.(S3BucketKeyer)
if !ok {
dcontext.GetLogger(ctx).Warn("the CloudFront middleware does not support this backend storage driver")
return lh.StorageDriver.URLFor(ctx, path, options)
dcontext.GetLogger(r.Context()).Warn("the CloudFront middleware does not support this backend storage driver")
return lh.StorageDriver.RedirectURL(r, path)
}
if eligibleForS3(ctx, lh.awsIPs) {
return lh.StorageDriver.URLFor(ctx, path, options)
if eligibleForS3(r, lh.awsIPs) {
return lh.StorageDriver.RedirectURL(r, path)
}
// Get signed cloudfront url.

View file

@ -12,7 +12,8 @@ import (
"sync"
"time"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/internal/requestutil"
)
const (
@ -192,12 +193,8 @@ func (s *awsIPs) contains(ip net.IP) bool {
// parseIPFromRequest attempts to extract the ip address of the
// client that made the request
func parseIPFromRequest(ctx context.Context) (net.IP, error) {
request, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
ipStr := dcontext.RemoteIP(request)
func parseIPFromRequest(request *http.Request) (net.IP, error) {
ipStr := requestutil.RemoteIP(request)
ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr)
@ -208,25 +205,20 @@ func parseIPFromRequest(ctx context.Context) (net.IP, error) {
// eligibleForS3 checks if a request is eligible for using S3 directly
// Return true only when the IP belongs to a specific aws region and user-agent is docker
func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool {
func eligibleForS3(request *http.Request, awsIPs *awsIPs) bool {
if awsIPs != nil && awsIPs.initialized {
if addr, err := parseIPFromRequest(ctx); err == nil {
request, err := dcontext.GetRequest(ctx)
if err != nil {
dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err)
} else {
loggerField := map[interface{}]interface{}{
"user-client": request.UserAgent(),
"ip": dcontext.RemoteIP(request),
}
if awsIPs.contains(addr) {
dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront")
return true
}
dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
if addr, err := parseIPFromRequest(request); err == nil {
loggerField := map[interface{}]interface{}{
"user-client": request.UserAgent(),
"ip": requestutil.RemoteIP(request),
}
if awsIPs.contains(addr) {
dcontext.GetLoggerWithFields(request.Context(), loggerField).Info("request from the allowed AWS region, skipping CloudFront")
return true
}
dcontext.GetLoggerWithFields(request.Context(), loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
} else {
dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront")
dcontext.GetLogger(request.Context()).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront")
}
}
return false

View file

@ -1,7 +1,6 @@
package middleware
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
@ -11,8 +10,6 @@ import (
"reflect" // used as a replacement for testify
"testing"
"time"
dcontext "github.com/distribution/distribution/v3/context"
)
// Rather than pull in all of testify
@ -276,29 +273,22 @@ func TestEligibleForS3(t *testing.T) {
}},
initialized: true,
}
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
tests := []struct {
Context context.Context
Expected bool
RemoteAddr string
Expected bool
}{
{Context: empty, Expected: false},
{Context: makeContext("192.168.1.2"), Expected: true},
{Context: makeContext("192.168.0.2"), Expected: false},
{RemoteAddr: "", Expected: false},
{RemoteAddr: "192.168.1.2", Expected: true},
{RemoteAddr: "192.168.0.2", Expected: false},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) {
t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) {
t.Parallel()
assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips))
req := &http.Request{RemoteAddr: tc.RemoteAddr}
assertEqual(t, tc.Expected, eligibleForS3(req, ips))
})
}
}
@ -312,29 +302,22 @@ func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) {
}},
initialized: false,
}
empty := context.TODO()
makeContext := func(ip string) context.Context {
req := &http.Request{
RemoteAddr: ip,
}
return dcontext.WithRequest(empty, req)
}
tests := []struct {
Context context.Context
Expected bool
RemoteAddr string
Expected bool
}{
{Context: empty, Expected: false},
{Context: makeContext("192.168.1.2"), Expected: false},
{Context: makeContext("192.168.0.2"), Expected: false},
{RemoteAddr: "", Expected: false},
{RemoteAddr: "192.168.1.2", Expected: false},
{RemoteAddr: "192.168.0.2", Expected: false},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("Client IP = %v", tc.Context.Value("http.request.ip")), func(t *testing.T) {
t.Run(fmt.Sprintf("Client IP = %v", tc.RemoteAddr), func(t *testing.T) {
t.Parallel()
assertEqual(t, tc.Expected, eligibleForS3(tc.Context, ips))
req := &http.Request{RemoteAddr: tc.RemoteAddr}
assertEqual(t, tc.Expected, eligibleForS3(req, ips))
})
}
}

View file

@ -1,8 +1,8 @@
package middleware
import (
"context"
"fmt"
"net/http"
"net/url"
"path"
@ -42,7 +42,7 @@ func newRedirectStorageMiddleware(ctx context.Context, sd storagedriver.StorageD
return &redirectStorageMiddleware{StorageDriver: sd, scheme: u.Scheme, host: u.Host, basePath: u.Path}, nil
}
func (r *redirectStorageMiddleware) URLFor(ctx context.Context, urlPath string, options map[string]interface{}) (string, error) {
func (r *redirectStorageMiddleware) RedirectURL(_ *http.Request, urlPath string) (string, error) {
if r.basePath != "" {
urlPath = path.Join(r.basePath, urlPath)
}

View file

@ -1,7 +1,6 @@
package middleware
import (
"context"
"testing"
"gopkg.in/check.v1"
@ -37,7 +36,7 @@ func (s *MiddlewareSuite) TestHttpsPort(c *check.C) {
c.Assert(m.scheme, check.Equals, "https")
c.Assert(m.host, check.Equals, "example.com:5443")
url, err := middleware.URLFor(context.TODO(), "/rick/data", nil)
url, err := middleware.RedirectURL(nil, "/rick/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com:5443/rick/data")
}
@ -53,7 +52,7 @@ func (s *MiddlewareSuite) TestHTTP(c *check.C) {
c.Assert(m.scheme, check.Equals, "http")
c.Assert(m.host, check.Equals, "example.com")
url, err := middleware.URLFor(context.TODO(), "morty/data", nil)
url, err := middleware.RedirectURL(nil, "morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "http://example.com/morty/data")
}
@ -71,12 +70,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
c.Assert(m.host, check.Equals, "example.com")
c.Assert(m.basePath, check.Equals, "/path")
// call URLFor() with no leading slash
url, err := middleware.URLFor(context.TODO(), "morty/data", nil)
// call RedirectURL() with no leading slash
url, err := middleware.RedirectURL(nil, "morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data")
// call URLFor() with leading slash
url, err = middleware.URLFor(context.TODO(), "/morty/data", nil)
// call RedirectURL() with leading slash
url, err = middleware.RedirectURL(nil, "/morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data")
@ -91,12 +90,12 @@ func (s *MiddlewareSuite) TestPath(c *check.C) {
c.Assert(m.host, check.Equals, "example.com")
c.Assert(m.basePath, check.Equals, "/path/")
// call URLFor() with no leading slash
url, err = middleware.URLFor(context.TODO(), "morty/data", nil)
// call RedirectURL() with no leading slash
url, err = middleware.RedirectURL(nil, "morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data")
// call URLFor() with leading slash
url, err = middleware.URLFor(context.TODO(), "/morty/data", nil)
// call RedirectURL() with leading slash
url, err = middleware.RedirectURL(nil, "/morty/data")
c.Assert(err, check.Equals, nil)
c.Assert(url, check.Equals, "https://example.com/path/morty/data")
}

View file

@ -36,7 +36,7 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/distribution/v3/registry/storage/driver/base"
"github.com/distribution/distribution/v3/registry/storage/driver/factory"
@ -1036,30 +1036,13 @@ func (d *driver) Delete(ctx context.Context, path string) error {
return nil
}
// URLFor returns a URL which may be used to retrieve the content stored at the given path.
// May return an UnsupportedMethodErr in certain StorageDriver implementations.
func (d *driver) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
methodString := http.MethodGet
method, ok := options["method"]
if ok {
methodString, ok = method.(string)
if !ok || (methodString != http.MethodGet && methodString != http.MethodHead) {
return "", storagedriver.ErrUnsupportedMethod{}
}
}
// RedirectURL returns a URL which may be used to retrieve the content stored at the given path.
func (d *driver) RedirectURL(r *http.Request, path string) (string, error) {
expiresIn := 20 * time.Minute
expires, ok := options["expiry"]
if ok {
et, ok := expires.(time.Time)
if ok {
expiresIn = time.Until(et)
}
}
var req *request.Request
switch methodString {
switch r.Method {
case http.MethodGet:
req, _ = d.S3.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String(d.Bucket),
@ -1071,7 +1054,7 @@ func (d *driver) URLFor(ctx context.Context, path string, options map[string]int
Key: aws.String(d.s3Path(path)),
})
default:
panic("unreachable")
return "", nil
}
return req.Presign(expiresIn)

View file

@ -16,7 +16,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
storagedriver "github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/distribution/v3/registry/storage/driver/testsuites"
)
@ -180,7 +180,7 @@ func TestEmptyRootList(t *testing.T) {
filename := "/test"
contents := []byte("contents")
ctx := context.Background()
ctx := dcontext.Background()
err = rootedDriver.PutContent(ctx, filename, contents)
if err != nil {
t.Fatalf("unexpected error creating content: %v", err)
@ -209,7 +209,7 @@ func TestStorageClass(t *testing.T) {
rootDir := t.TempDir()
contents := []byte("contents")
ctx := context.Background()
ctx := dcontext.Background()
// We don't need to test all the storage classes, just that its selectable.
// The first 3 are common to AWS and MinIO, so use those.
@ -377,7 +377,7 @@ func TestDelete(t *testing.T) {
// init file structure matching objs
var created []string
for _, p := range objs {
err := drvr.PutContent(context.Background(), p, []byte("content "+p))
err := drvr.PutContent(dcontext.Background(), p, []byte("content "+p))
if err != nil {
fmt.Printf("unable to init file %s: %s\n", p, err)
continue
@ -390,7 +390,7 @@ func TestDelete(t *testing.T) {
cleanup := func(objs []string) {
var lastErr error
for _, p := range objs {
err := drvr.Delete(context.Background(), p)
err := drvr.Delete(dcontext.Background(), p)
if err != nil {
switch err.(type) {
case storagedriver.PathNotFoundError:
@ -409,7 +409,7 @@ func TestDelete(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
objs := init()
err := drvr.Delete(context.Background(), tc.delete)
err := drvr.Delete(dcontext.Background(), tc.delete)
if tc.err != nil {
if err == nil {
@ -437,7 +437,7 @@ func TestDelete(t *testing.T) {
return false
}
for _, path := range objs {
stat, err := drvr.Stat(context.Background(), path)
stat, err := drvr.Stat(dcontext.Background(), path)
if err != nil {
switch err.(type) {
case storagedriver.PathNotFoundError:
@ -491,7 +491,7 @@ func TestWalk(t *testing.T) {
// create file structure matching fileset above
created := make([]string, 0, len(fileset))
for _, p := range fileset {
err := drvr.PutContent(context.Background(), p, []byte("content "+p))
err := drvr.PutContent(dcontext.Background(), p, []byte("content "+p))
if err != nil {
fmt.Printf("unable to create file %s: %s\n", p, err)
continue
@ -503,7 +503,7 @@ func TestWalk(t *testing.T) {
defer func() {
var lastErr error
for _, p := range created {
err := drvr.Delete(context.Background(), p)
err := drvr.Delete(dcontext.Background(), p)
if err != nil {
_ = fmt.Errorf("cleanup failed for path %s: %s", p, err)
lastErr = err
@ -692,7 +692,7 @@ func TestWalk(t *testing.T) {
tc.from = "/"
}
t.Run(tc.name, func(t *testing.T) {
err := drvr.Walk(context.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error {
err := drvr.Walk(dcontext.Background(), tc.from, func(fileInfo storagedriver.FileInfo) error {
walked = append(walked, fileInfo.Path())
return tc.fn(fileInfo)
}, tc.options...)
@ -718,7 +718,7 @@ func TestOverThousandBlobs(t *testing.T) {
t.Fatalf("unexpected error creating driver with standard storage: %v", err)
}
ctx := context.Background()
ctx := dcontext.Background()
for i := 0; i < 1005; i++ {
filename := "/thousandfiletest/file" + strconv.Itoa(i)
contents := []byte("contents")
@ -746,7 +746,7 @@ func TestMoveWithMultipartCopy(t *testing.T) {
t.Fatalf("unexpected error creating driver: %v", err)
}
ctx := context.Background()
ctx := dcontext.Background()
sourcePath := "/source"
destPath := "/dest"
@ -795,7 +795,7 @@ func TestListObjectsV2(t *testing.T) {
t.Fatalf("unexpected error creating driver: %v", err)
}
ctx := context.Background()
ctx := dcontext.Background()
n := 6
prefix := "/test-list-objects-v2"
var filePaths []string

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
@ -92,11 +93,10 @@ type StorageDriver interface {
// Delete recursively deletes all objects stored at "path" and its subpaths.
Delete(ctx context.Context, path string) error
// URLFor returns a URL which may be used to retrieve the content stored at
// the given path, possibly using the given options.
// May return an ErrUnsupportedMethod in certain StorageDriver
// implementations.
URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error)
// RedirectURL returns a URL which the client of the request r may use
// to retrieve the content stored at path. Returning the empty string
// signals that the request may not be redirected.
RedirectURL(r *http.Request, path string) (string, error)
// Walk traverses a filesystem defined within driver, starting
// from the given path, calling f on each file.

View file

@ -8,6 +8,7 @@ import (
"io"
"math/rand"
"net/http"
"net/http/httptest"
"os"
"path"
"sort"
@ -733,9 +734,9 @@ func (suite *DriverSuite) TestDelete(c *check.C) {
c.Assert(strings.Contains(err.Error(), suite.Name()), check.Equals, true)
}
// TestURLFor checks that the URLFor method functions properly, but only if it
// is implemented
func (suite *DriverSuite) TestURLFor(c *check.C) {
// TestRedirectURL checks that the RedirectURL method functions properly,
// but only if it is implemented
func (suite *DriverSuite) TestRedirectURL(c *check.C) {
filename := randomPath(32)
contents := randomContents(32)
@ -744,8 +745,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) {
err := suite.StorageDriver.PutContent(suite.ctx, filename, contents)
c.Assert(err, check.IsNil)
url, err := suite.StorageDriver.URLFor(suite.ctx, filename, nil)
if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok {
url, err := suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodGet, filename, nil), filename)
if url == "" && err == nil {
return
}
c.Assert(err, check.IsNil)
@ -758,8 +759,8 @@ func (suite *DriverSuite) TestURLFor(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(read, check.DeepEquals, contents)
url, err = suite.StorageDriver.URLFor(suite.ctx, filename, map[string]interface{}{"method": http.MethodHead})
if _, ok := err.(storagedriver.ErrUnsupportedMethod); ok {
url, err = suite.StorageDriver.RedirectURL(httptest.NewRequest(http.MethodHead, filename, nil), filename)
if url == "" && err == nil {
return
}
c.Assert(err, check.IsNil)

View file

@ -7,13 +7,13 @@ import (
mrand "math/rand"
"testing"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
"github.com/opencontainers/go-digest"
)
func TestSimpleRead(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
content := make([]byte, 1<<20)
n, err := crand.Read(content)
if err != nil {
@ -55,7 +55,7 @@ func TestFileReaderSeek(t *testing.T) {
repititions := 1024
path := "/patterned"
content := bytes.Repeat([]byte(pattern), repititions)
ctx := context.Background()
ctx := dcontext.Background()
if err := driver.PutContent(ctx, path, content); err != nil {
t.Fatalf("error putting patterned content: %v", err)
@ -156,7 +156,7 @@ func TestFileReaderSeek(t *testing.T) {
// read method, with an io.EOF error.
func TestFileReaderNonExistentFile(t *testing.T) {
driver := inmemory.New()
fr, err := newFileReader(context.Background(), driver, "/doesnotexist", 10)
fr, err := newFileReader(dcontext.Background(), driver, "/doesnotexist", 10)
if err != nil {
t.Fatalf("unexpected error initializing reader: %v", err)
}

View file

@ -6,7 +6,7 @@ import (
"testing"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
"github.com/distribution/distribution/v3/testutil"
@ -21,7 +21,7 @@ type image struct {
}
func createRegistry(t *testing.T, driver driver.StorageDriver, options ...RegistryOption) distribution.Namespace {
ctx := context.Background()
ctx := dcontext.Background()
options = append(options, EnableDelete)
registry, err := NewRegistry(ctx, driver, options...)
if err != nil {
@ -31,7 +31,7 @@ func createRegistry(t *testing.T, driver driver.StorageDriver, options ...Regist
}
func makeRepository(t *testing.T, registry distribution.Namespace, name string) distribution.Repository {
ctx := context.Background()
ctx := dcontext.Background()
// Initialize a dummy repository
named, err := reference.WithName(name)
@ -47,7 +47,7 @@ func makeRepository(t *testing.T, registry distribution.Namespace, name string)
}
func makeManifestService(t *testing.T, repository distribution.Repository) distribution.ManifestService {
ctx := context.Background()
ctx := dcontext.Background()
manifestService, err := repository.Manifests(ctx)
if err != nil {
@ -57,7 +57,7 @@ func makeManifestService(t *testing.T, repository distribution.Repository) distr
}
func allManifests(t *testing.T, manifestService distribution.ManifestService) map[digest.Digest]struct{} {
ctx := context.Background()
ctx := dcontext.Background()
allManMap := make(map[digest.Digest]struct{})
manifestEnumerator, ok := manifestService.(distribution.ManifestEnumerator)
if !ok {
@ -74,7 +74,7 @@ func allManifests(t *testing.T, manifestService distribution.ManifestService) ma
}
func allBlobs(t *testing.T, registry distribution.Namespace) map[digest.Digest]struct{} {
ctx := context.Background()
ctx := dcontext.Background()
blobService := registry.Blobs()
allBlobsMap := make(map[digest.Digest]struct{})
err := blobService.Enumerate(ctx, func(dgst digest.Digest) error {
@ -95,7 +95,7 @@ func uploadImage(t *testing.T, repository distribution.Repository, im image) dig
}
// upload manifest
ctx := context.Background()
ctx := dcontext.Background()
manifestService := makeManifestService(t, repository)
manifestDigest, err := manifestService.Put(ctx, im.manifest)
if err != nil {
@ -130,7 +130,7 @@ func uploadRandomSchema2Image(t *testing.T, repository distribution.Repository)
}
func TestNoDeletionNoEffect(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver)
@ -158,7 +158,7 @@ func TestNoDeletionNoEffect(t *testing.T) {
before := allBlobs(t, registry)
// Run GC
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false,
RemoveUntagged: false,
})
@ -173,7 +173,7 @@ func TestNoDeletionNoEffect(t *testing.T) {
}
func TestDeleteManifestIfTagNotFound(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver)
@ -233,7 +233,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) {
before2 := allManifests(t, manifestService)
// run GC with dry-run (should not remove anything)
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: true,
RemoveUntagged: true,
})
@ -250,7 +250,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) {
}
// Run GC (removes everything because no manifests with tags exist)
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false,
RemoveUntagged: true,
})
@ -269,7 +269,7 @@ func TestDeleteManifestIfTagNotFound(t *testing.T) {
}
func TestGCWithMissingManifests(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
d := inmemory.New()
registry := createRegistry(t, d)
@ -288,7 +288,7 @@ func TestGCWithMissingManifests(t *testing.T) {
t.Fatal(err)
}
err = MarkAndSweep(context.Background(), d, registry, GCOpts{
err = MarkAndSweep(dcontext.Background(), d, registry, GCOpts{
DryRun: false,
RemoveUntagged: false,
})
@ -303,7 +303,7 @@ func TestGCWithMissingManifests(t *testing.T) {
}
func TestDeletionHasEffect(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver)
@ -318,7 +318,7 @@ func TestDeletionHasEffect(t *testing.T) {
manifests.Delete(ctx, image3.manifestDigest)
// Run GC
err := MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
err := MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false,
RemoveUntagged: false,
})
@ -368,7 +368,7 @@ func getKeys(digests map[digest.Digest]io.ReadSeeker) (ds []digest.Digest) {
}
func TestDeletionWithSharedLayer(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver)
@ -455,7 +455,7 @@ func TestOrphanBlobDeleted(t *testing.T) {
uploadRandomSchema2Image(t, repo)
// Run GC
err = MarkAndSweep(context.Background(), inmemoryDriver, registry, GCOpts{
err = MarkAndSweep(dcontext.Background(), inmemoryDriver, registry, GCOpts{
DryRun: false,
RemoveUntagged: false,
})

View file

@ -9,7 +9,7 @@ import (
"time"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/distribution/reference"
"github.com/google/uuid"

View file

@ -5,7 +5,7 @@ import (
"fmt"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/manifestlist"
"github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/opencontainers/go-digest"

View file

@ -6,7 +6,7 @@ import (
"fmt"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest"
"github.com/distribution/distribution/v3/manifest/manifestlist"
"github.com/distribution/distribution/v3/manifest/ocischema"

View file

@ -4,7 +4,7 @@ import (
"context"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/opencontainers/go-digest"
)

View file

@ -6,7 +6,7 @@ import (
"net/url"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/ocischema"
"github.com/opencontainers/go-digest"
v1 "github.com/opencontainers/image-spec/specs-go/v1"

View file

@ -34,7 +34,7 @@ type manifestURLs struct {
type RegistryOption func(*registry) error
// EnableRedirect is a functional option for NewRegistry. It causes the backend
// blob server to attempt using (StorageDriver).URLFor to serve all blobs.
// blob server to attempt using (StorageDriver).RedirectURL to serve all blobs.
func EnableRedirect(registry *registry) error {
registry.blobServer.redirect = true
return nil
@ -102,7 +102,7 @@ func BlobDescriptorCacheProvider(blobDescriptorCacheProvider cache.BlobDescripto
// NewRegistry creates a new registry instance from the provided driver. The
// resulting registry may be shared by multiple goroutines but is cheap to
// allocate. If the Redirect option is specified, the backend blob server will
// attempt to use (StorageDriver).URLFor to serve all blobs.
// attempt to use (StorageDriver).RedirectURL to serve all blobs.
func NewRegistry(ctx context.Context, driver storagedriver.StorageDriver, options ...RegistryOption) (distribution.Namespace, error) {
// create global statter
statter := &blobStatter{

View file

@ -7,7 +7,7 @@ import (
"net/url"
"github.com/distribution/distribution/v3"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/schema2"
"github.com/opencontainers/go-digest"
)

View file

@ -6,7 +6,7 @@ import (
"testing"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest"
"github.com/distribution/distribution/v3/manifest/schema2"
"github.com/distribution/distribution/v3/registry/storage/driver/inmemory"
@ -14,7 +14,7 @@ import (
)
func TestVerifyManifestForeignLayer(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver,
ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")),
@ -152,7 +152,7 @@ func TestVerifyManifestForeignLayer(t *testing.T) {
}
func TestVerifyManifestBlobLayerAndConfig(t *testing.T) {
ctx := context.Background()
ctx := dcontext.Background()
inmemoryDriver := inmemory.New()
registry := createRegistry(t, inmemoryDriver,
ManifestURLsAllowRegexp(regexp.MustCompile("^https?://foo")),

View file

@ -4,7 +4,7 @@ import (
"context"
"path"
dcontext "github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/storage/driver"
"github.com/opencontainers/go-digest"
)

View file

@ -4,7 +4,7 @@ import (
"fmt"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/manifest/manifestlist"
"github.com/distribution/distribution/v3/manifest/schema2"
"github.com/opencontainers/go-digest"
@ -12,7 +12,7 @@ import (
// MakeManifestList constructs a manifest list out of a list of manifest digests
func MakeManifestList(blobstatter distribution.BlobStatter, manifestDigests []digest.Digest) (*manifestlist.DeserializedManifestList, error) {
ctx := context.Background()
ctx := dcontext.Background()
manifestDescriptors := make([]manifestlist.ManifestDescriptor, 0, len(manifestDigests))
for _, manifestDigest := range manifestDigests {
@ -39,7 +39,7 @@ func MakeManifestList(blobstatter distribution.BlobStatter, manifestDigests []di
// MakeSchema2Manifest constructs a schema 2 manifest from a given list of digests and returns
// the digest of the manifest
func MakeSchema2Manifest(repository distribution.Repository, digests []digest.Digest) (distribution.Manifest, error) {
ctx := context.Background()
ctx := dcontext.Background()
blobStore := repository.Blobs(ctx)
var configJSON []byte

View file

@ -10,7 +10,7 @@ import (
"time"
"github.com/distribution/distribution/v3"
"github.com/distribution/distribution/v3/context"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/opencontainers/go-digest"
)
@ -96,7 +96,7 @@ func CreateRandomLayers(n int) (map[digest.Digest]io.ReadSeeker, error) {
// UploadBlobs lets you upload blobs to a repository
func UploadBlobs(repository distribution.Repository, layers map[digest.Digest]io.ReadSeeker) error {
ctx := context.Background()
ctx := dcontext.Background()
for dgst, rs := range layers {
wr, err := repository.Blobs(ctx).Create(ctx)
if err != nil {