Get token on each request

Signed-off-by: Derek McGowan <derek@mcgstyle.net>
This commit is contained in:
Derek McGowan 2014-12-19 16:14:04 -08:00
parent 6f36ce3a01
commit 22c7328529
2 changed files with 62 additions and 32 deletions

View file

@ -38,56 +38,70 @@ type ConfigFile struct {
} }
type RequestAuthorization struct { type RequestAuthorization struct {
Token string authConfig *AuthConfig
Username string registryEndpoint *Endpoint
Password string resource string
scope string
actions []string
} }
func NewRequestAuthorization(authConfig *AuthConfig, registryEndpoint *Endpoint, resource, scope string, actions []string) (*RequestAuthorization, error) { func NewRequestAuthorization(authConfig *AuthConfig, registryEndpoint *Endpoint, resource, scope string, actions []string) *RequestAuthorization {
var auth RequestAuthorization return &RequestAuthorization{
authConfig: authConfig,
registryEndpoint: registryEndpoint,
resource: resource,
scope: scope,
actions: actions,
}
}
func (auth *RequestAuthorization) getToken() (string, error) {
// TODO check if already has token and before expiration
client := &http.Client{ client := &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DisableKeepAlives: true, DisableKeepAlives: true,
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment},
},
CheckRedirect: AddRequiredHeadersToRedirectedRequests, CheckRedirect: AddRequiredHeadersToRedirectedRequests,
} }
factory := HTTPRequestFactory(nil) factory := HTTPRequestFactory(nil)
for _, challenge := range registryEndpoint.AuthChallenges { for _, challenge := range auth.registryEndpoint.AuthChallenges {
log.Debugf("Using %q auth challenge with params %s for %s", challenge.Scheme, challenge.Parameters, authConfig.Username)
switch strings.ToLower(challenge.Scheme) { switch strings.ToLower(challenge.Scheme) {
case "basic": case "basic":
auth.Username = authConfig.Username // no token necessary
auth.Password = authConfig.Password
case "bearer": case "bearer":
log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, auth.authConfig.Username)
params := map[string]string{} params := map[string]string{}
for k, v := range challenge.Parameters { for k, v := range challenge.Parameters {
params[k] = v params[k] = v
} }
params["scope"] = fmt.Sprintf("%s:%s:%s", resource, scope, strings.Join(actions, ",")) params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ","))
token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client, factory) token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client, factory)
if err != nil { if err != nil {
return nil, err return "", err
} }
// TODO cache token and set expiration to one minute from now
auth.Token = token return token, nil
default: default:
log.Infof("Unsupported auth scheme: %q", challenge.Scheme) log.Infof("Unsupported auth scheme: %q", challenge.Scheme)
} }
} }
// TODO no expiration, do not reattempt to get a token
return &auth, nil return "", nil
} }
func (auth *RequestAuthorization) Authorize(req *http.Request) { func (auth *RequestAuthorization) Authorize(req *http.Request) error {
if auth.Token != "" { token, err := auth.getToken()
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", auth.Token)) if err != nil {
} else if auth.Username != "" && auth.Password != "" { return err
req.SetBasicAuth(auth.Username, auth.Password)
} }
if token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
} else if auth.authConfig.Username != "" && auth.authConfig.Password != "" {
req.SetBasicAuth(auth.authConfig.Username, auth.authConfig.Password)
}
return nil
} }
// create a base64 encoded auth string to store in config // create a base64 encoded auth string to store in config

View file

@ -42,7 +42,7 @@ func (r *Session) GetV2Authorization(imageName string, readOnly bool) (auth *Req
r.indexEndpoint = registry r.indexEndpoint = registry
log.Debugf("Getting authorization for %s %s", imageName, scopes) log.Debugf("Getting authorization for %s %s", imageName, scopes)
return NewRequestAuthorization(r.GetAuthConfig(true), registry, "repository", imageName, scopes) return NewRequestAuthorization(r.GetAuthConfig(true), registry, "repository", imageName, scopes), nil
} }
// //
@ -65,7 +65,9 @@ func (r *Session) GetV2ImageManifest(imageName, tagName string, auth *RequestAut
if err != nil { if err != nil {
return nil, err return nil, err
} }
auth.Authorize(req) if err := auth.Authorize(req) {
return nil, err
}
res, _, err := r.doRequest(req) res, _, err := r.doRequest(req)
if err != nil { if err != nil {
return nil, err return nil, err
@ -103,7 +105,9 @@ func (r *Session) PostV2ImageMountBlob(imageName, sumType, sum string, auth *Req
if err != nil { if err != nil {
return false, err return false, err
} }
auth.Authorize(req) if err := auth.Authorize(req) {
return nil, err
}
res, _, err := r.doRequest(req) res, _, err := r.doRequest(req)
if err != nil { if err != nil {
return false, err return false, err
@ -132,7 +136,9 @@ func (r *Session) GetV2ImageBlob(imageName, sumType, sum string, blobWrtr io.Wri
if err != nil { if err != nil {
return err return err
} }
auth.Authorize(req) if err := auth.Authorize(req) {
return nil, err
}
res, _, err := r.doRequest(req) res, _, err := r.doRequest(req)
if err != nil { if err != nil {
return err return err
@ -161,7 +167,9 @@ func (r *Session) GetV2ImageBlobReader(imageName, sumType, sum string, auth *Req
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
auth.Authorize(req) if err := auth.Authorize(req) {
return nil, err
}
res, _, err := r.doRequest(req) res, _, err := r.doRequest(req)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -196,7 +204,9 @@ func (r *Session) PutV2ImageBlob(imageName, sumType, sumStr string, blobRdr io.R
return err return err
} }
auth.Authorize(req) if err := auth.Authorize(req) {
return nil, err
}
res, _, err := r.doRequest(req) res, _, err := r.doRequest(req)
if err != nil { if err != nil {
return err return err
@ -212,7 +222,9 @@ func (r *Session) PutV2ImageBlob(imageName, sumType, sumStr string, blobRdr io.R
queryParams := url.Values{} queryParams := url.Values{}
queryParams.Add("digest", sumType+":"+sumStr) queryParams.Add("digest", sumType+":"+sumStr)
req.URL.RawQuery = queryParams.Encode() req.URL.RawQuery = queryParams.Encode()
auth.Authorize(req) if err := auth.Authorize(req) {
return nil, err
}
res, _, err = r.doRequest(req) res, _, err = r.doRequest(req)
if err != nil { if err != nil {
return err return err
@ -242,7 +254,9 @@ func (r *Session) PutV2ImageManifest(imageName, tagName string, manifestRdr io.R
if err != nil { if err != nil {
return err return err
} }
auth.Authorize(req) if err := auth.Authorize(req) {
return nil, err
}
res, _, err := r.doRequest(req) res, _, err := r.doRequest(req)
if err != nil { if err != nil {
return err return err
@ -274,7 +288,9 @@ func (r *Session) GetV2RemoteTags(imageName string, auth *RequestAuthorization)
if err != nil { if err != nil {
return nil, err return nil, err
} }
auth.Authorize(req) if err := auth.Authorize(req) {
return nil, err
}
res, _, err := r.doRequest(req) res, _, err := r.doRequest(req)
if err != nil { if err != nil {
return nil, err return nil, err