From b1ba2183ee74a482fe97e925cbc157669ac40d2e Mon Sep 17 00:00:00 2001 From: Derek McGowan Date: Thu, 7 May 2015 13:16:52 -0700 Subject: [PATCH] Add unit tests for auth challenge and endpoint Signed-off-by: Derek McGowan (github: dmcgowan) --- registry/client/authchallenge.go | 2 +- registry/client/authchallenge_test.go | 37 ++++ registry/client/endpoint.go | 2 + registry/client/endpoint_test.go | 259 ++++++++++++++++++++++++++ registry/client/repository.go | 4 +- registry/client/repository_test.go | 16 +- testutil/handler.go | 9 +- 7 files changed, 315 insertions(+), 14 deletions(-) create mode 100644 registry/client/authchallenge_test.go create mode 100644 registry/client/endpoint_test.go diff --git a/registry/client/authchallenge.go b/registry/client/authchallenge.go index 0485f42d..f45704b1 100644 --- a/registry/client/authchallenge.go +++ b/registry/client/authchallenge.go @@ -127,7 +127,7 @@ func expectTokenOrQuoted(s string) (value string, rest string) { p := make([]byte, len(s)-1) j := copy(p, s[:i]) escape := true - for i = i + i; i < len(s); i++ { + for i = i + 1; i < len(s); i++ { b := s[i] switch { case escape: diff --git a/registry/client/authchallenge_test.go b/registry/client/authchallenge_test.go new file mode 100644 index 00000000..bb3016ee --- /dev/null +++ b/registry/client/authchallenge_test.go @@ -0,0 +1,37 @@ +package client + +import ( + "net/http" + "testing" +) + +func TestAuthChallengeParse(t *testing.T) { + header := http.Header{} + header.Add("WWW-Authenticate", `Bearer realm="https://auth.example.com/token",service="registry.example.com",other=fun,slashed="he\"\l\lo"`) + + challenges := parseAuthHeader(header) + if len(challenges) != 1 { + t.Fatalf("Unexpected number of auth challenges: %d, expected 1", len(challenges)) + } + + if expected := "bearer"; challenges[0].Scheme != expected { + t.Fatalf("Unexpected scheme: %s, expected: %s", challenges[0].Scheme, expected) + } + + if expected := "https://auth.example.com/token"; challenges[0].Parameters["realm"] != expected { + t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["realm"], expected) + } + + if expected := "registry.example.com"; challenges[0].Parameters["service"] != expected { + t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["service"], expected) + } + + if expected := "fun"; challenges[0].Parameters["other"] != expected { + t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["other"], expected) + } + + if expected := "he\"llo"; challenges[0].Parameters["slashed"] != expected { + t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["slashed"], expected) + } + +} diff --git a/registry/client/endpoint.go b/registry/client/endpoint.go index 83d3d991..9889dc66 100644 --- a/registry/client/endpoint.go +++ b/registry/client/endpoint.go @@ -117,6 +117,8 @@ func (e *RepositoryEndpoint) URLBuilder() (*v2.URLBuilder, error) { // HTTPClient returns a new HTTP client configured for this endpoint func (e *RepositoryEndpoint) HTTPClient(name string) (*http.Client, error) { + // TODO(dmcgowan): create http.Transport + transport := &repositoryTransport{ Header: e.Header, } diff --git a/registry/client/endpoint_test.go b/registry/client/endpoint_test.go new file mode 100644 index 00000000..42bdc357 --- /dev/null +++ b/registry/client/endpoint_test.go @@ -0,0 +1,259 @@ +package client + +import ( + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/docker/distribution/testutil" +) + +type testAuthenticationWrapper struct { + headers http.Header + authCheck func(string) bool + next http.Handler +} + +func (w *testAuthenticationWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth == "" || !w.authCheck(auth) { + h := rw.Header() + for k, values := range w.headers { + h[k] = values + } + rw.WriteHeader(http.StatusUnauthorized) + return + } + w.next.ServeHTTP(rw, r) +} + +func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, authCheck func(string) bool) (*RepositoryEndpoint, func()) { + h := testutil.NewHandler(rrm) + wrapper := &testAuthenticationWrapper{ + + headers: http.Header(map[string][]string{ + "Docker-Distribution-API-Version": {"registry/2.0"}, + "WWW-Authenticate": {authenticate}, + }), + authCheck: authCheck, + next: h, + } + + s := httptest.NewServer(wrapper) + e := RepositoryEndpoint{Endpoint: s.URL, Mirror: false} + return &e, s.Close +} + +type testCredentialStore struct { + username string + password string +} + +func (tcs *testCredentialStore) Basic(*url.URL) (string, string) { + return tcs.username, tcs.password +} + +func TestEndpointAuthorizeToken(t *testing.T) { + service := "localhost.localdomain" + repo1 := "some/registry" + repo2 := "other/registry" + scope1 := fmt.Sprintf("repository:%s:pull,push", repo1) + scope2 := fmt.Sprintf("repository:%s:pull,push", repo2) + + tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope1), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"token":"statictoken"}`), + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope2), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"token":"badtoken"}`), + }, + }, + }) + te, tc := testServer(tokenMap) + defer tc() + + m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: "/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + }) + + authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service) + validCheck := func(a string) bool { + return a == "Bearer statictoken" + } + e, c := testServerWithAuth(m, authenicate, validCheck) + defer c() + + client, err := e.HTTPClient(repo1) + if err != nil { + t.Fatalf("Error creating http client: %s", err) + } + + req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } + + badCheck := func(a string) bool { + return a == "Bearer statictoken" + } + e2, c2 := testServerWithAuth(m, authenicate, badCheck) + defer c2() + + client2, err := e2.HTTPClient(repo2) + if err != nil { + t.Fatalf("Error creating http client: %s", err) + } + + req, _ = http.NewRequest("GET", e.Endpoint+"/hello", nil) + resp, err = client2.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized) + } +} + +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func TestEndpointAuthorizeTokenBasic(t *testing.T) { + service := "localhost.localdomain" + repo := "some/fun/registry" + scope := fmt.Sprintf("repository:%s:pull,push", repo) + username := "tokenuser" + password := "superSecretPa$$word" + + tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"token":"statictoken"}`), + }, + }, + }) + + authenicate1 := fmt.Sprintf("Basic realm=localhost") + basicCheck := func(a string) bool { + return a == fmt.Sprintf("Basic %s", basicAuth(username, password)) + } + te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck) + defer tc() + + m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: "/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + }) + + authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service) + bearerCheck := func(a string) bool { + return a == "Bearer statictoken" + } + e, c := testServerWithAuth(m, authenicate2, bearerCheck) + defer c() + + e.Credentials = &testCredentialStore{ + username: username, + password: password, + } + + client, err := e.HTTPClient(repo) + if err != nil { + t.Fatalf("Error creating http client: %s", err) + } + + req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } +} + +func TestEndpointAuthorizeBasic(t *testing.T) { + m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: "/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + }) + + username := "user1" + password := "funSecretPa$$word" + authenicate := fmt.Sprintf("Basic realm=localhost") + validCheck := func(a string) bool { + return a == fmt.Sprintf("Basic %s", basicAuth(username, password)) + } + e, c := testServerWithAuth(m, authenicate, validCheck) + defer c() + e.Credentials = &testCredentialStore{ + username: username, + password: password, + } + + client, err := e.HTTPClient("test/repo/basic") + if err != nil { + t.Fatalf("Error creating http client: %s", err) + } + + req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } +} diff --git a/registry/client/repository.go b/registry/client/repository.go index a96390fa..578c3fca 100644 --- a/registry/client/repository.go +++ b/registry/client/repository.go @@ -25,8 +25,8 @@ import ( "golang.org/x/net/context" ) -// NewRepositoryClient creates a new Repository for the given repository name and endpoint -func NewRepositoryClient(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) { +// NewRepository creates a new Repository for the given repository name and endpoint +func NewRepository(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) { if err := v2.ValidateRespositoryName(name); err != nil { return nil, err } diff --git a/registry/client/repository_test.go b/registry/client/repository_test.go index 67138db6..b96c52e5 100644 --- a/registry/client/repository_test.go +++ b/registry/client/repository_test.go @@ -97,7 +97,7 @@ func TestLayerFetch(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e) + r, err := NewRepository(context.Background(), "test.example.com/repo1", e) if err != nil { t.Fatal(err) } @@ -127,7 +127,7 @@ func TestLayerExists(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e) + r, err := NewRepository(context.Background(), "test.example.com/repo1", e) if err != nil { t.Fatal(err) } @@ -227,7 +227,7 @@ func TestLayerUploadChunked(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepositoryClient(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e) if err != nil { t.Fatal(err) } @@ -334,7 +334,7 @@ func TestLayerUploadMonolithic(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepositoryClient(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e) if err != nil { t.Fatal(err) } @@ -475,7 +475,7 @@ func TestManifestFetch(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepositoryClient(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e) if err != nil { t.Fatal(err) } @@ -508,7 +508,7 @@ func TestManifestFetchByTag(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepositoryClient(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e) if err != nil { t.Fatal(err) } @@ -553,7 +553,7 @@ func TestManifestDelete(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepositoryClient(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e) if err != nil { t.Fatal(err) } @@ -591,7 +591,7 @@ func TestManifestPut(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepositoryClient(context.Background(), repo, e) + r, err := NewRepository(context.Background(), repo, e) if err != nil { t.Fatal(err) } diff --git a/testutil/handler.go b/testutil/handler.go index fa118cd1..10850e24 100644 --- a/testutil/handler.go +++ b/testutil/handler.go @@ -6,6 +6,7 @@ import ( "io" "io/ioutil" "net/http" + "net/url" "sort" "strings" ) @@ -40,16 +41,18 @@ type Request struct { func (r Request) String() string { queryString := "" if len(r.QueryParams) > 0 { - queryString = "?" keys := make([]string, 0, len(r.QueryParams)) + queryParts := make([]string, 0, len(r.QueryParams)) for k := range r.QueryParams { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { - queryString += strings.Join(r.QueryParams[k], "&") + "&" + for _, val := range r.QueryParams[k] { + queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, url.QueryEscape(val))) + } } - queryString = queryString[:len(queryString)-1] + queryString = "?" + strings.Join(queryParts, "&") } return fmt.Sprintf("%s %s%s\n%s", r.Method, r.Route, queryString, r.Body) }