Merge pull request #3711 from aaronlehmann/request-with-context

Use http.NewRequestWithContext for outgoing HTTP requests
This commit is contained in:
Milos Gajdos 2022-08-16 16:03:28 +01:00 committed by GitHub
commit 6c237953cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 20 deletions

View file

@ -1,6 +1,7 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -258,7 +259,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
}.String())
}
token, err := th.getToken(params, additionalScopes...)
token, err := th.getToken(req.Context(), params, additionalScopes...)
if err != nil {
return err
}
@ -268,7 +269,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
return nil
}
func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) {
func (th *tokenHandler) getToken(ctx context.Context, params map[string]string, additionalScopes ...string) (string, error) {
th.tokenLock.Lock()
defer th.tokenLock.Unlock()
scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
@ -286,7 +287,7 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s
now := th.clock.Now()
if now.After(th.tokenExpiration) || addedScopes {
token, expiration, err := th.fetchToken(params, scopes)
token, expiration, err := th.fetchToken(ctx, params, scopes)
if err != nil {
return "", err
}
@ -320,7 +321,7 @@ type postTokenResponse struct {
Scope string `json:"scope"`
}
func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
func (th *tokenHandler) fetchTokenWithOAuth(ctx context.Context, realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
form := url.Values{}
form.Set("scope", strings.Join(scopes, " "))
form.Set("service", service)
@ -348,7 +349,12 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic
return "", time.Time{}, fmt.Errorf("no supported grant type")
}
resp, err := th.client().PostForm(realm.String(), form)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm.String(), strings.NewReader(form.Encode()))
if err != nil {
return "", time.Time{}, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := th.client().Do(req)
if err != nil {
return "", time.Time{}, err
}
@ -396,9 +402,8 @@ type getTokenResponse struct {
RefreshToken string `json:"refresh_token"`
}
func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequest("GET", realm.String(), nil)
func (th *tokenHandler) fetchTokenWithBasicAuth(ctx context.Context, realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm.String(), nil)
if err != nil {
return "", time.Time{}, err
}
@ -479,7 +484,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string,
return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}
func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
func (th *tokenHandler) fetchToken(ctx context.Context, params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"]
if !ok {
return "", time.Time{}, errors.New("no realm specified for token auth challenge")
@ -500,10 +505,10 @@ func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (t
}
if refreshToken != "" || th.forceOAuth {
return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes)
return th.fetchTokenWithOAuth(ctx, realmURL, refreshToken, service, scopes)
}
return th.fetchTokenWithBasicAuth(realmURL, service, scopes)
return th.fetchTokenWithBasicAuth(ctx, realmURL, service, scopes)
}
type basicHandler struct {

View file

@ -13,6 +13,8 @@ import (
)
type httpBlobUpload struct {
ctx context.Context
statter distribution.BlobStatter
client *http.Client
@ -36,7 +38,7 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error {
}
func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r))
req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, ioutil.NopCloser(r))
if err != nil {
return 0, err
}
@ -69,7 +71,7 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
}
func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) {
req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p))
req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, bytes.NewReader(p))
if err != nil {
return 0, err
}
@ -117,7 +119,7 @@ func (hbu *httpBlobUpload) StartedAt() time.Time {
func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
// TODO(dmcgowan): Check if already finished, if so just fetch
req, err := http.NewRequest("PUT", hbu.location, nil)
req, err := http.NewRequestWithContext(hbu.ctx, "PUT", hbu.location, nil)
if err != nil {
return distribution.Descriptor{}, err
}
@ -140,7 +142,7 @@ func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descrip
}
func (hbu *httpBlobUpload) Cancel(ctx context.Context) error {
req, err := http.NewRequest("DELETE", hbu.location, nil)
req, err := http.NewRequestWithContext(hbu.ctx, "DELETE", hbu.location, nil)
if err != nil {
return err
}

View file

@ -2,6 +2,7 @@ package client
import (
"bytes"
"context"
"fmt"
"net/http"
"testing"
@ -126,6 +127,7 @@ func TestUploadReadFrom(t *testing.T) {
defer c()
blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
}
@ -265,6 +267,7 @@ func TestUploadSize(t *testing.T) {
// Writing with ReadFrom
blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
location: e + readFromLocationPath,
}
@ -284,6 +287,7 @@ func TestUploadSize(t *testing.T) {
// Writing with Write
blobUpload = &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
location: e + writeLocationPath,
}
@ -409,6 +413,7 @@ func TestUploadWrite(t *testing.T) {
defer c()
blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
}

View file

@ -118,9 +118,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri
return 0, err
}
for cnt := range ctlg.Repositories {
entries[cnt] = ctlg.Repositories[cnt]
}
copy(entries, ctlg.Repositories)
numFilled = len(ctlg.Repositories)
link := resp.Header.Get("Link")
@ -373,7 +371,7 @@ func (t *tags) Untag(ctx context.Context, tag string) error {
return err
}
req, err := http.NewRequest("DELETE", u, nil)
req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil)
if err != nil {
return err
}
@ -792,7 +790,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
return nil, err
}
req, err := http.NewRequest("POST", u, nil)
req, err := http.NewRequestWithContext(ctx, "POST", u, nil)
if err != nil {
return nil, err
}
@ -827,6 +825,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
}
return &httpBlobUpload{
ctx: ctx,
statter: bs.statter,
client: bs.client,
uuid: uuid,
@ -845,6 +844,7 @@ func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter
}
return &httpBlobUpload{
ctx: ctx,
statter: bs.statter,
client: bs.client,
uuid: id,