diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go index 4a0a7141e..439caf932 100644 --- a/plugin/forward/connect.go +++ b/plugin/forward/connect.go @@ -91,7 +91,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options proto = state.Proto() } - conn, cached, err := p.Dial(proto) + conn, cached, err := p.transport.Dial(proto) if err != nil { return nil, err } @@ -125,7 +125,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options p.updateRtt(time.Since(reqTime)) - p.Yield(conn) + p.transport.Yield(conn) rc, ok := dns.RcodeToString[ret.Rcode] if !ok { diff --git a/plugin/forward/forward_test.go b/plugin/forward/forward_test.go index 96f5fa0ce..82844f811 100644 --- a/plugin/forward/forward_test.go +++ b/plugin/forward/forward_test.go @@ -19,7 +19,7 @@ func TestForward(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, nil /* not TLS */) + p := NewProxy(s.Addr, DNS) f := New() f.SetProxy(p) defer f.Close() @@ -51,7 +51,7 @@ func TestForwardRefused(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, nil) + p := NewProxy(s.Addr, DNS) f := New() f.SetProxy(p) defer f.Close() diff --git a/plugin/forward/health.go b/plugin/forward/health.go index 03322e92e..4d3278f6d 100644 --- a/plugin/forward/health.go +++ b/plugin/forward/health.go @@ -1,17 +1,48 @@ package forward import ( + "crypto/tls" "sync/atomic" + "time" "github.com/miekg/dns" ) +// HealthChecker checks the upstream health. +type HealthChecker interface { + Check(*Proxy) error + SetTLSConfig(*tls.Config) +} + +// dnsHc is a health checker for a DNS endpoint (DNS, and DoT). +type dnsHc struct{ c *dns.Client } + +// NewHealthChecker returns a new HealthChecker based on protocol. +func NewHealthChecker(protocol int) HealthChecker { + switch protocol { + case DNS, TLS: + c := new(dns.Client) + c.Net = "udp" + c.ReadTimeout = 1 * time.Second + c.WriteTimeout = 1 * time.Second + + return &dnsHc{c: c} + } + + return nil +} + +func (h *dnsHc) SetTLSConfig(cfg *tls.Config) { + h.c.Net = "tcp-tls" + h.c.TLSConfig = cfg +} + // For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty // replies are considered fails, basically anything else constitutes a healthy upstream. // Check is used as the up.Func in the up.Probe. -func (p *Proxy) Check() error { - err := p.send() +func (h *dnsHc) Check(p *Proxy) error { + err := h.send(p.addr) if err != nil { HealthcheckFailureCount.WithLabelValues(p.addr).Add(1) atomic.AddUint32(&p.fails, 1) @@ -22,14 +53,14 @@ func (p *Proxy) Check() error { return nil } -func (p *Proxy) send() error { - hcping := new(dns.Msg) - hcping.SetQuestion(".", dns.TypeNS) +func (h *dnsHc) send(addr string) error { + ping := new(dns.Msg) + ping.SetQuestion(".", dns.TypeNS) - m, _, err := p.client.Exchange(hcping, p.addr) - // If we got a header, we're alright, basically only care about I/O errors 'n stuff + m, _, err := h.c.Exchange(ping, addr) + // If we got a header, we're alright, basically only care about I/O errors 'n stuff. if err != nil && m != nil { - // Silly check, something sane came back + // Silly check, something sane came back. if m.Response || m.Opcode == dns.OpcodeQuery { err = nil } diff --git a/plugin/forward/health_test.go b/plugin/forward/health_test.go index 0588f1454..75d57f285 100644 --- a/plugin/forward/health_test.go +++ b/plugin/forward/health_test.go @@ -25,7 +25,7 @@ func TestHealth(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, nil /* no TLS */) + p := NewProxy(s.Addr, DNS) f := New() f.SetProxy(p) defer f.Close() @@ -65,7 +65,7 @@ func TestHealthTimeout(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, nil /* no TLS */) + p := NewProxy(s.Addr, DNS) f := New() f.SetProxy(p) defer f.Close() @@ -109,7 +109,7 @@ func TestHealthFailTwice(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, nil /* no TLS */) + p := NewProxy(s.Addr, DNS) f := New() f.SetProxy(p) defer f.Close() @@ -132,7 +132,7 @@ func TestHealthMaxFails(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, nil /* no TLS */) + p := NewProxy(s.Addr, DNS) f := New() f.maxfails = 2 f.SetProxy(p) @@ -163,7 +163,7 @@ func TestHealthNoMaxFails(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, nil /* no TLS */) + p := NewProxy(s.Addr, DNS) f := New() f.maxfails = 0 f.SetProxy(p) diff --git a/plugin/forward/lookup.go b/plugin/forward/lookup.go index 96eceab84..94114647c 100644 --- a/plugin/forward/lookup.go +++ b/plugin/forward/lookup.go @@ -81,7 +81,7 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M func NewLookup(addr []string) *Forward { f := New() for i := range addr { - p := NewProxy(addr[i], nil) + p := NewProxy(addr[i], DNS) f.SetProxy(p) } return f diff --git a/plugin/forward/lookup_test.go b/plugin/forward/lookup_test.go index e37a0c5d7..1968ef979 100644 --- a/plugin/forward/lookup_test.go +++ b/plugin/forward/lookup_test.go @@ -19,7 +19,7 @@ func TestLookup(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, nil /* no TLS */) + p := NewProxy(s.Addr, DNS) f := New() f.SetProxy(p) defer f.Close() diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go index 52bd24918..4da1514fe 100644 --- a/plugin/forward/persistent.go +++ b/plugin/forward/persistent.go @@ -29,7 +29,7 @@ type transport struct { stop chan bool } -func newTransport(addr string, tlsConfig *tls.Config) *transport { +func newTransport(addr string) *transport { t := &transport{ avgDialTime: int64(defaultDialTimeout / 2), conns: make(map[string][]*persistConn), diff --git a/plugin/forward/persistent_test.go b/plugin/forward/persistent_test.go index e046cf4de..271a80c0b 100644 --- a/plugin/forward/persistent_test.go +++ b/plugin/forward/persistent_test.go @@ -17,7 +17,7 @@ func TestCached(t *testing.T) { }) defer s.Close() - tr := newTransport(s.Addr, nil /* no TLS */) + tr := newTransport(s.Addr) tr.Start() defer tr.Stop() @@ -56,7 +56,7 @@ func TestCleanupByTimer(t *testing.T) { }) defer s.Close() - tr := newTransport(s.Addr, nil /* no TLS */) + tr := newTransport(s.Addr) tr.SetExpire(100 * time.Millisecond) tr.Start() defer tr.Stop() @@ -90,7 +90,7 @@ func TestPartialCleanup(t *testing.T) { }) defer s.Close() - tr := newTransport(s.Addr, nil /* no TLS */) + tr := newTransport(s.Addr) tr.SetExpire(100 * time.Millisecond) tr.Start() defer tr.Stop() @@ -138,7 +138,7 @@ func TestCleanupAll(t *testing.T) { }) defer s.Close() - tr := newTransport(s.Addr, nil /* no TLS */) + tr := newTransport(s.Addr) c1, _ := dns.DialTimeout("udp", tr.addr, defaultDialTimeout) c2, _ := dns.DialTimeout("udp", tr.addr, defaultDialTimeout) diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go index 91d7c38b1..ac74bf0f8 100644 --- a/plugin/forward/proxy.go +++ b/plugin/forward/proxy.go @@ -7,8 +7,6 @@ import ( "time" "github.com/coredns/coredns/plugin/pkg/up" - - "github.com/miekg/dns" ) // Proxy defines an upstream host. @@ -16,69 +14,46 @@ type Proxy struct { avgRtt int64 fails uint32 - addr string - client *dns.Client + addr string // Connection caching expire time.Duration transport *transport // health checking - probe *up.Probe + probe *up.Probe + health HealthChecker } // NewProxy returns a new proxy. -func NewProxy(addr string, tlsConfig *tls.Config) *Proxy { +func NewProxy(addr string, protocol int) *Proxy { p := &Proxy{ addr: addr, fails: 0, probe: up.New(), - transport: newTransport(addr, tlsConfig), + transport: newTransport(addr), avgRtt: int64(maxTimeout / 2), } - p.client = dnsClient(tlsConfig) + p.health = NewHealthChecker(protocol) runtime.SetFinalizer(p, (*Proxy).finalizer) return p } -// Addr returns the address to forward to. -func (p *Proxy) Addr() (addr string) { return p.addr } - -// dnsClient returns a client used for health checking. -func dnsClient(tlsConfig *tls.Config) *dns.Client { - c := new(dns.Client) - c.Net = "udp" - // TODO(miek): this should be half of hcDuration? - c.ReadTimeout = 1 * time.Second - c.WriteTimeout = 1 * time.Second - - if tlsConfig != nil { - c.Net = "tcp-tls" - c.TLSConfig = tlsConfig - } - return c -} - // SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client. func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.transport.SetTLSConfig(cfg) - p.client = dnsClient(cfg) + p.health.SetTLSConfig(cfg) } -// IsTLS returns true if proxy uses tls. -func (p *Proxy) IsTLS() bool { return p.transport.tlsConfig != nil } - // SetExpire sets the expire duration in the lower p.transport. func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } -// Dial connects to the host in p with the configured transport. -func (p *Proxy) Dial(proto string) (*dns.Conn, bool, error) { return p.transport.Dial(proto) } - -// Yield returns the connection to the pool. -func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) } - // Healthcheck kicks of a round of health checks for this proxy. -func (p *Proxy) Healthcheck() { p.probe.Do(p.Check) } +func (p *Proxy) Healthcheck() { + p.probe.Do(func() error { + return p.health.Check(p) + }) +} // Down returns true if this proxy is down, i.e. has *more* fails than maxfails. func (p *Proxy) Down(maxfails uint32) bool { @@ -91,13 +66,8 @@ func (p *Proxy) Down(maxfails uint32) bool { } // close stops the health checking goroutine. -func (p *Proxy) close() { - p.probe.Stop() -} - -func (p *Proxy) finalizer() { - p.transport.Stop() -} +func (p *Proxy) close() { p.probe.Stop() } +func (p *Proxy) finalizer() { p.transport.Stop() } // start starts the proxy's healthchecking. func (p *Proxy) start(duration time.Duration) { diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go index d68d1def2..d7af25aa0 100644 --- a/plugin/forward/proxy_test.go +++ b/plugin/forward/proxy_test.go @@ -26,7 +26,7 @@ func TestProxyClose(t *testing.T) { ctx := context.TODO() for i := 0; i < 100; i++ { - p := NewProxy(s.Addr, nil) + p := NewProxy(s.Addr, DNS) p.start(hcInterval) go func() { p.Connect(ctx, state, options{}) }() @@ -95,7 +95,7 @@ func TestProxyTLSFail(t *testing.T) { } func TestProtocolSelection(t *testing.T) { - p := NewProxy("bad_address", nil) + p := NewProxy("bad_address", DNS) stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 152ba36c5..ee48fdaf6 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -124,7 +124,7 @@ func parseForward(c *caddy.Controller) (*Forward, error) { // We can't set tlsConfig here, because we haven't parsed it yet. // We set it below at the end of parseBlock, use nil now. - p := NewProxy(h, nil /* no TLS */) + p := NewProxy(h, protocols[i]) f.proxies = append(f.proxies, p) } diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index a8140d410..c72cd1106 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -113,8 +113,8 @@ func TestSetupTLS(t *testing.T) { t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName) } - if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].client.TLSConfig.ServerName { - t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].client.TLSConfig.ServerName) + if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName) } } } diff --git a/plugin/forward/truncated_test.go b/plugin/forward/truncated_test.go index 1c9e92a07..b7ff47c14 100644 --- a/plugin/forward/truncated_test.go +++ b/plugin/forward/truncated_test.go @@ -34,7 +34,7 @@ func TestLookupTruncated(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, nil /* no TLS */) + p := NewProxy(s.Addr, DNS) f := New() f.SetProxy(p) defer f.Close() @@ -88,9 +88,9 @@ func TestForwardTruncated(t *testing.T) { f := New() - p1 := NewProxy(s.Addr, nil /* no TLS */) + p1 := NewProxy(s.Addr, DNS) f.SetProxy(p1) - p2 := NewProxy(s.Addr, nil /* no TLS */) + p2 := NewProxy(s.Addr, DNS) f.SetProxy(p2) defer f.Close()