Merge pull request #3711 from aaronlehmann/request-with-context

Use http.NewRequestWithContext for outgoing HTTP requests
This commit is contained in:
Milos Gajdos 2022-08-16 16:03:28 +01:00 committed by GitHub
commit 6c237953cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 20 deletions

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -258,7 +259,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
}.String()) }.String())
} }
token, err := th.getToken(params, additionalScopes...) token, err := th.getToken(req.Context(), params, additionalScopes...)
if err != nil { if err != nil {
return err return err
} }
@ -268,7 +269,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
return nil return nil
} }
func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) { func (th *tokenHandler) getToken(ctx context.Context, params map[string]string, additionalScopes ...string) (string, error) {
th.tokenLock.Lock() th.tokenLock.Lock()
defer th.tokenLock.Unlock() defer th.tokenLock.Unlock()
scopes := make([]string, 0, len(th.scopes)+len(additionalScopes)) scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
@ -286,7 +287,7 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s
now := th.clock.Now() now := th.clock.Now()
if now.After(th.tokenExpiration) || addedScopes { if now.After(th.tokenExpiration) || addedScopes {
token, expiration, err := th.fetchToken(params, scopes) token, expiration, err := th.fetchToken(ctx, params, scopes)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -320,7 +321,7 @@ type postTokenResponse struct {
Scope string `json:"scope"` Scope string `json:"scope"`
} }
func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) { func (th *tokenHandler) fetchTokenWithOAuth(ctx context.Context, realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
form := url.Values{} form := url.Values{}
form.Set("scope", strings.Join(scopes, " ")) form.Set("scope", strings.Join(scopes, " "))
form.Set("service", service) form.Set("service", service)
@ -348,7 +349,12 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic
return "", time.Time{}, fmt.Errorf("no supported grant type") return "", time.Time{}, fmt.Errorf("no supported grant type")
} }
resp, err := th.client().PostForm(realm.String(), form) req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm.String(), strings.NewReader(form.Encode()))
if err != nil {
return "", time.Time{}, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := th.client().Do(req)
if err != nil { if err != nil {
return "", time.Time{}, err return "", time.Time{}, err
} }
@ -396,9 +402,8 @@ type getTokenResponse struct {
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
} }
func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) { func (th *tokenHandler) fetchTokenWithBasicAuth(ctx context.Context, realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm.String(), nil)
req, err := http.NewRequest("GET", realm.String(), nil)
if err != nil { if err != nil {
return "", time.Time{}, err return "", time.Time{}, err
} }
@ -479,7 +484,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string,
return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
} }
func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) { func (th *tokenHandler) fetchToken(ctx context.Context, params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"] realm, ok := params["realm"]
if !ok { if !ok {
return "", time.Time{}, errors.New("no realm specified for token auth challenge") return "", time.Time{}, errors.New("no realm specified for token auth challenge")
@ -500,10 +505,10 @@ func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (t
} }
if refreshToken != "" || th.forceOAuth { if refreshToken != "" || th.forceOAuth {
return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes) return th.fetchTokenWithOAuth(ctx, realmURL, refreshToken, service, scopes)
} }
return th.fetchTokenWithBasicAuth(realmURL, service, scopes) return th.fetchTokenWithBasicAuth(ctx, realmURL, service, scopes)
} }
type basicHandler struct { type basicHandler struct {

View file

@ -13,6 +13,8 @@ import (
) )
type httpBlobUpload struct { type httpBlobUpload struct {
ctx context.Context
statter distribution.BlobStatter statter distribution.BlobStatter
client *http.Client client *http.Client
@ -36,7 +38,7 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error {
} }
func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) { func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r)) req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, ioutil.NopCloser(r))
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -69,7 +71,7 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
} }
func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) { func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) {
req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p)) req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, bytes.NewReader(p))
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -117,7 +119,7 @@ func (hbu *httpBlobUpload) StartedAt() time.Time {
func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
// TODO(dmcgowan): Check if already finished, if so just fetch // TODO(dmcgowan): Check if already finished, if so just fetch
req, err := http.NewRequest("PUT", hbu.location, nil) req, err := http.NewRequestWithContext(hbu.ctx, "PUT", hbu.location, nil)
if err != nil { if err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }
@ -140,7 +142,7 @@ func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descrip
} }
func (hbu *httpBlobUpload) Cancel(ctx context.Context) error { func (hbu *httpBlobUpload) Cancel(ctx context.Context) error {
req, err := http.NewRequest("DELETE", hbu.location, nil) req, err := http.NewRequestWithContext(hbu.ctx, "DELETE", hbu.location, nil)
if err != nil { if err != nil {
return err return err
} }

View file

@ -2,6 +2,7 @@ package client
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"net/http" "net/http"
"testing" "testing"
@ -126,6 +127,7 @@ func TestUploadReadFrom(t *testing.T) {
defer c() defer c()
blobUpload := &httpBlobUpload{ blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{}, client: &http.Client{},
} }
@ -265,6 +267,7 @@ func TestUploadSize(t *testing.T) {
// Writing with ReadFrom // Writing with ReadFrom
blobUpload := &httpBlobUpload{ blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{}, client: &http.Client{},
location: e + readFromLocationPath, location: e + readFromLocationPath,
} }
@ -284,6 +287,7 @@ func TestUploadSize(t *testing.T) {
// Writing with Write // Writing with Write
blobUpload = &httpBlobUpload{ blobUpload = &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{}, client: &http.Client{},
location: e + writeLocationPath, location: e + writeLocationPath,
} }
@ -409,6 +413,7 @@ func TestUploadWrite(t *testing.T) {
defer c() defer c()
blobUpload := &httpBlobUpload{ blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{}, client: &http.Client{},
} }

View file

@ -118,9 +118,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri
return 0, err return 0, err
} }
for cnt := range ctlg.Repositories { copy(entries, ctlg.Repositories)
entries[cnt] = ctlg.Repositories[cnt]
}
numFilled = len(ctlg.Repositories) numFilled = len(ctlg.Repositories)
link := resp.Header.Get("Link") link := resp.Header.Get("Link")
@ -373,7 +371,7 @@ func (t *tags) Untag(ctx context.Context, tag string) error {
return err return err
} }
req, err := http.NewRequest("DELETE", u, nil) req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil)
if err != nil { if err != nil {
return err return err
} }
@ -792,7 +790,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
return nil, err return nil, err
} }
req, err := http.NewRequest("POST", u, nil) req, err := http.NewRequestWithContext(ctx, "POST", u, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -827,6 +825,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
} }
return &httpBlobUpload{ return &httpBlobUpload{
ctx: ctx,
statter: bs.statter, statter: bs.statter,
client: bs.client, client: bs.client,
uuid: uuid, uuid: uuid,
@ -845,6 +844,7 @@ func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter
} }
return &httpBlobUpload{ return &httpBlobUpload{
ctx: ctx,
statter: bs.statter, statter: bs.statter,
client: bs.client, client: bs.client,
uuid: id, uuid: id,