diff --git a/docs/client/auth/authchallenge.go b/docs/client/auth/authchallenge.go index 5d371646..a6ad45d8 100644 --- a/docs/client/auth/authchallenge.go +++ b/docs/client/auth/authchallenge.go @@ -1,7 +1,9 @@ package auth import ( + "fmt" "net/http" + "net/url" "strings" ) @@ -15,6 +17,57 @@ type Challenge struct { Parameters map[string]string } +// ChallengeManager manages the challenges for endpoints. +// The challenges are pulled out of HTTP responses. Only +// responses which expect challenges should be added to +// the manager, since a non-unauthorized request will be +// viewed as not requiring challenges. +type ChallengeManager interface { + // GetChallenges returns the challenges for the given + // endpoint URL. + GetChallenges(endpoint string) ([]Challenge, error) + + // AddResponse adds the response to the challenge + // manager. The challenges will be parsed out of + // the WWW-Authenicate headers and added to the + // URL which was produced the response. If the + // response was authorized, any challenges for the + // endpoint will be cleared. + AddResponse(resp *http.Response) error +} + +// NewSimpleChallengeManager returns an instance of +// ChallengeManger which only maps endpoints to challenges +// based on the responses which have been added the +// manager. The simple manager will make no attempt to +// perform requests on the endpoints or cache the responses +// to a backend. +func NewSimpleChallengeManager() ChallengeManager { + return simpleChallengeManager{} +} + +type simpleChallengeManager map[string][]Challenge + +func (m simpleChallengeManager) GetChallenges(endpoint string) ([]Challenge, error) { + challenges := m[endpoint] + return challenges, nil +} + +func (m simpleChallengeManager) AddResponse(resp *http.Response) error { + challenges := ResponseChallenges(resp) + if resp.Request == nil { + return fmt.Errorf("missing request reference") + } + urlCopy := url.URL{ + Path: resp.Request.URL.Path, + Host: resp.Request.URL.Host, + Scheme: resp.Request.URL.Scheme, + } + m[urlCopy.String()] = challenges + + return nil +} + // Octet types from RFC 2616. type octetType byte diff --git a/docs/client/auth/session.go b/docs/client/auth/session.go index 5512a9a1..27e1d9e3 100644 --- a/docs/client/auth/session.go +++ b/docs/client/auth/session.go @@ -36,15 +36,15 @@ type CredentialStore interface { // schemes. The handlers are tried in order, the higher priority authentication // methods should be first. The challengeMap holds a list of challenges for // a given root API endpoint (for example "https://registry-1.docker.io/v2/"). -func NewAuthorizer(challengeMap map[string][]Challenge, handlers ...AuthenticationHandler) transport.RequestModifier { +func NewAuthorizer(manager ChallengeManager, handlers ...AuthenticationHandler) transport.RequestModifier { return &endpointAuthorizer{ - challenges: challengeMap, + challenges: manager, handlers: handlers, } } type endpointAuthorizer struct { - challenges map[string][]Challenge + challenges ChallengeManager handlers []AuthenticationHandler transport http.RoundTripper } @@ -63,18 +63,20 @@ func (ea *endpointAuthorizer) ModifyRequest(req *http.Request) error { pingEndpoint := ping.String() - challenges, ok := ea.challenges[pingEndpoint] - if !ok { - return nil + challenges, err := ea.challenges.GetChallenges(pingEndpoint) + if err != nil { + return err } - for _, handler := range ea.handlers { - for _, challenge := range challenges { - if challenge.Scheme != handler.Scheme() { - continue - } - if err := handler.AuthorizeRequest(req, challenge.Parameters); err != nil { - return err + if len(challenges) > 0 { + for _, handler := range ea.handlers { + for _, challenge := range challenges { + if challenge.Scheme != handler.Scheme() { + continue + } + if err := handler.AuthorizeRequest(req, challenge.Parameters); err != nil { + return err + } } } } diff --git a/docs/client/auth/session_test.go b/docs/client/auth/session_test.go index 3d19d4a7..1b4754ab 100644 --- a/docs/client/auth/session_test.go +++ b/docs/client/auth/session_test.go @@ -56,14 +56,18 @@ func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, au // ping pings the provided endpoint to determine its required authorization challenges. // If a version header is provided, the versions will be returned. -func ping(endpoint, versionHeader string) ([]Challenge, []APIVersion, error) { +func ping(manager ChallengeManager, endpoint, versionHeader string) ([]APIVersion, error) { resp, err := http.Get(endpoint) if err != nil { - return nil, nil, err + return nil, err } defer resp.Body.Close() - return ResponseChallenges(resp), APIVersions(resp, versionHeader), err + if err := manager.AddResponse(resp); err != nil { + return nil, err + } + + return APIVersions(resp, versionHeader), err } type testCredentialStore struct { @@ -125,7 +129,8 @@ func TestEndpointAuthorizeToken(t *testing.T) { e, c := testServerWithAuth(m, authenicate, validCheck) defer c() - challenges1, versions, err := ping(e+"/v2/", "x-api-version") + challengeManager1 := NewSimpleChallengeManager() + versions, err := ping(challengeManager1, e+"/v2/", "x-api-version") if err != nil { t.Fatal(err) } @@ -135,10 +140,7 @@ func TestEndpointAuthorizeToken(t *testing.T) { if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check { t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check) } - challengeMap1 := map[string][]Challenge{ - e + "/v2/": challenges1, - } - transport1 := transport.NewTransport(nil, NewAuthorizer(challengeMap1, NewTokenHandler(nil, nil, repo1, "pull", "push"))) + transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager1, NewTokenHandler(nil, nil, repo1, "pull", "push"))) client := &http.Client{Transport: transport1} req, _ := http.NewRequest("GET", e+"/v2/hello", nil) @@ -157,7 +159,8 @@ func TestEndpointAuthorizeToken(t *testing.T) { e2, c2 := testServerWithAuth(m, authenicate, badCheck) defer c2() - challenges2, versions, err := ping(e+"/v2/", "x-multi-api-version") + challengeManager2 := NewSimpleChallengeManager() + versions, err = ping(challengeManager2, e+"/v2/", "x-multi-api-version") if err != nil { t.Fatal(err) } @@ -173,10 +176,7 @@ func TestEndpointAuthorizeToken(t *testing.T) { if check := (APIVersion{Type: "trust", Version: "1.0"}); versions[2] != check { t.Fatalf("Unexpected api version: %#v, expected %#v", versions[2], check) } - challengeMap2 := map[string][]Challenge{ - e + "/v2/": challenges2, - } - transport2 := transport.NewTransport(nil, NewAuthorizer(challengeMap2, NewTokenHandler(nil, nil, repo2, "pull", "push"))) + transport2 := transport.NewTransport(nil, NewAuthorizer(challengeManager2, NewTokenHandler(nil, nil, repo2, "pull", "push"))) client2 := &http.Client{Transport: transport2} req, _ = http.NewRequest("GET", e2+"/v2/hello", nil) @@ -246,14 +246,12 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) { password: password, } - challenges, _, err := ping(e+"/v2/", "") + challengeManager := NewSimpleChallengeManager() + _, err := ping(challengeManager, e+"/v2/", "") if err != nil { t.Fatal(err) } - challengeMap := map[string][]Challenge{ - e + "/v2/": challenges, - } - transport1 := transport.NewTransport(nil, NewAuthorizer(challengeMap, NewTokenHandler(nil, creds, repo, "pull", "push"), NewBasicHandler(creds))) + transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, NewTokenHandler(nil, creds, repo, "pull", "push"), NewBasicHandler(creds))) client := &http.Client{Transport: transport1} req, _ := http.NewRequest("GET", e+"/v2/hello", nil) @@ -293,14 +291,12 @@ func TestEndpointAuthorizeBasic(t *testing.T) { password: password, } - challenges, _, err := ping(e+"/v2/", "") + challengeManager := NewSimpleChallengeManager() + _, err := ping(challengeManager, e+"/v2/", "") if err != nil { t.Fatal(err) } - challengeMap := map[string][]Challenge{ - e + "/v2/": challenges, - } - transport1 := transport.NewTransport(nil, NewAuthorizer(challengeMap, NewBasicHandler(creds))) + transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, NewBasicHandler(creds))) client := &http.Client{Transport: transport1} req, _ := http.NewRequest("GET", e+"/v2/hello", nil)