diff --git a/docs/client/auth/session.go b/docs/client/auth/session.go index 27a2aa71..6c92fc34 100644 --- a/docs/client/auth/session.go +++ b/docs/client/auth/session.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/Sirupsen/logrus" "github.com/docker/distribution/registry/client" "github.com/docker/distribution/registry/client/transport" ) @@ -85,11 +86,24 @@ func (ea *endpointAuthorizer) ModifyRequest(req *http.Request) error { return nil } +// This is the minimum duration a token can last (in seconds). +// A token must not live less than 60 seconds because older versions +// of the Docker client didn't read their expiration from the token +// response and assumed 60 seconds. So to remain compatible with +// those implementations, a token must live at least this long. +const minimumTokenLifetimeSeconds = 60 + +// Private interface for time used by this package to enable tests to provide their own implementation. +type clock interface { + Now() time.Time +} + type tokenHandler struct { header http.Header creds CredentialStore scope tokenScope transport http.RoundTripper + clock clock tokenLock sync.Mutex tokenCache string @@ -108,12 +122,24 @@ func (ts tokenScope) String() string { return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ",")) } +// An implementation of clock for providing real time data. +type realClock struct{} + +// Now implements clock +func (realClock) Now() time.Time { return time.Now() } + // NewTokenHandler creates a new AuthenicationHandler which supports // fetching tokens from a remote token server. func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope string, actions ...string) AuthenticationHandler { + return newTokenHandler(transport, creds, realClock{}, scope, actions...) +} + +// newTokenHandler exposes the option to provide a clock to manipulate time in unit testing. +func newTokenHandler(transport http.RoundTripper, creds CredentialStore, c clock, scope string, actions ...string) AuthenticationHandler { return &tokenHandler{ transport: transport, creds: creds, + clock: c, scope: tokenScope{ Resource: "repository", Scope: scope, @@ -146,40 +172,43 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st func (th *tokenHandler) refreshToken(params map[string]string) error { th.tokenLock.Lock() defer th.tokenLock.Unlock() - now := time.Now() + now := th.clock.Now() if now.After(th.tokenExpiration) { - token, err := th.fetchToken(params) + tr, err := th.fetchToken(params) if err != nil { return err } - th.tokenCache = token - th.tokenExpiration = now.Add(time.Minute) + th.tokenCache = tr.Token + th.tokenExpiration = tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second) } return nil } type tokenResponse struct { - Token string `json:"token"` + Token string `json:"token"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + IssuedAt time.Time `json:"issued_at"` } -func (th *tokenHandler) fetchToken(params map[string]string) (token string, err error) { +func (th *tokenHandler) fetchToken(params map[string]string) (token *tokenResponse, err error) { //log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username) realm, ok := params["realm"] if !ok { - return "", errors.New("no realm specified for token auth challenge") + return nil, errors.New("no realm specified for token auth challenge") } // TODO(dmcgowan): Handle empty scheme realmURL, err := url.Parse(realm) if err != nil { - return "", fmt.Errorf("invalid token auth challenge realm: %s", err) + return nil, fmt.Errorf("invalid token auth challenge realm: %s", err) } req, err := http.NewRequest("GET", realmURL.String(), nil) if err != nil { - return "", err + return nil, err } reqParams := req.URL.Query() @@ -206,26 +235,44 @@ func (th *tokenHandler) fetchToken(params map[string]string) (token string, err resp, err := th.client().Do(req) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() if !client.SuccessStatus(resp.StatusCode) { - return "", fmt.Errorf("token auth attempt for registry: %s request failed with status: %d %s", req.URL, resp.StatusCode, http.StatusText(resp.StatusCode)) + return nil, fmt.Errorf("token auth attempt for registry: %s request failed with status: %d %s", req.URL, resp.StatusCode, http.StatusText(resp.StatusCode)) } decoder := json.NewDecoder(resp.Body) tr := new(tokenResponse) if err = decoder.Decode(tr); err != nil { - return "", fmt.Errorf("unable to decode token response: %s", err) + return nil, fmt.Errorf("unable to decode token response: %s", err) + } + + // `access_token` is equivalent to `token` and if both are specified + // the choice is undefined. Canonicalize `access_token` by sticking + // things in `token`. + if tr.AccessToken != "" { + tr.Token = tr.AccessToken } if tr.Token == "" { - return "", errors.New("authorization server did not include a token in the response") + return nil, errors.New("authorization server did not include a token in the response") } - return tr.Token, nil + if tr.ExpiresIn < minimumTokenLifetimeSeconds { + logrus.Debugf("Increasing token expiration to: %d seconds", tr.ExpiresIn) + // The default/minimum lifetime. + tr.ExpiresIn = minimumTokenLifetimeSeconds + } + + if tr.IssuedAt.IsZero() { + // issued_at is optional in the token response. + tr.IssuedAt = th.clock.Now() + } + + return tr, nil } type basicHandler struct { diff --git a/docs/client/auth/session_test.go b/docs/client/auth/session_test.go index 1b4754ab..f1686942 100644 --- a/docs/client/auth/session_test.go +++ b/docs/client/auth/session_test.go @@ -7,11 +7,20 @@ import ( "net/http/httptest" "net/url" "testing" + "time" "github.com/docker/distribution/registry/client/transport" "github.com/docker/distribution/testutil" ) +// An implementation of clock for providing fake time data. +type fakeClock struct { + current time.Time +} + +// Now implements clock +func (fc *fakeClock) Now() time.Time { return fc.current } + func testServer(rrm testutil.RequestResponseMap) (string, func()) { h := testutil.NewHandler(rrm) s := httptest.NewServer(h) @@ -210,7 +219,7 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) { }, Response: testutil.Response{ StatusCode: http.StatusOK, - Body: []byte(`{"token":"statictoken"}`), + Body: []byte(`{"access_token":"statictoken"}`), }, }, }) @@ -265,6 +274,285 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) { } } +func TestEndpointAuthorizeTokenBasicWithExpiresIn(t *testing.T) { + service := "localhost.localdomain" + repo := "some/fun/registry" + scope := fmt.Sprintf("repository:%s:pull,push", repo) + username := "tokenuser" + password := "superSecretPa$$word" + + tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"token":"statictoken", "expires_in": 3001}`), + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"access_token":"statictoken", "expires_in": 3001}`), + }, + }, + }) + + authenicate1 := fmt.Sprintf("Basic realm=localhost") + tokenExchanges := 0 + basicCheck := func(a string) bool { + tokenExchanges = tokenExchanges + 1 + return a == fmt.Sprintf("Basic %s", basicAuth(username, password)) + } + te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck) + defer tc() + + m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + }) + + authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service) + bearerCheck := func(a string) bool { + return a == "Bearer statictoken" + } + e, c := testServerWithAuth(m, authenicate2, bearerCheck) + defer c() + + creds := &testCredentialStore{ + username: username, + password: password, + } + + challengeManager := NewSimpleChallengeManager() + _, err := ping(challengeManager, e+"/v2/", "") + if err != nil { + t.Fatal(err) + } + clock := &fakeClock{current: time.Now()} + transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds))) + client := &http.Client{Transport: transport1} + + // First call should result in a token exchange + // Subsequent calls should recycle the token from the first request, until the expiration has lapsed. + timeIncrement := 1000 * time.Second + for i := 0; i < 4; i++ { + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } + if tokenExchanges != 1 { + t.Fatalf("Unexpected number of token exchanges, want: 1, got %d (iteration: %d)", tokenExchanges, i) + } + clock.current = clock.current.Add(timeIncrement) + } + + // After we've exceeded the expiration, we should see a second token exchange. + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } + if tokenExchanges != 2 { + t.Fatalf("Unexpected number of token exchanges, want: 2, got %d", tokenExchanges) + } +} + +func TestEndpointAuthorizeTokenBasicWithExpiresInAndIssuedAt(t *testing.T) { + service := "localhost.localdomain" + repo := "some/fun/registry" + scope := fmt.Sprintf("repository:%s:pull,push", repo) + username := "tokenuser" + password := "superSecretPa$$word" + + // This test sets things up such that the token was issued one increment + // earlier than its sibling in TestEndpointAuthorizeTokenBasicWithExpiresIn. + // This will mean that the token expires after 3 increments instead of 4. + clock := &fakeClock{current: time.Now()} + timeIncrement := 1000 * time.Second + firstIssuedAt := clock.Now() + clock.current = clock.current.Add(timeIncrement) + secondIssuedAt := clock.current.Add(2 * timeIncrement) + tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"token":"statictoken", "issued_at": "` + firstIssuedAt.Format(time.RFC3339Nano) + `", "expires_in": 3001}`), + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"access_token":"statictoken", "issued_at": "` + secondIssuedAt.Format(time.RFC3339Nano) + `", "expires_in": 3001}`), + }, + }, + }) + + authenicate1 := fmt.Sprintf("Basic realm=localhost") + tokenExchanges := 0 + basicCheck := func(a string) bool { + tokenExchanges = tokenExchanges + 1 + return a == fmt.Sprintf("Basic %s", basicAuth(username, password)) + } + te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck) + defer tc() + + m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + }) + + authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service) + bearerCheck := func(a string) bool { + return a == "Bearer statictoken" + } + e, c := testServerWithAuth(m, authenicate2, bearerCheck) + defer c() + + creds := &testCredentialStore{ + username: username, + password: password, + } + + challengeManager := NewSimpleChallengeManager() + _, err := ping(challengeManager, e+"/v2/", "") + if err != nil { + t.Fatal(err) + } + transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds))) + client := &http.Client{Transport: transport1} + + // First call should result in a token exchange + // Subsequent calls should recycle the token from the first request, until the expiration has lapsed. + // We shaved one increment off of the equivalent logic in TestEndpointAuthorizeTokenBasicWithExpiresIn + // so this loop should have one fewer iteration. + for i := 0; i < 3; i++ { + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } + if tokenExchanges != 1 { + t.Fatalf("Unexpected number of token exchanges, want: 1, got %d (iteration: %d)", tokenExchanges, i) + } + clock.current = clock.current.Add(timeIncrement) + } + + // After we've exceeded the expiration, we should see a second token exchange. + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } + if tokenExchanges != 2 { + t.Fatalf("Unexpected number of token exchanges, want: 2, got %d", tokenExchanges) + } +} + func TestEndpointAuthorizeBasic(t *testing.T) { m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ {