diff --git a/notifications/endpoint.go b/notifications/endpoint.go index dfdb111c5..b5ed955d1 100644 --- a/notifications/endpoint.go +++ b/notifications/endpoint.go @@ -12,6 +12,7 @@ type EndpointConfig struct { Timeout time.Duration Threshold int Backoff time.Duration + Transport *http.Transport } // defaults set any zero-valued fields to a reasonable default. @@ -27,6 +28,10 @@ func (ec *EndpointConfig) defaults() { if ec.Backoff <= 0 { ec.Backoff = time.Second } + + if ec.Transport == nil { + ec.Transport = http.DefaultTransport.(*http.Transport) + } } // Endpoint is a reliable, queued, thread-safe sink that notify external http @@ -54,7 +59,7 @@ func NewEndpoint(name, url string, config EndpointConfig) *Endpoint { // Configures the inmemory queue, retry, http pipeline. endpoint.Sink = newHTTPSink( endpoint.url, endpoint.Timeout, endpoint.Headers, - endpoint.metrics.httpStatusListener()) + endpoint.Transport, endpoint.metrics.httpStatusListener()) endpoint.Sink = newRetryingSink(endpoint.Sink, endpoint.Threshold, endpoint.Backoff) endpoint.Sink = newEventQueue(endpoint.Sink, endpoint.metrics.eventQueueListener()) diff --git a/notifications/http.go b/notifications/http.go index 465434f1c..15751619b 100644 --- a/notifications/http.go +++ b/notifications/http.go @@ -26,13 +26,16 @@ type httpSink struct { // newHTTPSink returns an unreliable, single-flight http sink. Wrap in other // sinks for increased reliability. -func newHTTPSink(u string, timeout time.Duration, headers http.Header, listeners ...httpStatusListener) *httpSink { +func newHTTPSink(u string, timeout time.Duration, headers http.Header, transport *http.Transport, listeners ...httpStatusListener) *httpSink { + if transport == nil { + transport = http.DefaultTransport.(*http.Transport) + } return &httpSink{ url: u, listeners: listeners, client: &http.Client{ Transport: &headerRoundTripper{ - Transport: http.DefaultTransport.(*http.Transport), + Transport: transport, headers: headers, }, Timeout: timeout, diff --git a/notifications/http_test.go b/notifications/http_test.go index 854dd404d..e04693621 100644 --- a/notifications/http_test.go +++ b/notifications/http_test.go @@ -1,6 +1,7 @@ package notifications import ( + "crypto/tls" "encoding/json" "fmt" "mime" @@ -8,6 +9,7 @@ import ( "net/http/httptest" "reflect" "strconv" + "strings" "testing" "github.com/docker/distribution/manifest/schema1" @@ -16,7 +18,7 @@ import ( // TestHTTPSink mocks out an http endpoint and notifies it under a couple of // conditions, ensuring correct behavior. func TestHTTPSink(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + serverHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if r.Method != "POST" { w.WriteHeader(http.StatusMethodNotAllowed) @@ -57,12 +59,38 @@ func TestHTTPSink(t *testing.T) { } w.WriteHeader(status) - })) + }) + server := httptest.NewTLSServer(serverHandler) metrics := newSafeMetrics() - sink := newHTTPSink(server.URL, 0, nil, + sink := newHTTPSink(server.URL, 0, nil, nil, &endpointMetricsHTTPStatusListener{safeMetrics: metrics}) + // first make sure that the default transport gives x509 untrusted cert error + events := []Event{} + err := sink.Write(events...) + if !strings.Contains(err.Error(), "x509") { + t.Fatal("TLS server with default transport should give unknown CA error") + } + if err := sink.Close(); err != nil { + t.Fatalf("unexpected error closing http sink: %v", err) + } + + // make sure that passing in the transport no longer gives this error + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + sink = newHTTPSink(server.URL, 0, nil, tr, + &endpointMetricsHTTPStatusListener{safeMetrics: metrics}) + err = sink.Write(events...) + if err != nil { + t.Fatalf("unexpected error writing events: %v", err) + } + + // reset server to standard http server and sink to a basic sink + server = httptest.NewServer(serverHandler) + sink = newHTTPSink(server.URL, 0, nil, nil, + &endpointMetricsHTTPStatusListener{safeMetrics: metrics}) var expectedMetrics EndpointMetrics expectedMetrics.Statuses = make(map[string]int)