diff --git a/docs/registry.go b/docs/registry.go index 24c55125c..748636dca 100644 --- a/docs/registry.go +++ b/docs/registry.go @@ -4,6 +4,8 @@ import ( "bytes" "crypto/sha256" _ "crypto/sha512" + "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" @@ -13,6 +15,8 @@ import ( "net/http" "net/http/cookiejar" "net/url" + "os" + "path" "regexp" "runtime" "strconv" @@ -29,31 +33,155 @@ var ( errLoginRequired = errors.New("Authentication is required.") ) +type TimeoutType uint32 + +const ( + NoTimeout TimeoutType = iota + ReceiveTimeout + ConnectTimeout +) + +func newClient(jar http.CookieJar, roots *x509.CertPool, cert *tls.Certificate, timeout TimeoutType) *http.Client { + tlsConfig := tls.Config{RootCAs: roots} + + if cert != nil { + tlsConfig.Certificates = append(tlsConfig.Certificates, *cert) + } + + httpTransport := &http.Transport{ + DisableKeepAlives: true, + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tlsConfig, + } + + switch timeout { + case ConnectTimeout: + httpTransport.Dial = func(proto string, addr string) (net.Conn, error) { + // Set the connect timeout to 5 seconds + conn, err := net.DialTimeout(proto, addr, 5*time.Second) + if err != nil { + return nil, err + } + // Set the recv timeout to 10 seconds + conn.SetDeadline(time.Now().Add(10 * time.Second)) + return conn, nil + } + case ReceiveTimeout: + httpTransport.Dial = func(proto string, addr string) (net.Conn, error) { + conn, err := net.Dial(proto, addr) + if err != nil { + return nil, err + } + conn = utils.NewTimeoutConn(conn, 1*time.Minute) + return conn, nil + } + } + + return &http.Client{ + Transport: httpTransport, + CheckRedirect: AddRequiredHeadersToRedirectedRequests, + Jar: jar, + } +} + +func doRequest(req *http.Request, jar http.CookieJar, timeout TimeoutType) (*http.Response, *http.Client, error) { + hasFile := func(files []os.FileInfo, name string) bool { + for _, f := range files { + if f.Name() == name { + return true + } + } + return false + } + + hostDir := path.Join("/etc/docker/certs.d", req.URL.Host) + fs, err := ioutil.ReadDir(hostDir) + if err != nil && !os.IsNotExist(err) { + return nil, nil, err + } + + var ( + pool *x509.CertPool + certs []*tls.Certificate + ) + + for _, f := range fs { + if strings.HasSuffix(f.Name(), ".crt") { + if pool == nil { + pool = x509.NewCertPool() + } + data, err := ioutil.ReadFile(path.Join(hostDir, f.Name())) + if err != nil { + return nil, nil, err + } else { + pool.AppendCertsFromPEM(data) + } + } + if strings.HasSuffix(f.Name(), ".cert") { + certName := f.Name() + keyName := certName[:len(certName)-5] + ".key" + if !hasFile(fs, keyName) { + return nil, nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName) + } else { + cert, err := tls.LoadX509KeyPair(path.Join(hostDir, certName), path.Join(hostDir, keyName)) + if err != nil { + return nil, nil, err + } + certs = append(certs, &cert) + } + } + if strings.HasSuffix(f.Name(), ".key") { + keyName := f.Name() + certName := keyName[:len(keyName)-4] + ".cert" + if !hasFile(fs, certName) { + return nil, nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName) + } + } + } + + if len(certs) == 0 { + client := newClient(jar, pool, nil, timeout) + res, err := client.Do(req) + if err != nil { + return nil, nil, err + } + return res, client, nil + } else { + for i, cert := range certs { + client := newClient(jar, pool, cert, timeout) + res, err := client.Do(req) + if i == len(certs)-1 { + // If this is the last cert, always return the result + return res, client, err + } else { + // Otherwise, continue to next cert if 403 or 5xx + if err == nil && res.StatusCode != 403 && !(res.StatusCode >= 500 && res.StatusCode < 600) { + return res, client, err + } + } + } + } + + return nil, nil, nil +} + func pingRegistryEndpoint(endpoint string) (RegistryInfo, error) { if endpoint == IndexServerAddress() { // Skip the check, we now this one is valid // (and we never want to fallback to http in case of error) return RegistryInfo{Standalone: false}, nil } - httpDial := func(proto string, addr string) (net.Conn, error) { - // Set the connect timeout to 5 seconds - conn, err := net.DialTimeout(proto, addr, 5*time.Second) - if err != nil { - return nil, err - } - // Set the recv timeout to 10 seconds - conn.SetDeadline(time.Now().Add(10 * time.Second)) - return conn, nil - } - httpTransport := &http.Transport{ - Dial: httpDial, - Proxy: http.ProxyFromEnvironment, - } - client := &http.Client{Transport: httpTransport} - resp, err := client.Get(endpoint + "_ping") + + req, err := http.NewRequest("GET", endpoint+"_ping", nil) if err != nil { return RegistryInfo{Standalone: false}, err } + + resp, _, err := doRequest(req, nil, ConnectTimeout) + if err != nil { + return RegistryInfo{Standalone: false}, err + } + defer resp.Body.Close() jsonString, err := ioutil.ReadAll(resp.Body) @@ -171,6 +299,10 @@ func setTokenAuth(req *http.Request, token []string) { } } +func (r *Registry) doRequest(req *http.Request) (*http.Response, *http.Client, error) { + return doRequest(req, r.jar, r.timeout) +} + // Retrieve the history of a given image from the Registry. // Return a list of the parent's json (requested image included) func (r *Registry) GetRemoteHistory(imgID, registry string, token []string) ([]string, error) { @@ -179,7 +311,7 @@ func (r *Registry) GetRemoteHistory(imgID, registry string, token []string) ([]s return nil, err } setTokenAuth(req, token) - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return nil, err } @@ -214,7 +346,7 @@ func (r *Registry) LookupRemoteImage(imgID, registry string, token []string) boo return false } setTokenAuth(req, token) - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { utils.Errorf("Error in LookupRemoteImage %s", err) return false @@ -231,7 +363,7 @@ func (r *Registry) GetRemoteImageJSON(imgID, registry string, token []string) ([ return nil, -1, fmt.Errorf("Failed to download json: %s", err) } setTokenAuth(req, token) - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return nil, -1, fmt.Errorf("Failed to download json: %s", err) } @@ -260,6 +392,7 @@ func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string, i var ( retries = 5 headRes *http.Response + client *http.Client hasResume bool = false imageURL = fmt.Sprintf("%simages/%s/layer", registry, imgID) ) @@ -267,9 +400,10 @@ func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string, i if err != nil { return nil, fmt.Errorf("Error while getting from the server: %s\n", err) } + setTokenAuth(headReq, token) for i := 1; i <= retries; i++ { - headRes, err = r.client.Do(headReq) + headRes, client, err = r.doRequest(headReq) if err != nil && i == retries { return nil, fmt.Errorf("Eror while making head request: %s\n", err) } else if err != nil { @@ -290,10 +424,10 @@ func (r *Registry) GetRemoteImageLayer(imgID, registry string, token []string, i setTokenAuth(req, token) if hasResume { utils.Debugf("server supports resume") - return utils.ResumableRequestReader(r.client, req, 5, imgSize), nil + return utils.ResumableRequestReader(client, req, 5, imgSize), nil } utils.Debugf("server doesn't support resume") - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return nil, err } @@ -319,7 +453,7 @@ func (r *Registry) GetRemoteTags(registries []string, repository string, token [ return nil, err } setTokenAuth(req, token) - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return nil, err } @@ -380,7 +514,7 @@ func (r *Registry) GetRepositoryData(remote string) (*RepositoryData, error) { } req.Header.Set("X-Docker-Token", "true") - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return nil, err } @@ -448,13 +582,13 @@ func (r *Registry) PushImageChecksumRegistry(imgData *ImgData, registry string, req.Header.Set("X-Docker-Checksum", imgData.Checksum) req.Header.Set("X-Docker-Checksum-Payload", imgData.ChecksumPayload) - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return fmt.Errorf("Failed to upload metadata: %s", err) } defer res.Body.Close() if len(res.Cookies()) > 0 { - r.client.Jar.SetCookies(req.URL, res.Cookies()) + r.jar.SetCookies(req.URL, res.Cookies()) } if res.StatusCode != 200 { errBody, err := ioutil.ReadAll(res.Body) @@ -484,7 +618,7 @@ func (r *Registry) PushImageJSONRegistry(imgData *ImgData, jsonRaw []byte, regis req.Header.Add("Content-type", "application/json") setTokenAuth(req, token) - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return fmt.Errorf("Failed to upload metadata: %s", err) } @@ -525,7 +659,7 @@ func (r *Registry) PushImageLayerRegistry(imgID string, layer io.Reader, registr req.ContentLength = -1 req.TransferEncoding = []string{"chunked"} setTokenAuth(req, token) - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return "", "", fmt.Errorf("Failed to upload layer: %s", err) } @@ -562,7 +696,7 @@ func (r *Registry) PushRegistryTag(remote, revision, tag, registry string, token req.Header.Add("Content-type", "application/json") setTokenAuth(req, token) req.ContentLength = int64(len(revision)) - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return err } @@ -610,7 +744,7 @@ func (r *Registry) PushImageJSONIndex(remote string, imgList []*ImgData, validat req.Header["X-Docker-Endpoints"] = regs } - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return nil, err } @@ -629,7 +763,7 @@ func (r *Registry) PushImageJSONIndex(remote string, imgList []*ImgData, validat if validate { req.Header["X-Docker-Endpoints"] = regs } - res, err = r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return nil, err } @@ -688,7 +822,7 @@ func (r *Registry) SearchRepositories(term string) (*SearchResults, error) { req.SetBasicAuth(r.authConfig.Username, r.authConfig.Password) } req.Header.Set("X-Docker-Token", "true") - res, err := r.client.Do(req) + res, _, err := r.doRequest(req) if err != nil { return nil, err } @@ -750,10 +884,11 @@ type RegistryInfo struct { } type Registry struct { - client *http.Client authConfig *AuthConfig reqFactory *utils.HTTPRequestFactory indexEndpoint string + jar *cookiejar.Jar + timeout TimeoutType } func trustedLocation(req *http.Request) bool { @@ -791,30 +926,16 @@ func AddRequiredHeadersToRedirectedRequests(req *http.Request, via []*http.Reque } func NewRegistry(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string, timeout bool) (r *Registry, err error) { - httpTransport := &http.Transport{ - DisableKeepAlives: true, - Proxy: http.ProxyFromEnvironment, - } - if timeout { - httpTransport.Dial = func(proto string, addr string) (net.Conn, error) { - conn, err := net.Dial(proto, addr) - if err != nil { - return nil, err - } - conn = utils.NewTimeoutConn(conn, 1*time.Minute) - return conn, nil - } - } r = &Registry{ - authConfig: authConfig, - client: &http.Client{ - Transport: httpTransport, - CheckRedirect: AddRequiredHeadersToRedirectedRequests, - }, + authConfig: authConfig, indexEndpoint: indexEndpoint, } - r.client.Jar, err = cookiejar.New(nil) + if timeout { + r.timeout = ReceiveTimeout + } + + r.jar, err = cookiejar.New(nil) if err != nil { return nil, err }