Add unit tests for auth challenge and endpoint
Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)
This commit is contained in:
parent
174a732c94
commit
b1ba2183ee
7 changed files with 315 additions and 14 deletions
|
@ -127,7 +127,7 @@ func expectTokenOrQuoted(s string) (value string, rest string) {
|
||||||
p := make([]byte, len(s)-1)
|
p := make([]byte, len(s)-1)
|
||||||
j := copy(p, s[:i])
|
j := copy(p, s[:i])
|
||||||
escape := true
|
escape := true
|
||||||
for i = i + i; i < len(s); i++ {
|
for i = i + 1; i < len(s); i++ {
|
||||||
b := s[i]
|
b := s[i]
|
||||||
switch {
|
switch {
|
||||||
case escape:
|
case escape:
|
||||||
|
|
37
registry/client/authchallenge_test.go
Normal file
37
registry/client/authchallenge_test.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthChallengeParse(t *testing.T) {
|
||||||
|
header := http.Header{}
|
||||||
|
header.Add("WWW-Authenticate", `Bearer realm="https://auth.example.com/token",service="registry.example.com",other=fun,slashed="he\"\l\lo"`)
|
||||||
|
|
||||||
|
challenges := parseAuthHeader(header)
|
||||||
|
if len(challenges) != 1 {
|
||||||
|
t.Fatalf("Unexpected number of auth challenges: %d, expected 1", len(challenges))
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected := "bearer"; challenges[0].Scheme != expected {
|
||||||
|
t.Fatalf("Unexpected scheme: %s, expected: %s", challenges[0].Scheme, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected := "https://auth.example.com/token"; challenges[0].Parameters["realm"] != expected {
|
||||||
|
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["realm"], expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected := "registry.example.com"; challenges[0].Parameters["service"] != expected {
|
||||||
|
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["service"], expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected := "fun"; challenges[0].Parameters["other"] != expected {
|
||||||
|
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["other"], expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected := "he\"llo"; challenges[0].Parameters["slashed"] != expected {
|
||||||
|
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["slashed"], expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -117,6 +117,8 @@ func (e *RepositoryEndpoint) URLBuilder() (*v2.URLBuilder, error) {
|
||||||
|
|
||||||
// HTTPClient returns a new HTTP client configured for this endpoint
|
// HTTPClient returns a new HTTP client configured for this endpoint
|
||||||
func (e *RepositoryEndpoint) HTTPClient(name string) (*http.Client, error) {
|
func (e *RepositoryEndpoint) HTTPClient(name string) (*http.Client, error) {
|
||||||
|
// TODO(dmcgowan): create http.Transport
|
||||||
|
|
||||||
transport := &repositoryTransport{
|
transport := &repositoryTransport{
|
||||||
Header: e.Header,
|
Header: e.Header,
|
||||||
}
|
}
|
||||||
|
|
259
registry/client/endpoint_test.go
Normal file
259
registry/client/endpoint_test.go
Normal file
|
@ -0,0 +1,259 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/docker/distribution/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testAuthenticationWrapper struct {
|
||||||
|
headers http.Header
|
||||||
|
authCheck func(string) bool
|
||||||
|
next http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testAuthenticationWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
auth := r.Header.Get("Authorization")
|
||||||
|
if auth == "" || !w.authCheck(auth) {
|
||||||
|
h := rw.Header()
|
||||||
|
for k, values := range w.headers {
|
||||||
|
h[k] = values
|
||||||
|
}
|
||||||
|
rw.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.next.ServeHTTP(rw, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, authCheck func(string) bool) (*RepositoryEndpoint, func()) {
|
||||||
|
h := testutil.NewHandler(rrm)
|
||||||
|
wrapper := &testAuthenticationWrapper{
|
||||||
|
|
||||||
|
headers: http.Header(map[string][]string{
|
||||||
|
"Docker-Distribution-API-Version": {"registry/2.0"},
|
||||||
|
"WWW-Authenticate": {authenticate},
|
||||||
|
}),
|
||||||
|
authCheck: authCheck,
|
||||||
|
next: h,
|
||||||
|
}
|
||||||
|
|
||||||
|
s := httptest.NewServer(wrapper)
|
||||||
|
e := RepositoryEndpoint{Endpoint: s.URL, Mirror: false}
|
||||||
|
return &e, s.Close
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCredentialStore struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tcs *testCredentialStore) Basic(*url.URL) (string, string) {
|
||||||
|
return tcs.username, tcs.password
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEndpointAuthorizeToken(t *testing.T) {
|
||||||
|
service := "localhost.localdomain"
|
||||||
|
repo1 := "some/registry"
|
||||||
|
repo2 := "other/registry"
|
||||||
|
scope1 := fmt.Sprintf("repository:%s:pull,push", repo1)
|
||||||
|
scope2 := fmt.Sprintf("repository:%s:pull,push", repo2)
|
||||||
|
|
||||||
|
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||||
|
{
|
||||||
|
Request: testutil.Request{
|
||||||
|
Method: "GET",
|
||||||
|
Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope1), service),
|
||||||
|
},
|
||||||
|
Response: testutil.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: []byte(`{"token":"statictoken"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Request: testutil.Request{
|
||||||
|
Method: "GET",
|
||||||
|
Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope2), service),
|
||||||
|
},
|
||||||
|
Response: testutil.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: []byte(`{"token":"badtoken"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
te, tc := testServer(tokenMap)
|
||||||
|
defer tc()
|
||||||
|
|
||||||
|
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||||
|
{
|
||||||
|
Request: testutil.Request{
|
||||||
|
Method: "GET",
|
||||||
|
Route: "/hello",
|
||||||
|
},
|
||||||
|
Response: testutil.Response{
|
||||||
|
StatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service)
|
||||||
|
validCheck := func(a string) bool {
|
||||||
|
return a == "Bearer statictoken"
|
||||||
|
}
|
||||||
|
e, c := testServerWithAuth(m, authenicate, validCheck)
|
||||||
|
defer c()
|
||||||
|
|
||||||
|
client, err := e.HTTPClient(repo1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating http client: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error sending get request: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusAccepted {
|
||||||
|
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
|
||||||
|
}
|
||||||
|
|
||||||
|
badCheck := func(a string) bool {
|
||||||
|
return a == "Bearer statictoken"
|
||||||
|
}
|
||||||
|
e2, c2 := testServerWithAuth(m, authenicate, badCheck)
|
||||||
|
defer c2()
|
||||||
|
|
||||||
|
client2, err := e2.HTTPClient(repo2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating http client: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ = http.NewRequest("GET", e.Endpoint+"/hello", nil)
|
||||||
|
resp, err = client2.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error sending get request: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func basicAuth(username, password string) string {
|
||||||
|
auth := username + ":" + password
|
||||||
|
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEndpointAuthorizeTokenBasic(t *testing.T) {
|
||||||
|
service := "localhost.localdomain"
|
||||||
|
repo := "some/fun/registry"
|
||||||
|
scope := fmt.Sprintf("repository:%s:pull,push", repo)
|
||||||
|
username := "tokenuser"
|
||||||
|
password := "superSecretPa$$word"
|
||||||
|
|
||||||
|
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||||
|
{
|
||||||
|
Request: testutil.Request{
|
||||||
|
Method: "GET",
|
||||||
|
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
|
||||||
|
},
|
||||||
|
Response: testutil.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: []byte(`{"token":"statictoken"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
authenicate1 := fmt.Sprintf("Basic realm=localhost")
|
||||||
|
basicCheck := func(a string) bool {
|
||||||
|
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
|
||||||
|
}
|
||||||
|
te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck)
|
||||||
|
defer tc()
|
||||||
|
|
||||||
|
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||||
|
{
|
||||||
|
Request: testutil.Request{
|
||||||
|
Method: "GET",
|
||||||
|
Route: "/hello",
|
||||||
|
},
|
||||||
|
Response: testutil.Response{
|
||||||
|
StatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service)
|
||||||
|
bearerCheck := func(a string) bool {
|
||||||
|
return a == "Bearer statictoken"
|
||||||
|
}
|
||||||
|
e, c := testServerWithAuth(m, authenicate2, bearerCheck)
|
||||||
|
defer c()
|
||||||
|
|
||||||
|
e.Credentials = &testCredentialStore{
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := e.HTTPClient(repo)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating http client: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error sending get request: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusAccepted {
|
||||||
|
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEndpointAuthorizeBasic(t *testing.T) {
|
||||||
|
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
|
||||||
|
{
|
||||||
|
Request: testutil.Request{
|
||||||
|
Method: "GET",
|
||||||
|
Route: "/hello",
|
||||||
|
},
|
||||||
|
Response: testutil.Response{
|
||||||
|
StatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
username := "user1"
|
||||||
|
password := "funSecretPa$$word"
|
||||||
|
authenicate := fmt.Sprintf("Basic realm=localhost")
|
||||||
|
validCheck := func(a string) bool {
|
||||||
|
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
|
||||||
|
}
|
||||||
|
e, c := testServerWithAuth(m, authenicate, validCheck)
|
||||||
|
defer c()
|
||||||
|
e.Credentials = &testCredentialStore{
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := e.HTTPClient("test/repo/basic")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating http client: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error sending get request: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusAccepted {
|
||||||
|
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
|
||||||
|
}
|
||||||
|
}
|
|
@ -25,8 +25,8 @@ import (
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewRepositoryClient creates a new Repository for the given repository name and endpoint
|
// NewRepository creates a new Repository for the given repository name and endpoint
|
||||||
func NewRepositoryClient(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) {
|
func NewRepository(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) {
|
||||||
if err := v2.ValidateRespositoryName(name); err != nil {
|
if err := v2.ValidateRespositoryName(name); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,7 +97,7 @@ func TestLayerFetch(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e)
|
r, err := NewRepository(context.Background(), "test.example.com/repo1", e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ func TestLayerExists(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e)
|
r, err := NewRepository(context.Background(), "test.example.com/repo1", e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -227,7 +227,7 @@ func TestLayerUploadChunked(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
r, err := NewRepository(context.Background(), repo, e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -334,7 +334,7 @@ func TestLayerUploadMonolithic(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
r, err := NewRepository(context.Background(), repo, e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -475,7 +475,7 @@ func TestManifestFetch(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
r, err := NewRepository(context.Background(), repo, e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -508,7 +508,7 @@ func TestManifestFetchByTag(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
r, err := NewRepository(context.Background(), repo, e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -553,7 +553,7 @@ func TestManifestDelete(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
r, err := NewRepository(context.Background(), repo, e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -591,7 +591,7 @@ func TestManifestPut(t *testing.T) {
|
||||||
e, c := testServer(m)
|
e, c := testServer(m)
|
||||||
defer c()
|
defer c()
|
||||||
|
|
||||||
r, err := NewRepositoryClient(context.Background(), repo, e)
|
r, err := NewRepository(context.Background(), repo, e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -40,16 +41,18 @@ type Request struct {
|
||||||
func (r Request) String() string {
|
func (r Request) String() string {
|
||||||
queryString := ""
|
queryString := ""
|
||||||
if len(r.QueryParams) > 0 {
|
if len(r.QueryParams) > 0 {
|
||||||
queryString = "?"
|
|
||||||
keys := make([]string, 0, len(r.QueryParams))
|
keys := make([]string, 0, len(r.QueryParams))
|
||||||
|
queryParts := make([]string, 0, len(r.QueryParams))
|
||||||
for k := range r.QueryParams {
|
for k := range r.QueryParams {
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
}
|
}
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
queryString += strings.Join(r.QueryParams[k], "&") + "&"
|
for _, val := range r.QueryParams[k] {
|
||||||
|
queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, url.QueryEscape(val)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
queryString = queryString[:len(queryString)-1]
|
queryString = "?" + strings.Join(queryParts, "&")
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s %s%s\n%s", r.Method, r.Route, queryString, r.Body)
|
return fmt.Sprintf("%s %s%s\n%s", r.Method, r.Route, queryString, r.Body)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue