registry/auth: pass request to AccessController

Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
Cory Snider 2023-10-24 14:08:04 -04:00
parent 9157226e7b
commit 49e22cbf3e
8 changed files with 23 additions and 47 deletions

View file

@ -18,7 +18,7 @@
// resource := auth.Resource{Type: "customerOrder", Name: orderNumber} // resource := auth.Resource{Type: "customerOrder", Name: orderNumber}
// access := auth.Access{Resource: resource, Action: "update"} // access := auth.Access{Resource: resource, Action: "update"}
// //
// if ctx, err := accessController.Authorized(ctx, access); err != nil { // if ctx, err := accessController.Authorized(r, access); err != nil {
// if challenge, ok := err.(auth.Challenge) { // if challenge, ok := err.(auth.Challenge) {
// // Let the challenge write the response. // // Let the challenge write the response.
// challenge.SetHeaders(r, w) // challenge.SetHeaders(r, w)
@ -93,16 +93,15 @@ type Challenge interface {
// and required access levels for a request. Implementations can support both // and required access levels for a request. Implementations can support both
// complete denial and http authorization challenges. // complete denial and http authorization challenges.
type AccessController interface { type AccessController interface {
// Authorized returns a non-nil error if the context is granted access and // Authorized returns a nil error if the request is granted access and
// returns a new authorized context. If one or more Access structs are // returns a new authorized context. If one or more Access structs are
// provided, the requested access will be compared with what is available // provided, the requested access will be compared with what is available
// to the context. The given context will contain a "http.request" key with // to the request. Access is denied if the error is non-nil. The error may
// a `*http.Request` value. If the error is non-nil, access should always // be of type Challenge, in which case the caller may have the Challenge
// be denied. The error may be of type Challenge, in which case the caller // handle the request or choose what action to take based on the Challenge
// may have the Challenge handle the request or choose what action to take // header or response status. The returned context object should be derived
// based on the Challenge header or response status. The returned context // from r.Context() and have a "auth.user" value set to a UserInfo struct.
// object should have a "auth.user" value set to a UserInfo struct. Authorized(r *http.Request, access ...Access) (context.Context, error)
Authorized(ctx context.Context, access ...Access) (context.Context, error)
} }
// CredentialAuthenticator is an object which is able to authenticate credentials // CredentialAuthenticator is an object which is able to authenticate credentials

View file

@ -49,12 +49,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
return &accessController{realm: realm.(string), path: path}, nil return &accessController{realm: realm.(string), path: path}, nil
} }
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) {
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
username, password, ok := req.BasicAuth() username, password, ok := req.BasicAuth()
if !ok { if !ok {
return nil, &challenge{ return nil, &challenge{
@ -92,14 +87,14 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
ac.mu.Unlock() ac.mu.Unlock()
if err := localHTPasswd.authenticateUser(username, password); err != nil { if err := localHTPasswd.authenticateUser(username, password); err != nil {
dcontext.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err) dcontext.GetLogger(req.Context()).Errorf("error authenticating user %q: %v", username, err)
return nil, &challenge{ return nil, &challenge{
realm: ac.realm, realm: ac.realm,
err: auth.ErrAuthenticationFailure, err: auth.ErrAuthenticationFailure,
} }
} }
return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil return auth.WithUser(req.Context(), auth.UserInfo{Name: username}), nil
} }
// challenge implements the auth.Challenge interface. // challenge implements the auth.Challenge interface.

View file

@ -8,7 +8,6 @@ import (
"os" "os"
"testing" "testing"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
) )
@ -33,7 +32,6 @@ func TestBasicAccessController(t *testing.T) {
"realm": testRealm, "realm": testRealm,
"path": tempFile.Name(), "path": tempFile.Name(),
} }
ctx := dcontext.Background()
accessController, err := newAccessController(options) accessController, err := newAccessController(options)
if err != nil { if err != nil {
@ -45,8 +43,7 @@ func TestBasicAccessController(t *testing.T) {
userNumber := 0 userNumber := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := dcontext.WithRequest(ctx, r) authCtx, err := accessController.Authorized(r)
authCtx, err := accessController.Authorized(ctx)
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case auth.Challenge: case auth.Challenge:

View file

@ -43,12 +43,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized simply checks for the existence of the authorization header, // Authorized simply checks for the existence of the authorization header,
// responding with a bearer challenge if it doesn't exist. // responding with a bearer challenge if it doesn't exist.
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) { func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) {
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
if req.Header.Get("Authorization") == "" { if req.Header.Get("Authorization") == "" {
challenge := challenge{ challenge := challenge{
realm: ac.realm, realm: ac.realm,
@ -66,7 +61,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
return nil, &challenge return nil, &challenge
} }
ctx = auth.WithUser(ctx, auth.UserInfo{Name: "silly"}) ctx := auth.WithUser(req.Context(), auth.UserInfo{Name: "silly"})
ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey)) ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey))
return ctx, nil return ctx, nil

View file

@ -5,7 +5,6 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
) )
@ -16,8 +15,7 @@ func TestSillyAccessController(t *testing.T) {
} }
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := dcontext.WithRequest(dcontext.Background(), r) authCtx, err := ac.Authorized(r)
authCtx, err := ac.Authorized(ctx)
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case auth.Challenge: case auth.Challenge:

View file

@ -13,7 +13,6 @@ import (
"os" "os"
"strings" "strings"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
) )
@ -292,7 +291,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized handles checking whether the given request is authorized // Authorized handles checking whether the given request is authorized
// for actions on resources described by the given access items. // for actions on resources described by the given access items.
func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) { func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (context.Context, error) {
challenge := &authChallenge{ challenge := &authChallenge{
realm: ac.realm, realm: ac.realm,
autoRedirect: ac.autoRedirect, autoRedirect: ac.autoRedirect,
@ -300,11 +299,6 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
accessSet: newAccessSet(accessItems...), accessSet: newAccessSet(accessItems...),
} }
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ") prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ")
if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") { if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") {
challenge.err = ErrTokenRequired challenge.err = ErrTokenRequired
@ -338,7 +332,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
} }
} }
ctx = auth.WithResources(ctx, claims.resources()) ctx := auth.WithResources(req.Context(), claims.resources())
return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil
} }

View file

@ -18,7 +18,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth" "github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt" "github.com/go-jose/go-jose/v3/jwt"
@ -466,8 +465,7 @@ func TestAccessController(t *testing.T) {
Action: "baz", Action: "baz",
} }
ctx := dcontext.WithRequest(dcontext.Background(), req) authCtx, err := accessController.Authorized(req, testAccess)
authCtx, err := accessController.Authorized(ctx, testAccess)
challenge, ok := err.(auth.Challenge) challenge, ok := err.(auth.Challenge)
if !ok { if !ok {
t.Fatal("accessController did not return a challenge") t.Fatal("accessController did not return a challenge")
@ -502,7 +500,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess) authCtx, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge) challenge, ok = err.(auth.Challenge)
if !ok { if !ok {
t.Fatal("accessController did not return a challenge") t.Fatal("accessController did not return a challenge")
@ -534,7 +532,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess) authCtx, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge) challenge, ok = err.(auth.Challenge)
if !ok { if !ok {
t.Fatal("accessController did not return a challenge") t.Fatal("accessController did not return a challenge")
@ -564,7 +562,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess) authCtx, err = accessController.Authorized(req, testAccess)
if err != nil { if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err) t.Fatalf("accessController returned unexpected error: %s", err)
} }
@ -594,7 +592,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
_, err = accessController.Authorized(ctx, testAccess) _, err = accessController.Authorized(req, testAccess)
if err != nil { if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err) t.Fatalf("accessController returned unexpected error: %s", err)
} }

View file

@ -797,7 +797,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont
accessRecords = appendCatalogAccessRecord(accessRecords, r) accessRecords = appendCatalogAccessRecord(accessRecords, r)
} }
ctx, err := app.accessController.Authorized(context.Context, accessRecords...) ctx, err := app.accessController.Authorized(r.WithContext(context.Context), accessRecords...)
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case auth.Challenge: case auth.Challenge: