Add expires_in and issued_at to the auth spec.

This extends the specification for the Bearer token response to include
information pertaining to when an issued Bearer token will expire.

This also allows the client to accept `access_token` as an alias for `token`.

Signed-off-by: Matt Moore <mattmoor@google.com>
This commit is contained in:
Matt Moore 2015-09-30 08:47:01 -07:00
parent 8cc790cccf
commit b38b98c8a8
2 changed files with 350 additions and 15 deletions

View file

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/Sirupsen/logrus"
"github.com/docker/distribution/registry/client" "github.com/docker/distribution/registry/client"
"github.com/docker/distribution/registry/client/transport" "github.com/docker/distribution/registry/client/transport"
) )
@ -85,11 +86,24 @@ func (ea *endpointAuthorizer) ModifyRequest(req *http.Request) error {
return nil return nil
} }
// This is the minimum duration a token can last (in seconds).
// A token must not live less than 60 seconds because older versions
// of the Docker client didn't read their expiration from the token
// response and assumed 60 seconds. So to remain compatible with
// those implementations, a token must live at least this long.
const minimumTokenLifetimeSeconds = 60
// Private interface for time used by this package to enable tests to provide their own implementation.
type clock interface {
Now() time.Time
}
type tokenHandler struct { type tokenHandler struct {
header http.Header header http.Header
creds CredentialStore creds CredentialStore
scope tokenScope scope tokenScope
transport http.RoundTripper transport http.RoundTripper
clock clock
tokenLock sync.Mutex tokenLock sync.Mutex
tokenCache string tokenCache string
@ -108,12 +122,24 @@ func (ts tokenScope) String() string {
return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ",")) return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ","))
} }
// An implementation of clock for providing real time data.
type realClock struct{}
// Now implements clock
func (realClock) Now() time.Time { return time.Now() }
// NewTokenHandler creates a new AuthenicationHandler which supports // NewTokenHandler creates a new AuthenicationHandler which supports
// fetching tokens from a remote token server. // fetching tokens from a remote token server.
func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope string, actions ...string) AuthenticationHandler { func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope string, actions ...string) AuthenticationHandler {
return newTokenHandler(transport, creds, realClock{}, scope, actions...)
}
// newTokenHandler exposes the option to provide a clock to manipulate time in unit testing.
func newTokenHandler(transport http.RoundTripper, creds CredentialStore, c clock, scope string, actions ...string) AuthenticationHandler {
return &tokenHandler{ return &tokenHandler{
transport: transport, transport: transport,
creds: creds, creds: creds,
clock: c,
scope: tokenScope{ scope: tokenScope{
Resource: "repository", Resource: "repository",
Scope: scope, Scope: scope,
@ -146,14 +172,14 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
func (th *tokenHandler) refreshToken(params map[string]string) error { func (th *tokenHandler) refreshToken(params map[string]string) error {
th.tokenLock.Lock() th.tokenLock.Lock()
defer th.tokenLock.Unlock() defer th.tokenLock.Unlock()
now := time.Now() now := th.clock.Now()
if now.After(th.tokenExpiration) { if now.After(th.tokenExpiration) {
token, err := th.fetchToken(params) tr, err := th.fetchToken(params)
if err != nil { if err != nil {
return err return err
} }
th.tokenCache = token th.tokenCache = tr.Token
th.tokenExpiration = now.Add(time.Minute) th.tokenExpiration = tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second)
} }
return nil return nil
@ -161,25 +187,28 @@ func (th *tokenHandler) refreshToken(params map[string]string) error {
type tokenResponse struct { type tokenResponse struct {
Token string `json:"token"` Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
} }
func (th *tokenHandler) fetchToken(params map[string]string) (token string, err error) { func (th *tokenHandler) fetchToken(params map[string]string) (token *tokenResponse, err error) {
//log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username) //log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username)
realm, ok := params["realm"] realm, ok := params["realm"]
if !ok { if !ok {
return "", errors.New("no realm specified for token auth challenge") return nil, errors.New("no realm specified for token auth challenge")
} }
// TODO(dmcgowan): Handle empty scheme // TODO(dmcgowan): Handle empty scheme
realmURL, err := url.Parse(realm) realmURL, err := url.Parse(realm)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid token auth challenge realm: %s", err) return nil, fmt.Errorf("invalid token auth challenge realm: %s", err)
} }
req, err := http.NewRequest("GET", realmURL.String(), nil) req, err := http.NewRequest("GET", realmURL.String(), nil)
if err != nil { if err != nil {
return "", err return nil, err
} }
reqParams := req.URL.Query() reqParams := req.URL.Query()
@ -206,26 +235,44 @@ func (th *tokenHandler) fetchToken(params map[string]string) (token string, err
resp, err := th.client().Do(req) resp, err := th.client().Do(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if !client.SuccessStatus(resp.StatusCode) { if !client.SuccessStatus(resp.StatusCode) {
return "", fmt.Errorf("token auth attempt for registry: %s request failed with status: %d %s", req.URL, resp.StatusCode, http.StatusText(resp.StatusCode)) return nil, fmt.Errorf("token auth attempt for registry: %s request failed with status: %d %s", req.URL, resp.StatusCode, http.StatusText(resp.StatusCode))
} }
decoder := json.NewDecoder(resp.Body) decoder := json.NewDecoder(resp.Body)
tr := new(tokenResponse) tr := new(tokenResponse)
if err = decoder.Decode(tr); err != nil { if err = decoder.Decode(tr); err != nil {
return "", fmt.Errorf("unable to decode token response: %s", err) return nil, fmt.Errorf("unable to decode token response: %s", err)
}
// `access_token` is equivalent to `token` and if both are specified
// the choice is undefined. Canonicalize `access_token` by sticking
// things in `token`.
if tr.AccessToken != "" {
tr.Token = tr.AccessToken
} }
if tr.Token == "" { if tr.Token == "" {
return "", errors.New("authorization server did not include a token in the response") return nil, errors.New("authorization server did not include a token in the response")
} }
return tr.Token, nil if tr.ExpiresIn < minimumTokenLifetimeSeconds {
logrus.Debugf("Increasing token expiration to: %d seconds", tr.ExpiresIn)
// The default/minimum lifetime.
tr.ExpiresIn = minimumTokenLifetimeSeconds
}
if tr.IssuedAt.IsZero() {
// issued_at is optional in the token response.
tr.IssuedAt = th.clock.Now()
}
return tr, nil
} }
type basicHandler struct { type basicHandler struct {

View file

@ -7,11 +7,20 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"time"
"github.com/docker/distribution/registry/client/transport" "github.com/docker/distribution/registry/client/transport"
"github.com/docker/distribution/testutil" "github.com/docker/distribution/testutil"
) )
// An implementation of clock for providing fake time data.
type fakeClock struct {
current time.Time
}
// Now implements clock
func (fc *fakeClock) Now() time.Time { return fc.current }
func testServer(rrm testutil.RequestResponseMap) (string, func()) { func testServer(rrm testutil.RequestResponseMap) (string, func()) {
h := testutil.NewHandler(rrm) h := testutil.NewHandler(rrm)
s := httptest.NewServer(h) s := httptest.NewServer(h)
@ -210,7 +219,7 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) {
}, },
Response: testutil.Response{ Response: testutil.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: []byte(`{"token":"statictoken"}`), Body: []byte(`{"access_token":"statictoken"}`),
}, },
}, },
}) })
@ -265,6 +274,285 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) {
} }
} }
func TestEndpointAuthorizeTokenBasicWithExpiresIn(t *testing.T) {
service := "localhost.localdomain"
repo := "some/fun/registry"
scope := fmt.Sprintf("repository:%s:pull,push", repo)
username := "tokenuser"
password := "superSecretPa$$word"
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"token":"statictoken", "expires_in": 3001}`),
},
},
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"access_token":"statictoken", "expires_in": 3001}`),
},
},
})
authenicate1 := fmt.Sprintf("Basic realm=localhost")
tokenExchanges := 0
basicCheck := func(a string) bool {
tokenExchanges = tokenExchanges + 1
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
}
te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
bearerCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate2, bearerCheck)
defer c()
creds := &testCredentialStore{
username: username,
password: password,
}
challengeManager := NewSimpleChallengeManager()
_, err := ping(challengeManager, e+"/v2/", "")
if err != nil {
t.Fatal(err)
}
clock := &fakeClock{current: time.Now()}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds)))
client := &http.Client{Transport: transport1}
// First call should result in a token exchange
// Subsequent calls should recycle the token from the first request, until the expiration has lapsed.
timeIncrement := 1000 * time.Second
for i := 0; i < 4; i++ {
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
if tokenExchanges != 1 {
t.Fatalf("Unexpected number of token exchanges, want: 1, got %d (iteration: %d)", tokenExchanges, i)
}
clock.current = clock.current.Add(timeIncrement)
}
// After we've exceeded the expiration, we should see a second token exchange.
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
if tokenExchanges != 2 {
t.Fatalf("Unexpected number of token exchanges, want: 2, got %d", tokenExchanges)
}
}
func TestEndpointAuthorizeTokenBasicWithExpiresInAndIssuedAt(t *testing.T) {
service := "localhost.localdomain"
repo := "some/fun/registry"
scope := fmt.Sprintf("repository:%s:pull,push", repo)
username := "tokenuser"
password := "superSecretPa$$word"
// This test sets things up such that the token was issued one increment
// earlier than its sibling in TestEndpointAuthorizeTokenBasicWithExpiresIn.
// This will mean that the token expires after 3 increments instead of 4.
clock := &fakeClock{current: time.Now()}
timeIncrement := 1000 * time.Second
firstIssuedAt := clock.Now()
clock.current = clock.current.Add(timeIncrement)
secondIssuedAt := clock.current.Add(2 * timeIncrement)
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"token":"statictoken", "issued_at": "` + firstIssuedAt.Format(time.RFC3339Nano) + `", "expires_in": 3001}`),
},
},
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"access_token":"statictoken", "issued_at": "` + secondIssuedAt.Format(time.RFC3339Nano) + `", "expires_in": 3001}`),
},
},
})
authenicate1 := fmt.Sprintf("Basic realm=localhost")
tokenExchanges := 0
basicCheck := func(a string) bool {
tokenExchanges = tokenExchanges + 1
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
}
te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
bearerCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate2, bearerCheck)
defer c()
creds := &testCredentialStore{
username: username,
password: password,
}
challengeManager := NewSimpleChallengeManager()
_, err := ping(challengeManager, e+"/v2/", "")
if err != nil {
t.Fatal(err)
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds)))
client := &http.Client{Transport: transport1}
// First call should result in a token exchange
// Subsequent calls should recycle the token from the first request, until the expiration has lapsed.
// We shaved one increment off of the equivalent logic in TestEndpointAuthorizeTokenBasicWithExpiresIn
// so this loop should have one fewer iteration.
for i := 0; i < 3; i++ {
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
if tokenExchanges != 1 {
t.Fatalf("Unexpected number of token exchanges, want: 1, got %d (iteration: %d)", tokenExchanges, i)
}
clock.current = clock.current.Add(timeIncrement)
}
// After we've exceeded the expiration, we should see a second token exchange.
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
if tokenExchanges != 2 {
t.Fatalf("Unexpected number of token exchanges, want: 2, got %d", tokenExchanges)
}
}
func TestEndpointAuthorizeBasic(t *testing.T) { func TestEndpointAuthorizeBasic(t *testing.T) {
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{ {