diff --git a/docs/client/repository.go b/docs/client/repository.go index c1e8e07f1..56d86df01 100644 --- a/docs/client/repository.go +++ b/docs/client/repository.go @@ -358,25 +358,18 @@ type blobs struct { distribution.BlobDeleter } -func sanitizeLocation(location, source string) (string, error) { +func sanitizeLocation(location, base string) (string, error) { + baseURL, err := url.Parse(base) + if err != nil { + return "", err + } + locationURL, err := url.Parse(location) if err != nil { return "", err } - if locationURL.Scheme == "" { - sourceURL, err := url.Parse(source) - if err != nil { - return "", err - } - locationURL = &url.URL{ - Scheme: sourceURL.Scheme, - Host: sourceURL.Host, - Path: location, - } - location = locationURL.String() - } - return location, nil + return baseURL.ResolveReference(locationURL).String(), nil } func (bs *blobs) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) { diff --git a/docs/client/repository_test.go b/docs/client/repository_test.go index 8a7a598e6..384a2311f 100644 --- a/docs/client/repository_test.go +++ b/docs/client/repository_test.go @@ -856,3 +856,49 @@ func TestCatalogInParts(t *testing.T) { t.Fatalf("Got wrong number of repos") } } + +func TestSanitizeLocation(t *testing.T) { + for _, testcase := range []struct { + description string + location string + source string + expected string + err error + }{ + { + description: "ensure relative location correctly resolved", + location: "/v2/foo/baasdf", + source: "http://blahalaja.com/v1", + expected: "http://blahalaja.com/v2/foo/baasdf", + }, + { + description: "ensure parameters are preserved", + location: "/v2/foo/baasdf?_state=asdfasfdasdfasdf&digest=foo", + source: "http://blahalaja.com/v1", + expected: "http://blahalaja.com/v2/foo/baasdf?_state=asdfasfdasdfasdf&digest=foo", + }, + { + description: "ensure new hostname overidden", + location: "https://mwhahaha.com/v2/foo/baasdf?_state=asdfasfdasdfasdf", + source: "http://blahalaja.com/v1", + expected: "https://mwhahaha.com/v2/foo/baasdf?_state=asdfasfdasdfasdf", + }, + } { + fatalf := func(format string, args ...interface{}) { + t.Fatalf(testcase.description+": "+format, args...) + } + + s, err := sanitizeLocation(testcase.location, testcase.source) + if err != testcase.err { + if testcase.err != nil { + fatalf("expected error: %v != %v", err, testcase) + } else { + fatalf("unexpected error sanitizing: %v", err) + } + } + + if s != testcase.expected { + fatalf("bad sanitize: %q != %q", s, testcase.expected) + } + } +}