diff --git a/registry/auth/auth.go b/registry/auth/auth.go index 9cb036f1..1f28ea85 100644 --- a/registry/auth/auth.go +++ b/registry/auth/auth.go @@ -18,7 +18,7 @@ // resource := auth.Resource{Type: "customerOrder", Name: orderNumber} // access := auth.Access{Resource: resource, Action: "update"} // -// if ctx, err := accessController.Authorized(ctx, access); err != nil { +// if ctx, err := accessController.Authorized(r, access); err != nil { // if challenge, ok := err.(auth.Challenge) { // // Let the challenge write the response. // challenge.SetHeaders(r, w) @@ -93,16 +93,15 @@ type Challenge interface { // and required access levels for a request. Implementations can support both // complete denial and http authorization challenges. type AccessController interface { - // Authorized returns a non-nil error if the context is granted access and + // Authorized returns a nil error if the request is granted access and // returns a new authorized context. If one or more Access structs are // provided, the requested access will be compared with what is available - // to the context. The given context will contain a "http.request" key with - // a `*http.Request` value. If the error is non-nil, access should always - // be denied. The error may be of type Challenge, in which case the caller - // may have the Challenge handle the request or choose what action to take - // based on the Challenge header or response status. The returned context - // object should have a "auth.user" value set to a UserInfo struct. - Authorized(ctx context.Context, access ...Access) (context.Context, error) + // to the request. Access is denied if the error is non-nil. The error may + // be of type Challenge, in which case the caller may have the Challenge + // handle the request or choose what action to take based on the Challenge + // header or response status. The returned context object should be derived + // from r.Context() and have a "auth.user" value set to a UserInfo struct. + Authorized(r *http.Request, access ...Access) (context.Context, error) } // CredentialAuthenticator is an object which is able to authenticate credentials diff --git a/registry/auth/htpasswd/access.go b/registry/auth/htpasswd/access.go index 0a1d0c1c..a5c89a42 100644 --- a/registry/auth/htpasswd/access.go +++ b/registry/auth/htpasswd/access.go @@ -49,12 +49,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, return &accessController{realm: realm.(string), path: path}, nil } -func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { - req, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - +func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) { username, password, ok := req.BasicAuth() if !ok { return nil, &challenge{ @@ -92,14 +87,14 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut ac.mu.Unlock() if err := localHTPasswd.authenticateUser(username, password); err != nil { - dcontext.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err) + dcontext.GetLogger(req.Context()).Errorf("error authenticating user %q: %v", username, err) return nil, &challenge{ realm: ac.realm, err: auth.ErrAuthenticationFailure, } } - return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil + return auth.WithUser(req.Context(), auth.UserInfo{Name: username}), nil } // challenge implements the auth.Challenge interface. diff --git a/registry/auth/htpasswd/access_test.go b/registry/auth/htpasswd/access_test.go index 0871ef41..ad5e7f70 100644 --- a/registry/auth/htpasswd/access_test.go +++ b/registry/auth/htpasswd/access_test.go @@ -8,7 +8,6 @@ import ( "os" "testing" - "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) @@ -33,7 +32,6 @@ func TestBasicAccessController(t *testing.T) { "realm": testRealm, "path": tempFile.Name(), } - ctx := dcontext.Background() accessController, err := newAccessController(options) if err != nil { @@ -45,8 +43,7 @@ func TestBasicAccessController(t *testing.T) { userNumber := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := dcontext.WithRequest(ctx, r) - authCtx, err := accessController.Authorized(ctx) + authCtx, err := accessController.Authorized(r) if err != nil { switch err := err.(type) { case auth.Challenge: diff --git a/registry/auth/silly/access.go b/registry/auth/silly/access.go index 685cf6a6..c8f383e2 100644 --- a/registry/auth/silly/access.go +++ b/registry/auth/silly/access.go @@ -43,12 +43,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, // Authorized simply checks for the existence of the authorization header, // responding with a bearer challenge if it doesn't exist. -func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { - req, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - +func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) { if req.Header.Get("Authorization") == "" { challenge := challenge{ realm: ac.realm, @@ -66,7 +61,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut return nil, &challenge } - ctx = auth.WithUser(ctx, auth.UserInfo{Name: "silly"}) + ctx := auth.WithUser(req.Context(), auth.UserInfo{Name: "silly"}) ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey)) return ctx, nil diff --git a/registry/auth/silly/access_test.go b/registry/auth/silly/access_test.go index f463e98c..1a137c71 100644 --- a/registry/auth/silly/access_test.go +++ b/registry/auth/silly/access_test.go @@ -5,7 +5,6 @@ import ( "net/http/httptest" "testing" - "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" ) @@ -16,8 +15,7 @@ func TestSillyAccessController(t *testing.T) { } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := dcontext.WithRequest(dcontext.Background(), r) - authCtx, err := ac.Authorized(ctx) + authCtx, err := ac.Authorized(r) if err != nil { switch err := err.(type) { case auth.Challenge: diff --git a/registry/auth/token/accesscontroller.go b/registry/auth/token/accesscontroller.go index b2e4e4b2..e019d0f5 100644 --- a/registry/auth/token/accesscontroller.go +++ b/registry/auth/token/accesscontroller.go @@ -13,7 +13,6 @@ import ( "os" "strings" - "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" "github.com/go-jose/go-jose/v3" ) @@ -292,7 +291,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController, // Authorized handles checking whether the given request is authorized // for actions on resources described by the given access items. -func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) { +func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (context.Context, error) { challenge := &authChallenge{ realm: ac.realm, autoRedirect: ac.autoRedirect, @@ -300,11 +299,6 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth. accessSet: newAccessSet(accessItems...), } - req, err := dcontext.GetRequest(ctx) - if err != nil { - return nil, err - } - prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ") if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") { challenge.err = ErrTokenRequired @@ -338,7 +332,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth. } } - ctx = auth.WithResources(ctx, claims.resources()) + ctx := auth.WithResources(req.Context(), claims.resources()) return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil } diff --git a/registry/auth/token/token_test.go b/registry/auth/token/token_test.go index a331a93b..52d34a70 100644 --- a/registry/auth/token/token_test.go +++ b/registry/auth/token/token_test.go @@ -18,7 +18,6 @@ import ( "testing" "time" - "github.com/distribution/distribution/v3/internal/dcontext" "github.com/distribution/distribution/v3/registry/auth" "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" @@ -466,8 +465,7 @@ func TestAccessController(t *testing.T) { Action: "baz", } - ctx := dcontext.WithRequest(dcontext.Background(), req) - authCtx, err := accessController.Authorized(ctx, testAccess) + authCtx, err := accessController.Authorized(req, testAccess) challenge, ok := err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -502,7 +500,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(ctx, testAccess) + authCtx, err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -534,7 +532,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(ctx, testAccess) + authCtx, err = accessController.Authorized(req, testAccess) challenge, ok = err.(auth.Challenge) if !ok { t.Fatal("accessController did not return a challenge") @@ -564,7 +562,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - authCtx, err = accessController.Authorized(ctx, testAccess) + authCtx, err = accessController.Authorized(req, testAccess) if err != nil { t.Fatalf("accessController returned unexpected error: %s", err) } @@ -594,7 +592,7 @@ func TestAccessController(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) - _, err = accessController.Authorized(ctx, testAccess) + _, err = accessController.Authorized(req, testAccess) if err != nil { t.Fatalf("accessController returned unexpected error: %s", err) } diff --git a/registry/handlers/app.go b/registry/handlers/app.go index 7ce27d6d..8bb5bbbc 100644 --- a/registry/handlers/app.go +++ b/registry/handlers/app.go @@ -797,7 +797,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont accessRecords = appendCatalogAccessRecord(accessRecords, r) } - ctx, err := app.accessController.Authorized(context.Context, accessRecords...) + ctx, err := app.accessController.Authorized(r.WithContext(context.Context), accessRecords...) if err != nil { switch err := err.(type) { case auth.Challenge: