Use http.NewRequestWithContext for outgoing HTTP requests

This simple change mainly affects the distribution client. By respecting
the context the caller passes in, timeouts and cancellations will work
as expected. Also, transports which rely on the context (such as tracing
transports that retrieve a span from the context) will work properly.

Signed-off-by: Aaron Lehmann <alehmann@netflix.com>
This commit is contained in:
Aaron Lehmann 2022-08-10 10:38:30 -07:00
parent 26163d8256
commit fbdfd1ac35
4 changed files with 32 additions and 20 deletions

View file

@ -1,6 +1,7 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -258,7 +259,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
}.String())
}
token, err := th.getToken(params, additionalScopes...)
token, err := th.getToken(req.Context(), params, additionalScopes...)
if err != nil {
return err
}
@ -268,7 +269,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
return nil
}
func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) {
func (th *tokenHandler) getToken(ctx context.Context, params map[string]string, additionalScopes ...string) (string, error) {
th.tokenLock.Lock()
defer th.tokenLock.Unlock()
scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
@ -286,7 +287,7 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s
now := th.clock.Now()
if now.After(th.tokenExpiration) || addedScopes {
token, expiration, err := th.fetchToken(params, scopes)
token, expiration, err := th.fetchToken(ctx, params, scopes)
if err != nil {
return "", err
}
@ -320,7 +321,7 @@ type postTokenResponse struct {
Scope string `json:"scope"`
}
func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
func (th *tokenHandler) fetchTokenWithOAuth(ctx context.Context, realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
form := url.Values{}
form.Set("scope", strings.Join(scopes, " "))
form.Set("service", service)
@ -348,7 +349,12 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic
return "", time.Time{}, fmt.Errorf("no supported grant type")
}
resp, err := th.client().PostForm(realm.String(), form)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm.String(), strings.NewReader(form.Encode()))
if err != nil {
return "", time.Time{}, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := th.client().Do(req)
if err != nil {
return "", time.Time{}, err
}
@ -396,9 +402,8 @@ type getTokenResponse struct {
RefreshToken string `json:"refresh_token"`
}
func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequest("GET", realm.String(), nil)
func (th *tokenHandler) fetchTokenWithBasicAuth(ctx context.Context, realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm.String(), nil)
if err != nil {
return "", time.Time{}, err
}
@ -479,7 +484,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string,
return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}
func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
func (th *tokenHandler) fetchToken(ctx context.Context, params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"]
if !ok {
return "", time.Time{}, errors.New("no realm specified for token auth challenge")
@ -500,10 +505,10 @@ func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (t
}
if refreshToken != "" || th.forceOAuth {
return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes)
return th.fetchTokenWithOAuth(ctx, realmURL, refreshToken, service, scopes)
}
return th.fetchTokenWithBasicAuth(realmURL, service, scopes)
return th.fetchTokenWithBasicAuth(ctx, realmURL, service, scopes)
}
type basicHandler struct {

View file

@ -13,6 +13,8 @@ import (
)
type httpBlobUpload struct {
ctx context.Context
statter distribution.BlobStatter
client *http.Client
@ -36,7 +38,7 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error {
}
func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r))
req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, ioutil.NopCloser(r))
if err != nil {
return 0, err
}
@ -69,7 +71,7 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
}
func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) {
req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p))
req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, bytes.NewReader(p))
if err != nil {
return 0, err
}
@ -117,7 +119,7 @@ func (hbu *httpBlobUpload) StartedAt() time.Time {
func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
// TODO(dmcgowan): Check if already finished, if so just fetch
req, err := http.NewRequest("PUT", hbu.location, nil)
req, err := http.NewRequestWithContext(hbu.ctx, "PUT", hbu.location, nil)
if err != nil {
return distribution.Descriptor{}, err
}
@ -140,7 +142,7 @@ func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descrip
}
func (hbu *httpBlobUpload) Cancel(ctx context.Context) error {
req, err := http.NewRequest("DELETE", hbu.location, nil)
req, err := http.NewRequestWithContext(hbu.ctx, "DELETE", hbu.location, nil)
if err != nil {
return err
}

View file

@ -2,6 +2,7 @@ package client
import (
"bytes"
"context"
"fmt"
"net/http"
"testing"
@ -126,6 +127,7 @@ func TestUploadReadFrom(t *testing.T) {
defer c()
blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
}
@ -265,6 +267,7 @@ func TestUploadSize(t *testing.T) {
// Writing with ReadFrom
blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
location: e + readFromLocationPath,
}
@ -284,6 +287,7 @@ func TestUploadSize(t *testing.T) {
// Writing with Write
blobUpload = &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
location: e + writeLocationPath,
}
@ -409,6 +413,7 @@ func TestUploadWrite(t *testing.T) {
defer c()
blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
}

View file

@ -118,9 +118,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri
return 0, err
}
for cnt := range ctlg.Repositories {
entries[cnt] = ctlg.Repositories[cnt]
}
copy(entries, ctlg.Repositories)
numFilled = len(ctlg.Repositories)
link := resp.Header.Get("Link")
@ -373,7 +371,7 @@ func (t *tags) Untag(ctx context.Context, tag string) error {
return err
}
req, err := http.NewRequest("DELETE", u, nil)
req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil)
if err != nil {
return err
}
@ -792,7 +790,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
return nil, err
}
req, err := http.NewRequest("POST", u, nil)
req, err := http.NewRequestWithContext(ctx, "POST", u, nil)
if err != nil {
return nil, err
}
@ -827,6 +825,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
}
return &httpBlobUpload{
ctx: ctx,
statter: bs.statter,
client: bs.client,
uuid: uuid,
@ -845,6 +844,7 @@ func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter
}
return &httpBlobUpload{
ctx: ctx,
statter: bs.statter,
client: bs.client,
uuid: id,