diff --git a/docs/registry.go b/docs/registry.go index 3d0a3ed2..7bcf0660 100644 --- a/docs/registry.go +++ b/docs/registry.go @@ -256,12 +256,43 @@ func (r *Registry) GetRemoteImageJSON(imgID, registry string, token []string) ([ return jsonString, imageSize, nil } -func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string) (io.ReadCloser, error) { - req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/layer", nil) +func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string, imgSize int64) (io.ReadCloser, error) { + var ( + retries = 5 + headRes *http.Response + hasResume bool = false + imageURL = fmt.Sprintf("%simages/%s/layer", registry, imgID) + ) + headReq, err := r.reqFactory.NewRequest("HEAD", imageURL, nil) + if err != nil { + return nil, fmt.Errorf("Error while getting from the server: %s\n", err) + } + setTokenAuth(headReq, token) + for i := 1; i <= retries; i++ { + headRes, err = r.client.Do(headReq) + if err != nil && i == retries { + return nil, fmt.Errorf("Eror while making head request: %s\n", err) + } else if err != nil { + time.Sleep(time.Duration(i) * 5 * time.Second) + continue + } + break + } + + if headRes.Header.Get("Accept-Ranges") == "bytes" && imgSize > 0 { + hasResume = true + } + + req, err := r.reqFactory.NewRequest("GET", imageURL, nil) if err != nil { return nil, fmt.Errorf("Error while getting from the server: %s\n", err) } setTokenAuth(req, token) + if hasResume { + utils.Debugf("server supports resume") + return utils.ResumableRequestReader(r.client, req, 5, imgSize), nil + } + utils.Debugf("server doesn't support resume") res, err := r.client.Do(req) if err != nil { return nil, err @@ -725,6 +756,13 @@ type Registry struct { indexEndpoint string } +func AddRequiredHeadersToRedirectedRequests(req *http.Request, via []*http.Request) error { + if via != nil && via[0] != nil { + req.Header = via[0].Header + } + return nil +} + func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string) (r *Registry, err error) { httpDial := func(proto string, addr string) (net.Conn, error) { conn, err := net.Dial(proto, addr) @@ -744,7 +782,8 @@ func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, inde r = &Registry{ authConfig: authConfig, client: &http.Client{ - Transport: httpTransport, + Transport: httpTransport, + CheckRedirect: AddRequiredHeadersToRedirectedRequests, }, indexEndpoint: indexEndpoint, } diff --git a/docs/registry_test.go b/docs/registry_test.go index 0a5be5e5..e207359e 100644 --- a/docs/registry_test.go +++ b/docs/registry_test.go @@ -70,7 +70,7 @@ func TestGetRemoteImageJSON(t *testing.T) { func TestGetRemoteImageLayer(t *testing.T) { r := spawnTestRegistry(t) - data, err := r.GetRemoteImageLayer(IMAGE_ID, makeURL("/v1/"), TOKEN) + data, err := r.GetRemoteImageLayer(IMAGE_ID, makeURL("/v1/"), TOKEN, 0) if err != nil { t.Fatal(err) } @@ -78,7 +78,7 @@ func TestGetRemoteImageLayer(t *testing.T) { t.Fatal("Expected non-nil data result") } - _, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), TOKEN) + _, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), TOKEN, 0) if err == nil { t.Fatal("Expected image not found error") }