diff --git a/registry/client/auth/session.go b/registry/client/auth/session.go index 8a1723f61..fe2128316 100644 --- a/registry/client/auth/session.go +++ b/registry/client/auth/session.go @@ -1,6 +1,7 @@ package auth import ( + "context" "encoding/json" "errors" "fmt" @@ -258,7 +259,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st }.String()) } - token, err := th.getToken(params, additionalScopes...) + token, err := th.getToken(req.Context(), params, additionalScopes...) if err != nil { return err } @@ -268,7 +269,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st return nil } -func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) { +func (th *tokenHandler) getToken(ctx context.Context, params map[string]string, additionalScopes ...string) (string, error) { th.tokenLock.Lock() defer th.tokenLock.Unlock() scopes := make([]string, 0, len(th.scopes)+len(additionalScopes)) @@ -286,7 +287,7 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s now := th.clock.Now() if now.After(th.tokenExpiration) || addedScopes { - token, expiration, err := th.fetchToken(params, scopes) + token, expiration, err := th.fetchToken(ctx, params, scopes) if err != nil { return "", err } @@ -320,7 +321,7 @@ type postTokenResponse struct { Scope string `json:"scope"` } -func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) { +func (th *tokenHandler) fetchTokenWithOAuth(ctx context.Context, realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) { form := url.Values{} form.Set("scope", strings.Join(scopes, " ")) form.Set("service", service) @@ -348,7 +349,12 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic return "", time.Time{}, fmt.Errorf("no supported grant type") } - resp, err := th.client().PostForm(realm.String(), form) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm.String(), strings.NewReader(form.Encode())) + if err != nil { + return "", time.Time{}, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := th.client().Do(req) if err != nil { return "", time.Time{}, err } @@ -396,9 +402,8 @@ type getTokenResponse struct { RefreshToken string `json:"refresh_token"` } -func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) { - - req, err := http.NewRequest("GET", realm.String(), nil) +func (th *tokenHandler) fetchTokenWithBasicAuth(ctx context.Context, realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm.String(), nil) if err != nil { return "", time.Time{}, err } @@ -479,7 +484,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil } -func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) { +func (th *tokenHandler) fetchToken(ctx context.Context, params map[string]string, scopes []string) (token string, expiration time.Time, err error) { realm, ok := params["realm"] if !ok { return "", time.Time{}, errors.New("no realm specified for token auth challenge") @@ -500,10 +505,10 @@ func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (t } if refreshToken != "" || th.forceOAuth { - return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes) + return th.fetchTokenWithOAuth(ctx, realmURL, refreshToken, service, scopes) } - return th.fetchTokenWithBasicAuth(realmURL, service, scopes) + return th.fetchTokenWithBasicAuth(ctx, realmURL, service, scopes) } type basicHandler struct { diff --git a/registry/client/blob_writer.go b/registry/client/blob_writer.go index 75ff20f9a..cfc7922e4 100644 --- a/registry/client/blob_writer.go +++ b/registry/client/blob_writer.go @@ -13,6 +13,8 @@ import ( ) type httpBlobUpload struct { + ctx context.Context + statter distribution.BlobStatter client *http.Client @@ -36,7 +38,7 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error { } func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) { - req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r)) + req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, ioutil.NopCloser(r)) if err != nil { return 0, err } @@ -69,7 +71,7 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) { } func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) { - req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p)) + req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, bytes.NewReader(p)) if err != nil { return 0, err } @@ -117,7 +119,7 @@ func (hbu *httpBlobUpload) StartedAt() time.Time { func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { // TODO(dmcgowan): Check if already finished, if so just fetch - req, err := http.NewRequest("PUT", hbu.location, nil) + req, err := http.NewRequestWithContext(hbu.ctx, "PUT", hbu.location, nil) if err != nil { return distribution.Descriptor{}, err } @@ -140,7 +142,7 @@ func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descrip } func (hbu *httpBlobUpload) Cancel(ctx context.Context) error { - req, err := http.NewRequest("DELETE", hbu.location, nil) + req, err := http.NewRequestWithContext(hbu.ctx, "DELETE", hbu.location, nil) if err != nil { return err } diff --git a/registry/client/blob_writer_test.go b/registry/client/blob_writer_test.go index f7a6e7ca6..8006864be 100644 --- a/registry/client/blob_writer_test.go +++ b/registry/client/blob_writer_test.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "context" "fmt" "net/http" "testing" @@ -126,6 +127,7 @@ func TestUploadReadFrom(t *testing.T) { defer c() blobUpload := &httpBlobUpload{ + ctx: context.Background(), client: &http.Client{}, } @@ -265,6 +267,7 @@ func TestUploadSize(t *testing.T) { // Writing with ReadFrom blobUpload := &httpBlobUpload{ + ctx: context.Background(), client: &http.Client{}, location: e + readFromLocationPath, } @@ -284,6 +287,7 @@ func TestUploadSize(t *testing.T) { // Writing with Write blobUpload = &httpBlobUpload{ + ctx: context.Background(), client: &http.Client{}, location: e + writeLocationPath, } @@ -409,6 +413,7 @@ func TestUploadWrite(t *testing.T) { defer c() blobUpload := &httpBlobUpload{ + ctx: context.Background(), client: &http.Client{}, } diff --git a/registry/client/repository.go b/registry/client/repository.go index a3379c0a0..7bcb03b49 100644 --- a/registry/client/repository.go +++ b/registry/client/repository.go @@ -118,9 +118,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri return 0, err } - for cnt := range ctlg.Repositories { - entries[cnt] = ctlg.Repositories[cnt] - } + copy(entries, ctlg.Repositories) numFilled = len(ctlg.Repositories) link := resp.Header.Get("Link") @@ -373,7 +371,7 @@ func (t *tags) Untag(ctx context.Context, tag string) error { return err } - req, err := http.NewRequest("DELETE", u, nil) + req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil) if err != nil { return err } @@ -792,7 +790,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO return nil, err } - req, err := http.NewRequest("POST", u, nil) + req, err := http.NewRequestWithContext(ctx, "POST", u, nil) if err != nil { return nil, err } @@ -827,6 +825,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO } return &httpBlobUpload{ + ctx: ctx, statter: bs.statter, client: bs.client, uuid: uuid, @@ -845,6 +844,7 @@ func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter } return &httpBlobUpload{ + ctx: ctx, statter: bs.statter, client: bs.client, uuid: id,