From c8fac94617d72197223490e0d57662587699fc58 Mon Sep 17 00:00:00 2001 From: Derek McGowan Date: Mon, 15 Jun 2015 16:10:48 -0700 Subject: [PATCH] Separate version and challenge parsing from ping Replace ping logic with individual functions to extract API version and authorization challenges. The response from a ping operation can be passed into these function. If an error occurs in parsing, the version or challenge will not be used. Sending the ping request is the responsibility of the caller. APIVersion has been converted from a string to a structure type. A parse function was added to convert from string to the structure type. Signed-off-by: Derek McGowan (github: dmcgowan) --- registry/client/auth/api_version.go | 58 +++++++++++++++++++++++++++ registry/client/auth/authchallenge.go | 38 +++++------------- registry/client/auth/session_test.go | 43 +++++++++++++++++--- 3 files changed, 104 insertions(+), 35 deletions(-) create mode 100644 registry/client/auth/api_version.go diff --git a/registry/client/auth/api_version.go b/registry/client/auth/api_version.go new file mode 100644 index 000000000..df095474d --- /dev/null +++ b/registry/client/auth/api_version.go @@ -0,0 +1,58 @@ +package auth + +import ( + "net/http" + "strings" +) + +// APIVersion represents a version of an API including its +// type and version number. +type APIVersion struct { + // Type refers to the name of a specific API specification + // such as "registry" + Type string + + // Version is the vesion of the API specification implemented, + // This may omit the revision number and only include + // the major and minor version, such as "2.0" + Version string +} + +// String returns the string formatted API Version +func (v APIVersion) String() string { + return v.Type + "/" + v.Version +} + +// APIVersions gets the API versions out of an HTTP response using the provided +// version header as the key for the HTTP header. +func APIVersions(resp *http.Response, versionHeader string) []APIVersion { + versions := []APIVersion{} + if versionHeader != "" { + for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey(versionHeader)] { + for _, version := range strings.Fields(supportedVersions) { + versions = append(versions, ParseAPIVersion(version)) + } + } + } + return versions +} + +// ParseAPIVersion parses an API version string into an APIVersion +// Format (Expected, not enforced): +// API version string = '/' +// API type = [a-z][a-z0-9]* +// API version = [0-9]+(\.[0-9]+)? +// TODO(dmcgowan): Enforce format, add error condition, remove unknown type +func ParseAPIVersion(versionStr string) APIVersion { + idx := strings.IndexRune(versionStr, '/') + if idx == -1 { + return APIVersion{ + Type: "unknown", + Version: versionStr, + } + } + return APIVersion{ + Type: strings.ToLower(versionStr[:idx]), + Version: versionStr[idx+1:], + } +} diff --git a/registry/client/auth/authchallenge.go b/registry/client/auth/authchallenge.go index e3abfb118..5d371646b 100644 --- a/registry/client/auth/authchallenge.go +++ b/registry/client/auth/authchallenge.go @@ -1,14 +1,10 @@ package auth import ( - "fmt" "net/http" "strings" ) -// Octet types from RFC 2616. -type octetType byte - // Challenge carries information from a WWW-Authenticate response header. // See RFC 2617. type Challenge struct { @@ -19,6 +15,9 @@ type Challenge struct { Parameters map[string]string } +// Octet types from RFC 2616. +type octetType byte + var octetTypes [256]octetType const ( @@ -58,36 +57,17 @@ func init() { } } -// Ping pings the provided endpoint to determine its required authorization challenges. -// If a version header is provided, the versions will be returned. -func Ping(client *http.Client, endpoint, versionHeader string) ([]Challenge, []string, error) { - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return nil, nil, err - } - - resp, err := client.Do(req) - if err != nil { - return nil, nil, err - } - defer resp.Body.Close() - - versions := []string{} - if versionHeader != "" { - for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey(versionHeader)] { - versions = append(versions, strings.Fields(supportedVersions)...) - } - } - +// ResponseChallenges returns a list of authorization challenges +// for the given http Response. Challenges are only checked if +// the response status code was a 401. +func ResponseChallenges(resp *http.Response) []Challenge { if resp.StatusCode == http.StatusUnauthorized { // Parse the WWW-Authenticate Header and store the challenges // on this endpoint object. - return parseAuthHeader(resp.Header), versions, nil - } else if resp.StatusCode != http.StatusOK { - return nil, versions, fmt.Errorf("unable to get valid ping response: %d", resp.StatusCode) + return parseAuthHeader(resp.Header) } - return nil, versions, nil + return nil } func parseAuthHeader(header http.Header) []Challenge { diff --git a/registry/client/auth/session_test.go b/registry/client/auth/session_test.go index f16836da3..3d19d4a7c 100644 --- a/registry/client/auth/session_test.go +++ b/registry/client/auth/session_test.go @@ -42,8 +42,9 @@ func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, au wrapper := &testAuthenticationWrapper{ headers: http.Header(map[string][]string{ - "Docker-Distribution-API-Version": {"registry/2.0"}, - "WWW-Authenticate": {authenticate}, + "X-API-Version": {"registry/2.0"}, + "X-Multi-API-Version": {"registry/2.0", "registry/2.1", "trust/1.0"}, + "WWW-Authenticate": {authenticate}, }), authCheck: authCheck, next: h, @@ -53,6 +54,18 @@ func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, au return s.URL, s.Close } +// ping pings the provided endpoint to determine its required authorization challenges. +// If a version header is provided, the versions will be returned. +func ping(endpoint, versionHeader string) ([]Challenge, []APIVersion, error) { + resp, err := http.Get(endpoint) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + return ResponseChallenges(resp), APIVersions(resp, versionHeader), err +} + type testCredentialStore struct { username string password string @@ -112,10 +125,16 @@ func TestEndpointAuthorizeToken(t *testing.T) { e, c := testServerWithAuth(m, authenicate, validCheck) defer c() - challenges1, _, err := Ping(&http.Client{}, e+"/v2/", "") + challenges1, versions, err := ping(e+"/v2/", "x-api-version") if err != nil { t.Fatal(err) } + if len(versions) != 1 { + t.Fatalf("Unexpected version count: %d, expected 1", len(versions)) + } + if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check { + t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check) + } challengeMap1 := map[string][]Challenge{ e + "/v2/": challenges1, } @@ -138,10 +157,22 @@ func TestEndpointAuthorizeToken(t *testing.T) { e2, c2 := testServerWithAuth(m, authenicate, badCheck) defer c2() - challenges2, _, err := Ping(&http.Client{}, e+"/v2/", "") + challenges2, versions, err := ping(e+"/v2/", "x-multi-api-version") if err != nil { t.Fatal(err) } + if len(versions) != 3 { + t.Fatalf("Unexpected version count: %d, expected 3", len(versions)) + } + if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check { + t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check) + } + if check := (APIVersion{Type: "registry", Version: "2.1"}); versions[1] != check { + t.Fatalf("Unexpected api version: %#v, expected %#v", versions[1], check) + } + if check := (APIVersion{Type: "trust", Version: "1.0"}); versions[2] != check { + t.Fatalf("Unexpected api version: %#v, expected %#v", versions[2], check) + } challengeMap2 := map[string][]Challenge{ e + "/v2/": challenges2, } @@ -215,7 +246,7 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) { password: password, } - challenges, _, err := Ping(&http.Client{}, e+"/v2/", "") + challenges, _, err := ping(e+"/v2/", "") if err != nil { t.Fatal(err) } @@ -262,7 +293,7 @@ func TestEndpointAuthorizeBasic(t *testing.T) { password: password, } - challenges, _, err := Ping(&http.Client{}, e+"/v2/", "") + challenges, _, err := ping(e+"/v2/", "") if err != nil { t.Fatal(err) }