diff --git a/docs/registry.go b/docs/registry.go index 7bcf06601..8d1a9f228 100644 --- a/docs/registry.go +++ b/docs/registry.go @@ -756,9 +756,36 @@ type Registry struct { indexEndpoint string } +func trustedLocation(req *http.Request) bool { + var ( + trusteds = []string{"docker.com", "docker.io"} + hostname = strings.SplitN(req.Host, ":", 2)[0] + ) + if req.URL.Scheme != "https" { + return false + } + + for _, trusted := range trusteds { + if strings.HasSuffix(hostname, trusted) { + return true + } + } + return false +} + func AddRequiredHeadersToRedirectedRequests(req *http.Request, via []*http.Request) error { if via != nil && via[0] != nil { - req.Header = via[0].Header + if trustedLocation(req) && trustedLocation(via[0]) { + req.Header = via[0].Header + } else { + for k, v := range via[0].Header { + if k != "Authorization" { + for _, vv := range v { + req.Header.Add(k, vv) + } + } + } + } } return nil } diff --git a/docs/registry_test.go b/docs/registry_test.go index e207359e6..2857ab4a4 100644 --- a/docs/registry_test.go +++ b/docs/registry_test.go @@ -2,10 +2,12 @@ package registry import ( "fmt" - "github.com/dotcloud/docker/utils" + "net/http" "net/url" "strings" "testing" + + "github.com/dotcloud/docker/utils" ) var ( @@ -231,3 +233,70 @@ func TestValidRepositoryName(t *testing.T) { t.Fail() } } + +func TestTrustedLocation(t *testing.T) { + for _, url := range []string{"http://example.com", "https://example.com:7777", "http://docker.io", "http://test.docker.io"} { + req, _ := http.NewRequest("GET", url, nil) + if trustedLocation(req) == true { + t.Fatalf("'%s' shouldn't be detected as a trusted location", url) + } + } + + for _, url := range []string{"https://docker.io", "https://test.docker.io:80"} { + req, _ := http.NewRequest("GET", url, nil) + if trustedLocation(req) == false { + t.Fatalf("'%s' should be detected as a trusted location", url) + } + } +} + +func TestAddRequiredHeadersToRedirectedRequests(t *testing.T) { + for _, urls := range [][]string{ + {"http://docker.io", "https://docker.com"}, + {"https://foo.docker.io:7777", "http://bar.docker.com"}, + {"https://foo.docker.io", "https://example.com"}, + } { + reqFrom, _ := http.NewRequest("GET", urls[0], nil) + reqFrom.Header.Add("Content-Type", "application/json") + reqFrom.Header.Add("Authorization", "super_secret") + reqTo, _ := http.NewRequest("GET", urls[1], nil) + + AddRequiredHeadersToRedirectedRequests(reqTo, []*http.Request{reqFrom}) + + if len(reqTo.Header) != 1 { + t.Fatal("Expected 1 headers, got %d", len(reqTo.Header)) + } + + if reqTo.Header.Get("Content-Type") != "application/json" { + t.Fatal("'Content-Type' should be 'application/json'") + } + + if reqTo.Header.Get("Authorization") != "" { + t.Fatal("'Authorization' should be empty") + } + } + + for _, urls := range [][]string{ + {"https://docker.io", "https://docker.com"}, + {"https://foo.docker.io:7777", "https://bar.docker.com"}, + } { + reqFrom, _ := http.NewRequest("GET", urls[0], nil) + reqFrom.Header.Add("Content-Type", "application/json") + reqFrom.Header.Add("Authorization", "super_secret") + reqTo, _ := http.NewRequest("GET", urls[1], nil) + + AddRequiredHeadersToRedirectedRequests(reqTo, []*http.Request{reqFrom}) + + if len(reqTo.Header) != 2 { + t.Fatal("Expected 2 headers, got %d", len(reqTo.Header)) + } + + if reqTo.Header.Get("Content-Type") != "application/json" { + t.Fatal("'Content-Type' should be 'application/json'") + } + + if reqTo.Header.Get("Authorization") != "super_secret" { + t.Fatal("'Authorization' should be 'super_secret'") + } + } +}