diff --git a/registry/client/auth/api_version.go b/registry/client/auth/api_version.go new file mode 100644 index 000000000..df095474d --- /dev/null +++ b/registry/client/auth/api_version.go @@ -0,0 +1,58 @@ +package auth + +import ( + "net/http" + "strings" +) + +// APIVersion represents a version of an API including its +// type and version number. +type APIVersion struct { + // Type refers to the name of a specific API specification + // such as "registry" + Type string + + // Version is the vesion of the API specification implemented, + // This may omit the revision number and only include + // the major and minor version, such as "2.0" + Version string +} + +// String returns the string formatted API Version +func (v APIVersion) String() string { + return v.Type + "/" + v.Version +} + +// APIVersions gets the API versions out of an HTTP response using the provided +// version header as the key for the HTTP header. +func APIVersions(resp *http.Response, versionHeader string) []APIVersion { + versions := []APIVersion{} + if versionHeader != "" { + for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey(versionHeader)] { + for _, version := range strings.Fields(supportedVersions) { + versions = append(versions, ParseAPIVersion(version)) + } + } + } + return versions +} + +// ParseAPIVersion parses an API version string into an APIVersion +// Format (Expected, not enforced): +// API version string = '/' +// API type = [a-z][a-z0-9]* +// API version = [0-9]+(\.[0-9]+)? +// TODO(dmcgowan): Enforce format, add error condition, remove unknown type +func ParseAPIVersion(versionStr string) APIVersion { + idx := strings.IndexRune(versionStr, '/') + if idx == -1 { + return APIVersion{ + Type: "unknown", + Version: versionStr, + } + } + return APIVersion{ + Type: strings.ToLower(versionStr[:idx]), + Version: versionStr[idx+1:], + } +} diff --git a/registry/client/auth/authchallenge.go b/registry/client/auth/authchallenge.go index e3abfb118..5d371646b 100644 --- a/registry/client/auth/authchallenge.go +++ b/registry/client/auth/authchallenge.go @@ -1,14 +1,10 @@ package auth import ( - "fmt" "net/http" "strings" ) -// Octet types from RFC 2616. -type octetType byte - // Challenge carries information from a WWW-Authenticate response header. // See RFC 2617. type Challenge struct { @@ -19,6 +15,9 @@ type Challenge struct { Parameters map[string]string } +// Octet types from RFC 2616. +type octetType byte + var octetTypes [256]octetType const ( @@ -58,36 +57,17 @@ func init() { } } -// Ping pings the provided endpoint to determine its required authorization challenges. -// If a version header is provided, the versions will be returned. -func Ping(client *http.Client, endpoint, versionHeader string) ([]Challenge, []string, error) { - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return nil, nil, err - } - - resp, err := client.Do(req) - if err != nil { - return nil, nil, err - } - defer resp.Body.Close() - - versions := []string{} - if versionHeader != "" { - for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey(versionHeader)] { - versions = append(versions, strings.Fields(supportedVersions)...) - } - } - +// ResponseChallenges returns a list of authorization challenges +// for the given http Response. Challenges are only checked if +// the response status code was a 401. +func ResponseChallenges(resp *http.Response) []Challenge { if resp.StatusCode == http.StatusUnauthorized { // Parse the WWW-Authenticate Header and store the challenges // on this endpoint object. - return parseAuthHeader(resp.Header), versions, nil - } else if resp.StatusCode != http.StatusOK { - return nil, versions, fmt.Errorf("unable to get valid ping response: %d", resp.StatusCode) + return parseAuthHeader(resp.Header) } - return nil, versions, nil + return nil } func parseAuthHeader(header http.Header) []Challenge { diff --git a/registry/client/auth/session_test.go b/registry/client/auth/session_test.go index f16836da3..3d19d4a7c 100644 --- a/registry/client/auth/session_test.go +++ b/registry/client/auth/session_test.go @@ -42,8 +42,9 @@ func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, au wrapper := &testAuthenticationWrapper{ headers: http.Header(map[string][]string{ - "Docker-Distribution-API-Version": {"registry/2.0"}, - "WWW-Authenticate": {authenticate}, + "X-API-Version": {"registry/2.0"}, + "X-Multi-API-Version": {"registry/2.0", "registry/2.1", "trust/1.0"}, + "WWW-Authenticate": {authenticate}, }), authCheck: authCheck, next: h, @@ -53,6 +54,18 @@ func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, au return s.URL, s.Close } +// ping pings the provided endpoint to determine its required authorization challenges. +// If a version header is provided, the versions will be returned. +func ping(endpoint, versionHeader string) ([]Challenge, []APIVersion, error) { + resp, err := http.Get(endpoint) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + return ResponseChallenges(resp), APIVersions(resp, versionHeader), err +} + type testCredentialStore struct { username string password string @@ -112,10 +125,16 @@ func TestEndpointAuthorizeToken(t *testing.T) { e, c := testServerWithAuth(m, authenicate, validCheck) defer c() - challenges1, _, err := Ping(&http.Client{}, e+"/v2/", "") + challenges1, versions, err := ping(e+"/v2/", "x-api-version") if err != nil { t.Fatal(err) } + if len(versions) != 1 { + t.Fatalf("Unexpected version count: %d, expected 1", len(versions)) + } + if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check { + t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check) + } challengeMap1 := map[string][]Challenge{ e + "/v2/": challenges1, } @@ -138,10 +157,22 @@ func TestEndpointAuthorizeToken(t *testing.T) { e2, c2 := testServerWithAuth(m, authenicate, badCheck) defer c2() - challenges2, _, err := Ping(&http.Client{}, e+"/v2/", "") + challenges2, versions, err := ping(e+"/v2/", "x-multi-api-version") if err != nil { t.Fatal(err) } + if len(versions) != 3 { + t.Fatalf("Unexpected version count: %d, expected 3", len(versions)) + } + if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check { + t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check) + } + if check := (APIVersion{Type: "registry", Version: "2.1"}); versions[1] != check { + t.Fatalf("Unexpected api version: %#v, expected %#v", versions[1], check) + } + if check := (APIVersion{Type: "trust", Version: "1.0"}); versions[2] != check { + t.Fatalf("Unexpected api version: %#v, expected %#v", versions[2], check) + } challengeMap2 := map[string][]Challenge{ e + "/v2/": challenges2, } @@ -215,7 +246,7 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) { password: password, } - challenges, _, err := Ping(&http.Client{}, e+"/v2/", "") + challenges, _, err := ping(e+"/v2/", "") if err != nil { t.Fatal(err) } @@ -262,7 +293,7 @@ func TestEndpointAuthorizeBasic(t *testing.T) { password: password, } - challenges, _, err := Ping(&http.Client{}, e+"/v2/", "") + challenges, _, err := ping(e+"/v2/", "") if err != nil { t.Fatal(err) }