registry/auth: pass request to AccessController

Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
Cory Snider 2023-10-24 14:08:04 -04:00
parent 9157226e7b
commit 49e22cbf3e
8 changed files with 23 additions and 47 deletions

View file

@ -18,7 +18,7 @@
// resource := auth.Resource{Type: "customerOrder", Name: orderNumber}
// access := auth.Access{Resource: resource, Action: "update"}
//
// if ctx, err := accessController.Authorized(ctx, access); err != nil {
// if ctx, err := accessController.Authorized(r, access); err != nil {
// if challenge, ok := err.(auth.Challenge) {
// // Let the challenge write the response.
// challenge.SetHeaders(r, w)
@ -93,16 +93,15 @@ type Challenge interface {
// and required access levels for a request. Implementations can support both
// complete denial and http authorization challenges.
type AccessController interface {
// Authorized returns a non-nil error if the context is granted access and
// Authorized returns a nil error if the request is granted access and
// returns a new authorized context. If one or more Access structs are
// provided, the requested access will be compared with what is available
// to the context. The given context will contain a "http.request" key with
// a `*http.Request` value. If the error is non-nil, access should always
// be denied. The error may be of type Challenge, in which case the caller
// may have the Challenge handle the request or choose what action to take
// based on the Challenge header or response status. The returned context
// object should have a "auth.user" value set to a UserInfo struct.
Authorized(ctx context.Context, access ...Access) (context.Context, error)
// to the request. Access is denied if the error is non-nil. The error may
// be of type Challenge, in which case the caller may have the Challenge
// handle the request or choose what action to take based on the Challenge
// header or response status. The returned context object should be derived
// from r.Context() and have a "auth.user" value set to a UserInfo struct.
Authorized(r *http.Request, access ...Access) (context.Context, error)
}
// CredentialAuthenticator is an object which is able to authenticate credentials

View file

@ -49,12 +49,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
return &accessController{realm: realm.(string), path: path}, nil
}
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) {
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) {
username, password, ok := req.BasicAuth()
if !ok {
return nil, &challenge{
@ -92,14 +87,14 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
ac.mu.Unlock()
if err := localHTPasswd.authenticateUser(username, password); err != nil {
dcontext.GetLogger(ctx).Errorf("error authenticating user %q: %v", username, err)
dcontext.GetLogger(req.Context()).Errorf("error authenticating user %q: %v", username, err)
return nil, &challenge{
realm: ac.realm,
err: auth.ErrAuthenticationFailure,
}
}
return auth.WithUser(ctx, auth.UserInfo{Name: username}), nil
return auth.WithUser(req.Context(), auth.UserInfo{Name: username}), nil
}
// challenge implements the auth.Challenge interface.

View file

@ -8,7 +8,6 @@ import (
"os"
"testing"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth"
)
@ -33,7 +32,6 @@ func TestBasicAccessController(t *testing.T) {
"realm": testRealm,
"path": tempFile.Name(),
}
ctx := dcontext.Background()
accessController, err := newAccessController(options)
if err != nil {
@ -45,8 +43,7 @@ func TestBasicAccessController(t *testing.T) {
userNumber := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := dcontext.WithRequest(ctx, r)
authCtx, err := accessController.Authorized(ctx)
authCtx, err := accessController.Authorized(r)
if err != nil {
switch err := err.(type) {
case auth.Challenge:

View file

@ -43,12 +43,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized simply checks for the existence of the authorization header,
// responding with a bearer challenge if it doesn't exist.
func (ac *accessController) Authorized(ctx context.Context, accessRecords ...auth.Access) (context.Context, error) {
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
func (ac *accessController) Authorized(req *http.Request, accessRecords ...auth.Access) (context.Context, error) {
if req.Header.Get("Authorization") == "" {
challenge := challenge{
realm: ac.realm,
@ -66,7 +61,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessRecords ...aut
return nil, &challenge
}
ctx = auth.WithUser(ctx, auth.UserInfo{Name: "silly"})
ctx := auth.WithUser(req.Context(), auth.UserInfo{Name: "silly"})
ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, auth.UserNameKey, auth.UserKey))
return ctx, nil

View file

@ -5,7 +5,6 @@ import (
"net/http/httptest"
"testing"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth"
)
@ -16,8 +15,7 @@ func TestSillyAccessController(t *testing.T) {
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := dcontext.WithRequest(dcontext.Background(), r)
authCtx, err := ac.Authorized(ctx)
authCtx, err := ac.Authorized(r)
if err != nil {
switch err := err.(type) {
case auth.Challenge:

View file

@ -13,7 +13,6 @@ import (
"os"
"strings"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3"
)
@ -292,7 +291,7 @@ func newAccessController(options map[string]interface{}) (auth.AccessController,
// Authorized handles checking whether the given request is authorized
// for actions on resources described by the given access items.
func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.Access) (context.Context, error) {
func (ac *accessController) Authorized(req *http.Request, accessItems ...auth.Access) (context.Context, error) {
challenge := &authChallenge{
realm: ac.realm,
autoRedirect: ac.autoRedirect,
@ -300,11 +299,6 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
accessSet: newAccessSet(accessItems...),
}
req, err := dcontext.GetRequest(ctx)
if err != nil {
return nil, err
}
prefix, rawToken, ok := strings.Cut(req.Header.Get("Authorization"), " ")
if !ok || rawToken == "" || !strings.EqualFold(prefix, "bearer") {
challenge.err = ErrTokenRequired
@ -338,7 +332,7 @@ func (ac *accessController) Authorized(ctx context.Context, accessItems ...auth.
}
}
ctx = auth.WithResources(ctx, claims.resources())
ctx := auth.WithResources(req.Context(), claims.resources())
return auth.WithUser(ctx, auth.UserInfo{Name: claims.Subject}), nil
}

View file

@ -18,7 +18,6 @@ import (
"testing"
"time"
"github.com/distribution/distribution/v3/internal/dcontext"
"github.com/distribution/distribution/v3/registry/auth"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
@ -466,8 +465,7 @@ func TestAccessController(t *testing.T) {
Action: "baz",
}
ctx := dcontext.WithRequest(dcontext.Background(), req)
authCtx, err := accessController.Authorized(ctx, testAccess)
authCtx, err := accessController.Authorized(req, testAccess)
challenge, ok := err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
@ -502,7 +500,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess)
authCtx, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
@ -534,7 +532,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess)
authCtx, err = accessController.Authorized(req, testAccess)
challenge, ok = err.(auth.Challenge)
if !ok {
t.Fatal("accessController did not return a challenge")
@ -564,7 +562,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
authCtx, err = accessController.Authorized(ctx, testAccess)
authCtx, err = accessController.Authorized(req, testAccess)
if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err)
}
@ -594,7 +592,7 @@ func TestAccessController(t *testing.T) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Raw))
_, err = accessController.Authorized(ctx, testAccess)
_, err = accessController.Authorized(req, testAccess)
if err != nil {
t.Fatalf("accessController returned unexpected error: %s", err)
}

View file

@ -797,7 +797,7 @@ func (app *App) authorized(w http.ResponseWriter, r *http.Request, context *Cont
accessRecords = appendCatalogAccessRecord(accessRecords, r)
}
ctx, err := app.accessController.Authorized(context.Context, accessRecords...)
ctx, err := app.accessController.Authorized(r.WithContext(context.Context), accessRecords...)
if err != nil {
switch err := err.(type) {
case auth.Challenge: