diff --git a/registry/client/blob_writer.go b/registry/client/blob_writer.go new file mode 100644 index 00000000..9ebd4183 --- /dev/null +++ b/registry/client/blob_writer.go @@ -0,0 +1,174 @@ +package client + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "time" + + "github.com/docker/distribution" + "github.com/docker/distribution/context" +) + +type httpBlobUpload struct { + statter distribution.BlobStatter + client *http.Client + + uuid string + startedAt time.Time + + location string // always the last value of the location header. + offset int64 + closed bool +} + +func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error { + if resp.StatusCode == http.StatusNotFound { + return distribution.ErrBlobUploadUnknown + } + return handleErrorResponse(resp) +} + +func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) { + req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r)) + if err != nil { + return 0, err + } + defer req.Body.Close() + + resp, err := hbu.client.Do(req) + if err != nil { + return 0, err + } + + if resp.StatusCode != http.StatusAccepted { + return 0, hbu.handleErrorResponse(resp) + } + + hbu.uuid = resp.Header.Get("Docker-Upload-UUID") + hbu.location, err = sanitizeLocation(resp.Header.Get("Location"), hbu.location) + if err != nil { + return 0, err + } + rng := resp.Header.Get("Range") + var start, end int64 + if n, err := fmt.Sscanf(rng, "%d-%d", &start, &end); err != nil { + return 0, err + } else if n != 2 || end < start { + return 0, fmt.Errorf("bad range format: %s", rng) + } + + return (end - start + 1), nil + +} + +func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) { + req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p)) + if err != nil { + return 0, err + } + req.Header.Set("Content-Range", fmt.Sprintf("%d-%d", hbu.offset, hbu.offset+int64(len(p)-1))) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(p))) + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := hbu.client.Do(req) + if err != nil { + return 0, err + } + + if resp.StatusCode != http.StatusAccepted { + return 0, hbu.handleErrorResponse(resp) + } + + hbu.uuid = resp.Header.Get("Docker-Upload-UUID") + hbu.location, err = sanitizeLocation(resp.Header.Get("Location"), hbu.location) + if err != nil { + return 0, err + } + rng := resp.Header.Get("Range") + var start, end int + if n, err := fmt.Sscanf(rng, "%d-%d", &start, &end); err != nil { + return 0, err + } else if n != 2 || end < start { + return 0, fmt.Errorf("bad range format: %s", rng) + } + + return (end - start + 1), nil + +} + +func (hbu *httpBlobUpload) Seek(offset int64, whence int) (int64, error) { + newOffset := hbu.offset + + switch whence { + case os.SEEK_CUR: + newOffset += int64(offset) + case os.SEEK_END: + newOffset += int64(offset) + case os.SEEK_SET: + newOffset = int64(offset) + } + + hbu.offset = newOffset + + return hbu.offset, nil +} + +func (hbu *httpBlobUpload) ID() string { + return hbu.uuid +} + +func (hbu *httpBlobUpload) StartedAt() time.Time { + return hbu.startedAt +} + +func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { + // TODO(dmcgowan): Check if already finished, if so just fetch + req, err := http.NewRequest("PUT", hbu.location, nil) + if err != nil { + return distribution.Descriptor{}, err + } + + values := req.URL.Query() + values.Set("digest", desc.Digest.String()) + req.URL.RawQuery = values.Encode() + + resp, err := hbu.client.Do(req) + if err != nil { + return distribution.Descriptor{}, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return distribution.Descriptor{}, hbu.handleErrorResponse(resp) + } + + return hbu.statter.Stat(ctx, desc.Digest) +} + +func (hbu *httpBlobUpload) Cancel(ctx context.Context) error { + req, err := http.NewRequest("DELETE", hbu.location, nil) + if err != nil { + return err + } + resp, err := hbu.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusNoContent, http.StatusNotFound: + return nil + default: + return hbu.handleErrorResponse(resp) + } +} + +func (hbu *httpBlobUpload) Close() error { + hbu.closed = true + return nil +} diff --git a/registry/client/blob_writer_test.go b/registry/client/blob_writer_test.go new file mode 100644 index 00000000..674d6e01 --- /dev/null +++ b/registry/client/blob_writer_test.go @@ -0,0 +1,207 @@ +package client + +import ( + "bytes" + "fmt" + "net/http" + "testing" + + "github.com/docker/distribution" + "github.com/docker/distribution/registry/api/v2" + "github.com/docker/distribution/testutil" +) + +// Test implements distribution.BlobWriter +var _ distribution.BlobWriter = &httpBlobUpload{} + +func TestUploadReadFrom(t *testing.T) { + _, b := newRandomBlob(64) + repo := "test/upload/readfrom" + locationPath := fmt.Sprintf("/v2/%s/uploads/testid", repo) + + m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/", + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Headers: http.Header(map[string][]string{ + "Docker-Distribution-API-Version": {"registry/2.0"}, + }), + }, + }, + // Test Valid case + { + Request: testutil.Request{ + Method: "PATCH", + Route: locationPath, + Body: b, + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + Headers: http.Header(map[string][]string{ + "Docker-Upload-UUID": {"46603072-7a1b-4b41-98f9-fd8a7da89f9b"}, + "Location": {locationPath}, + "Range": {"0-63"}, + }), + }, + }, + // Test invalid range + { + Request: testutil.Request{ + Method: "PATCH", + Route: locationPath, + Body: b, + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + Headers: http.Header(map[string][]string{ + "Docker-Upload-UUID": {"46603072-7a1b-4b41-98f9-fd8a7da89f9b"}, + "Location": {locationPath}, + "Range": {""}, + }), + }, + }, + // Test 404 + { + Request: testutil.Request{ + Method: "PATCH", + Route: locationPath, + Body: b, + }, + Response: testutil.Response{ + StatusCode: http.StatusNotFound, + }, + }, + // Test 400 valid json + { + Request: testutil.Request{ + Method: "PATCH", + Route: locationPath, + Body: b, + }, + Response: testutil.Response{ + StatusCode: http.StatusBadRequest, + Body: []byte(` + { + "errors": [ + { + "code": "BLOB_UPLOAD_INVALID", + "message": "invalid upload identifier", + "detail": "more detail" + } + ] + }`), + }, + }, + // Test 400 invalid json + { + Request: testutil.Request{ + Method: "PATCH", + Route: locationPath, + Body: b, + }, + Response: testutil.Response{ + StatusCode: http.StatusBadRequest, + Body: []byte("something bad happened"), + }, + }, + // Test 500 + { + Request: testutil.Request{ + Method: "PATCH", + Route: locationPath, + Body: b, + }, + Response: testutil.Response{ + StatusCode: http.StatusInternalServerError, + }, + }, + }) + + e, c := testServer(m) + defer c() + + blobUpload := &httpBlobUpload{ + client: &http.Client{}, + } + + // Valid case + blobUpload.location = e + locationPath + n, err := blobUpload.ReadFrom(bytes.NewReader(b)) + if err != nil { + t.Fatalf("Error calling ReadFrom: %s", err) + } + if n != 64 { + t.Fatalf("Wrong length returned from ReadFrom: %d, expected 64", n) + } + + // Bad range + blobUpload.location = e + locationPath + _, err = blobUpload.ReadFrom(bytes.NewReader(b)) + if err == nil { + t.Fatalf("Expected error when bad range received") + } + + // 404 + blobUpload.location = e + locationPath + _, err = blobUpload.ReadFrom(bytes.NewReader(b)) + if err == nil { + t.Fatalf("Expected error when not found") + } + if err != distribution.ErrBlobUploadUnknown { + t.Fatalf("Wrong error thrown: %s, expected %s", err, distribution.ErrBlobUploadUnknown) + } + + // 400 valid json + blobUpload.location = e + locationPath + _, err = blobUpload.ReadFrom(bytes.NewReader(b)) + if err == nil { + t.Fatalf("Expected error when not found") + } + if uploadErr, ok := err.(*v2.Errors); !ok { + t.Fatalf("Wrong error type %T: %s", err, err) + } else if len(uploadErr.Errors) != 1 { + t.Fatalf("Unexpected number of errors: %d, expected 1", len(uploadErr.Errors)) + } else { + v2Err := uploadErr.Errors[0] + if v2Err.Code != v2.ErrorCodeBlobUploadInvalid { + t.Fatalf("Unexpected error code: %s, expected %s", v2Err.Code.String(), v2.ErrorCodeBlobUploadInvalid.String()) + } + if expected := "invalid upload identifier"; v2Err.Message != expected { + t.Fatalf("Unexpected error message: %s, expected %s", v2Err.Message, expected) + } + if expected := "more detail"; v2Err.Detail.(string) != expected { + t.Fatalf("Unexpected error message: %s, expected %s", v2Err.Detail.(string), expected) + } + } + + // 400 invalid json + blobUpload.location = e + locationPath + _, err = blobUpload.ReadFrom(bytes.NewReader(b)) + if err == nil { + t.Fatalf("Expected error when not found") + } + if uploadErr, ok := err.(*UnexpectedHTTPResponseError); !ok { + t.Fatalf("Wrong error type %T: %s", err, err) + } else { + respStr := string(uploadErr.Response) + if expected := "something bad happened"; respStr != expected { + t.Fatalf("Unexpected response string: %s, expected: %s", respStr, expected) + } + } + + // 500 + blobUpload.location = e + locationPath + _, err = blobUpload.ReadFrom(bytes.NewReader(b)) + if err == nil { + t.Fatalf("Expected error when not found") + } + if uploadErr, ok := err.(*UnexpectedHTTPStatusError); !ok { + t.Fatalf("Wrong error type %T: %s", err, err) + } else if expected := "500 " + http.StatusText(http.StatusInternalServerError); uploadErr.Status != expected { + t.Fatalf("Unexpected response status: %s, expected %s", uploadErr.Status, expected) + } +} diff --git a/registry/client/client.go b/registry/client/client.go deleted file mode 100644 index 36be960d..00000000 --- a/registry/client/client.go +++ /dev/null @@ -1,573 +0,0 @@ -package client - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "regexp" - "strconv" - - "github.com/docker/distribution/digest" - "github.com/docker/distribution/manifest" - "github.com/docker/distribution/registry/api/v2" -) - -// Client implements the client interface to the registry http api -type Client interface { - // GetImageManifest returns an image manifest for the image at the given - // name, tag pair. - GetImageManifest(name, tag string) (*manifest.SignedManifest, error) - - // PutImageManifest uploads an image manifest for the image at the given - // name, tag pair. - PutImageManifest(name, tag string, imageManifest *manifest.SignedManifest) error - - // DeleteImage removes the image at the given name, tag pair. - DeleteImage(name, tag string) error - - // ListImageTags returns a list of all image tags with the given repository - // name. - ListImageTags(name string) ([]string, error) - - // BlobLength returns the length of the blob stored at the given name, - // digest pair. - // Returns a length value of -1 on error or if the blob does not exist. - BlobLength(name string, dgst digest.Digest) (int, error) - - // GetBlob returns the blob stored at the given name, digest pair in the - // form of an io.ReadCloser with the length of this blob. - // A nonzero byteOffset can be provided to receive a partial blob beginning - // at the given offset. - GetBlob(name string, dgst digest.Digest, byteOffset int) (io.ReadCloser, int, error) - - // InitiateBlobUpload starts a blob upload in the given repository namespace - // and returns a unique location url to use for other blob upload methods. - InitiateBlobUpload(name string) (string, error) - - // GetBlobUploadStatus returns the byte offset and length of the blob at the - // given upload location. - GetBlobUploadStatus(location string) (int, int, error) - - // UploadBlob uploads a full blob to the registry. - UploadBlob(location string, blob io.ReadCloser, length int, dgst digest.Digest) error - - // UploadBlobChunk uploads a blob chunk with a given length and startByte to - // the registry. - // FinishChunkedBlobUpload must be called to finalize this upload. - UploadBlobChunk(location string, blobChunk io.ReadCloser, length, startByte int) error - - // FinishChunkedBlobUpload completes a chunked blob upload at a given - // location. - FinishChunkedBlobUpload(location string, length int, dgst digest.Digest) error - - // CancelBlobUpload deletes all content at the unfinished blob upload - // location and invalidates any future calls to this blob upload. - CancelBlobUpload(location string) error -} - -var ( - patternRangeHeader = regexp.MustCompile("bytes=0-(\\d+)/(\\d+)") -) - -// New returns a new Client which operates against a registry with the -// given base endpoint -// This endpoint should not include /v2/ or any part of the url after this. -func New(endpoint string) (Client, error) { - ub, err := v2.NewURLBuilderFromString(endpoint) - if err != nil { - return nil, err - } - - return &clientImpl{ - endpoint: endpoint, - ub: ub, - }, nil -} - -// clientImpl is the default implementation of the Client interface -type clientImpl struct { - endpoint string - ub *v2.URLBuilder -} - -// TODO(bbland): use consistent route generation between server and client - -func (r *clientImpl) GetImageManifest(name, tag string) (*manifest.SignedManifest, error) { - manifestURL, err := r.ub.BuildManifestURL(name, tag) - if err != nil { - return nil, err - } - - response, err := http.Get(manifestURL) - if err != nil { - return nil, err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusOK: - break - case response.StatusCode == http.StatusNotFound: - return nil, &ImageManifestNotFoundError{Name: name, Tag: tag} - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return nil, err - } - return nil, &errs - default: - return nil, &UnexpectedHTTPStatusError{Status: response.Status} - } - - decoder := json.NewDecoder(response.Body) - - manifest := new(manifest.SignedManifest) - err = decoder.Decode(manifest) - if err != nil { - return nil, err - } - return manifest, nil -} - -func (r *clientImpl) PutImageManifest(name, tag string, manifest *manifest.SignedManifest) error { - manifestURL, err := r.ub.BuildManifestURL(name, tag) - if err != nil { - return err - } - - putRequest, err := http.NewRequest("PUT", manifestURL, bytes.NewReader(manifest.Raw)) - if err != nil { - return err - } - - response, err := http.DefaultClient.Do(putRequest) - if err != nil { - return err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusOK || response.StatusCode == http.StatusAccepted: - return nil - case response.StatusCode >= 400 && response.StatusCode < 500: - var errors v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errors) - if err != nil { - return err - } - - return &errors - default: - return &UnexpectedHTTPStatusError{Status: response.Status} - } -} - -func (r *clientImpl) DeleteImage(name, tag string) error { - manifestURL, err := r.ub.BuildManifestURL(name, tag) - if err != nil { - return err - } - - deleteRequest, err := http.NewRequest("DELETE", manifestURL, nil) - if err != nil { - return err - } - - response, err := http.DefaultClient.Do(deleteRequest) - if err != nil { - return err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusNoContent: - break - case response.StatusCode == http.StatusNotFound: - return &ImageManifestNotFoundError{Name: name, Tag: tag} - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return err - } - return &errs - default: - return &UnexpectedHTTPStatusError{Status: response.Status} - } - - return nil -} - -func (r *clientImpl) ListImageTags(name string) ([]string, error) { - tagsURL, err := r.ub.BuildTagsURL(name) - if err != nil { - return nil, err - } - - response, err := http.Get(tagsURL) - if err != nil { - return nil, err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusOK: - break - case response.StatusCode == http.StatusNotFound: - return nil, &RepositoryNotFoundError{Name: name} - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return nil, err - } - return nil, &errs - default: - return nil, &UnexpectedHTTPStatusError{Status: response.Status} - } - - tags := struct { - Tags []string `json:"tags"` - }{} - - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&tags) - if err != nil { - return nil, err - } - - return tags.Tags, nil -} - -func (r *clientImpl) BlobLength(name string, dgst digest.Digest) (int, error) { - blobURL, err := r.ub.BuildBlobURL(name, dgst) - if err != nil { - return -1, err - } - - response, err := http.Head(blobURL) - if err != nil { - return -1, err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusOK: - lengthHeader := response.Header.Get("Content-Length") - length, err := strconv.ParseInt(lengthHeader, 10, 64) - if err != nil { - return -1, err - } - return int(length), nil - case response.StatusCode == http.StatusNotFound: - return -1, nil - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return -1, err - } - return -1, &errs - default: - return -1, &UnexpectedHTTPStatusError{Status: response.Status} - } -} - -func (r *clientImpl) GetBlob(name string, dgst digest.Digest, byteOffset int) (io.ReadCloser, int, error) { - blobURL, err := r.ub.BuildBlobURL(name, dgst) - if err != nil { - return nil, 0, err - } - - getRequest, err := http.NewRequest("GET", blobURL, nil) - if err != nil { - return nil, 0, err - } - - getRequest.Header.Add("Range", fmt.Sprintf("%d-", byteOffset)) - response, err := http.DefaultClient.Do(getRequest) - if err != nil { - return nil, 0, err - } - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusOK: - lengthHeader := response.Header.Get("Content-Length") - length, err := strconv.ParseInt(lengthHeader, 10, 0) - if err != nil { - return nil, 0, err - } - return response.Body, int(length), nil - case response.StatusCode == http.StatusNotFound: - response.Body.Close() - return nil, 0, &BlobNotFoundError{Name: name, Digest: dgst} - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return nil, 0, err - } - return nil, 0, &errs - default: - response.Body.Close() - return nil, 0, &UnexpectedHTTPStatusError{Status: response.Status} - } -} - -func (r *clientImpl) InitiateBlobUpload(name string) (string, error) { - uploadURL, err := r.ub.BuildBlobUploadURL(name) - if err != nil { - return "", err - } - - postRequest, err := http.NewRequest("POST", uploadURL, nil) - if err != nil { - return "", err - } - - response, err := http.DefaultClient.Do(postRequest) - if err != nil { - return "", err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusAccepted: - return response.Header.Get("Location"), nil - // case response.StatusCode == http.StatusNotFound: - // return - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return "", err - } - return "", &errs - default: - return "", &UnexpectedHTTPStatusError{Status: response.Status} - } -} - -func (r *clientImpl) GetBlobUploadStatus(location string) (int, int, error) { - response, err := http.Get(location) - if err != nil { - return 0, 0, err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusNoContent: - return parseRangeHeader(response.Header.Get("Range")) - case response.StatusCode == http.StatusNotFound: - return 0, 0, &BlobUploadNotFoundError{Location: location} - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return 0, 0, err - } - return 0, 0, &errs - default: - return 0, 0, &UnexpectedHTTPStatusError{Status: response.Status} - } -} - -func (r *clientImpl) UploadBlob(location string, blob io.ReadCloser, length int, dgst digest.Digest) error { - defer blob.Close() - - putRequest, err := http.NewRequest("PUT", location, blob) - if err != nil { - return err - } - - values := putRequest.URL.Query() - values.Set("digest", dgst.String()) - putRequest.URL.RawQuery = values.Encode() - - putRequest.Header.Set("Content-Type", "application/octet-stream") - putRequest.Header.Set("Content-Length", fmt.Sprint(length)) - - response, err := http.DefaultClient.Do(putRequest) - if err != nil { - return err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusCreated: - return nil - case response.StatusCode == http.StatusNotFound: - return &BlobUploadNotFoundError{Location: location} - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return err - } - return &errs - default: - return &UnexpectedHTTPStatusError{Status: response.Status} - } -} - -func (r *clientImpl) UploadBlobChunk(location string, blobChunk io.ReadCloser, length, startByte int) error { - defer blobChunk.Close() - - putRequest, err := http.NewRequest("PUT", location, blobChunk) - if err != nil { - return err - } - - endByte := startByte + length - - putRequest.Header.Set("Content-Type", "application/octet-stream") - putRequest.Header.Set("Content-Length", fmt.Sprint(length)) - putRequest.Header.Set("Content-Range", - fmt.Sprintf("%d-%d/%d", startByte, endByte, endByte)) - - response, err := http.DefaultClient.Do(putRequest) - if err != nil { - return err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusAccepted: - return nil - case response.StatusCode == http.StatusRequestedRangeNotSatisfiable: - lastValidRange, blobSize, err := parseRangeHeader(response.Header.Get("Range")) - if err != nil { - return err - } - return &BlobUploadInvalidRangeError{ - Location: location, - LastValidRange: lastValidRange, - BlobSize: blobSize, - } - case response.StatusCode == http.StatusNotFound: - return &BlobUploadNotFoundError{Location: location} - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return err - } - return &errs - default: - return &UnexpectedHTTPStatusError{Status: response.Status} - } -} - -func (r *clientImpl) FinishChunkedBlobUpload(location string, length int, dgst digest.Digest) error { - putRequest, err := http.NewRequest("PUT", location, nil) - if err != nil { - return err - } - - values := putRequest.URL.Query() - values.Set("digest", dgst.String()) - putRequest.URL.RawQuery = values.Encode() - - putRequest.Header.Set("Content-Type", "application/octet-stream") - putRequest.Header.Set("Content-Length", "0") - putRequest.Header.Set("Content-Range", - fmt.Sprintf("%d-%d/%d", length, length, length)) - - response, err := http.DefaultClient.Do(putRequest) - if err != nil { - return err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusCreated: - return nil - case response.StatusCode == http.StatusNotFound: - return &BlobUploadNotFoundError{Location: location} - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return err - } - return &errs - default: - return &UnexpectedHTTPStatusError{Status: response.Status} - } -} - -func (r *clientImpl) CancelBlobUpload(location string) error { - deleteRequest, err := http.NewRequest("DELETE", location, nil) - if err != nil { - return err - } - - response, err := http.DefaultClient.Do(deleteRequest) - if err != nil { - return err - } - defer response.Body.Close() - - // TODO(bbland): handle other status codes, like 5xx errors - switch { - case response.StatusCode == http.StatusNoContent: - return nil - case response.StatusCode == http.StatusNotFound: - return &BlobUploadNotFoundError{Location: location} - case response.StatusCode >= 400 && response.StatusCode < 500: - var errs v2.Errors - decoder := json.NewDecoder(response.Body) - err = decoder.Decode(&errs) - if err != nil { - return err - } - return &errs - default: - return &UnexpectedHTTPStatusError{Status: response.Status} - } -} - -// parseRangeHeader parses out the offset and length from a returned Range -// header -func parseRangeHeader(byteRangeHeader string) (int, int, error) { - submatches := patternRangeHeader.FindStringSubmatch(byteRangeHeader) - if submatches == nil || len(submatches) < 3 { - return 0, 0, fmt.Errorf("Malformed Range header") - } - - offset, err := strconv.Atoi(submatches[1]) - if err != nil { - return 0, 0, err - } - length, err := strconv.Atoi(submatches[2]) - if err != nil { - return 0, 0, err - } - return offset, length, nil -} diff --git a/registry/client/client_test.go b/registry/client/client_test.go deleted file mode 100644 index 2c1d1cc2..00000000 --- a/registry/client/client_test.go +++ /dev/null @@ -1,440 +0,0 @@ -package client - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/httptest" - "sync" - "testing" - - "github.com/docker/distribution/digest" - "github.com/docker/distribution/manifest" - "github.com/docker/distribution/testutil" -) - -type testBlob struct { - digest digest.Digest - contents []byte -} - -func TestRangeHeaderParser(t *testing.T) { - const ( - malformedRangeHeader = "bytes=0-A/C" - emptyRangeHeader = "" - rFirst = 100 - rSecond = 200 - ) - - var ( - wellformedRangeHeader = fmt.Sprintf("bytes=0-%d/%d", rFirst, rSecond) - ) - - if _, _, err := parseRangeHeader(malformedRangeHeader); err == nil { - t.Fatalf("malformedRangeHeader: error expected, got nil") - } - - if _, _, err := parseRangeHeader(emptyRangeHeader); err == nil { - t.Fatalf("emptyRangeHeader: error expected, got nil") - } - - first, second, err := parseRangeHeader(wellformedRangeHeader) - if err != nil { - t.Fatalf("wellformedRangeHeader: unexpected error %v", err) - } - - if first != rFirst || second != rSecond { - t.Fatalf("Range has been parsed unproperly: %d/%d", first, second) - } - -} - -func TestPush(t *testing.T) { - name := "hello/world" - tag := "sometag" - testBlobs := []testBlob{ - { - digest: "tarsum.v2+sha256:12345", - contents: []byte("some contents"), - }, - { - digest: "tarsum.v2+sha256:98765", - contents: []byte("some other contents"), - }, - } - uploadLocations := make([]string, len(testBlobs)) - blobs := make([]manifest.FSLayer, len(testBlobs)) - history := make([]manifest.History, len(testBlobs)) - - for i, blob := range testBlobs { - // TODO(bbland): this is returning the same location for all uploads, - // because we can't know which blob will get which location. - // It's sort of okay because we're using unique digests, but this needs - // to change at some point. - uploadLocations[i] = fmt.Sprintf("/v2/%s/blobs/test-uuid", name) - blobs[i] = manifest.FSLayer{BlobSum: blob.digest} - history[i] = manifest.History{V1Compatibility: blob.digest.String()} - } - - m := &manifest.SignedManifest{ - Manifest: manifest.Manifest{ - Name: name, - Tag: tag, - Architecture: "x86", - FSLayers: blobs, - History: history, - Versioned: manifest.Versioned{ - SchemaVersion: 1, - }, - }, - } - var err error - m.Raw, err = json.Marshal(m) - - blobRequestResponseMappings := make([]testutil.RequestResponseMapping, 2*len(testBlobs)) - for i, blob := range testBlobs { - blobRequestResponseMappings[2*i] = testutil.RequestResponseMapping{ - Request: testutil.Request{ - Method: "POST", - Route: "/v2/" + name + "/blobs/uploads/", - }, - Response: testutil.Response{ - StatusCode: http.StatusAccepted, - Headers: http.Header(map[string][]string{ - "Location": {uploadLocations[i]}, - }), - }, - } - blobRequestResponseMappings[2*i+1] = testutil.RequestResponseMapping{ - Request: testutil.Request{ - Method: "PUT", - Route: uploadLocations[i], - QueryParams: map[string][]string{ - "digest": {blob.digest.String()}, - }, - Body: blob.contents, - }, - Response: testutil.Response{ - StatusCode: http.StatusCreated, - }, - } - } - - handler := testutil.NewHandler(append(blobRequestResponseMappings, testutil.RequestResponseMapping{ - Request: testutil.Request{ - Method: "PUT", - Route: "/v2/" + name + "/manifests/" + tag, - Body: m.Raw, - }, - Response: testutil.Response{ - StatusCode: http.StatusOK, - }, - })) - var server *httptest.Server - - // HACK(stevvooe): Super hack to follow: the request response map approach - // above does not let us correctly format the location header to the - // server url. This handler intercepts and re-writes the location header - // to the server url. - - hack := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w = &headerInterceptingResponseWriter{ResponseWriter: w, serverURL: server.URL} - handler.ServeHTTP(w, r) - }) - - server = httptest.NewServer(hack) - client, err := New(server.URL) - if err != nil { - t.Fatalf("error creating client: %v", err) - } - objectStore := &memoryObjectStore{ - mutex: new(sync.Mutex), - manifestStorage: make(map[string]*manifest.SignedManifest), - layerStorage: make(map[digest.Digest]Layer), - } - - for _, blob := range testBlobs { - l, err := objectStore.Layer(blob.digest) - if err != nil { - t.Fatal(err) - } - - writer, err := l.Writer() - if err != nil { - t.Fatal(err) - } - - writer.SetSize(len(blob.contents)) - writer.Write(blob.contents) - writer.Close() - } - - objectStore.WriteManifest(name, tag, m) - - err = Push(client, objectStore, name, tag) - if err != nil { - t.Fatal(err) - } -} - -func TestPull(t *testing.T) { - name := "hello/world" - tag := "sometag" - testBlobs := []testBlob{ - { - digest: "tarsum.v2+sha256:12345", - contents: []byte("some contents"), - }, - { - digest: "tarsum.v2+sha256:98765", - contents: []byte("some other contents"), - }, - } - blobs := make([]manifest.FSLayer, len(testBlobs)) - history := make([]manifest.History, len(testBlobs)) - - for i, blob := range testBlobs { - blobs[i] = manifest.FSLayer{BlobSum: blob.digest} - history[i] = manifest.History{V1Compatibility: blob.digest.String()} - } - - m := &manifest.SignedManifest{ - Manifest: manifest.Manifest{ - Name: name, - Tag: tag, - Architecture: "x86", - FSLayers: blobs, - History: history, - Versioned: manifest.Versioned{ - SchemaVersion: 1, - }, - }, - } - manifestBytes, err := json.Marshal(m) - - blobRequestResponseMappings := make([]testutil.RequestResponseMapping, len(testBlobs)) - for i, blob := range testBlobs { - blobRequestResponseMappings[i] = testutil.RequestResponseMapping{ - Request: testutil.Request{ - Method: "GET", - Route: "/v2/" + name + "/blobs/" + blob.digest.String(), - }, - Response: testutil.Response{ - StatusCode: http.StatusOK, - Body: blob.contents, - }, - } - } - - handler := testutil.NewHandler(append(blobRequestResponseMappings, testutil.RequestResponseMapping{ - Request: testutil.Request{ - Method: "GET", - Route: "/v2/" + name + "/manifests/" + tag, - }, - Response: testutil.Response{ - StatusCode: http.StatusOK, - Body: manifestBytes, - }, - })) - server := httptest.NewServer(handler) - client, err := New(server.URL) - if err != nil { - t.Fatalf("error creating client: %v", err) - } - objectStore := &memoryObjectStore{ - mutex: new(sync.Mutex), - manifestStorage: make(map[string]*manifest.SignedManifest), - layerStorage: make(map[digest.Digest]Layer), - } - - err = Pull(client, objectStore, name, tag) - if err != nil { - t.Fatal(err) - } - - m, err = objectStore.Manifest(name, tag) - if err != nil { - t.Fatal(err) - } - - mBytes, err := json.Marshal(m) - if err != nil { - t.Fatal(err) - } - - if string(mBytes) != string(manifestBytes) { - t.Fatal("Incorrect manifest") - } - - for _, blob := range testBlobs { - l, err := objectStore.Layer(blob.digest) - if err != nil { - t.Fatal(err) - } - - reader, err := l.Reader() - if err != nil { - t.Fatal(err) - } - defer reader.Close() - - blobBytes, err := ioutil.ReadAll(reader) - if err != nil { - t.Fatal(err) - } - - if string(blobBytes) != string(blob.contents) { - t.Fatal("Incorrect blob") - } - } -} - -func TestPullResume(t *testing.T) { - name := "hello/world" - tag := "sometag" - testBlobs := []testBlob{ - { - digest: "tarsum.v2+sha256:12345", - contents: []byte("some contents"), - }, - { - digest: "tarsum.v2+sha256:98765", - contents: []byte("some other contents"), - }, - } - layers := make([]manifest.FSLayer, len(testBlobs)) - history := make([]manifest.History, len(testBlobs)) - - for i, layer := range testBlobs { - layers[i] = manifest.FSLayer{BlobSum: layer.digest} - history[i] = manifest.History{V1Compatibility: layer.digest.String()} - } - - m := &manifest.Manifest{ - Name: name, - Tag: tag, - Architecture: "x86", - FSLayers: layers, - History: history, - Versioned: manifest.Versioned{ - SchemaVersion: 1, - }, - } - manifestBytes, err := json.Marshal(m) - - layerRequestResponseMappings := make([]testutil.RequestResponseMapping, 2*len(testBlobs)) - for i, blob := range testBlobs { - layerRequestResponseMappings[2*i] = testutil.RequestResponseMapping{ - Request: testutil.Request{ - Method: "GET", - Route: "/v2/" + name + "/blobs/" + blob.digest.String(), - }, - Response: testutil.Response{ - StatusCode: http.StatusOK, - Body: blob.contents[:len(blob.contents)/2], - Headers: http.Header(map[string][]string{ - "Content-Length": {fmt.Sprint(len(blob.contents))}, - }), - }, - } - layerRequestResponseMappings[2*i+1] = testutil.RequestResponseMapping{ - Request: testutil.Request{ - Method: "GET", - Route: "/v2/" + name + "/blobs/" + blob.digest.String(), - }, - Response: testutil.Response{ - StatusCode: http.StatusOK, - Body: blob.contents[len(blob.contents)/2:], - }, - } - } - - for i := 0; i < 3; i++ { - layerRequestResponseMappings = append(layerRequestResponseMappings, testutil.RequestResponseMapping{ - Request: testutil.Request{ - Method: "GET", - Route: "/v2/" + name + "/manifests/" + tag, - }, - Response: testutil.Response{ - StatusCode: http.StatusOK, - Body: manifestBytes, - }, - }) - } - - handler := testutil.NewHandler(layerRequestResponseMappings) - server := httptest.NewServer(handler) - client, err := New(server.URL) - if err != nil { - t.Fatalf("error creating client: %v", err) - } - objectStore := &memoryObjectStore{ - mutex: new(sync.Mutex), - manifestStorage: make(map[string]*manifest.SignedManifest), - layerStorage: make(map[digest.Digest]Layer), - } - - for attempts := 0; attempts < 3; attempts++ { - err = Pull(client, objectStore, name, tag) - if err == nil { - break - } - } - - if err != nil { - t.Fatal(err) - } - - sm, err := objectStore.Manifest(name, tag) - if err != nil { - t.Fatal(err) - } - - mBytes, err := json.Marshal(sm) - if err != nil { - t.Fatal(err) - } - - if string(mBytes) != string(manifestBytes) { - t.Fatal("Incorrect manifest") - } - - for _, blob := range testBlobs { - l, err := objectStore.Layer(blob.digest) - if err != nil { - t.Fatal(err) - } - - reader, err := l.Reader() - if err != nil { - t.Fatal(err) - } - defer reader.Close() - - layerBytes, err := ioutil.ReadAll(reader) - if err != nil { - t.Fatal(err) - } - - if string(layerBytes) != string(blob.contents) { - t.Fatal("Incorrect blob") - } - } -} - -// headerInterceptingResponseWriter is a hacky workaround to re-write the -// location header to have the server url. -type headerInterceptingResponseWriter struct { - http.ResponseWriter - serverURL string -} - -func (hirw *headerInterceptingResponseWriter) WriteHeader(status int) { - location := hirw.Header().Get("Location") - if location != "" { - hirw.Header().Set("Location", hirw.serverURL+location) - } - - hirw.ResponseWriter.WriteHeader(status) -} diff --git a/registry/client/errors.go b/registry/client/errors.go index 3e89e674..2638055d 100644 --- a/registry/client/errors.go +++ b/registry/client/errors.go @@ -1,73 +1,15 @@ package client import ( + "encoding/json" "fmt" + "io" + "io/ioutil" + "net/http" - "github.com/docker/distribution/digest" + "github.com/docker/distribution/registry/api/v2" ) -// RepositoryNotFoundError is returned when making an operation against a -// repository that does not exist in the registry. -type RepositoryNotFoundError struct { - Name string -} - -func (e *RepositoryNotFoundError) Error() string { - return fmt.Sprintf("No repository found with Name: %s", e.Name) -} - -// ImageManifestNotFoundError is returned when making an operation against a -// given image manifest that does not exist in the registry. -type ImageManifestNotFoundError struct { - Name string - Tag string -} - -func (e *ImageManifestNotFoundError) Error() string { - return fmt.Sprintf("No manifest found with Name: %s, Tag: %s", - e.Name, e.Tag) -} - -// BlobNotFoundError is returned when making an operation against a given image -// layer that does not exist in the registry. -type BlobNotFoundError struct { - Name string - Digest digest.Digest -} - -func (e *BlobNotFoundError) Error() string { - return fmt.Sprintf("No blob found with Name: %s, Digest: %s", - e.Name, e.Digest) -} - -// BlobUploadNotFoundError is returned when making a blob upload operation against an -// invalid blob upload location url. -// This may be the result of using a cancelled, completed, or stale upload -// location. -type BlobUploadNotFoundError struct { - Location string -} - -func (e *BlobUploadNotFoundError) Error() string { - return fmt.Sprintf("No blob upload found at Location: %s", e.Location) -} - -// BlobUploadInvalidRangeError is returned when attempting to upload an image -// blob chunk that is out of order. -// This provides the known BlobSize and LastValidRange which can be used to -// resume the upload. -type BlobUploadInvalidRangeError struct { - Location string - LastValidRange int - BlobSize int -} - -func (e *BlobUploadInvalidRangeError) Error() string { - return fmt.Sprintf( - "Invalid range provided for upload at Location: %s. Last Valid Range: %d, Blob Size: %d", - e.Location, e.LastValidRange, e.BlobSize) -} - // UnexpectedHTTPStatusError is returned when an unexpected HTTP status is // returned when making a registry api call. type UnexpectedHTTPStatusError struct { @@ -77,3 +19,48 @@ type UnexpectedHTTPStatusError struct { func (e *UnexpectedHTTPStatusError) Error() string { return fmt.Sprintf("Received unexpected HTTP status: %s", e.Status) } + +// UnexpectedHTTPResponseError is returned when an expected HTTP status code +// is returned, but the content was unexpected and failed to be parsed. +type UnexpectedHTTPResponseError struct { + ParseErr error + Response []byte +} + +func (e *UnexpectedHTTPResponseError) Error() string { + return fmt.Sprintf("Error parsing HTTP response: %s: %q", e.ParseErr.Error(), string(e.Response)) +} + +func parseHTTPErrorResponse(r io.Reader) error { + var errors v2.Errors + body, err := ioutil.ReadAll(r) + if err != nil { + return err + } + + if err := json.Unmarshal(body, &errors); err != nil { + return &UnexpectedHTTPResponseError{ + ParseErr: err, + Response: body, + } + } + return &errors +} + +func handleErrorResponse(resp *http.Response) error { + if resp.StatusCode == 401 { + err := parseHTTPErrorResponse(resp.Body) + if uErr, ok := err.(*UnexpectedHTTPResponseError); ok { + return &v2.Error{ + Code: v2.ErrorCodeUnauthorized, + Message: "401 Unauthorized", + Detail: uErr.Response, + } + } + return err + } + if resp.StatusCode >= 400 && resp.StatusCode < 500 { + return parseHTTPErrorResponse(resp.Body) + } + return &UnexpectedHTTPStatusError{Status: resp.Status} +} diff --git a/registry/client/objectstore.go b/registry/client/objectstore.go deleted file mode 100644 index 5969c9d2..00000000 --- a/registry/client/objectstore.go +++ /dev/null @@ -1,239 +0,0 @@ -package client - -import ( - "bytes" - "fmt" - "io" - "sync" - - "github.com/docker/distribution/digest" - "github.com/docker/distribution/manifest" -) - -var ( - // ErrLayerAlreadyExists is returned when attempting to create a layer with - // a tarsum that is already in use. - ErrLayerAlreadyExists = fmt.Errorf("Layer already exists") - - // ErrLayerLocked is returned when attempting to write to a layer which is - // currently being written to. - ErrLayerLocked = fmt.Errorf("Layer locked") -) - -// ObjectStore is an interface which is designed to approximate the docker -// engine storage. This interface is subject to change to conform to the -// future requirements of the engine. -type ObjectStore interface { - // Manifest retrieves the image manifest stored at the given repository name - // and tag - Manifest(name, tag string) (*manifest.SignedManifest, error) - - // WriteManifest stores an image manifest at the given repository name and - // tag - WriteManifest(name, tag string, manifest *manifest.SignedManifest) error - - // Layer returns a handle to a layer for reading and writing - Layer(dgst digest.Digest) (Layer, error) -} - -// Layer is a generic image layer interface. -// A Layer may not be written to if it is already complete. -type Layer interface { - // Reader returns a LayerReader or an error if the layer has not been - // written to or is currently being written to. - Reader() (LayerReader, error) - - // Writer returns a LayerWriter or an error if the layer has been fully - // written to or is currently being written to. - Writer() (LayerWriter, error) - - // Wait blocks until the Layer can be read from. - Wait() error -} - -// LayerReader is a read-only handle to a Layer, which exposes the CurrentSize -// and full Size in addition to implementing the io.ReadCloser interface. -type LayerReader interface { - io.ReadCloser - - // CurrentSize returns the number of bytes written to the underlying Layer - CurrentSize() int - - // Size returns the full size of the underlying Layer - Size() int -} - -// LayerWriter is a write-only handle to a Layer, which exposes the CurrentSize -// and full Size in addition to implementing the io.WriteCloser interface. -// SetSize must be called on this LayerWriter before it can be written to. -type LayerWriter interface { - io.WriteCloser - - // CurrentSize returns the number of bytes written to the underlying Layer - CurrentSize() int - - // Size returns the full size of the underlying Layer - Size() int - - // SetSize sets the full size of the underlying Layer. - // This must be called before any calls to Write - SetSize(int) error -} - -// memoryObjectStore is an in-memory implementation of the ObjectStore interface -type memoryObjectStore struct { - mutex *sync.Mutex - manifestStorage map[string]*manifest.SignedManifest - layerStorage map[digest.Digest]Layer -} - -func (objStore *memoryObjectStore) Manifest(name, tag string) (*manifest.SignedManifest, error) { - objStore.mutex.Lock() - defer objStore.mutex.Unlock() - - manifest, ok := objStore.manifestStorage[name+":"+tag] - if !ok { - return nil, fmt.Errorf("No manifest found with Name: %q, Tag: %q", name, tag) - } - return manifest, nil -} - -func (objStore *memoryObjectStore) WriteManifest(name, tag string, manifest *manifest.SignedManifest) error { - objStore.mutex.Lock() - defer objStore.mutex.Unlock() - - objStore.manifestStorage[name+":"+tag] = manifest - return nil -} - -func (objStore *memoryObjectStore) Layer(dgst digest.Digest) (Layer, error) { - objStore.mutex.Lock() - defer objStore.mutex.Unlock() - - layer, ok := objStore.layerStorage[dgst] - if !ok { - layer = &memoryLayer{cond: sync.NewCond(new(sync.Mutex))} - objStore.layerStorage[dgst] = layer - } - - return layer, nil -} - -type memoryLayer struct { - cond *sync.Cond - contents []byte - expectedSize int - writing bool -} - -func (ml *memoryLayer) Reader() (LayerReader, error) { - ml.cond.L.Lock() - defer ml.cond.L.Unlock() - - if ml.contents == nil { - return nil, fmt.Errorf("Layer has not been written to yet") - } - if ml.writing { - return nil, ErrLayerLocked - } - - return &memoryLayerReader{ml: ml, reader: bytes.NewReader(ml.contents)}, nil -} - -func (ml *memoryLayer) Writer() (LayerWriter, error) { - ml.cond.L.Lock() - defer ml.cond.L.Unlock() - - if ml.contents != nil { - if ml.writing { - return nil, ErrLayerLocked - } - if ml.expectedSize == len(ml.contents) { - return nil, ErrLayerAlreadyExists - } - } else { - ml.contents = make([]byte, 0) - } - - ml.writing = true - return &memoryLayerWriter{ml: ml, buffer: bytes.NewBuffer(ml.contents)}, nil -} - -func (ml *memoryLayer) Wait() error { - ml.cond.L.Lock() - defer ml.cond.L.Unlock() - - if ml.contents == nil { - return fmt.Errorf("No writer to wait on") - } - - for ml.writing { - ml.cond.Wait() - } - - return nil -} - -type memoryLayerReader struct { - ml *memoryLayer - reader *bytes.Reader -} - -func (mlr *memoryLayerReader) Read(p []byte) (int, error) { - return mlr.reader.Read(p) -} - -func (mlr *memoryLayerReader) Close() error { - return nil -} - -func (mlr *memoryLayerReader) CurrentSize() int { - return len(mlr.ml.contents) -} - -func (mlr *memoryLayerReader) Size() int { - return mlr.ml.expectedSize -} - -type memoryLayerWriter struct { - ml *memoryLayer - buffer *bytes.Buffer -} - -func (mlw *memoryLayerWriter) Write(p []byte) (int, error) { - if mlw.ml.expectedSize == 0 { - return 0, fmt.Errorf("Must set size before writing to layer") - } - wrote, err := mlw.buffer.Write(p) - mlw.ml.contents = mlw.buffer.Bytes() - return wrote, err -} - -func (mlw *memoryLayerWriter) Close() error { - mlw.ml.cond.L.Lock() - defer mlw.ml.cond.L.Unlock() - - return mlw.close() -} - -func (mlw *memoryLayerWriter) close() error { - mlw.ml.writing = false - mlw.ml.cond.Broadcast() - return nil -} - -func (mlw *memoryLayerWriter) CurrentSize() int { - return len(mlw.ml.contents) -} - -func (mlw *memoryLayerWriter) Size() int { - return mlw.ml.expectedSize -} - -func (mlw *memoryLayerWriter) SetSize(size int) error { - if !mlw.ml.writing { - return fmt.Errorf("Layer is closed for writing") - } - mlw.ml.expectedSize = size - return nil -} diff --git a/registry/client/pull.go b/registry/client/pull.go deleted file mode 100644 index 385158db..00000000 --- a/registry/client/pull.go +++ /dev/null @@ -1,151 +0,0 @@ -package client - -import ( - "fmt" - "io" - - log "github.com/Sirupsen/logrus" - - "github.com/docker/distribution/manifest" -) - -// simultaneousLayerPullWindow is the size of the parallel layer pull window. -// A layer may not be pulled until the layer preceeding it by the length of the -// pull window has been successfully pulled. -const simultaneousLayerPullWindow = 4 - -// Pull implements a client pull workflow for the image defined by the given -// name and tag pair, using the given ObjectStore for local manifest and layer -// storage -func Pull(c Client, objectStore ObjectStore, name, tag string) error { - manifest, err := c.GetImageManifest(name, tag) - if err != nil { - return err - } - log.WithField("manifest", manifest).Info("Pulled manifest") - - if len(manifest.FSLayers) != len(manifest.History) { - return fmt.Errorf("Length of history not equal to number of layers") - } - if len(manifest.FSLayers) == 0 { - return fmt.Errorf("Image has no layers") - } - - errChans := make([]chan error, len(manifest.FSLayers)) - for i := range manifest.FSLayers { - errChans[i] = make(chan error) - } - - // To avoid leak of goroutines we must notify - // pullLayer goroutines about a cancelation, - // otherwise they will lock forever. - cancelCh := make(chan struct{}) - - // Iterate over each layer in the manifest, simultaneously pulling no more - // than simultaneousLayerPullWindow layers at a time. If an error is - // received from a layer pull, we abort the push. - for i := 0; i < len(manifest.FSLayers)+simultaneousLayerPullWindow; i++ { - dependentLayer := i - simultaneousLayerPullWindow - if dependentLayer >= 0 { - err := <-errChans[dependentLayer] - if err != nil { - log.WithField("error", err).Warn("Pull aborted") - close(cancelCh) - return err - } - } - - if i < len(manifest.FSLayers) { - go func(i int) { - select { - case errChans[i] <- pullLayer(c, objectStore, name, manifest.FSLayers[i]): - case <-cancelCh: // no chance to recv until cancelCh's closed - } - }(i) - } - } - - err = objectStore.WriteManifest(name, tag, manifest) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "manifest": manifest, - }).Warn("Unable to write image manifest") - return err - } - - return nil -} - -func pullLayer(c Client, objectStore ObjectStore, name string, fsLayer manifest.FSLayer) error { - log.WithField("layer", fsLayer).Info("Pulling layer") - - layer, err := objectStore.Layer(fsLayer.BlobSum) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "layer": fsLayer, - }).Warn("Unable to write local layer") - return err - } - - layerWriter, err := layer.Writer() - if err == ErrLayerAlreadyExists { - log.WithField("layer", fsLayer).Info("Layer already exists") - return nil - } - if err == ErrLayerLocked { - log.WithField("layer", fsLayer).Info("Layer download in progress, waiting") - layer.Wait() - return nil - } - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "layer": fsLayer, - }).Warn("Unable to write local layer") - return err - } - defer layerWriter.Close() - - if layerWriter.CurrentSize() > 0 { - log.WithFields(log.Fields{ - "layer": fsLayer, - "currentSize": layerWriter.CurrentSize(), - "size": layerWriter.Size(), - }).Info("Layer partially downloaded, resuming") - } - - layerReader, length, err := c.GetBlob(name, fsLayer.BlobSum, layerWriter.CurrentSize()) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "layer": fsLayer, - }).Warn("Unable to download layer") - return err - } - defer layerReader.Close() - - layerWriter.SetSize(layerWriter.CurrentSize() + length) - - _, err = io.Copy(layerWriter, layerReader) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "layer": fsLayer, - }).Warn("Unable to download layer") - return err - } - if layerWriter.CurrentSize() != layerWriter.Size() { - log.WithFields(log.Fields{ - "size": layerWriter.Size(), - "currentSize": layerWriter.CurrentSize(), - "layer": fsLayer, - }).Warn("Layer invalid size") - return fmt.Errorf( - "Wrote incorrect number of bytes for layer %v. Expected %d, Wrote %d", - fsLayer, layerWriter.Size(), layerWriter.CurrentSize(), - ) - } - return nil -} diff --git a/registry/client/push.go b/registry/client/push.go deleted file mode 100644 index c26bd174..00000000 --- a/registry/client/push.go +++ /dev/null @@ -1,137 +0,0 @@ -package client - -import ( - "fmt" - - log "github.com/Sirupsen/logrus" - "github.com/docker/distribution/manifest" -) - -// simultaneousLayerPushWindow is the size of the parallel layer push window. -// A layer may not be pushed until the layer preceeding it by the length of the -// push window has been successfully pushed. -const simultaneousLayerPushWindow = 4 - -type pushFunction func(fsLayer manifest.FSLayer) error - -// Push implements a client push workflow for the image defined by the given -// name and tag pair, using the given ObjectStore for local manifest and layer -// storage -func Push(c Client, objectStore ObjectStore, name, tag string) error { - manifest, err := objectStore.Manifest(name, tag) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "name": name, - "tag": tag, - }).Info("No image found") - return err - } - - errChans := make([]chan error, len(manifest.FSLayers)) - for i := range manifest.FSLayers { - errChans[i] = make(chan error) - } - - cancelCh := make(chan struct{}) - - // Iterate over each layer in the manifest, simultaneously pushing no more - // than simultaneousLayerPushWindow layers at a time. If an error is - // received from a layer push, we abort the push. - for i := 0; i < len(manifest.FSLayers)+simultaneousLayerPushWindow; i++ { - dependentLayer := i - simultaneousLayerPushWindow - if dependentLayer >= 0 { - err := <-errChans[dependentLayer] - if err != nil { - log.WithField("error", err).Warn("Push aborted") - close(cancelCh) - return err - } - } - - if i < len(manifest.FSLayers) { - go func(i int) { - select { - case errChans[i] <- pushLayer(c, objectStore, name, manifest.FSLayers[i]): - case <-cancelCh: // recv broadcast notification about cancelation - } - }(i) - } - } - - err = c.PutImageManifest(name, tag, manifest) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "manifest": manifest, - }).Warn("Unable to upload manifest") - return err - } - - return nil -} - -func pushLayer(c Client, objectStore ObjectStore, name string, fsLayer manifest.FSLayer) error { - log.WithField("layer", fsLayer).Info("Pushing layer") - - layer, err := objectStore.Layer(fsLayer.BlobSum) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "layer": fsLayer, - }).Warn("Unable to read local layer") - return err - } - - layerReader, err := layer.Reader() - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "layer": fsLayer, - }).Warn("Unable to read local layer") - return err - } - defer layerReader.Close() - - if layerReader.CurrentSize() != layerReader.Size() { - log.WithFields(log.Fields{ - "layer": fsLayer, - "currentSize": layerReader.CurrentSize(), - "size": layerReader.Size(), - }).Warn("Local layer incomplete") - return fmt.Errorf("Local layer incomplete") - } - - length, err := c.BlobLength(name, fsLayer.BlobSum) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "layer": fsLayer, - }).Warn("Unable to check existence of remote layer") - return err - } - if length >= 0 { - log.WithField("layer", fsLayer).Info("Layer already exists") - return nil - } - - location, err := c.InitiateBlobUpload(name) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "layer": fsLayer, - }).Warn("Unable to upload layer") - return err - } - - err = c.UploadBlob(location, layerReader, int(layerReader.CurrentSize()), fsLayer.BlobSum) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - "layer": fsLayer, - }).Warn("Unable to upload layer") - return err - } - - return nil -} diff --git a/registry/client/repository.go b/registry/client/repository.go new file mode 100644 index 00000000..a1117ac2 --- /dev/null +++ b/registry/client/repository.go @@ -0,0 +1,412 @@ +package client + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strconv" + "time" + + "github.com/docker/distribution" + "github.com/docker/distribution/context" + "github.com/docker/distribution/digest" + "github.com/docker/distribution/manifest" + "github.com/docker/distribution/registry/api/v2" + "github.com/docker/distribution/registry/client/transport" + "github.com/docker/distribution/registry/storage/cache" +) + +// NewRepository creates a new Repository for the given repository name and base URL +func NewRepository(ctx context.Context, name, baseURL string, transport http.RoundTripper) (distribution.Repository, error) { + if err := v2.ValidateRespositoryName(name); err != nil { + return nil, err + } + + ub, err := v2.NewURLBuilderFromString(baseURL) + if err != nil { + return nil, err + } + + client := &http.Client{ + Transport: transport, + Timeout: 1 * time.Minute, + // TODO(dmcgowan): create cookie jar + } + + return &repository{ + client: client, + ub: ub, + name: name, + context: ctx, + }, nil +} + +type repository struct { + client *http.Client + ub *v2.URLBuilder + context context.Context + name string +} + +func (r *repository) Name() string { + return r.name +} + +func (r *repository) Blobs(ctx context.Context) distribution.BlobStore { + statter := &blobStatter{ + name: r.Name(), + ub: r.ub, + client: r.client, + } + return &blobs{ + name: r.Name(), + ub: r.ub, + client: r.client, + statter: cache.NewCachedBlobStatter(cache.NewInMemoryBlobDescriptorCacheProvider(), statter), + } +} + +func (r *repository) Manifests() distribution.ManifestService { + return &manifests{ + name: r.Name(), + ub: r.ub, + client: r.client, + } +} + +func (r *repository) Signatures() distribution.SignatureService { + return &signatures{ + manifests: r.Manifests(), + } +} + +type signatures struct { + manifests distribution.ManifestService +} + +func (s *signatures) Get(dgst digest.Digest) ([][]byte, error) { + m, err := s.manifests.Get(dgst) + if err != nil { + return nil, err + } + return m.Signatures() +} + +func (s *signatures) Put(dgst digest.Digest, signatures ...[]byte) error { + panic("not implemented") +} + +type manifests struct { + name string + ub *v2.URLBuilder + client *http.Client +} + +func (ms *manifests) Tags() ([]string, error) { + u, err := ms.ub.BuildTagsURL(ms.name) + if err != nil { + return nil, err + } + + resp, err := ms.client.Get(u) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + tagsResponse := struct { + Tags []string `json:"tags"` + }{} + if err := json.Unmarshal(b, &tagsResponse); err != nil { + return nil, err + } + + return tagsResponse.Tags, nil + case http.StatusNotFound: + return nil, nil + default: + return nil, handleErrorResponse(resp) + } +} + +func (ms *manifests) Exists(dgst digest.Digest) (bool, error) { + // Call by Tag endpoint since the API uses the same + // URL endpoint for tags and digests. + return ms.ExistsByTag(dgst.String()) +} + +func (ms *manifests) ExistsByTag(tag string) (bool, error) { + u, err := ms.ub.BuildManifestURL(ms.name, tag) + if err != nil { + return false, err + } + + resp, err := ms.client.Head(u) + if err != nil { + return false, err + } + + switch resp.StatusCode { + case http.StatusOK: + return true, nil + case http.StatusNotFound: + return false, nil + default: + return false, handleErrorResponse(resp) + } +} + +func (ms *manifests) Get(dgst digest.Digest) (*manifest.SignedManifest, error) { + // Call by Tag endpoint since the API uses the same + // URL endpoint for tags and digests. + return ms.GetByTag(dgst.String()) +} + +func (ms *manifests) GetByTag(tag string) (*manifest.SignedManifest, error) { + u, err := ms.ub.BuildManifestURL(ms.name, tag) + if err != nil { + return nil, err + } + + resp, err := ms.client.Get(u) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + var sm manifest.SignedManifest + decoder := json.NewDecoder(resp.Body) + + if err := decoder.Decode(&sm); err != nil { + return nil, err + } + + return &sm, nil + default: + return nil, handleErrorResponse(resp) + } +} + +func (ms *manifests) Put(m *manifest.SignedManifest) error { + manifestURL, err := ms.ub.BuildManifestURL(ms.name, m.Tag) + if err != nil { + return err + } + + putRequest, err := http.NewRequest("PUT", manifestURL, bytes.NewReader(m.Raw)) + if err != nil { + return err + } + + resp, err := ms.client.Do(putRequest) + if err != nil { + return err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusAccepted: + // TODO(dmcgowan): make use of digest header + return nil + default: + return handleErrorResponse(resp) + } +} + +func (ms *manifests) Delete(dgst digest.Digest) error { + u, err := ms.ub.BuildManifestURL(ms.name, dgst.String()) + if err != nil { + return err + } + req, err := http.NewRequest("DELETE", u, nil) + if err != nil { + return err + } + + resp, err := ms.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + return nil + default: + return handleErrorResponse(resp) + } +} + +type blobs struct { + name string + ub *v2.URLBuilder + client *http.Client + + statter distribution.BlobStatter +} + +func sanitizeLocation(location, source string) (string, error) { + locationURL, err := url.Parse(location) + if err != nil { + return "", err + } + + if locationURL.Scheme == "" { + sourceURL, err := url.Parse(source) + if err != nil { + return "", err + } + locationURL = &url.URL{ + Scheme: sourceURL.Scheme, + Host: sourceURL.Host, + Path: location, + } + location = locationURL.String() + } + return location, nil +} + +func (bs *blobs) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) { + return bs.statter.Stat(ctx, dgst) + +} + +func (bs *blobs) Get(ctx context.Context, dgst digest.Digest) ([]byte, error) { + desc, err := bs.Stat(ctx, dgst) + if err != nil { + return nil, err + } + reader, err := bs.Open(ctx, desc.Digest) + if err != nil { + return nil, err + } + defer reader.Close() + + return ioutil.ReadAll(reader) +} + +func (bs *blobs) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) { + stat, err := bs.statter.Stat(ctx, dgst) + if err != nil { + return nil, err + } + + blobURL, err := bs.ub.BuildBlobURL(bs.name, stat.Digest) + if err != nil { + return nil, err + } + + return transport.NewHTTPReadSeeker(bs.client, blobURL, stat.Length), nil +} + +func (bs *blobs) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error { + panic("not implemented") +} + +func (bs *blobs) Put(ctx context.Context, mediaType string, p []byte) (distribution.Descriptor, error) { + writer, err := bs.Create(ctx) + if err != nil { + return distribution.Descriptor{}, err + } + dgstr := digest.NewCanonicalDigester() + n, err := io.Copy(writer, io.TeeReader(bytes.NewReader(p), dgstr)) + if err != nil { + return distribution.Descriptor{}, err + } + if n < int64(len(p)) { + return distribution.Descriptor{}, fmt.Errorf("short copy: wrote %d of %d", n, len(p)) + } + + desc := distribution.Descriptor{ + MediaType: mediaType, + Length: int64(len(p)), + Digest: dgstr.Digest(), + } + + return writer.Commit(ctx, desc) +} + +func (bs *blobs) Create(ctx context.Context) (distribution.BlobWriter, error) { + u, err := bs.ub.BuildBlobUploadURL(bs.name) + + resp, err := bs.client.Post(u, "", nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusAccepted: + // TODO(dmcgowan): Check for invalid UUID + uuid := resp.Header.Get("Docker-Upload-UUID") + location, err := sanitizeLocation(resp.Header.Get("Location"), u) + if err != nil { + return nil, err + } + + return &httpBlobUpload{ + statter: bs.statter, + client: bs.client, + uuid: uuid, + startedAt: time.Now(), + location: location, + }, nil + default: + return nil, handleErrorResponse(resp) + } +} + +func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) { + panic("not implemented") +} + +type blobStatter struct { + name string + ub *v2.URLBuilder + client *http.Client +} + +func (bs *blobStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) { + u, err := bs.ub.BuildBlobURL(bs.name, dgst) + if err != nil { + return distribution.Descriptor{}, err + } + + resp, err := bs.client.Head(u) + if err != nil { + return distribution.Descriptor{}, err + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + lengthHeader := resp.Header.Get("Content-Length") + length, err := strconv.ParseInt(lengthHeader, 10, 64) + if err != nil { + return distribution.Descriptor{}, fmt.Errorf("error parsing content-length: %v", err) + } + + return distribution.Descriptor{ + MediaType: resp.Header.Get("Content-Type"), + Length: length, + Digest: dgst, + }, nil + case http.StatusNotFound: + return distribution.Descriptor{}, distribution.ErrBlobUnknown + default: + return distribution.Descriptor{}, handleErrorResponse(resp) + } +} diff --git a/registry/client/repository_test.go b/registry/client/repository_test.go new file mode 100644 index 00000000..9530bd37 --- /dev/null +++ b/registry/client/repository_test.go @@ -0,0 +1,681 @@ +package client + +import ( + "bytes" + "crypto/rand" + "encoding/json" + "fmt" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "code.google.com/p/go-uuid/uuid" + + "github.com/docker/distribution" + "github.com/docker/distribution/context" + "github.com/docker/distribution/digest" + "github.com/docker/distribution/manifest" + "github.com/docker/distribution/registry/api/v2" + "github.com/docker/distribution/testutil" +) + +func testServer(rrm testutil.RequestResponseMap) (string, func()) { + h := testutil.NewHandler(rrm) + s := httptest.NewServer(h) + return s.URL, s.Close +} + +func newRandomBlob(size int) (digest.Digest, []byte) { + b := make([]byte, size) + if n, err := rand.Read(b); err != nil { + panic(err) + } else if n != size { + panic("unable to read enough bytes") + } + + dgst, err := digest.FromBytes(b) + if err != nil { + panic(err) + } + + return dgst, b +} + +func addTestFetch(repo string, dgst digest.Digest, content []byte, m *testutil.RequestResponseMap) { + *m = append(*m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "GET", + Route: "/v2/" + repo + "/blobs/" + dgst.String(), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: content, + Headers: http.Header(map[string][]string{ + "Content-Length": {fmt.Sprint(len(content))}, + "Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)}, + }), + }, + }) + *m = append(*m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "HEAD", + Route: "/v2/" + repo + "/blobs/" + dgst.String(), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Headers: http.Header(map[string][]string{ + "Content-Length": {fmt.Sprint(len(content))}, + "Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)}, + }), + }, + }) +} + +func TestBlobFetch(t *testing.T) { + d1, b1 := newRandomBlob(1024) + var m testutil.RequestResponseMap + addTestFetch("test.example.com/repo1", d1, b1, &m) + + e, c := testServer(m) + defer c() + + ctx := context.Background() + r, err := NewRepository(ctx, "test.example.com/repo1", e, nil) + if err != nil { + t.Fatal(err) + } + l := r.Blobs(ctx) + + b, err := l.Get(ctx, d1) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(b, b1) != 0 { + t.Fatalf("Wrong bytes values fetched: [%d]byte != [%d]byte", len(b), len(b1)) + } + + // TODO(dmcgowan): Test for unknown blob case +} + +func TestBlobExists(t *testing.T) { + d1, b1 := newRandomBlob(1024) + var m testutil.RequestResponseMap + addTestFetch("test.example.com/repo1", d1, b1, &m) + + e, c := testServer(m) + defer c() + + ctx := context.Background() + r, err := NewRepository(ctx, "test.example.com/repo1", e, nil) + if err != nil { + t.Fatal(err) + } + l := r.Blobs(ctx) + + stat, err := l.Stat(ctx, d1) + if err != nil { + t.Fatal(err) + } + + if stat.Digest != d1 { + t.Fatalf("Unexpected digest: %s, expected %s", stat.Digest, d1) + } + + if stat.Length != int64(len(b1)) { + t.Fatalf("Unexpected length: %d, expected %d", stat.Length, len(b1)) + } + + // TODO(dmcgowan): Test error cases and ErrBlobUnknown case +} + +func TestBlobUploadChunked(t *testing.T) { + dgst, b1 := newRandomBlob(1024) + var m testutil.RequestResponseMap + chunks := [][]byte{ + b1[0:256], + b1[256:512], + b1[512:513], + b1[513:1024], + } + repo := "test.example.com/uploadrepo" + uuids := []string{uuid.New()} + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "POST", + Route: "/v2/" + repo + "/blobs/uploads/", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + Headers: http.Header(map[string][]string{ + "Content-Length": {"0"}, + "Location": {"/v2/" + repo + "/blobs/uploads/" + uuids[0]}, + "Docker-Upload-UUID": {uuids[0]}, + "Range": {"0-0"}, + }), + }, + }) + offset := 0 + for i, chunk := range chunks { + uuids = append(uuids, uuid.New()) + newOffset := offset + len(chunk) + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "PATCH", + Route: "/v2/" + repo + "/blobs/uploads/" + uuids[i], + Body: chunk, + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + Headers: http.Header(map[string][]string{ + "Content-Length": {"0"}, + "Location": {"/v2/" + repo + "/blobs/uploads/" + uuids[i+1]}, + "Docker-Upload-UUID": {uuids[i+1]}, + "Range": {fmt.Sprintf("%d-%d", offset, newOffset-1)}, + }), + }, + }) + offset = newOffset + } + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "PUT", + Route: "/v2/" + repo + "/blobs/uploads/" + uuids[len(uuids)-1], + QueryParams: map[string][]string{ + "digest": {dgst.String()}, + }, + }, + Response: testutil.Response{ + StatusCode: http.StatusCreated, + Headers: http.Header(map[string][]string{ + "Content-Length": {"0"}, + "Docker-Content-Digest": {dgst.String()}, + "Content-Range": {fmt.Sprintf("0-%d", offset-1)}, + }), + }, + }) + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "HEAD", + Route: "/v2/" + repo + "/blobs/" + dgst.String(), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Headers: http.Header(map[string][]string{ + "Content-Length": {fmt.Sprint(offset)}, + "Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)}, + }), + }, + }) + + e, c := testServer(m) + defer c() + + ctx := context.Background() + r, err := NewRepository(ctx, repo, e, nil) + if err != nil { + t.Fatal(err) + } + l := r.Blobs(ctx) + + upload, err := l.Create(ctx) + if err != nil { + t.Fatal(err) + } + + if upload.ID() != uuids[0] { + log.Fatalf("Unexpected UUID %s; expected %s", upload.ID(), uuids[0]) + } + + for _, chunk := range chunks { + n, err := upload.Write(chunk) + if err != nil { + t.Fatal(err) + } + if n != len(chunk) { + t.Fatalf("Unexpected length returned from write: %d; expected: %d", n, len(chunk)) + } + } + + blob, err := upload.Commit(ctx, distribution.Descriptor{ + Digest: dgst, + Length: int64(len(b1)), + }) + if err != nil { + t.Fatal(err) + } + + if blob.Length != int64(len(b1)) { + t.Fatalf("Unexpected blob size: %d; expected: %d", blob.Length, len(b1)) + } +} + +func TestBlobUploadMonolithic(t *testing.T) { + dgst, b1 := newRandomBlob(1024) + var m testutil.RequestResponseMap + repo := "test.example.com/uploadrepo" + uploadID := uuid.New() + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "POST", + Route: "/v2/" + repo + "/blobs/uploads/", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + Headers: http.Header(map[string][]string{ + "Content-Length": {"0"}, + "Location": {"/v2/" + repo + "/blobs/uploads/" + uploadID}, + "Docker-Upload-UUID": {uploadID}, + "Range": {"0-0"}, + }), + }, + }) + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "PATCH", + Route: "/v2/" + repo + "/blobs/uploads/" + uploadID, + Body: b1, + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + Headers: http.Header(map[string][]string{ + "Location": {"/v2/" + repo + "/blobs/uploads/" + uploadID}, + "Docker-Upload-UUID": {uploadID}, + "Content-Length": {"0"}, + "Docker-Content-Digest": {dgst.String()}, + "Range": {fmt.Sprintf("0-%d", len(b1)-1)}, + }), + }, + }) + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "PUT", + Route: "/v2/" + repo + "/blobs/uploads/" + uploadID, + QueryParams: map[string][]string{ + "digest": {dgst.String()}, + }, + }, + Response: testutil.Response{ + StatusCode: http.StatusCreated, + Headers: http.Header(map[string][]string{ + "Content-Length": {"0"}, + "Docker-Content-Digest": {dgst.String()}, + "Content-Range": {fmt.Sprintf("0-%d", len(b1)-1)}, + }), + }, + }) + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "HEAD", + Route: "/v2/" + repo + "/blobs/" + dgst.String(), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Headers: http.Header(map[string][]string{ + "Content-Length": {fmt.Sprint(len(b1))}, + "Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)}, + }), + }, + }) + + e, c := testServer(m) + defer c() + + ctx := context.Background() + r, err := NewRepository(ctx, repo, e, nil) + if err != nil { + t.Fatal(err) + } + l := r.Blobs(ctx) + + upload, err := l.Create(ctx) + if err != nil { + t.Fatal(err) + } + + if upload.ID() != uploadID { + log.Fatalf("Unexpected UUID %s; expected %s", upload.ID(), uploadID) + } + + n, err := upload.ReadFrom(bytes.NewReader(b1)) + if err != nil { + t.Fatal(err) + } + if n != int64(len(b1)) { + t.Fatalf("Unexpected ReadFrom length: %d; expected: %d", n, len(b1)) + } + + blob, err := upload.Commit(ctx, distribution.Descriptor{ + Digest: dgst, + Length: int64(len(b1)), + }) + if err != nil { + t.Fatal(err) + } + + if blob.Length != int64(len(b1)) { + t.Fatalf("Unexpected blob size: %d; expected: %d", blob.Length, len(b1)) + } +} + +func newRandomSchemaV1Manifest(name, tag string, blobCount int) (*manifest.SignedManifest, digest.Digest) { + blobs := make([]manifest.FSLayer, blobCount) + history := make([]manifest.History, blobCount) + + for i := 0; i < blobCount; i++ { + dgst, blob := newRandomBlob((i % 5) * 16) + + blobs[i] = manifest.FSLayer{BlobSum: dgst} + history[i] = manifest.History{V1Compatibility: fmt.Sprintf("{\"Hex\": \"%x\"}", blob)} + } + + m := &manifest.SignedManifest{ + Manifest: manifest.Manifest{ + Name: name, + Tag: tag, + Architecture: "x86", + FSLayers: blobs, + History: history, + Versioned: manifest.Versioned{ + SchemaVersion: 1, + }, + }, + } + manifestBytes, err := json.Marshal(m) + if err != nil { + panic(err) + } + dgst, err := digest.FromBytes(manifestBytes) + if err != nil { + panic(err) + } + + m.Raw = manifestBytes + + return m, dgst +} + +func addTestManifest(repo, reference string, content []byte, m *testutil.RequestResponseMap) { + *m = append(*m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "GET", + Route: "/v2/" + repo + "/manifests/" + reference, + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: content, + Headers: http.Header(map[string][]string{ + "Content-Length": {fmt.Sprint(len(content))}, + "Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)}, + }), + }, + }) + *m = append(*m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "HEAD", + Route: "/v2/" + repo + "/manifests/" + reference, + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Headers: http.Header(map[string][]string{ + "Content-Length": {fmt.Sprint(len(content))}, + "Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)}, + }), + }, + }) + +} + +func checkEqualManifest(m1, m2 *manifest.SignedManifest) error { + if m1.Name != m2.Name { + return fmt.Errorf("name does not match %q != %q", m1.Name, m2.Name) + } + if m1.Tag != m2.Tag { + return fmt.Errorf("tag does not match %q != %q", m1.Tag, m2.Tag) + } + if len(m1.FSLayers) != len(m2.FSLayers) { + return fmt.Errorf("fs blob length does not match %d != %d", len(m1.FSLayers), len(m2.FSLayers)) + } + for i := range m1.FSLayers { + if m1.FSLayers[i].BlobSum != m2.FSLayers[i].BlobSum { + return fmt.Errorf("blobsum does not match %q != %q", m1.FSLayers[i].BlobSum, m2.FSLayers[i].BlobSum) + } + } + if len(m1.History) != len(m2.History) { + return fmt.Errorf("history length does not match %d != %d", len(m1.History), len(m2.History)) + } + for i := range m1.History { + if m1.History[i].V1Compatibility != m2.History[i].V1Compatibility { + return fmt.Errorf("blobsum does not match %q != %q", m1.History[i].V1Compatibility, m2.History[i].V1Compatibility) + } + } + return nil +} + +func TestManifestFetch(t *testing.T) { + repo := "test.example.com/repo" + m1, dgst := newRandomSchemaV1Manifest(repo, "latest", 6) + var m testutil.RequestResponseMap + addTestManifest(repo, dgst.String(), m1.Raw, &m) + + e, c := testServer(m) + defer c() + + r, err := NewRepository(context.Background(), repo, e, nil) + if err != nil { + t.Fatal(err) + } + ms := r.Manifests() + + ok, err := ms.Exists(dgst) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Manifest does not exist") + } + + manifest, err := ms.Get(dgst) + if err != nil { + t.Fatal(err) + } + if err := checkEqualManifest(manifest, m1); err != nil { + t.Fatal(err) + } +} + +func TestManifestFetchByTag(t *testing.T) { + repo := "test.example.com/repo/by/tag" + m1, _ := newRandomSchemaV1Manifest(repo, "latest", 6) + var m testutil.RequestResponseMap + addTestManifest(repo, "latest", m1.Raw, &m) + + e, c := testServer(m) + defer c() + + r, err := NewRepository(context.Background(), repo, e, nil) + if err != nil { + t.Fatal(err) + } + + ms := r.Manifests() + ok, err := ms.ExistsByTag("latest") + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("Manifest does not exist") + } + + manifest, err := ms.GetByTag("latest") + if err != nil { + t.Fatal(err) + } + if err := checkEqualManifest(manifest, m1); err != nil { + t.Fatal(err) + } +} + +func TestManifestDelete(t *testing.T) { + repo := "test.example.com/repo/delete" + _, dgst1 := newRandomSchemaV1Manifest(repo, "latest", 6) + _, dgst2 := newRandomSchemaV1Manifest(repo, "latest", 6) + var m testutil.RequestResponseMap + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "DELETE", + Route: "/v2/" + repo + "/manifests/" + dgst1.String(), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Headers: http.Header(map[string][]string{ + "Content-Length": {"0"}, + }), + }, + }) + + e, c := testServer(m) + defer c() + + r, err := NewRepository(context.Background(), repo, e, nil) + if err != nil { + t.Fatal(err) + } + + ms := r.Manifests() + if err := ms.Delete(dgst1); err != nil { + t.Fatal(err) + } + if err := ms.Delete(dgst2); err == nil { + t.Fatal("Expected error deleting unknown manifest") + } + // TODO(dmcgowan): Check for specific unknown error +} + +func TestManifestPut(t *testing.T) { + repo := "test.example.com/repo/delete" + m1, dgst := newRandomSchemaV1Manifest(repo, "other", 6) + var m testutil.RequestResponseMap + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "PUT", + Route: "/v2/" + repo + "/manifests/other", + Body: m1.Raw, + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + Headers: http.Header(map[string][]string{ + "Content-Length": {"0"}, + "Docker-Content-Digest": {dgst.String()}, + }), + }, + }) + + e, c := testServer(m) + defer c() + + r, err := NewRepository(context.Background(), repo, e, nil) + if err != nil { + t.Fatal(err) + } + + ms := r.Manifests() + if err := ms.Put(m1); err != nil { + t.Fatal(err) + } + + // TODO(dmcgowan): Check for invalid input error +} + +func TestManifestTags(t *testing.T) { + repo := "test.example.com/repo/tags/list" + tagsList := []byte(strings.TrimSpace(` +{ + "name": "test.example.com/repo/tags/list", + "tags": [ + "tag1", + "tag2", + "funtag" + ] +} + `)) + var m testutil.RequestResponseMap + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "GET", + Route: "/v2/" + repo + "/tags/list", + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: tagsList, + Headers: http.Header(map[string][]string{ + "Content-Length": {fmt.Sprint(len(tagsList))}, + "Last-Modified": {time.Now().Add(-1 * time.Second).Format(time.ANSIC)}, + }), + }, + }) + + e, c := testServer(m) + defer c() + + r, err := NewRepository(context.Background(), repo, e, nil) + if err != nil { + t.Fatal(err) + } + + ms := r.Manifests() + tags, err := ms.Tags() + if err != nil { + t.Fatal(err) + } + + if len(tags) != 3 { + t.Fatalf("Wrong number of tags returned: %d, expected 3", len(tags)) + } + // TODO(dmcgowan): Check array + + // TODO(dmcgowan): Check for error cases +} + +func TestManifestUnauthorized(t *testing.T) { + repo := "test.example.com/repo" + _, dgst := newRandomSchemaV1Manifest(repo, "latest", 6) + var m testutil.RequestResponseMap + + m = append(m, testutil.RequestResponseMapping{ + Request: testutil.Request{ + Method: "GET", + Route: "/v2/" + repo + "/manifests/" + dgst.String(), + }, + Response: testutil.Response{ + StatusCode: http.StatusUnauthorized, + Body: []byte("garbage"), + }, + }) + + e, c := testServer(m) + defer c() + + r, err := NewRepository(context.Background(), repo, e, nil) + if err != nil { + t.Fatal(err) + } + ms := r.Manifests() + + _, err = ms.Get(dgst) + if err == nil { + t.Fatal("Expected error fetching manifest") + } + v2Err, ok := err.(*v2.Error) + if !ok { + t.Fatalf("Unexpected error type: %#v", err) + } + if v2Err.Code != v2.ErrorCodeUnauthorized { + t.Fatalf("Unexpected error code: %s", v2Err.Code.String()) + } + if expected := "401 Unauthorized"; v2Err.Message != expected { + t.Fatalf("Unexpected message value: %s, expected %s", v2Err.Message, expected) + } +} diff --git a/registry/client/transport/authchallenge.go b/registry/client/transport/authchallenge.go new file mode 100644 index 00000000..fffd560b --- /dev/null +++ b/registry/client/transport/authchallenge.go @@ -0,0 +1,150 @@ +package transport + +import ( + "net/http" + "strings" +) + +// Octet types from RFC 2616. +type octetType byte + +// authorizationChallenge carries information +// from a WWW-Authenticate response header. +type authorizationChallenge struct { + Scheme string + Parameters map[string]string +} + +var octetTypes [256]octetType + +const ( + isToken octetType = 1 << iota + isSpace +) + +func init() { + // OCTET = + // CHAR = + // CTL = + // CR = + // LF = + // SP = + // HT = + // <"> = + // CRLF = CR LF + // LWS = [CRLF] 1*( SP | HT ) + // TEXT = + // separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT + // token = 1* + // qdtext = > + + for c := 0; c < 256; c++ { + var t octetType + isCtl := c <= 31 || c == 127 + isChar := 0 <= c && c <= 127 + isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0 + if strings.IndexRune(" \t\r\n", rune(c)) >= 0 { + t |= isSpace + } + if isChar && !isCtl && !isSeparator { + t |= isToken + } + octetTypes[c] = t + } +} + +func parseAuthHeader(header http.Header) map[string]authorizationChallenge { + challenges := map[string]authorizationChallenge{} + for _, h := range header[http.CanonicalHeaderKey("WWW-Authenticate")] { + v, p := parseValueAndParams(h) + if v != "" { + challenges[v] = authorizationChallenge{Scheme: v, Parameters: p} + } + } + return challenges +} + +func parseValueAndParams(header string) (value string, params map[string]string) { + params = make(map[string]string) + value, s := expectToken(header) + if value == "" { + return + } + value = strings.ToLower(value) + s = "," + skipSpace(s) + for strings.HasPrefix(s, ",") { + var pkey string + pkey, s = expectToken(skipSpace(s[1:])) + if pkey == "" { + return + } + if !strings.HasPrefix(s, "=") { + return + } + var pvalue string + pvalue, s = expectTokenOrQuoted(s[1:]) + if pvalue == "" { + return + } + pkey = strings.ToLower(pkey) + params[pkey] = pvalue + s = skipSpace(s) + } + return +} + +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isSpace == 0 { + break + } + } + return s[i:] +} + +func expectToken(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isToken == 0 { + break + } + } + return s[:i], s[i:] +} + +func expectTokenOrQuoted(s string) (value string, rest string) { + if !strings.HasPrefix(s, "\"") { + return expectToken(s) + } + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j++ + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j++ + } + } + return "", "" + } + } + return "", "" +} diff --git a/registry/client/transport/authchallenge_test.go b/registry/client/transport/authchallenge_test.go new file mode 100644 index 00000000..45c932b9 --- /dev/null +++ b/registry/client/transport/authchallenge_test.go @@ -0,0 +1,38 @@ +package transport + +import ( + "net/http" + "testing" +) + +func TestAuthChallengeParse(t *testing.T) { + header := http.Header{} + header.Add("WWW-Authenticate", `Bearer realm="https://auth.example.com/token",service="registry.example.com",other=fun,slashed="he\"\l\lo"`) + + challenges := parseAuthHeader(header) + if len(challenges) != 1 { + t.Fatalf("Unexpected number of auth challenges: %d, expected 1", len(challenges)) + } + challenge := challenges["bearer"] + + if expected := "bearer"; challenge.Scheme != expected { + t.Fatalf("Unexpected scheme: %s, expected: %s", challenge.Scheme, expected) + } + + if expected := "https://auth.example.com/token"; challenge.Parameters["realm"] != expected { + t.Fatalf("Unexpected param: %s, expected: %s", challenge.Parameters["realm"], expected) + } + + if expected := "registry.example.com"; challenge.Parameters["service"] != expected { + t.Fatalf("Unexpected param: %s, expected: %s", challenge.Parameters["service"], expected) + } + + if expected := "fun"; challenge.Parameters["other"] != expected { + t.Fatalf("Unexpected param: %s, expected: %s", challenge.Parameters["other"], expected) + } + + if expected := "he\"llo"; challenge.Parameters["slashed"] != expected { + t.Fatalf("Unexpected param: %s, expected: %s", challenge.Parameters["slashed"], expected) + } + +} diff --git a/registry/client/transport/http_reader.go b/registry/client/transport/http_reader.go new file mode 100644 index 00000000..e351bdfe --- /dev/null +++ b/registry/client/transport/http_reader.go @@ -0,0 +1,172 @@ +package transport + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" +) + +// ReadSeekCloser combines io.ReadSeeker with io.Closer. +type ReadSeekCloser interface { + io.ReadSeeker + io.Closer +} + +// NewHTTPReadSeeker handles reading from an HTTP endpoint using a GET +// request. When seeking and starting a read from a non-zero offset +// the a "Range" header will be added which sets the offset. +// TODO(dmcgowan): Move this into a separate utility package +func NewHTTPReadSeeker(client *http.Client, url string, size int64) ReadSeekCloser { + return &httpReadSeeker{ + client: client, + url: url, + size: size, + } +} + +type httpReadSeeker struct { + client *http.Client + url string + + size int64 + + rc io.ReadCloser // remote read closer + brd *bufio.Reader // internal buffered io + offset int64 + err error +} + +func (hrs *httpReadSeeker) Read(p []byte) (n int, err error) { + if hrs.err != nil { + return 0, hrs.err + } + + rd, err := hrs.reader() + if err != nil { + return 0, err + } + + n, err = rd.Read(p) + hrs.offset += int64(n) + + // Simulate io.EOF error if we reach filesize. + if err == nil && hrs.offset >= hrs.size { + err = io.EOF + } + + return n, err +} + +func (hrs *httpReadSeeker) Seek(offset int64, whence int) (int64, error) { + if hrs.err != nil { + return 0, hrs.err + } + + var err error + newOffset := hrs.offset + + switch whence { + case os.SEEK_CUR: + newOffset += int64(offset) + case os.SEEK_END: + newOffset = hrs.size + int64(offset) + case os.SEEK_SET: + newOffset = int64(offset) + } + + if newOffset < 0 { + err = errors.New("cannot seek to negative position") + } else { + if hrs.offset != newOffset { + hrs.reset() + } + + // No problems, set the offset. + hrs.offset = newOffset + } + + return hrs.offset, err +} + +func (hrs *httpReadSeeker) Close() error { + if hrs.err != nil { + return hrs.err + } + + // close and release reader chain + if hrs.rc != nil { + hrs.rc.Close() + } + + hrs.rc = nil + hrs.brd = nil + + hrs.err = errors.New("httpLayer: closed") + + return nil +} + +func (hrs *httpReadSeeker) reset() { + if hrs.err != nil { + return + } + if hrs.rc != nil { + hrs.rc.Close() + hrs.rc = nil + } +} + +func (hrs *httpReadSeeker) reader() (io.Reader, error) { + if hrs.err != nil { + return nil, hrs.err + } + + if hrs.rc != nil { + return hrs.brd, nil + } + + // If the offset is great than or equal to size, return a empty, noop reader. + if hrs.offset >= hrs.size { + return ioutil.NopCloser(bytes.NewReader([]byte{})), nil + } + + req, err := http.NewRequest("GET", hrs.url, nil) + if err != nil { + return nil, err + } + + if hrs.offset > 0 { + // TODO(stevvooe): Get this working correctly. + + // If we are at different offset, issue a range request from there. + req.Header.Add("Range", "1-") + // TODO: get context in here + // context.GetLogger(hrs.context).Infof("Range: %s", req.Header.Get("Range")) + } + + resp, err := hrs.client.Do(req) + if err != nil { + return nil, err + } + + switch { + case resp.StatusCode == 200: + hrs.rc = resp.Body + default: + defer resp.Body.Close() + return nil, fmt.Errorf("unexpected status resolving reader: %v", resp.Status) + } + + if hrs.brd == nil { + hrs.brd = bufio.NewReader(hrs.rc) + } else { + hrs.brd.Reset(hrs.rc) + } + + return hrs.brd, nil +} diff --git a/registry/client/transport/session.go b/registry/client/transport/session.go new file mode 100644 index 00000000..90c8082c --- /dev/null +++ b/registry/client/transport/session.go @@ -0,0 +1,297 @@ +package transport + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// AuthenticationHandler is an interface for authorizing a request from +// params from a "WWW-Authenicate" header for a single scheme. +type AuthenticationHandler interface { + // Scheme returns the scheme as expected from the "WWW-Authenicate" header. + Scheme() string + + // AuthorizeRequest adds the authorization header to a request (if needed) + // using the parameters from "WWW-Authenticate" method. The parameters + // values depend on the scheme. + AuthorizeRequest(req *http.Request, params map[string]string) error +} + +// CredentialStore is an interface for getting credentials for +// a given URL +type CredentialStore interface { + // Basic returns basic auth for the given URL + Basic(*url.URL) (string, string) +} + +// NewAuthorizer creates an authorizer which can handle multiple authentication +// schemes. The handlers are tried in order, the higher priority authentication +// methods should be first. +func NewAuthorizer(transport http.RoundTripper, handlers ...AuthenticationHandler) RequestModifier { + return &tokenAuthorizer{ + challenges: map[string]map[string]authorizationChallenge{}, + handlers: handlers, + transport: transport, + } +} + +type tokenAuthorizer struct { + challenges map[string]map[string]authorizationChallenge + handlers []AuthenticationHandler + transport http.RoundTripper +} + +func (ta *tokenAuthorizer) ping(endpoint string) (map[string]authorizationChallenge, error) { + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return nil, err + } + + client := &http.Client{ + Transport: ta.transport, + // Ping should fail fast + Timeout: 5 * time.Second, + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // TODO(dmcgowan): Add version string which would allow skipping this section + var supportsV2 bool +HeaderLoop: + for _, supportedVersions := range resp.Header[http.CanonicalHeaderKey("Docker-Distribution-API-Version")] { + for _, versionName := range strings.Fields(supportedVersions) { + if versionName == "registry/2.0" { + supportsV2 = true + break HeaderLoop + } + } + } + + if !supportsV2 { + return nil, fmt.Errorf("%s does not appear to be a v2 registry endpoint", endpoint) + } + + if resp.StatusCode == http.StatusUnauthorized { + // Parse the WWW-Authenticate Header and store the challenges + // on this endpoint object. + return parseAuthHeader(resp.Header), nil + } else if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unable to get valid ping response: %d", resp.StatusCode) + } + + return nil, nil +} + +func (ta *tokenAuthorizer) ModifyRequest(req *http.Request) error { + v2Root := strings.Index(req.URL.Path, "/v2/") + // Test if /v2/ does not exist or not at beginning + // TODO(dmcgowan) support v2 endpoints which have a prefix before /v2/ + if v2Root == -1 || v2Root > 0 { + return nil + } + + ping := url.URL{ + Host: req.URL.Host, + Scheme: req.URL.Scheme, + Path: req.URL.Path[:v2Root+4], + } + + pingEndpoint := ping.String() + + challenges, ok := ta.challenges[pingEndpoint] + if !ok { + var err error + challenges, err = ta.ping(pingEndpoint) + if err != nil { + return err + } + ta.challenges[pingEndpoint] = challenges + } + + for _, handler := range ta.handlers { + challenge, ok := challenges[handler.Scheme()] + if ok { + if err := handler.AuthorizeRequest(req, challenge.Parameters); err != nil { + return err + } + } + } + + return nil +} + +type tokenHandler struct { + header http.Header + creds CredentialStore + scope TokenScope + transport http.RoundTripper + + tokenLock sync.Mutex + tokenCache string + tokenExpiration time.Time +} + +// TokenScope represents the scope at which a token will be requested. +// This represents a specific action on a registry resource. +type TokenScope struct { + Resource string + Scope string + Actions []string +} + +func (ts TokenScope) String() string { + return fmt.Sprintf("%s:%s:%s", ts.Resource, ts.Scope, strings.Join(ts.Actions, ",")) +} + +// NewTokenHandler creates a new AuthenicationHandler which supports +// fetching tokens from a remote token server. +func NewTokenHandler(transport http.RoundTripper, creds CredentialStore, scope TokenScope) AuthenticationHandler { + return &tokenHandler{ + transport: transport, + creds: creds, + scope: scope, + } +} + +func (th *tokenHandler) client() *http.Client { + return &http.Client{ + Transport: th.transport, + Timeout: 15 * time.Second, + } +} + +func (th *tokenHandler) Scheme() string { + return "bearer" +} + +func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]string) error { + if err := th.refreshToken(params); err != nil { + return err + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", th.tokenCache)) + + return nil +} + +func (th *tokenHandler) refreshToken(params map[string]string) error { + th.tokenLock.Lock() + defer th.tokenLock.Unlock() + now := time.Now() + if now.After(th.tokenExpiration) { + token, err := th.fetchToken(params) + if err != nil { + return err + } + th.tokenCache = token + th.tokenExpiration = now.Add(time.Minute) + } + + return nil +} + +type tokenResponse struct { + Token string `json:"token"` +} + +func (th *tokenHandler) fetchToken(params map[string]string) (token string, err error) { + //log.Debugf("Getting bearer token with %s for %s", challenge.Parameters, ta.auth.Username) + realm, ok := params["realm"] + if !ok { + return "", errors.New("no realm specified for token auth challenge") + } + + // TODO(dmcgowan): Handle empty scheme + + realmURL, err := url.Parse(realm) + if err != nil { + return "", fmt.Errorf("invalid token auth challenge realm: %s", err) + } + + req, err := http.NewRequest("GET", realmURL.String(), nil) + if err != nil { + return "", err + } + + reqParams := req.URL.Query() + service := params["service"] + scope := th.scope.String() + + if service != "" { + reqParams.Add("service", service) + } + + for _, scopeField := range strings.Fields(scope) { + reqParams.Add("scope", scopeField) + } + + if th.creds != nil { + username, password := th.creds.Basic(realmURL) + if username != "" && password != "" { + reqParams.Add("account", username) + req.SetBasicAuth(username, password) + } + } + + req.URL.RawQuery = reqParams.Encode() + + resp, err := th.client().Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("token auth attempt for registry: %s request failed with status: %d %s", req.URL, resp.StatusCode, http.StatusText(resp.StatusCode)) + } + + decoder := json.NewDecoder(resp.Body) + + tr := new(tokenResponse) + if err = decoder.Decode(tr); err != nil { + return "", fmt.Errorf("unable to decode token response: %s", err) + } + + if tr.Token == "" { + return "", errors.New("authorization server did not include a token in the response") + } + + return tr.Token, nil +} + +type basicHandler struct { + creds CredentialStore +} + +// NewBasicHandler creaters a new authentiation handler which adds +// basic authentication credentials to a request. +func NewBasicHandler(creds CredentialStore) AuthenticationHandler { + return &basicHandler{ + creds: creds, + } +} + +func (*basicHandler) Scheme() string { + return "basic" +} + +func (bh *basicHandler) AuthorizeRequest(req *http.Request, params map[string]string) error { + if bh.creds != nil { + username, password := bh.creds.Basic(req.URL) + if username != "" && password != "" { + req.SetBasicAuth(username, password) + return nil + } + } + return errors.New("no basic auth credentials") +} diff --git a/registry/client/transport/session_test.go b/registry/client/transport/session_test.go new file mode 100644 index 00000000..374d6e79 --- /dev/null +++ b/registry/client/transport/session_test.go @@ -0,0 +1,271 @@ +package transport + +import ( + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/docker/distribution/testutil" +) + +func testServer(rrm testutil.RequestResponseMap) (string, func()) { + h := testutil.NewHandler(rrm) + s := httptest.NewServer(h) + return s.URL, s.Close +} + +type testAuthenticationWrapper struct { + headers http.Header + authCheck func(string) bool + next http.Handler +} + +func (w *testAuthenticationWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth == "" || !w.authCheck(auth) { + h := rw.Header() + for k, values := range w.headers { + h[k] = values + } + rw.WriteHeader(http.StatusUnauthorized) + return + } + w.next.ServeHTTP(rw, r) +} + +func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, authCheck func(string) bool) (string, func()) { + h := testutil.NewHandler(rrm) + wrapper := &testAuthenticationWrapper{ + + headers: http.Header(map[string][]string{ + "Docker-Distribution-API-Version": {"registry/2.0"}, + "WWW-Authenticate": {authenticate}, + }), + authCheck: authCheck, + next: h, + } + + s := httptest.NewServer(wrapper) + return s.URL, s.Close +} + +type testCredentialStore struct { + username string + password string +} + +func (tcs *testCredentialStore) Basic(*url.URL) (string, string) { + return tcs.username, tcs.password +} + +func TestEndpointAuthorizeToken(t *testing.T) { + service := "localhost.localdomain" + repo1 := "some/registry" + repo2 := "other/registry" + scope1 := fmt.Sprintf("repository:%s:pull,push", repo1) + scope2 := fmt.Sprintf("repository:%s:pull,push", repo2) + tokenScope1 := TokenScope{ + Resource: "repository", + Scope: repo1, + Actions: []string{"pull", "push"}, + } + tokenScope2 := TokenScope{ + Resource: "repository", + Scope: repo2, + Actions: []string{"pull", "push"}, + } + + tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope1), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"token":"statictoken"}`), + }, + }, + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope2), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"token":"badtoken"}`), + }, + }, + }) + te, tc := testServer(tokenMap) + defer tc() + + m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + }) + + authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service) + validCheck := func(a string) bool { + return a == "Bearer statictoken" + } + e, c := testServerWithAuth(m, authenicate, validCheck) + defer c() + + transport1 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, nil, tokenScope1))) + client := &http.Client{Transport: transport1} + + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } + + badCheck := func(a string) bool { + return a == "Bearer statictoken" + } + e2, c2 := testServerWithAuth(m, authenicate, badCheck) + defer c2() + + transport2 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, nil, tokenScope2))) + client2 := &http.Client{Transport: transport2} + + req, _ = http.NewRequest("GET", e2+"/v2/hello", nil) + resp, err = client2.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized) + } +} + +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func TestEndpointAuthorizeTokenBasic(t *testing.T) { + service := "localhost.localdomain" + repo := "some/fun/registry" + scope := fmt.Sprintf("repository:%s:pull,push", repo) + username := "tokenuser" + password := "superSecretPa$$word" + tokenScope := TokenScope{ + Resource: "repository", + Scope: repo, + Actions: []string{"pull", "push"}, + } + + tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service), + }, + Response: testutil.Response{ + StatusCode: http.StatusOK, + Body: []byte(`{"token":"statictoken"}`), + }, + }, + }) + + authenicate1 := fmt.Sprintf("Basic realm=localhost") + basicCheck := func(a string) bool { + return a == fmt.Sprintf("Basic %s", basicAuth(username, password)) + } + te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck) + defer tc() + + m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + }) + + authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te+"/token", service) + bearerCheck := func(a string) bool { + return a == "Bearer statictoken" + } + e, c := testServerWithAuth(m, authenicate2, bearerCheck) + defer c() + + creds := &testCredentialStore{ + username: username, + password: password, + } + + transport1 := NewTransport(nil, NewAuthorizer(nil, NewTokenHandler(nil, creds, tokenScope), NewBasicHandler(creds))) + client := &http.Client{Transport: transport1} + + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } +} + +func TestEndpointAuthorizeBasic(t *testing.T) { + m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{ + { + Request: testutil.Request{ + Method: "GET", + Route: "/v2/hello", + }, + Response: testutil.Response{ + StatusCode: http.StatusAccepted, + }, + }, + }) + + username := "user1" + password := "funSecretPa$$word" + authenicate := fmt.Sprintf("Basic realm=localhost") + validCheck := func(a string) bool { + return a == fmt.Sprintf("Basic %s", basicAuth(username, password)) + } + e, c := testServerWithAuth(m, authenicate, validCheck) + defer c() + creds := &testCredentialStore{ + username: username, + password: password, + } + + transport1 := NewTransport(nil, NewAuthorizer(nil, NewBasicHandler(creds))) + client := &http.Client{Transport: transport1} + + req, _ := http.NewRequest("GET", e+"/v2/hello", nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error sending get request: %s", err) + } + + if resp.StatusCode != http.StatusAccepted { + t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted) + } +} diff --git a/registry/client/transport/transport.go b/registry/client/transport/transport.go new file mode 100644 index 00000000..30e45fab --- /dev/null +++ b/registry/client/transport/transport.go @@ -0,0 +1,147 @@ +package transport + +import ( + "io" + "net/http" + "sync" +) + +// RequestModifier represents an object which will do an inplace +// modification of an HTTP request. +type RequestModifier interface { + ModifyRequest(*http.Request) error +} + +type headerModifier http.Header + +// NewHeaderRequestModifier returns a new RequestModifier which will +// add the given headers to a request. +func NewHeaderRequestModifier(header http.Header) RequestModifier { + return headerModifier(header) +} + +func (h headerModifier) ModifyRequest(req *http.Request) error { + for k, s := range http.Header(h) { + req.Header[k] = append(req.Header[k], s...) + } + + return nil +} + +// NewTransport creates a new transport which will apply modifiers to +// the request on a RoundTrip call. +func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper { + return &transport{ + Modifiers: modifiers, + Base: base, + } +} + +// transport is an http.RoundTripper that makes HTTP requests after +// copying and modifying the request +type transport struct { + Modifiers []RequestModifier + Base http.RoundTripper + + mu sync.Mutex // guards modReq + modReq map[*http.Request]*http.Request // original -> modified +} + +// RoundTrip authorizes and authenticates the request with an +// access token. If no token exists or token is expired, +// tries to refresh/fetch a new token. +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := cloneRequest(req) + for _, modifier := range t.Modifiers { + if err := modifier.ModifyRequest(req2); err != nil { + return nil, err + } + } + + t.setModReq(req, req2) + res, err := t.base().RoundTrip(req2) + if err != nil { + t.setModReq(req, nil) + return nil, err + } + res.Body = &onEOFReader{ + rc: res.Body, + fn: func() { t.setModReq(req, nil) }, + } + return res, nil +} + +// CancelRequest cancels an in-flight request by closing its connection. +func (t *transport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := t.base().(canceler); ok { + t.mu.Lock() + modReq := t.modReq[req] + delete(t.modReq, req) + t.mu.Unlock() + cr.CancelRequest(modReq) + } +} + +func (t *transport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func (t *transport) setModReq(orig, mod *http.Request) { + t.mu.Lock() + defer t.mu.Unlock() + if t.modReq == nil { + t.modReq = make(map[*http.Request]*http.Request) + } + if mod == nil { + delete(t.modReq, orig) + } else { + t.modReq[orig] = mod + } +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +func cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + + return r2 +} + +type onEOFReader struct { + rc io.ReadCloser + fn func() +} + +func (r *onEOFReader) Read(p []byte) (n int, err error) { + n, err = r.rc.Read(p) + if err == io.EOF { + r.runFunc() + } + return +} + +func (r *onEOFReader) Close() error { + err := r.rc.Close() + r.runFunc() + return err +} + +func (r *onEOFReader) runFunc() { + if fn := r.fn; fn != nil { + fn() + r.fn = nil + } +} diff --git a/registry/storage/blobcachemetrics.go b/registry/storage/blobcachemetrics.go new file mode 100644 index 00000000..fad0a77a --- /dev/null +++ b/registry/storage/blobcachemetrics.go @@ -0,0 +1,60 @@ +package storage + +import ( + "expvar" + "sync/atomic" + + "github.com/docker/distribution/registry/storage/cache" +) + +type blobStatCollector struct { + metrics cache.Metrics +} + +func (bsc *blobStatCollector) Hit() { + atomic.AddUint64(&bsc.metrics.Requests, 1) + atomic.AddUint64(&bsc.metrics.Hits, 1) +} + +func (bsc *blobStatCollector) Miss() { + atomic.AddUint64(&bsc.metrics.Requests, 1) + atomic.AddUint64(&bsc.metrics.Misses, 1) +} + +func (bsc *blobStatCollector) Metrics() cache.Metrics { + return bsc.metrics +} + +// blobStatterCacheMetrics keeps track of cache metrics for blob descriptor +// cache requests. Note this is kept globally and made available via expvar. +// For more detailed metrics, its recommend to instrument a particular cache +// implementation. +var blobStatterCacheMetrics cache.MetricsTracker = &blobStatCollector{} + +func init() { + registry := expvar.Get("registry") + if registry == nil { + registry = expvar.NewMap("registry") + } + + cache := registry.(*expvar.Map).Get("cache") + if cache == nil { + cache = &expvar.Map{} + cache.(*expvar.Map).Init() + registry.(*expvar.Map).Set("cache", cache) + } + + storage := cache.(*expvar.Map).Get("storage") + if storage == nil { + storage = &expvar.Map{} + storage.(*expvar.Map).Init() + cache.(*expvar.Map).Set("storage", storage) + } + + storage.(*expvar.Map).Set("blobdescriptor", expvar.Func(func() interface{} { + // no need for synchronous access: the increments are atomic and + // during reading, we don't care if the data is up to date. The + // numbers will always *eventually* be reported correctly. + return blobStatterCacheMetrics + })) +} diff --git a/registry/storage/cache/cachedblobdescriptorstore.go b/registry/storage/cache/cachedblobdescriptorstore.go new file mode 100644 index 00000000..a095b19a --- /dev/null +++ b/registry/storage/cache/cachedblobdescriptorstore.go @@ -0,0 +1,80 @@ +package cache + +import ( + "github.com/docker/distribution/context" + "github.com/docker/distribution/digest" + + "github.com/docker/distribution" +) + +// Metrics is used to hold metric counters +// related to the number of times a cache was +// hit or missed. +type Metrics struct { + Requests uint64 + Hits uint64 + Misses uint64 +} + +// MetricsTracker represents a metric tracker +// which simply counts the number of hits and misses. +type MetricsTracker interface { + Hit() + Miss() + Metrics() Metrics +} + +type cachedBlobStatter struct { + cache distribution.BlobDescriptorService + backend distribution.BlobStatter + tracker MetricsTracker +} + +// NewCachedBlobStatter creates a new statter which prefers a cache and +// falls back to a backend. +func NewCachedBlobStatter(cache distribution.BlobDescriptorService, backend distribution.BlobStatter) distribution.BlobStatter { + return &cachedBlobStatter{ + cache: cache, + backend: backend, + } +} + +// NewCachedBlobStatterWithMetrics creates a new statter which prefers a cache and +// falls back to a backend. Hits and misses will send to the tracker. +func NewCachedBlobStatterWithMetrics(cache distribution.BlobDescriptorService, backend distribution.BlobStatter, tracker MetricsTracker) distribution.BlobStatter { + return &cachedBlobStatter{ + cache: cache, + backend: backend, + tracker: tracker, + } +} + +func (cbds *cachedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) { + desc, err := cbds.cache.Stat(ctx, dgst) + if err != nil { + if err != distribution.ErrBlobUnknown { + context.GetLogger(ctx).Errorf("error retrieving descriptor from cache: %v", err) + } + + goto fallback + } + + if cbds.tracker != nil { + cbds.tracker.Hit() + } + return desc, nil +fallback: + if cbds.tracker != nil { + cbds.tracker.Miss() + } + desc, err = cbds.backend.Stat(ctx, dgst) + if err != nil { + return desc, err + } + + if err := cbds.cache.SetDescriptor(ctx, dgst, desc); err != nil { + context.GetLogger(ctx).Errorf("error adding descriptor %v to cache: %v", desc.Digest, err) + } + + return desc, err +} diff --git a/registry/storage/cachedblobdescriptorstore.go b/registry/storage/cachedblobdescriptorstore.go deleted file mode 100644 index a0ccd067..00000000 --- a/registry/storage/cachedblobdescriptorstore.go +++ /dev/null @@ -1,84 +0,0 @@ -package storage - -import ( - "expvar" - "sync/atomic" - - "github.com/docker/distribution/context" - "github.com/docker/distribution/digest" - - "github.com/docker/distribution" -) - -type cachedBlobStatter struct { - cache distribution.BlobDescriptorService - backend distribution.BlobStatter -} - -func (cbds *cachedBlobStatter) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) { - atomic.AddUint64(&blobStatterCacheMetrics.Stat.Requests, 1) - desc, err := cbds.cache.Stat(ctx, dgst) - if err != nil { - if err != distribution.ErrBlobUnknown { - context.GetLogger(ctx).Errorf("error retrieving descriptor from cache: %v", err) - } - - goto fallback - } - - atomic.AddUint64(&blobStatterCacheMetrics.Stat.Hits, 1) - return desc, nil -fallback: - atomic.AddUint64(&blobStatterCacheMetrics.Stat.Misses, 1) - desc, err = cbds.backend.Stat(ctx, dgst) - if err != nil { - return desc, err - } - - if err := cbds.cache.SetDescriptor(ctx, dgst, desc); err != nil { - context.GetLogger(ctx).Errorf("error adding descriptor %v to cache: %v", desc.Digest, err) - } - - return desc, err -} - -// blobStatterCacheMetrics keeps track of cache metrics for blob descriptor -// cache requests. Note this is kept globally and made available via expvar. -// For more detailed metrics, its recommend to instrument a particular cache -// implementation. -var blobStatterCacheMetrics struct { - // Stat tracks calls to the caches. - Stat struct { - Requests uint64 - Hits uint64 - Misses uint64 - } -} - -func init() { - registry := expvar.Get("registry") - if registry == nil { - registry = expvar.NewMap("registry") - } - - cache := registry.(*expvar.Map).Get("cache") - if cache == nil { - cache = &expvar.Map{} - cache.(*expvar.Map).Init() - registry.(*expvar.Map).Set("cache", cache) - } - - storage := cache.(*expvar.Map).Get("storage") - if storage == nil { - storage = &expvar.Map{} - storage.(*expvar.Map).Init() - cache.(*expvar.Map).Set("storage", storage) - } - - storage.(*expvar.Map).Set("blobdescriptor", expvar.Func(func() interface{} { - // no need for synchronous access: the increments are atomic and - // during reading, we don't care if the data is up to date. The - // numbers will always *eventually* be reported correctly. - return blobStatterCacheMetrics - })) -} diff --git a/registry/storage/registry.go b/registry/storage/registry.go index 331aba73..ff33f410 100644 --- a/registry/storage/registry.go +++ b/registry/storage/registry.go @@ -29,10 +29,7 @@ func NewRegistryWithDriver(ctx context.Context, driver storagedriver.StorageDriv } if blobDescriptorCacheProvider != nil { - statter = &cachedBlobStatter{ - cache: blobDescriptorCacheProvider, - backend: statter, - } + statter = cache.NewCachedBlobStatter(blobDescriptorCacheProvider, statter) } bs := &blobStore{ @@ -143,10 +140,7 @@ func (repo *repository) Blobs(ctx context.Context) distribution.BlobStore { } if repo.descriptorCache != nil { - statter = &cachedBlobStatter{ - cache: repo.descriptorCache, - backend: statter, - } + statter = cache.NewCachedBlobStatter(repo.descriptorCache, statter) } return &linkedBlobStore{ diff --git a/testutil/handler.go b/testutil/handler.go index fa118cd1..10850e24 100644 --- a/testutil/handler.go +++ b/testutil/handler.go @@ -6,6 +6,7 @@ import ( "io" "io/ioutil" "net/http" + "net/url" "sort" "strings" ) @@ -40,16 +41,18 @@ type Request struct { func (r Request) String() string { queryString := "" if len(r.QueryParams) > 0 { - queryString = "?" keys := make([]string, 0, len(r.QueryParams)) + queryParts := make([]string, 0, len(r.QueryParams)) for k := range r.QueryParams { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { - queryString += strings.Join(r.QueryParams[k], "&") + "&" + for _, val := range r.QueryParams[k] { + queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, url.QueryEscape(val))) + } } - queryString = queryString[:len(queryString)-1] + queryString = "?" + strings.Join(queryParts, "&") } return fmt.Sprintf("%s %s%s\n%s", r.Method, r.Route, queryString, r.Body) }