Separate version and challenge parsing from ping

Replace ping logic with individual functions to extract API version and authorization challenges. The response from a ping operation can be passed into these function. If an error occurs in parsing, the version or challenge will not be used. Sending the ping request is the responsibility of the caller.
APIVersion has been converted from a string to a structure type. A parse function was added to convert from string to the structure type.

Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)
This commit is contained in:
Derek McGowan 2015-06-15 16:10:48 -07:00
parent 5a7dab4670
commit c8fac94617
3 changed files with 104 additions and 35 deletions

View file

@ -0,0 +1,58 @@
package auth
import (
"net/http"
"strings"
)
// APIVersion represents a version of an API including its
// type and version number.
type APIVersion struct {
// Type refers to the name of a specific API specification
// such as "registry"
Type string
// Version is the vesion of the API specification implemented,
// This may omit the revision number and only include
// the major and minor version, such as "2.0"
Version string
}
// String returns the string formatted API Version
func (v APIVersion) String() string {
return v.Type + "/" + v.Version
}
// APIVersions gets the API versions out of an HTTP response using the provided
// version header as the key for the HTTP header.
func APIVersions(resp *http.Response, versionHeader string) []APIVersion {
versions := []APIVersion{}
if versionHeader != "" {
for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey(versionHeader)] {
for _, version := range strings.Fields(supportedVersions) {
versions = append(versions, ParseAPIVersion(version))
}
}
}
return versions
}
// ParseAPIVersion parses an API version string into an APIVersion
// Format (Expected, not enforced):
// API version string = <API type> '/' <API version>
// API type = [a-z][a-z0-9]*
// API version = [0-9]+(\.[0-9]+)?
// TODO(dmcgowan): Enforce format, add error condition, remove unknown type
func ParseAPIVersion(versionStr string) APIVersion {
idx := strings.IndexRune(versionStr, '/')
if idx == -1 {
return APIVersion{
Type: "unknown",
Version: versionStr,
}
}
return APIVersion{
Type: strings.ToLower(versionStr[:idx]),
Version: versionStr[idx+1:],
}
}

View file

@ -1,14 +1,10 @@
package auth package auth
import ( import (
"fmt"
"net/http" "net/http"
"strings" "strings"
) )
// Octet types from RFC 2616.
type octetType byte
// Challenge carries information from a WWW-Authenticate response header. // Challenge carries information from a WWW-Authenticate response header.
// See RFC 2617. // See RFC 2617.
type Challenge struct { type Challenge struct {
@ -19,6 +15,9 @@ type Challenge struct {
Parameters map[string]string Parameters map[string]string
} }
// Octet types from RFC 2616.
type octetType byte
var octetTypes [256]octetType var octetTypes [256]octetType
const ( const (
@ -58,36 +57,17 @@ func init() {
} }
} }
// Ping pings the provided endpoint to determine its required authorization challenges. // ResponseChallenges returns a list of authorization challenges
// If a version header is provided, the versions will be returned. // for the given http Response. Challenges are only checked if
func Ping(client *http.Client, endpoint, versionHeader string) ([]Challenge, []string, error) { // the response status code was a 401.
req, err := http.NewRequest("GET", endpoint, nil) func ResponseChallenges(resp *http.Response) []Challenge {
if err != nil {
return nil, nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
versions := []string{}
if versionHeader != "" {
for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey(versionHeader)] {
versions = append(versions, strings.Fields(supportedVersions)...)
}
}
if resp.StatusCode == http.StatusUnauthorized { if resp.StatusCode == http.StatusUnauthorized {
// Parse the WWW-Authenticate Header and store the challenges // Parse the WWW-Authenticate Header and store the challenges
// on this endpoint object. // on this endpoint object.
return parseAuthHeader(resp.Header), versions, nil return parseAuthHeader(resp.Header)
} else if resp.StatusCode != http.StatusOK {
return nil, versions, fmt.Errorf("unable to get valid ping response: %d", resp.StatusCode)
} }
return nil, versions, nil return nil
} }
func parseAuthHeader(header http.Header) []Challenge { func parseAuthHeader(header http.Header) []Challenge {

View file

@ -42,8 +42,9 @@ func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, au
wrapper := &testAuthenticationWrapper{ wrapper := &testAuthenticationWrapper{
headers: http.Header(map[string][]string{ headers: http.Header(map[string][]string{
"Docker-Distribution-API-Version": {"registry/2.0"}, "X-API-Version": {"registry/2.0"},
"WWW-Authenticate": {authenticate}, "X-Multi-API-Version": {"registry/2.0", "registry/2.1", "trust/1.0"},
"WWW-Authenticate": {authenticate},
}), }),
authCheck: authCheck, authCheck: authCheck,
next: h, next: h,
@ -53,6 +54,18 @@ func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, au
return s.URL, s.Close return s.URL, s.Close
} }
// ping pings the provided endpoint to determine its required authorization challenges.
// If a version header is provided, the versions will be returned.
func ping(endpoint, versionHeader string) ([]Challenge, []APIVersion, error) {
resp, err := http.Get(endpoint)
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
return ResponseChallenges(resp), APIVersions(resp, versionHeader), err
}
type testCredentialStore struct { type testCredentialStore struct {
username string username string
password string password string
@ -112,10 +125,16 @@ func TestEndpointAuthorizeToken(t *testing.T) {
e, c := testServerWithAuth(m, authenicate, validCheck) e, c := testServerWithAuth(m, authenicate, validCheck)
defer c() defer c()
challenges1, _, err := Ping(&http.Client{}, e+"/v2/", "") challenges1, versions, err := ping(e+"/v2/", "x-api-version")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
challengeMap1 := map[string][]Challenge{ challengeMap1 := map[string][]Challenge{
e + "/v2/": challenges1, e + "/v2/": challenges1,
} }
@ -138,10 +157,22 @@ func TestEndpointAuthorizeToken(t *testing.T) {
e2, c2 := testServerWithAuth(m, authenicate, badCheck) e2, c2 := testServerWithAuth(m, authenicate, badCheck)
defer c2() defer c2()
challenges2, _, err := Ping(&http.Client{}, e+"/v2/", "") challenges2, versions, err := ping(e+"/v2/", "x-multi-api-version")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(versions) != 3 {
t.Fatalf("Unexpected version count: %d, expected 3", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
if check := (APIVersion{Type: "registry", Version: "2.1"}); versions[1] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[1], check)
}
if check := (APIVersion{Type: "trust", Version: "1.0"}); versions[2] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[2], check)
}
challengeMap2 := map[string][]Challenge{ challengeMap2 := map[string][]Challenge{
e + "/v2/": challenges2, e + "/v2/": challenges2,
} }
@ -215,7 +246,7 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) {
password: password, password: password,
} }
challenges, _, err := Ping(&http.Client{}, e+"/v2/", "") challenges, _, err := ping(e+"/v2/", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -262,7 +293,7 @@ func TestEndpointAuthorizeBasic(t *testing.T) {
password: password, password: password,
} }
challenges, _, err := Ping(&http.Client{}, e+"/v2/", "") challenges, _, err := ping(e+"/v2/", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }