Use http.NewRequestWithContext for outgoing HTTP requests

This simple change mainly affects the distribution client. By respecting
the context the caller passes in, timeouts and cancellations will work
as expected. Also, transports which rely on the context (such as tracing
transports that retrieve a span from the context) will work properly.

Signed-off-by: Aaron Lehmann <alehmann@netflix.com>
This commit is contained in:
Aaron Lehmann 2022-08-10 10:38:30 -07:00
parent 26163d8256
commit fbdfd1ac35
4 changed files with 32 additions and 20 deletions

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -258,7 +259,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
}.String()) }.String())
} }
token, err := th.getToken(params, additionalScopes...) token, err := th.getToken(req.Context(), params, additionalScopes...)
if err != nil { if err != nil {
return err return err
} }
@ -268,7 +269,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
return nil return nil
} }
func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) { func (th *tokenHandler) getToken(ctx context.Context, params map[string]string, additionalScopes ...string) (string, error) {
th.tokenLock.Lock() th.tokenLock.Lock()
defer th.tokenLock.Unlock() defer th.tokenLock.Unlock()
scopes := make([]string, 0, len(th.scopes)+len(additionalScopes)) scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
@ -286,7 +287,7 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s
now := th.clock.Now() now := th.clock.Now()
if now.After(th.tokenExpiration) || addedScopes { if now.After(th.tokenExpiration) || addedScopes {
token, expiration, err := th.fetchToken(params, scopes) token, expiration, err := th.fetchToken(ctx, params, scopes)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -320,7 +321,7 @@ type postTokenResponse struct {
Scope string `json:"scope"` Scope string `json:"scope"`
} }
func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) { func (th *tokenHandler) fetchTokenWithOAuth(ctx context.Context, realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
form := url.Values{} form := url.Values{}
form.Set("scope", strings.Join(scopes, " ")) form.Set("scope", strings.Join(scopes, " "))
form.Set("service", service) form.Set("service", service)
@ -348,7 +349,12 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic
return "", time.Time{}, fmt.Errorf("no supported grant type") return "", time.Time{}, fmt.Errorf("no supported grant type")
} }
resp, err := th.client().PostForm(realm.String(), form) req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm.String(), strings.NewReader(form.Encode()))
if err != nil {
return "", time.Time{}, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := th.client().Do(req)
if err != nil { if err != nil {
return "", time.Time{}, err return "", time.Time{}, err
} }
@ -396,9 +402,8 @@ type getTokenResponse struct {
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
} }
func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) { func (th *tokenHandler) fetchTokenWithBasicAuth(ctx context.Context, realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm.String(), nil)
req, err := http.NewRequest("GET", realm.String(), nil)
if err != nil { if err != nil {
return "", time.Time{}, err return "", time.Time{}, err
} }
@ -479,7 +484,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string,
return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
} }
func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) { func (th *tokenHandler) fetchToken(ctx context.Context, params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"] realm, ok := params["realm"]
if !ok { if !ok {
return "", time.Time{}, errors.New("no realm specified for token auth challenge") return "", time.Time{}, errors.New("no realm specified for token auth challenge")
@ -500,10 +505,10 @@ func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (t
} }
if refreshToken != "" || th.forceOAuth { if refreshToken != "" || th.forceOAuth {
return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes) return th.fetchTokenWithOAuth(ctx, realmURL, refreshToken, service, scopes)
} }
return th.fetchTokenWithBasicAuth(realmURL, service, scopes) return th.fetchTokenWithBasicAuth(ctx, realmURL, service, scopes)
} }
type basicHandler struct { type basicHandler struct {

View file

@ -13,6 +13,8 @@ import (
) )
type httpBlobUpload struct { type httpBlobUpload struct {
ctx context.Context
statter distribution.BlobStatter statter distribution.BlobStatter
client *http.Client client *http.Client
@ -36,7 +38,7 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error {
} }
func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) { func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r)) req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, ioutil.NopCloser(r))
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -69,7 +71,7 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
} }
func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) { func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) {
req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p)) req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, bytes.NewReader(p))
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -117,7 +119,7 @@ func (hbu *httpBlobUpload) StartedAt() time.Time {
func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
// TODO(dmcgowan): Check if already finished, if so just fetch // TODO(dmcgowan): Check if already finished, if so just fetch
req, err := http.NewRequest("PUT", hbu.location, nil) req, err := http.NewRequestWithContext(hbu.ctx, "PUT", hbu.location, nil)
if err != nil { if err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
@ -140,7 +142,7 @@ func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descrip
} }
func (hbu *httpBlobUpload) Cancel(ctx context.Context) error { func (hbu *httpBlobUpload) Cancel(ctx context.Context) error {
req, err := http.NewRequest("DELETE", hbu.location, nil) req, err := http.NewRequestWithContext(hbu.ctx, "DELETE", hbu.location, nil)
if err != nil { if err != nil {
return err return err
} }

View file

@ -2,6 +2,7 @@ package client
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net/http" "net/http"
"testing" "testing"
@ -126,6 +127,7 @@ func TestUploadReadFrom(t *testing.T) {
defer c() defer c()
blobUpload := &httpBlobUpload{ blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{}, client: &http.Client{},
} }
@ -265,6 +267,7 @@ func TestUploadSize(t *testing.T) {
// Writing with ReadFrom // Writing with ReadFrom
blobUpload := &httpBlobUpload{ blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{}, client: &http.Client{},
location: e + readFromLocationPath, location: e + readFromLocationPath,
} }
@ -284,6 +287,7 @@ func TestUploadSize(t *testing.T) {
// Writing with Write // Writing with Write
blobUpload = &httpBlobUpload{ blobUpload = &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{}, client: &http.Client{},
location: e + writeLocationPath, location: e + writeLocationPath,
} }
@ -409,6 +413,7 @@ func TestUploadWrite(t *testing.T) {
defer c() defer c()
blobUpload := &httpBlobUpload{ blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{}, client: &http.Client{},
} }

View file

@ -118,9 +118,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri
return 0, err return 0, err
} }
for cnt := range ctlg.Repositories { copy(entries, ctlg.Repositories)
entries[cnt] = ctlg.Repositories[cnt]
}
numFilled = len(ctlg.Repositories) numFilled = len(ctlg.Repositories)
link := resp.Header.Get("Link") link := resp.Header.Get("Link")
@ -373,7 +371,7 @@ func (t *tags) Untag(ctx context.Context, tag string) error {
return err return err
} }
req, err := http.NewRequest("DELETE", u, nil) req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil)
if err != nil { if err != nil {
return err return err
} }
@ -792,7 +790,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
return nil, err return nil, err
} }
req, err := http.NewRequest("POST", u, nil) req, err := http.NewRequestWithContext(ctx, "POST", u, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -827,6 +825,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
} }
return &httpBlobUpload{ return &httpBlobUpload{
ctx: ctx,
statter: bs.statter, statter: bs.statter,
client: bs.client, client: bs.client,
uuid: uuid, uuid: uuid,
@ -845,6 +844,7 @@ func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter
} }
return &httpBlobUpload{ return &httpBlobUpload{
ctx: ctx,
statter: bs.statter, statter: bs.statter,
client: bs.client, client: bs.client,
uuid: id, uuid: id,