Merge pull request #1475 from dmcgowan/oauth-registry-client

Add oauth support to registry client auth
This commit is contained in:
Richard Scothern 2016-03-04 11:51:53 -08:00
commit 5b8592b72c
3 changed files with 432 additions and 108 deletions

View file

@ -36,6 +36,14 @@ type AuthenticationHandler interface {
type CredentialStore interface {
// Basic returns basic auth for the given URL
Basic(*url.URL) (string, string)
// RefreshToken returns a refresh token for the
// given URL and service
RefreshToken(*url.URL, string) string
// SetRefreshToken sets the refresh token if none
// is provided for the given url and service
SetRefreshToken(realm *url.URL, service, token string)
}
// NewAuthorizer creates an authorizer which can handle multiple authentication
@ -105,27 +113,45 @@ type clock interface {
type tokenHandler struct {
header http.Header
creds CredentialStore
scope tokenScope
transport http.RoundTripper
clock clock
forceOAuth bool
clientID string
scopes []Scope
tokenLock sync.Mutex
tokenCache string
tokenExpiration time.Time
additionalScopes map[string]struct{}
}
// tokenScope represents the scope at which a token will be requested.
// This represents a specific action on a registry resource.
type tokenScope struct {
Resource string
Scope string
// Scope is a type which is serializable to a string
// using the allow scope grammar.
type Scope interface {
String() string
}
// RepositoryScope represents a token scope for access
// to a repository.
type RepositoryScope struct {
Repository string
Actions []string
}
func (ts tokenScope) String() string {
return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ","))
// String returns the string representation of the repository
// using the scope grammar
func (rs RepositoryScope) String() string {
return fmt.Sprintf("repository:%s:%s", rs.Repository, strings.Join(rs.Actions, ","))
}
// TokenHandlerOptions is used to configure a new token handler
type TokenHandlerOptions struct {
Transport http.RoundTripper
Credentials CredentialStore
ForceOAuth bool
ClientID string
Scopes []Scope
}
// An implementation of clock for providing real time data.
@ -137,22 +163,32 @@ func (realClock) Now() time.Time { return time.Now() }
// NewTokenHandler creates a new AuthenicationHandler which supports
// fetching tokens from a remote token server.
func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope string, actions ...string) AuthenticationHandler {
return newTokenHandler(transport, creds, realClock{}, scope, actions...)
}
// newTokenHandler exposes the option to provide a clock to manipulate time in unit testing.
func newTokenHandler(transport http.RoundTripper, creds CredentialStore, c clock, scope string, actions ...string) AuthenticationHandler {
return &tokenHandler{
transport: transport,
creds: creds,
clock: c,
scope: tokenScope{
Resource: "repository",
Scope: scope,
// Create options...
return NewTokenHandlerWithOptions(TokenHandlerOptions{
Transport: transport,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: scope,
Actions: actions,
},
additionalScopes: map[string]struct{}{},
},
})
}
// NewTokenHandlerWithOptions creates a new token handler using the provided
// options structure.
func NewTokenHandlerWithOptions(options TokenHandlerOptions) AuthenticationHandler {
handler := &tokenHandler{
transport: options.Transport,
creds: options.Credentials,
forceOAuth: options.ForceOAuth,
clientID: options.ClientID,
scopes: options.Scopes,
clock: realClock{},
}
return handler
}
func (th *tokenHandler) client() *http.Client {
@ -169,9 +205,8 @@ func (th *tokenHandler) Scheme() string {
func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error {
var additionalScopes []string
if fromParam := req.URL.Query().Get("from"); fromParam != "" {
additionalScopes = append(additionalScopes, tokenScope{
Resource: "repository",
Scope: fromParam,
additionalScopes = append(additionalScopes, RepositoryScope{
Repository: fromParam,
Actions: []string{"pull"},
}.String())
}
@ -187,104 +222,89 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
func (th *tokenHandler) refreshToken(params map[string]string, additionalScopes ...string) error {
th.tokenLock.Lock()
defer th.tokenLock.Unlock()
scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
for _, scope := range th.scopes {
scopes = append(scopes, scope.String())
}
var addedScopes bool
for _, scope := range additionalScopes {
if _, ok := th.additionalScopes[scope]; !ok {
th.additionalScopes[scope] = struct{}{}
scopes = append(scopes, scope)
addedScopes = true
}
}
now := th.clock.Now()
if now.After(th.tokenExpiration) || addedScopes {
tr, err := th.fetchToken(params)
token, expiration, err := th.fetchToken(params, scopes)
if err != nil {
return err
}
th.tokenCache = tr.Token
th.tokenExpiration = tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second)
// do not update cache for added scope tokens
if !addedScopes {
th.tokenCache = token
th.tokenExpiration = expiration
}
}
return nil
}
type tokenResponse struct {
Token string `json:"token"`
type postTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"`
}
func (th *tokenHandler) fetchToken(params map[string]string) (token *tokenResponse, err error) {
realm, ok := params["realm"]
if !ok {
return nil, errors.New("no realm specified for token auth challenge")
func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
form := url.Values{}
form.Set("scope", strings.Join(scopes, " "))
form.Set("service", service)
clientID := th.clientID
if clientID == "" {
// Use default client, this is a required field
clientID = "registry-client"
}
form.Set("client_id", clientID)
if refreshToken != "" {
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
} else if th.creds != nil {
form.Set("grant_type", "password")
username, password := th.creds.Basic(realm)
form.Set("username", username)
form.Set("password", password)
// attempt to get a refresh token
form.Set("access_type", "offline")
} else {
// refuse to do oauth without a grant type
return "", time.Time{}, fmt.Errorf("no supported grant type")
}
// TODO(dmcgowan): Handle empty scheme
realmURL, err := url.Parse(realm)
resp, err := th.client().PostForm(realm.String(), form)
if err != nil {
return nil, fmt.Errorf("invalid token auth challenge realm: %s", err)
}
req, err := http.NewRequest("GET", realmURL.String(), nil)
if err != nil {
return nil, err
}
reqParams := req.URL.Query()
service := params["service"]
scope := th.scope.String()
if service != "" {
reqParams.Add("service", service)
}
for _, scopeField := range strings.Fields(scope) {
reqParams.Add("scope", scopeField)
}
for scope := range th.additionalScopes {
reqParams.Add("scope", scope)
}
if th.creds != nil {
username, password := th.creds.Basic(realmURL)
if username != "" && password != "" {
reqParams.Add("account", username)
req.SetBasicAuth(username, password)
}
}
req.URL.RawQuery = reqParams.Encode()
resp, err := th.client().Do(req)
if err != nil {
return nil, err
return "", time.Time{}, err
}
defer resp.Body.Close()
if !client.SuccessStatus(resp.StatusCode) {
err := client.HandleErrorResponse(resp)
return nil, err
return "", time.Time{}, err
}
decoder := json.NewDecoder(resp.Body)
tr := new(tokenResponse)
if err = decoder.Decode(tr); err != nil {
return nil, fmt.Errorf("unable to decode token response: %s", err)
var tr postTokenResponse
if err = decoder.Decode(&tr); err != nil {
return "", time.Time{}, fmt.Errorf("unable to decode token response: %s", err)
}
// `access_token` is equivalent to `token` and if both are specified
// the choice is undefined. Canonicalize `access_token` by sticking
// things in `token`.
if tr.AccessToken != "" {
tr.Token = tr.AccessToken
}
if tr.Token == "" {
return nil, errors.New("authorization server did not include a token in the response")
if tr.RefreshToken != "" && tr.RefreshToken != refreshToken {
th.creds.SetRefreshToken(realm, service, tr.RefreshToken)
}
if tr.ExpiresIn < minimumTokenLifetimeSeconds {
@ -295,10 +315,119 @@ func (th *tokenHandler) fetchToken(params map[string]string) (token *tokenRespon
if tr.IssuedAt.IsZero() {
// issued_at is optional in the token response.
tr.IssuedAt = th.clock.Now()
tr.IssuedAt = th.clock.Now().UTC()
}
return tr, nil
return tr.AccessToken, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}
type getTokenResponse struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"`
}
func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequest("GET", realm.String(), nil)
if err != nil {
return "", time.Time{}, err
}
reqParams := req.URL.Query()
if service != "" {
reqParams.Add("service", service)
}
for _, scope := range scopes {
reqParams.Add("scope", scope)
}
if th.creds != nil {
username, password := th.creds.Basic(realm)
if username != "" && password != "" {
reqParams.Add("account", username)
req.SetBasicAuth(username, password)
}
}
req.URL.RawQuery = reqParams.Encode()
resp, err := th.client().Do(req)
if err != nil {
return "", time.Time{}, err
}
defer resp.Body.Close()
if !client.SuccessStatus(resp.StatusCode) {
err := client.HandleErrorResponse(resp)
return "", time.Time{}, err
}
decoder := json.NewDecoder(resp.Body)
var tr getTokenResponse
if err = decoder.Decode(&tr); err != nil {
return "", time.Time{}, fmt.Errorf("unable to decode token response: %s", err)
}
if tr.RefreshToken != "" && th.creds != nil {
th.creds.SetRefreshToken(realm, service, tr.RefreshToken)
}
// `access_token` is equivalent to `token` and if both are specified
// the choice is undefined. Canonicalize `access_token` by sticking
// things in `token`.
if tr.AccessToken != "" {
tr.Token = tr.AccessToken
}
if tr.Token == "" {
return "", time.Time{}, errors.New("authorization server did not include a token in the response")
}
if tr.ExpiresIn < minimumTokenLifetimeSeconds {
// The default/minimum lifetime.
tr.ExpiresIn = minimumTokenLifetimeSeconds
logrus.Debugf("Increasing token expiration to: %d seconds", tr.ExpiresIn)
}
if tr.IssuedAt.IsZero() {
// issued_at is optional in the token response.
tr.IssuedAt = th.clock.Now().UTC()
}
return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}
func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"]
if !ok {
return "", time.Time{}, errors.New("no realm specified for token auth challenge")
}
// TODO(dmcgowan): Handle empty scheme and relative realm
realmURL, err := url.Parse(realm)
if err != nil {
return "", time.Time{}, fmt.Errorf("invalid token auth challenge realm: %s", err)
}
service := params["service"]
var refreshToken string
if th.creds != nil {
refreshToken = th.creds.RefreshToken(realmURL, service)
}
if refreshToken != "" || th.forceOAuth {
return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes)
}
return th.fetchTokenWithBasicAuth(realmURL, service, scopes)
}
type basicHandler struct {

View file

@ -82,12 +82,23 @@ func ping(manager ChallengeManager, endpoint, versionHeader string) ([]APIVersio
type testCredentialStore struct {
username string
password string
refreshTokens map[string]string
}
func (tcs *testCredentialStore) Basic(*url.URL) (string, string) {
return tcs.username, tcs.password
}
func (tcs *testCredentialStore) RefreshToken(u *url.URL, service string) string {
return tcs.refreshTokens[service]
}
func (tcs *testCredentialStore) SetRefreshToken(u *url.URL, service string, token string) {
if tcs.refreshTokens != nil {
tcs.refreshTokens[service] = token
}
}
func TestEndpointAuthorizeToken(t *testing.T) {
service := "localhost.localdomain"
repo1 := "some/registry"
@ -162,14 +173,11 @@ func TestEndpointAuthorizeToken(t *testing.T) {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
badCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e2, c2 := testServerWithAuth(m, authenicate, badCheck)
e2, c2 := testServerWithAuth(m, authenicate, validCheck)
defer c2()
challengeManager2 := NewSimpleChallengeManager()
versions, err = ping(challengeManager2, e+"/v2/", "x-multi-api-version")
versions, err = ping(challengeManager2, e2+"/v2/", "x-multi-api-version")
if err != nil {
t.Fatal(err)
}
@ -199,6 +207,161 @@ func TestEndpointAuthorizeToken(t *testing.T) {
}
}
func TestEndpointAuthorizeRefreshToken(t *testing.T) {
service := "localhost.localdomain"
repo1 := "some/registry"
repo2 := "other/registry"
scope1 := fmt.Sprintf("repository:%s:pull,push", repo1)
scope2 := fmt.Sprintf("repository:%s:pull,push", repo2)
refreshToken1 := "0123456790abcdef"
refreshToken2 := "0123456790fedcba"
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken1, url.QueryEscape(scope1), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(fmt.Sprintf(`{"access_token":"statictoken","refresh_token":"%s"}`, refreshToken1)),
},
},
{
// In the future this test may fail and require using basic auth to get a different refresh token
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken1, url.QueryEscape(scope2), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(fmt.Sprintf(`{"access_token":"statictoken","refresh_token":"%s"}`, refreshToken2)),
},
},
{
Request: testutil.Request{
Method: "POST",
Route: "/token",
Body: []byte(fmt.Sprintf("client_id=registry-client&grant_type=refresh_token&refresh_token=%s&scope=%s&service=%s", refreshToken2, url.QueryEscape(scope2), service)),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"access_token":"badtoken","refresh_token":"%s"}`),
},
},
})
te, tc := testServer(tokenMap)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/v2/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service)
validCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate, validCheck)
defer c()
challengeManager1 := NewSimpleChallengeManager()
versions, err := ping(challengeManager1, e+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
creds := &testCredentialStore{
refreshTokens: map[string]string{
service: refreshToken1,
},
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager1, NewTokenHandler(nil, creds, repo1, "pull", "push")))
client := &http.Client{Transport: transport1}
req, _ := http.NewRequest("GET", e+"/v2/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
// Try with refresh token setting
e2, c2 := testServerWithAuth(m, authenicate, validCheck)
defer c2()
challengeManager2 := NewSimpleChallengeManager()
versions, err = ping(challengeManager2, e2+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if len(versions) != 1 {
t.Fatalf("Unexpected version count: %d, expected 1", len(versions))
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
transport2 := transport.NewTransport(nil, NewAuthorizer(challengeManager2, NewTokenHandler(nil, creds, repo2, "pull", "push")))
client2 := &http.Client{Transport: transport2}
req, _ = http.NewRequest("GET", e2+"/v2/hello", nil)
resp, err = client2.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
if creds.refreshTokens[service] != refreshToken2 {
t.Fatalf("Refresh token not set after change")
}
// Try with bad token
e3, c3 := testServerWithAuth(m, authenicate, validCheck)
defer c3()
challengeManager3 := NewSimpleChallengeManager()
versions, err = ping(challengeManager3, e3+"/v2/", "x-api-version")
if err != nil {
t.Fatal(err)
}
if check := (APIVersion{Type: "registry", Version: "2.0"}); versions[0] != check {
t.Fatalf("Unexpected api version: %#v, expected %#v", versions[0], check)
}
transport3 := transport.NewTransport(nil, NewAuthorizer(challengeManager3, NewTokenHandler(nil, creds, repo2, "pull", "push")))
client3 := &http.Client{Transport: transport3}
req, _ = http.NewRequest("GET", e3+"/v2/hello", nil)
resp, err = client3.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
}
func basicAuth(username, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
@ -379,7 +542,19 @@ func TestEndpointAuthorizeTokenBasicWithExpiresIn(t *testing.T) {
t.Fatal(err)
}
clock := &fakeClock{current: time.Now()}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds)))
options := TokenHandlerOptions{
Transport: nil,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: repo,
Actions: []string{"pull", "push"},
},
},
}
tHandler := NewTokenHandlerWithOptions(options)
tHandler.(*tokenHandler).clock = clock
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, tHandler, NewBasicHandler(creds)))
client := &http.Client{Transport: transport1}
// First call should result in a token exchange
@ -517,7 +692,20 @@ func TestEndpointAuthorizeTokenBasicWithExpiresInAndIssuedAt(t *testing.T) {
if err != nil {
t.Fatal(err)
}
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, newTokenHandler(nil, creds, clock, repo, "pull", "push"), NewBasicHandler(creds)))
options := TokenHandlerOptions{
Transport: nil,
Credentials: creds,
Scopes: []Scope{
RepositoryScope{
Repository: repo,
Actions: []string{"pull", "push"},
},
},
}
tHandler := NewTokenHandlerWithOptions(options)
tHandler.(*tokenHandler).clock = clock
transport1 := transport.NewTransport(nil, NewAuthorizer(challengeManager, tHandler, NewBasicHandler(creds)))
client := &http.Client{Transport: transport1}
// First call should result in a token exchange

View file

@ -25,6 +25,13 @@ func (c credentials) Basic(u *url.URL) (string, string) {
return up.username, up.password
}
func (c credentials) RefreshToken(u *url.URL, service string) string {
return ""
}
func (c credentials) SetRefreshToken(u *url.URL, service, token string) {
}
// configureAuth stores credentials for challenge responses
func configureAuth(username, password string) (auth.CredentialStore, error) {
creds := map[string]userpass{