From fbdfd1ac3580f1aafdbd17c09ebd02c86ccfc4a1 Mon Sep 17 00:00:00 2001 From: Aaron Lehmann Date: Wed, 10 Aug 2022 10:38:30 -0700 Subject: [PATCH] Use http.NewRequestWithContext for outgoing HTTP requests This simple change mainly affects the distribution client. By respecting the context the caller passes in, timeouts and cancellations will work as expected. Also, transports which rely on the context (such as tracing transports that retrieve a span from the context) will work properly. Signed-off-by: Aaron Lehmann --- registry/client/auth/session.go | 27 ++++++++++++++++----------- registry/client/blob_writer.go | 10 ++++++---- registry/client/blob_writer_test.go | 5 +++++ registry/client/repository.go | 10 +++++----- 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/registry/client/auth/session.go b/registry/client/auth/session.go index 8a1723f6..fe212831 100644 --- a/registry/client/auth/session.go +++ b/registry/client/auth/session.go @@ -1,6 +1,7 @@ package auth import ( + "context" "encoding/json" "errors" "fmt" @@ -258,7 +259,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st }.String()) } - token, err := th.getToken(params, additionalScopes...) + token, err := th.getToken(req.Context(), params, additionalScopes...) if err != nil { return err } @@ -268,7 +269,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st return nil } -func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) { +func (th *tokenHandler) getToken(ctx context.Context, params map[string]string, additionalScopes ...string) (string, error) { th.tokenLock.Lock() defer th.tokenLock.Unlock() scopes := make([]string, 0, len(th.scopes)+len(additionalScopes)) @@ -286,7 +287,7 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s now := th.clock.Now() if now.After(th.tokenExpiration) || addedScopes { - token, expiration, err := th.fetchToken(params, scopes) + token, expiration, err := th.fetchToken(ctx, params, scopes) if err != nil { return "", err } @@ -320,7 +321,7 @@ type postTokenResponse struct { Scope string `json:"scope"` } -func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) { +func (th *tokenHandler) fetchTokenWithOAuth(ctx context.Context, realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) { form := url.Values{} form.Set("scope", strings.Join(scopes, " ")) form.Set("service", service) @@ -348,7 +349,12 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic return "", time.Time{}, fmt.Errorf("no supported grant type") } - resp, err := th.client().PostForm(realm.String(), form) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm.String(), strings.NewReader(form.Encode())) + if err != nil { + return "", time.Time{}, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := th.client().Do(req) if err != nil { return "", time.Time{}, err } @@ -396,9 +402,8 @@ type getTokenResponse struct { RefreshToken string `json:"refresh_token"` } -func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) { - - req, err := http.NewRequest("GET", realm.String(), nil) +func (th *tokenHandler) fetchTokenWithBasicAuth(ctx context.Context, realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm.String(), nil) if err != nil { return "", time.Time{}, err } @@ -479,7 +484,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil } -func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) { +func (th *tokenHandler) fetchToken(ctx context.Context, params map[string]string, scopes []string) (token string, expiration time.Time, err error) { realm, ok := params["realm"] if !ok { return "", time.Time{}, errors.New("no realm specified for token auth challenge") @@ -500,10 +505,10 @@ func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (t } if refreshToken != "" || th.forceOAuth { - return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes) + return th.fetchTokenWithOAuth(ctx, realmURL, refreshToken, service, scopes) } - return th.fetchTokenWithBasicAuth(realmURL, service, scopes) + return th.fetchTokenWithBasicAuth(ctx, realmURL, service, scopes) } type basicHandler struct { diff --git a/registry/client/blob_writer.go b/registry/client/blob_writer.go index 75ff20f9..cfc7922e 100644 --- a/registry/client/blob_writer.go +++ b/registry/client/blob_writer.go @@ -13,6 +13,8 @@ import ( ) type httpBlobUpload struct { + ctx context.Context + statter distribution.BlobStatter client *http.Client @@ -36,7 +38,7 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error { } func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) { - req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r)) + req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, ioutil.NopCloser(r)) if err != nil { return 0, err } @@ -69,7 +71,7 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) { } func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) { - req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p)) + req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, bytes.NewReader(p)) if err != nil { return 0, err } @@ -117,7 +119,7 @@ func (hbu *httpBlobUpload) StartedAt() time.Time { func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { // TODO(dmcgowan): Check if already finished, if so just fetch - req, err := http.NewRequest("PUT", hbu.location, nil) + req, err := http.NewRequestWithContext(hbu.ctx, "PUT", hbu.location, nil) if err != nil { return distribution.Descriptor{}, err } @@ -140,7 +142,7 @@ func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descrip } func (hbu *httpBlobUpload) Cancel(ctx context.Context) error { - req, err := http.NewRequest("DELETE", hbu.location, nil) + req, err := http.NewRequestWithContext(hbu.ctx, "DELETE", hbu.location, nil) if err != nil { return err } diff --git a/registry/client/blob_writer_test.go b/registry/client/blob_writer_test.go index f7a6e7ca..8006864b 100644 --- a/registry/client/blob_writer_test.go +++ b/registry/client/blob_writer_test.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "context" "fmt" "net/http" "testing" @@ -126,6 +127,7 @@ func TestUploadReadFrom(t *testing.T) { defer c() blobUpload := &httpBlobUpload{ + ctx: context.Background(), client: &http.Client{}, } @@ -265,6 +267,7 @@ func TestUploadSize(t *testing.T) { // Writing with ReadFrom blobUpload := &httpBlobUpload{ + ctx: context.Background(), client: &http.Client{}, location: e + readFromLocationPath, } @@ -284,6 +287,7 @@ func TestUploadSize(t *testing.T) { // Writing with Write blobUpload = &httpBlobUpload{ + ctx: context.Background(), client: &http.Client{}, location: e + writeLocationPath, } @@ -409,6 +413,7 @@ func TestUploadWrite(t *testing.T) { defer c() blobUpload := &httpBlobUpload{ + ctx: context.Background(), client: &http.Client{}, } diff --git a/registry/client/repository.go b/registry/client/repository.go index a3379c0a..7bcb03b4 100644 --- a/registry/client/repository.go +++ b/registry/client/repository.go @@ -118,9 +118,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri return 0, err } - for cnt := range ctlg.Repositories { - entries[cnt] = ctlg.Repositories[cnt] - } + copy(entries, ctlg.Repositories) numFilled = len(ctlg.Repositories) link := resp.Header.Get("Link") @@ -373,7 +371,7 @@ func (t *tags) Untag(ctx context.Context, tag string) error { return err } - req, err := http.NewRequest("DELETE", u, nil) + req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil) if err != nil { return err } @@ -792,7 +790,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO return nil, err } - req, err := http.NewRequest("POST", u, nil) + req, err := http.NewRequestWithContext(ctx, "POST", u, nil) if err != nil { return nil, err } @@ -827,6 +825,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO } return &httpBlobUpload{ + ctx: ctx, statter: bs.statter, client: bs.client, uuid: uuid, @@ -845,6 +844,7 @@ func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter } return &httpBlobUpload{ + ctx: ctx, statter: bs.statter, client: bs.client, uuid: id,