diff --git a/registry/client/layer_upload_test.go b/registry/client/layer_upload_test.go index 9e22cb7c..3879c867 100644 --- a/registry/client/layer_upload_test.go +++ b/registry/client/layer_upload_test.go @@ -124,13 +124,8 @@ func TestUploadReadFrom(t *testing.T) { e, c := testServer(m) defer c() - repoConfig := &RepositoryConfig{} - client, err := repoConfig.HTTPClient() - if err != nil { - t.Fatalf("Error creating client: %s", err) - } layerUpload := &httpLayerUpload{ - client: client, + client: &http.Client{}, } // Valid case diff --git a/registry/client/repository.go b/registry/client/repository.go index e7fcfa9f..0bd89b11 100644 --- a/registry/client/repository.go +++ b/registry/client/repository.go @@ -20,7 +20,7 @@ import ( ) // NewRepository creates a new Repository for the given repository name and endpoint -func NewRepository(ctx context.Context, name, endpoint string, repoConfig *RepositoryConfig) (distribution.Repository, error) { +func NewRepository(ctx context.Context, name, endpoint string, transport http.RoundTripper) (distribution.Repository, error) { if err := v2.ValidateRespositoryName(name); err != nil { return nil, err } @@ -30,9 +30,10 @@ func NewRepository(ctx context.Context, name, endpoint string, repoConfig *Repos return nil, err } - client, err := repoConfig.HTTPClient() - if err != nil { - return nil, err + client := &http.Client{ + Transport: transport, + Timeout: 1 * time.Minute, + // TODO(dmcgowan): create cookie jar } return &repository{ diff --git a/registry/client/repository_test.go b/registry/client/repository_test.go index fe8ffeb7..650391c4 100644 --- a/registry/client/repository_test.go +++ b/registry/client/repository_test.go @@ -97,7 +97,7 @@ func TestLayerFetch(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), "test.example.com/repo1", e, &RepositoryConfig{}) + r, err := NewRepository(context.Background(), "test.example.com/repo1", e, nil) if err != nil { t.Fatal(err) } @@ -127,7 +127,7 @@ func TestLayerExists(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), "test.example.com/repo1", e, &RepositoryConfig{}) + r, err := NewRepository(context.Background(), "test.example.com/repo1", e, nil) if err != nil { t.Fatal(err) } @@ -227,7 +227,7 @@ func TestLayerUploadChunked(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) + r, err := NewRepository(context.Background(), repo, e, nil) if err != nil { t.Fatal(err) } @@ -334,7 +334,7 @@ func TestLayerUploadMonolithic(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) + r, err := NewRepository(context.Background(), repo, e, nil) if err != nil { t.Fatal(err) } @@ -475,7 +475,7 @@ func TestManifestFetch(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) + r, err := NewRepository(context.Background(), repo, e, nil) if err != nil { t.Fatal(err) } @@ -508,7 +508,7 @@ func TestManifestFetchByTag(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) + r, err := NewRepository(context.Background(), repo, e, nil) if err != nil { t.Fatal(err) } @@ -553,7 +553,7 @@ func TestManifestDelete(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) + r, err := NewRepository(context.Background(), repo, e, nil) if err != nil { t.Fatal(err) } @@ -591,7 +591,7 @@ func TestManifestPut(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) + r, err := NewRepository(context.Background(), repo, e, nil) if err != nil { t.Fatal(err) } @@ -636,7 +636,7 @@ func TestManifestTags(t *testing.T) { e, c := testServer(m) defer c() - r, err := NewRepository(context.Background(), repo, e, &RepositoryConfig{}) + r, err := NewRepository(context.Background(), repo, e, nil) if err != nil { t.Fatal(err) } diff --git a/registry/client/session.go b/registry/client/session.go index e4e92383..41bb4f31 100644 --- a/registry/client/session.go +++ b/registry/client/session.go @@ -11,12 +11,6 @@ import ( "time" ) -// Authorizer is used to apply Authorization to an HTTP request -type Authorizer interface { - // Authorizer updates an HTTP request with the needed authorization - Authorize(req *http.Request) error -} - // AuthenticationHandler is an interface for authorizing a request from // params from a "WWW-Authenicate" header for a single scheme. type AuthenticationHandler interface { @@ -31,54 +25,11 @@ type CredentialStore interface { Basic(*url.URL) (string, string) } -// RepositoryConfig holds the base configuration needed to communicate -// with a registry including a method of authorization and HTTP headers. -type RepositoryConfig struct { - Header http.Header - AuthSource Authorizer - - BaseTransport http.RoundTripper -} - -// HTTPClient returns a new HTTP client configured for this configuration -func (rc *RepositoryConfig) HTTPClient() (*http.Client, error) { - transport := &Transport{ - ExtraHeader: rc.Header, - AuthSource: rc.AuthSource, - Base: rc.BaseTransport, - } - - client := &http.Client{ - Transport: transport, - } - - return client, nil -} - -// NewTokenAuthorizer returns an authorizer which is capable of getting a token -// from a token server. The expected authorization method will be discovered -// by the authorizer, getting the token server endpoint from the URL being -// requested. Basic authentication may either be done to the token source or -// directly with the requested endpoint depending on the endpoint's -// WWW-Authenticate header. -func NewTokenAuthorizer(creds CredentialStore, transport http.RoundTripper, header http.Header, scope TokenScope) Authorizer { - return &tokenAuthorizer{ - header: header, - challenges: map[string]map[string]authorizationChallenge{}, - handlers: []AuthenticationHandler{ - NewTokenHandler(transport, creds, scope, header), - NewBasicHandler(creds), - }, - transport: transport, - } -} - // NewAuthorizer creates an authorizer which can handle multiple authentication // schemes. The handlers are tried in order, the higher priority authentication // methods should be first. -func NewAuthorizer(transport http.RoundTripper, header http.Header, handlers ...AuthenticationHandler) Authorizer { +func NewAuthorizer(transport http.RoundTripper, handlers ...AuthenticationHandler) RequestModifier { return &tokenAuthorizer{ - header: header, challenges: map[string]map[string]authorizationChallenge{}, handlers: handlers, transport: transport, @@ -86,7 +37,6 @@ func NewAuthorizer(transport http.RoundTripper, header http.Header, handlers ... } type tokenAuthorizer struct { - header http.Header challenges map[string]map[string]authorizationChallenge handlers []AuthenticationHandler transport http.RoundTripper @@ -99,10 +49,7 @@ func (ta *tokenAuthorizer) ping(endpoint string) (map[string]authorizationChalle } client := &http.Client{ - Transport: &Transport{ - ExtraHeader: ta.header, - Base: ta.transport, - }, + Transport: ta.transport, // Ping should fail fast Timeout: 5 * time.Second, } @@ -140,7 +87,7 @@ HeaderLoop: return nil, nil } -func (ta *tokenAuthorizer) Authorize(req *http.Request) error { +func (ta *tokenAuthorizer) ModifyRequest(req *http.Request) error { v2Root := strings.Index(req.URL.Path, "/v2/") if v2Root == -1 { return nil @@ -195,54 +142,52 @@ type TokenScope struct { Actions []string } -// NewTokenHandler creates a new AuthenicationHandler which supports -// fetching tokens from a remote token server. -func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope TokenScope, header http.Header) AuthenticationHandler { - return &tokenHandler{ - header: header, - creds: creds, - scope: scope, - } -} - func (ts TokenScope) String() string { return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ",")) } -func (ts *tokenHandler) client() *http.Client { - return &http.Client{ - Transport: &Transport{ - ExtraHeader: ts.header, - Base: ts.transport, - }, +// NewTokenHandler creates a new AuthenicationHandler which supports +// fetching tokens from a remote token server. +func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope TokenScope) AuthenticationHandler { + return &tokenHandler{ + transport: transport, + creds: creds, + scope: scope, } } -func (ts *tokenHandler) Scheme() string { +func (th *tokenHandler) client() *http.Client { + return &http.Client{ + Transport: th.transport, + Timeout: 15 * time.Second, + } +} + +func (th *tokenHandler) Scheme() string { return "bearer" } -func (ts *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error { - if err := ts.refreshToken(params); err != nil { +func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error { + if err := th.refreshToken(params); err != nil { return err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.tokenCache)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.tokenCache)) return nil } -func (ts *tokenHandler) refreshToken(params map[string]string) error { - ts.tokenLock.Lock() - defer ts.tokenLock.Unlock() +func (th *tokenHandler) refreshToken(params map[string]string) error { + th.tokenLock.Lock() + defer th.tokenLock.Unlock() now := time.Now() - if now.After(ts.tokenExpiration) { - token, err := ts.fetchToken(params) + if now.After(th.tokenExpiration) { + token, err := th.fetchToken(params) if err != nil { return err } - ts.tokenCache = token - ts.tokenExpiration = now.Add(time.Minute) + th.tokenCache = token + th.tokenExpiration = now.Add(time.Minute) } return nil @@ -252,7 +197,7 @@ type tokenResponse struct { Token string `json:"token"` } -func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err error) { +func (th *tokenHandler) fetchToken(params map[string]string) (token string, err error) { //log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username) realm, ok := params["realm"] if !ok { @@ -273,7 +218,7 @@ func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err reqParams := req.URL.Query() service := params["service"] - scope := ts.scope.String() + scope := th.scope.String() if service != "" { reqParams.Add("service", service) @@ -283,8 +228,8 @@ func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err reqParams.Add("scope", scopeField) } - if ts.creds != nil { - username, password := ts.creds.Basic(realmURL) + if th.creds != nil { + username, password := th.creds.Basic(realmURL) if username != "" && password != "" { reqParams.Add("account", username) req.SetBasicAuth(username, password) @@ -293,7 +238,7 @@ func (ts *tokenHandler) fetchToken(params map[string]string) (token string, err req.URL.RawQuery = reqParams.Encode() - resp, err := ts.client().Do(req) + resp, err := th.client().Do(req) if err != nil { return "", err } diff --git a/registry/client/session_test.go b/registry/client/session_test.go index ee306cf6..cf8e546e 100644 --- a/registry/client/session_test.go +++ b/registry/client/session_test.go @@ -116,14 +116,8 @@ func TestEndpointAuthorizeToken(t *testing.T) { e, c := testServerWithAuth(m, authenicate, validCheck) defer c() - repo1Config := &RepositoryConfig{ - AuthSource: NewTokenAuthorizer(nil, nil, nil, tokenScope1), - } - - client, err := repo1Config.HTTPClient() - if err != nil { - t.Fatalf("Error creating http client: %s", err) - } + transport1 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, nil, tokenScope1))) + client := &http.Client{Transport: transport1} req, _ := http.NewRequest("GET", e+"/v2/hello", nil) resp, err := client.Do(req) @@ -141,13 +135,8 @@ func TestEndpointAuthorizeToken(t *testing.T) { e2, c2 := testServerWithAuth(m, authenicate, badCheck) defer c2() - repo2Config := &RepositoryConfig{ - AuthSource: NewTokenAuthorizer(nil, nil, nil, tokenScope2), - } - client2, err := repo2Config.HTTPClient() - if err != nil { - t.Fatalf("Error creating http client: %s", err) - } + transport2 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, nil, tokenScope2))) + client2 := &http.Client{Transport: transport2} req, _ = http.NewRequest("GET", e2+"/v2/hello", nil) resp, err = client2.Do(req) @@ -220,14 +209,9 @@ func TestEndpointAuthorizeTokenBasic(t *testing.T) { username: username, password: password, } - repoConfig := &RepositoryConfig{ - AuthSource: NewTokenAuthorizer(creds, nil, nil, tokenScope), - } - client, err := repoConfig.HTTPClient() - if err != nil { - t.Fatalf("Error creating http client: %s", err) - } + transport1 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, creds, tokenScope), NewBasicHandler(creds))) + client := &http.Client{Transport: transport1} req, _ := http.NewRequest("GET", e+"/v2/hello", nil) resp, err := client.Do(req) @@ -265,14 +249,9 @@ func TestEndpointAuthorizeBasic(t *testing.T) { username: username, password: password, } - repoConfig := &RepositoryConfig{ - AuthSource: NewTokenAuthorizer(creds, nil, nil, TokenScope{}), - } - client, err := repoConfig.HTTPClient() - if err != nil { - t.Fatalf("Error creating http client: %s", err) - } + transport1 := NewTransport(nil, NewAuthorizer(nil, NewBasicHandler(creds))) + client := &http.Client{Transport: transport1} req, _ := http.NewRequest("GET", e+"/v2/hello", nil) resp, err := client.Do(req) diff --git a/registry/client/transport.go b/registry/client/transport.go index e92ba543..0b241619 100644 --- a/registry/client/transport.go +++ b/registry/client/transport.go @@ -6,14 +6,36 @@ import ( "sync" ) -// Transport is an http.RoundTripper that makes registry HTTP requests, -// wrapping a base RoundTripper and adding an Authorization header -// from an Auth source -type Transport struct { - AuthSource Authorizer - ExtraHeader http.Header +type RequestModifier interface { + ModifyRequest(*http.Request) error +} - Base http.RoundTripper +type headerModifier http.Header + +func NewHeaderRequestModifier(header http.Header) RequestModifier { + return headerModifier(header) +} + +func (h headerModifier) ModifyRequest(req *http.Request) error { + for k, s := range http.Header(h) { + req.Header[k] = append(req.Header[k], s...) + } + + return nil +} + +func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper { + return &transport{ + Modifiers: modifiers, + Base: base, + } +} + +// transport is an http.RoundTripper that makes HTTP requests after +// copying and modifying the request +type transport struct { + Modifiers []RequestModifier + Base http.RoundTripper mu sync.Mutex // guards modReq modReq map[*http.Request]*http.Request // original -> modified @@ -22,13 +44,14 @@ type Transport struct { // RoundTrip authorizes and authenticates the request with an // access token. If no token exists or token is expired, // tries to refresh/fetch a new token. -func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - req2 := t.cloneRequest(req) - if t.AuthSource != nil { - if err := t.AuthSource.Authorize(req2); err != nil { +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := cloneRequest(req) + for _, modifier := range t.Modifiers { + if err := modifier.ModifyRequest(req2); err != nil { return nil, err } } + t.setModReq(req, req2) res, err := t.base().RoundTrip(req2) if err != nil { @@ -43,7 +66,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { } // CancelRequest cancels an in-flight request by closing its connection. -func (t *Transport) CancelRequest(req *http.Request) { +func (t *transport) CancelRequest(req *http.Request) { type canceler interface { CancelRequest(*http.Request) } @@ -56,14 +79,14 @@ func (t *Transport) CancelRequest(req *http.Request) { } } -func (t *Transport) base() http.RoundTripper { +func (t *transport) base() http.RoundTripper { if t.Base != nil { return t.Base } return http.DefaultTransport } -func (t *Transport) setModReq(orig, mod *http.Request) { +func (t *transport) setModReq(orig, mod *http.Request) { t.mu.Lock() defer t.mu.Unlock() if t.modReq == nil { @@ -78,7 +101,7 @@ func (t *Transport) setModReq(orig, mod *http.Request) { // cloneRequest returns a clone of the provided *http.Request. // The clone is a shallow copy of the struct and its Header map. -func (t *Transport) cloneRequest(r *http.Request) *http.Request { +func cloneRequest(r *http.Request) *http.Request { // shallow copy of the struct r2 := new(http.Request) *r2 = *r @@ -87,9 +110,7 @@ func (t *Transport) cloneRequest(r *http.Request) *http.Request { for k, s := range r.Header { r2.Header[k] = append([]string(nil), s...) } - for k, s := range t.ExtraHeader { - r2.Header[k] = append(r2.Header[k], s...) - } + return r2 }