plugin/forward: add HealthChecker interface (#1950)

* plugin/forward: add HealthChecker interface

Make the HealthChecker interface and morph the current DNS health
checker into that interface.

Remove all whole bunch of method on Forward that didn't make sense.

This is done in preparation of adding a DoH client to forward - which
requires a completely different healthcheck implementation (and more,
but lets start here)

Signed-off-by: Miek Gieben <miek@miek.nl>

* Use protocol

Signed-off-by: Miek Gieben <miek@miek.nl>

* Dial doesnt need to be method an Forward either

Signed-off-by: Miek Gieben <miek@miek.nl>

* Address comments

Address various comments on the PR.

Signed-off-by: Miek Gieben <miek@miek.nl>
This commit is contained in:
Miek Gieben 2018-07-09 15:14:55 +01:00 committed by GitHub
parent 4083852b70
commit a536833546
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 77 additions and 76 deletions

View file

@ -91,7 +91,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
proto = state.Proto() proto = state.Proto()
} }
conn, cached, err := p.Dial(proto) conn, cached, err := p.transport.Dial(proto)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -125,7 +125,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options
p.updateRtt(time.Since(reqTime)) p.updateRtt(time.Since(reqTime))
p.Yield(conn) p.transport.Yield(conn)
rc, ok := dns.RcodeToString[ret.Rcode] rc, ok := dns.RcodeToString[ret.Rcode]
if !ok { if !ok {

View file

@ -19,7 +19,7 @@ func TestForward(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, nil /* not TLS */) p := NewProxy(s.Addr, DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -51,7 +51,7 @@ func TestForwardRefused(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, nil) p := NewProxy(s.Addr, DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()

View file

@ -1,17 +1,48 @@
package forward package forward
import ( import (
"crypto/tls"
"sync/atomic" "sync/atomic"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// HealthChecker checks the upstream health.
type HealthChecker interface {
Check(*Proxy) error
SetTLSConfig(*tls.Config)
}
// dnsHc is a health checker for a DNS endpoint (DNS, and DoT).
type dnsHc struct{ c *dns.Client }
// NewHealthChecker returns a new HealthChecker based on protocol.
func NewHealthChecker(protocol int) HealthChecker {
switch protocol {
case DNS, TLS:
c := new(dns.Client)
c.Net = "udp"
c.ReadTimeout = 1 * time.Second
c.WriteTimeout = 1 * time.Second
return &dnsHc{c: c}
}
return nil
}
func (h *dnsHc) SetTLSConfig(cfg *tls.Config) {
h.c.Net = "tcp-tls"
h.c.TLSConfig = cfg
}
// For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty // For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty
// replies are considered fails, basically anything else constitutes a healthy upstream. // replies are considered fails, basically anything else constitutes a healthy upstream.
// Check is used as the up.Func in the up.Probe. // Check is used as the up.Func in the up.Probe.
func (p *Proxy) Check() error { func (h *dnsHc) Check(p *Proxy) error {
err := p.send() err := h.send(p.addr)
if err != nil { if err != nil {
HealthcheckFailureCount.WithLabelValues(p.addr).Add(1) HealthcheckFailureCount.WithLabelValues(p.addr).Add(1)
atomic.AddUint32(&p.fails, 1) atomic.AddUint32(&p.fails, 1)
@ -22,14 +53,14 @@ func (p *Proxy) Check() error {
return nil return nil
} }
func (p *Proxy) send() error { func (h *dnsHc) send(addr string) error {
hcping := new(dns.Msg) ping := new(dns.Msg)
hcping.SetQuestion(".", dns.TypeNS) ping.SetQuestion(".", dns.TypeNS)
m, _, err := p.client.Exchange(hcping, p.addr) m, _, err := h.c.Exchange(ping, addr)
// If we got a header, we're alright, basically only care about I/O errors 'n stuff // If we got a header, we're alright, basically only care about I/O errors 'n stuff.
if err != nil && m != nil { if err != nil && m != nil {
// Silly check, something sane came back // Silly check, something sane came back.
if m.Response || m.Opcode == dns.OpcodeQuery { if m.Response || m.Opcode == dns.OpcodeQuery {
err = nil err = nil
} }

View file

@ -25,7 +25,7 @@ func TestHealth(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */) p := NewProxy(s.Addr, DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -65,7 +65,7 @@ func TestHealthTimeout(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */) p := NewProxy(s.Addr, DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -109,7 +109,7 @@ func TestHealthFailTwice(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */) p := NewProxy(s.Addr, DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -132,7 +132,7 @@ func TestHealthMaxFails(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */) p := NewProxy(s.Addr, DNS)
f := New() f := New()
f.maxfails = 2 f.maxfails = 2
f.SetProxy(p) f.SetProxy(p)
@ -163,7 +163,7 @@ func TestHealthNoMaxFails(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */) p := NewProxy(s.Addr, DNS)
f := New() f := New()
f.maxfails = 0 f.maxfails = 0
f.SetProxy(p) f.SetProxy(p)

View file

@ -81,7 +81,7 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M
func NewLookup(addr []string) *Forward { func NewLookup(addr []string) *Forward {
f := New() f := New()
for i := range addr { for i := range addr {
p := NewProxy(addr[i], nil) p := NewProxy(addr[i], DNS)
f.SetProxy(p) f.SetProxy(p)
} }
return f return f

View file

@ -19,7 +19,7 @@ func TestLookup(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */) p := NewProxy(s.Addr, DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()

View file

@ -29,7 +29,7 @@ type transport struct {
stop chan bool stop chan bool
} }
func newTransport(addr string, tlsConfig *tls.Config) *transport { func newTransport(addr string) *transport {
t := &transport{ t := &transport{
avgDialTime: int64(defaultDialTimeout / 2), avgDialTime: int64(defaultDialTimeout / 2),
conns: make(map[string][]*persistConn), conns: make(map[string][]*persistConn),

View file

@ -17,7 +17,7 @@ func TestCached(t *testing.T) {
}) })
defer s.Close() defer s.Close()
tr := newTransport(s.Addr, nil /* no TLS */) tr := newTransport(s.Addr)
tr.Start() tr.Start()
defer tr.Stop() defer tr.Stop()
@ -56,7 +56,7 @@ func TestCleanupByTimer(t *testing.T) {
}) })
defer s.Close() defer s.Close()
tr := newTransport(s.Addr, nil /* no TLS */) tr := newTransport(s.Addr)
tr.SetExpire(100 * time.Millisecond) tr.SetExpire(100 * time.Millisecond)
tr.Start() tr.Start()
defer tr.Stop() defer tr.Stop()
@ -90,7 +90,7 @@ func TestPartialCleanup(t *testing.T) {
}) })
defer s.Close() defer s.Close()
tr := newTransport(s.Addr, nil /* no TLS */) tr := newTransport(s.Addr)
tr.SetExpire(100 * time.Millisecond) tr.SetExpire(100 * time.Millisecond)
tr.Start() tr.Start()
defer tr.Stop() defer tr.Stop()
@ -138,7 +138,7 @@ func TestCleanupAll(t *testing.T) {
}) })
defer s.Close() defer s.Close()
tr := newTransport(s.Addr, nil /* no TLS */) tr := newTransport(s.Addr)
c1, _ := dns.DialTimeout("udp", tr.addr, defaultDialTimeout) c1, _ := dns.DialTimeout("udp", tr.addr, defaultDialTimeout)
c2, _ := dns.DialTimeout("udp", tr.addr, defaultDialTimeout) c2, _ := dns.DialTimeout("udp", tr.addr, defaultDialTimeout)

View file

@ -7,8 +7,6 @@ import (
"time" "time"
"github.com/coredns/coredns/plugin/pkg/up" "github.com/coredns/coredns/plugin/pkg/up"
"github.com/miekg/dns"
) )
// Proxy defines an upstream host. // Proxy defines an upstream host.
@ -16,69 +14,46 @@ type Proxy struct {
avgRtt int64 avgRtt int64
fails uint32 fails uint32
addr string addr string
client *dns.Client
// Connection caching // Connection caching
expire time.Duration expire time.Duration
transport *transport transport *transport
// health checking // health checking
probe *up.Probe probe *up.Probe
health HealthChecker
} }
// NewProxy returns a new proxy. // NewProxy returns a new proxy.
func NewProxy(addr string, tlsConfig *tls.Config) *Proxy { func NewProxy(addr string, protocol int) *Proxy {
p := &Proxy{ p := &Proxy{
addr: addr, addr: addr,
fails: 0, fails: 0,
probe: up.New(), probe: up.New(),
transport: newTransport(addr, tlsConfig), transport: newTransport(addr),
avgRtt: int64(maxTimeout / 2), avgRtt: int64(maxTimeout / 2),
} }
p.client = dnsClient(tlsConfig) p.health = NewHealthChecker(protocol)
runtime.SetFinalizer(p, (*Proxy).finalizer) runtime.SetFinalizer(p, (*Proxy).finalizer)
return p return p
} }
// Addr returns the address to forward to.
func (p *Proxy) Addr() (addr string) { return p.addr }
// dnsClient returns a client used for health checking.
func dnsClient(tlsConfig *tls.Config) *dns.Client {
c := new(dns.Client)
c.Net = "udp"
// TODO(miek): this should be half of hcDuration?
c.ReadTimeout = 1 * time.Second
c.WriteTimeout = 1 * time.Second
if tlsConfig != nil {
c.Net = "tcp-tls"
c.TLSConfig = tlsConfig
}
return c
}
// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client. // SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client.
func (p *Proxy) SetTLSConfig(cfg *tls.Config) { func (p *Proxy) SetTLSConfig(cfg *tls.Config) {
p.transport.SetTLSConfig(cfg) p.transport.SetTLSConfig(cfg)
p.client = dnsClient(cfg) p.health.SetTLSConfig(cfg)
} }
// IsTLS returns true if proxy uses tls.
func (p *Proxy) IsTLS() bool { return p.transport.tlsConfig != nil }
// SetExpire sets the expire duration in the lower p.transport. // SetExpire sets the expire duration in the lower p.transport.
func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
// Dial connects to the host in p with the configured transport.
func (p *Proxy) Dial(proto string) (*dns.Conn, bool, error) { return p.transport.Dial(proto) }
// Yield returns the connection to the pool.
func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) }
// Healthcheck kicks of a round of health checks for this proxy. // Healthcheck kicks of a round of health checks for this proxy.
func (p *Proxy) Healthcheck() { p.probe.Do(p.Check) } func (p *Proxy) Healthcheck() {
p.probe.Do(func() error {
return p.health.Check(p)
})
}
// Down returns true if this proxy is down, i.e. has *more* fails than maxfails. // Down returns true if this proxy is down, i.e. has *more* fails than maxfails.
func (p *Proxy) Down(maxfails uint32) bool { func (p *Proxy) Down(maxfails uint32) bool {
@ -91,13 +66,8 @@ func (p *Proxy) Down(maxfails uint32) bool {
} }
// close stops the health checking goroutine. // close stops the health checking goroutine.
func (p *Proxy) close() { func (p *Proxy) close() { p.probe.Stop() }
p.probe.Stop() func (p *Proxy) finalizer() { p.transport.Stop() }
}
func (p *Proxy) finalizer() {
p.transport.Stop()
}
// start starts the proxy's healthchecking. // start starts the proxy's healthchecking.
func (p *Proxy) start(duration time.Duration) { func (p *Proxy) start(duration time.Duration) {

View file

@ -26,7 +26,7 @@ func TestProxyClose(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
p := NewProxy(s.Addr, nil) p := NewProxy(s.Addr, DNS)
p.start(hcInterval) p.start(hcInterval)
go func() { p.Connect(ctx, state, options{}) }() go func() { p.Connect(ctx, state, options{}) }()
@ -95,7 +95,7 @@ func TestProxyTLSFail(t *testing.T) {
} }
func TestProtocolSelection(t *testing.T) { func TestProtocolSelection(t *testing.T) {
p := NewProxy("bad_address", nil) p := NewProxy("bad_address", DNS)
stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)}

View file

@ -124,7 +124,7 @@ func parseForward(c *caddy.Controller) (*Forward, error) {
// We can't set tlsConfig here, because we haven't parsed it yet. // We can't set tlsConfig here, because we haven't parsed it yet.
// We set it below at the end of parseBlock, use nil now. // We set it below at the end of parseBlock, use nil now.
p := NewProxy(h, nil /* no TLS */) p := NewProxy(h, protocols[i])
f.proxies = append(f.proxies, p) f.proxies = append(f.proxies, p)
} }

View file

@ -113,8 +113,8 @@ func TestSetupTLS(t *testing.T) {
t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName) t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName)
} }
if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].client.TLSConfig.ServerName { if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName {
t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].client.TLSConfig.ServerName) t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName)
} }
} }
} }

View file

@ -34,7 +34,7 @@ func TestLookupTruncated(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */) p := NewProxy(s.Addr, DNS)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()
@ -88,9 +88,9 @@ func TestForwardTruncated(t *testing.T) {
f := New() f := New()
p1 := NewProxy(s.Addr, nil /* no TLS */) p1 := NewProxy(s.Addr, DNS)
f.SetProxy(p1) f.SetProxy(p1)
p2 := NewProxy(s.Addr, nil /* no TLS */) p2 := NewProxy(s.Addr, DNS)
f.SetProxy(p2) f.SetProxy(p2)
defer f.Close() defer f.Close()