plugin/forward using pkg/up (#1493)

* plugin/forward: on demand healtchecking

Only start doing health checks when we encouner an error (any error).
This uses the new pluing/pkg/up package to abstract away the actual
checking. This reduces the LOC quite a bit; does need more testing, unit
testing and tcpdumping a bit.

* fix tests

* Fix readme

* Use pkg/up for healthchecks

* remove unused channel

* more cleanups

* update readme

* * Again do go generate and go build; still referencing the wrong forward
  repo? Anyway fixed.
* Use pkg/up for doing the healtchecks to cut back on unwanted queries
  * Change up.Func to return an error instead of a boolean.
  * Drop the string target argument as it doesn't make sense.
* Add healthcheck test on failing to get an upstream answer.

TODO(miek): double check Forward and Lookup and how they interact with
HC, and if we correctly call close() on those

* actual test

* Tests here

* more tests

* try getting rid of host

* Get rid of the host indirection

* Finish removing hosts

* moar testing

* import fmt

* field is not used

* docs

* move some stuff

* bring back health_check

* maxfails=0 test

* git and merging, bah

* review
This commit is contained in:
Miek Gieben 2018-02-15 10:21:57 +01:00 committed by GitHub
parent 8b035fa938
commit 16504234e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 306 additions and 221 deletions

View file

@ -6,10 +6,17 @@
## Description ## Description
The *forward* plugin is generally faster (~30+%) than *proxy* as it re-uses already opened sockets The *forward* plugin re-uses already opened sockets to the upstreams. It supports UDP, TCP and
to the upstreams. It supports UDP, TCP and DNS-over-TLS and uses inband health checking that is DNS-over-TLS and uses in band health checking.
enabled by default.
When *all* upstreams are down it assumes healtchecking as a mechanism has failed and will try to When it detects an error a health check is performed. This checks runs in a loop, every *0.5s*, for
as long as the upstream reports unhealthy. Once healthy we stop health checking (until the next
error). The health checks use a recursive DNS query (`. IN NS`) to get upstream health. Any response
that is not a network error (REFUSED, NOTIMPL, SERVFAIL, etc) is taken as a healthy upstream. The
health check uses the same protocol as specified in **TO**. If `max_fails` is set to 0, no checking
is performed and upstreams will always be considered healthy.
When *all* upstreams are down it assumes health checking as a mechanism has failed and will try to
connect to a random upstream (which may or may not work). connect to a random upstream (which may or may not work).
## Syntax ## Syntax
@ -22,16 +29,11 @@ forward FROM TO...
* **FROM** is the base domain to match for the request to be forwarded. * **FROM** is the base domain to match for the request to be forwarded.
* **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify * **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify
a protocol, `tls://9.9.9.9` or `dns://` for plain DNS. The number of upstreams is limited to 15. a protocol, `tls://9.9.9.9` or `dns://` (or no protocol) for plain DNS. The number of upstreams is
limited to 15.
The health checks are done every *0.5s*. After *two* failed checks the upstream is considered Multiple upstreams are randomized (see `policy`) on first use. When a healthy proxy returns an error
unhealthy. The health checks use a recursive DNS query (`. IN NS`) to get upstream health. Any during the exchange the next upstream in the list is tried.
response that is not an error (REFUSED, NOTIMPL, SERVFAIL, etc) is taken as a healthy upstream. The
health check uses the same protocol as specific in the **TO**. On startup each upstream is marked
unhealthy until it passes a health check. A 0 duration will disable any health checks.
Multiple upstreams are randomized (default policy) on first use. When a healthy proxy returns an
error during the exchange the next upstream in the list is tried.
Extra knobs are available with an expanded syntax: Extra knobs are available with an expanded syntax:
@ -39,12 +41,12 @@ Extra knobs are available with an expanded syntax:
forward FROM TO... { forward FROM TO... {
except IGNORED_NAMES... except IGNORED_NAMES...
force_tcp force_tcp
health_check DURATION
expire DURATION expire DURATION
max_fails INTEGER max_fails INTEGER
tls CERT KEY CA tls CERT KEY CA
tls_servername NAME tls_servername NAME
policy random|round_robin policy random|round_robin
health_checks DURATION
} }
~~~ ~~~
@ -52,21 +54,16 @@ forward FROM TO... {
* **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding. * **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding.
Requests that match none of these names will be passed through. Requests that match none of these names will be passed through.
* `force_tcp`, use TCP even when the request comes in over UDP. * `force_tcp`, use TCP even when the request comes in over UDP.
* `health_checks`, use a different **DURATION** for health checking, the default duration is 0.5s.
A value of 0 disables the health checks completely.
* `max_fails` is the number of subsequent failed health checks that are needed before considering * `max_fails` is the number of subsequent failed health checks that are needed before considering
a backend to be down. If 0, the backend will never be marked as down. Default is 2. an upstream to be down. If 0, the upstream will never be marked as down (nor health checked).
Default is 2.
* `expire` **DURATION**, expire (cached) connections after this time, the default is 10s. * `expire` **DURATION**, expire (cached) connections after this time, the default is 10s.
* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS; if you leave this out the * `tls` **CERT** **KEY** **CA** define the TLS properties for TLS; if you leave this out the
system's configuration will be used. system's configuration will be used.
* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9 * `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9
needs this to be set to `dns.quad9.net`. needs this to be set to `dns.quad9.net`.
* `policy` specifies the policy to use for selecting upstream servers. The default is `random`. * `policy` specifies the policy to use for selecting upstream servers. The default is `random`.
* `health_checks`, use a different **DURATION** for health checking, the default duration is 0.5s.
The upstream selection is done via random (default policy) selection. If the socket for this client
isn't known *forward* will randomly choose one. If this turns out to be unhealthy, the next one is
tried. If *all* hosts are down, we assume health checking is broken and select a *random* upstream to
try.
Also note the TLS config is "global" for the whole forwarding proxy if you need a different Also note the TLS config is "global" for the whole forwarding proxy if you need a different
`tls-name` for different upstreams you're out of luck. `tls-name` for different upstreams you're out of luck.
@ -80,7 +77,7 @@ If monitoring is enabled (via the *prometheus* directive) then the following met
* `coredns_forward_response_rcode_total{to, rcode}` - count of RCODEs per upstream. * `coredns_forward_response_rcode_total{to, rcode}` - count of RCODEs per upstream.
* `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream. * `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream.
* `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy, * `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy,
and we are randomly spraying to a target. and we are randomly (this always uses the `random` policy) spraying to an upstream.
* `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream. * `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream.
Where `to` is one of the upstream servers (**TO** from the config), `proto` is the protocol used by Where `to` is one of the upstream servers (**TO** from the config), `proto` is the protocol used by
@ -125,16 +122,10 @@ Proxy everything except `example.org` using the host's `resolv.conf`'s nameserve
} }
~~~ ~~~
Forward to a IPv6 host:
~~~ corefile
. {
forward . [::1]:1053
}
~~~
Proxy all requests to 9.9.9.9 using the DNS-over-TLS protocol, and cache every answer for up to 30 Proxy all requests to 9.9.9.9 using the DNS-over-TLS protocol, and cache every answer for up to 30
seconds. seconds. Note the `tls_servername` is mandatory if you want a working setup, as 9.9.9.9 can't be
used in the TLS negotiation. Also set the health check duration to 5s to not completely swamp the
service with health checks.
~~~ corefile ~~~ corefile
. { . {
@ -148,7 +139,7 @@ seconds.
## Bugs ## Bugs
The TLS config is global for the whole forwarding proxy if you need a different `tls-name` for The TLS config is global for the whole forwarding proxy if you need a different `tls_serveraame` for
different upstreams you're out of luck. different upstreams you're out of luck.
## Also See ## Also See

View file

@ -21,9 +21,6 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
if forceTCP { if forceTCP {
proto = "tcp" proto = "tcp"
} }
if p.host.tlsConfig != nil {
proto = "tcp-tls"
}
conn, err := p.Dial(proto) conn, err := p.Dial(proto)
if err != nil { if err != nil {
@ -57,9 +54,9 @@ func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, me
rc = strconv.Itoa(ret.Rcode) rc = strconv.Itoa(ret.Rcode)
} }
RequestCount.WithLabelValues(p.host.addr).Add(1) RequestCount.WithLabelValues(p.addr).Add(1)
RcodeCount.WithLabelValues(rc, p.host.addr).Add(1) RcodeCount.WithLabelValues(rc, p.addr).Add(1)
RequestDuration.WithLabelValues(p.host.addr).Observe(time.Since(start).Seconds()) RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds())
} }
return ret, nil return ret, nil

View file

@ -20,8 +20,9 @@ import (
// Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list // Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list
// of proxies each representing one upstream proxy. // of proxies each representing one upstream proxy.
type Forward struct { type Forward struct {
proxies []*Proxy proxies []*Proxy
p Policy p Policy
hcInterval time.Duration
from string from string
ignored []string ignored []string
@ -31,22 +32,21 @@ type Forward struct {
maxfails uint32 maxfails uint32
expire time.Duration expire time.Duration
forceTCP bool // also here for testing forceTCP bool // also here for testing
hcInterval time.Duration // also here for testing
Next plugin.Handler Next plugin.Handler
} }
// New returns a new Forward. // New returns a new Forward.
func New() *Forward { func New() *Forward {
f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: hcDuration, p: new(random)} f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcDuration}
return f return f
} }
// SetProxy appends p to the proxy list and starts healthchecking. // SetProxy appends p to the proxy list and starts healthchecking.
func (f *Forward) SetProxy(p *Proxy) { func (f *Forward) SetProxy(p *Proxy) {
f.proxies = append(f.proxies, p) f.proxies = append(f.proxies, p)
go p.healthCheck() p.start(f.hcInterval)
} }
// Len returns the number of configured proxies. // Len returns the number of configured proxies.
@ -92,7 +92,27 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
child.Finish() child.Finish()
} }
// If you query for instance ANY isc.org; you get a truncated query back which miekg/dns fails to unpack
// because the RRs are not finished. The returned message can be useful or useless. Return the original
// query with some header bits set that they should retry with TCP.
if err == dns.ErrTruncated {
// We may or may not have something sensible... if not reassemble something to send to the client.
if ret == nil {
ret = new(dns.Msg)
ret.SetReply(r)
ret.Truncated = true
ret.Authoritative = true
ret.Rcode = dns.RcodeSuccess
}
err = nil // and reset err to pass this back to the client.
}
if err != nil { if err != nil {
// Kick off health check to see if *our* upstream is broken.
if f.maxfails != 0 {
proxy.Healthcheck()
}
if fails < len(f.proxies) { if fails < len(f.proxies) {
continue continue
} }
@ -140,8 +160,8 @@ func (f *Forward) isAllowedDomain(name string) bool {
func (f *Forward) list() []*Proxy { return f.p.List(f.proxies) } func (f *Forward) list() []*Proxy { return f.p.List(f.proxies) }
var ( var (
errInvalidDomain = errors.New("invalid domain for proxy") errInvalidDomain = errors.New("invalid domain for forward")
errNoHealthy = errors.New("no healthy proxies") errNoHealthy = errors.New("no healthy proxies or upstream error")
errNoForward = errors.New("no forwarder defined") errNoForward = errors.New("no forwarder defined")
) )

View file

@ -6,6 +6,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -18,7 +19,7 @@ func TestForward(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr) p := NewProxy(s.Addr, nil /* not TLS */)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()

View file

@ -1,7 +1,6 @@
package forward package forward
import ( import (
"log"
"sync/atomic" "sync/atomic"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -10,41 +9,25 @@ import (
// For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty // For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty
// replies are considered fails, basically anything else constitutes a healthy upstream. // replies are considered fails, basically anything else constitutes a healthy upstream.
func (h *host) Check() { // Check is used as the up.Func in the up.Probe.
h.Lock() func (p *Proxy) Check() error {
err := p.send()
if h.checking {
h.Unlock()
return
}
h.checking = true
h.Unlock()
err := h.send()
if err != nil { if err != nil {
log.Printf("[INFO] healtheck of %s failed with %s", h.addr, err) HealthcheckFailureCount.WithLabelValues(p.addr).Add(1)
atomic.AddUint32(&p.fails, 1)
HealthcheckFailureCount.WithLabelValues(h.addr).Add(1) return err
atomic.AddUint32(&h.fails, 1)
} else {
atomic.StoreUint32(&h.fails, 0)
} }
h.Lock() atomic.StoreUint32(&p.fails, 0)
h.checking = false return nil
h.Unlock()
return
} }
func (h *host) send() error { func (p *Proxy) send() error {
hcping := new(dns.Msg) hcping := new(dns.Msg)
hcping.SetQuestion(".", dns.TypeNS) hcping.SetQuestion(".", dns.TypeNS)
hcping.RecursionDesired = false hcping.RecursionDesired = false
m, _, err := h.client.Exchange(hcping, h.addr) m, _, err := p.client.Exchange(hcping, p.addr)
// If we got a header, we're alright, basically only care about I/O errors 'n stuff // If we got a header, we're alright, basically only care about I/O errors 'n stuff
if err != nil && m != nil { if err != nil && m != nil {
// Silly check, something sane came back // Silly check, something sane came back
@ -55,13 +38,3 @@ func (h *host) send() error {
return err return err
} }
// down returns true is this host has more than maxfails fails.
func (h *host) down(maxfails uint32) bool {
if maxfails == 0 {
return false
}
fails := atomic.LoadUint32(&h.fails)
return fails > maxfails
}

View file

@ -0,0 +1,136 @@
package forward
import (
"sync/atomic"
"testing"
"time"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
func TestHealth(t *testing.T) {
const expected = 0
i := uint32(0)
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
if r.Question[0].Name == "." {
atomic.AddUint32(&i, 1)
}
ret := new(dns.Msg)
ret.SetReply(r)
w.WriteMsg(ret)
})
defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */)
f := New()
f.SetProxy(p)
defer f.Close()
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req)
time.Sleep(1 * time.Second)
i1 := atomic.LoadUint32(&i)
if i1 != expected {
t.Errorf("Expected number of health checks to be %d, got %d", expected, i1)
}
}
func TestHealthTimeout(t *testing.T) {
const expected = 1
i := uint32(0)
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
if r.Question[0].Name == "." {
// health check, answer
atomic.AddUint32(&i, 1)
ret := new(dns.Msg)
ret.SetReply(r)
w.WriteMsg(ret)
}
// not a health check, do a timeout
})
defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */)
f := New()
f.SetProxy(p)
defer f.Close()
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req)
time.Sleep(1 * time.Second)
i1 := atomic.LoadUint32(&i)
if i1 != expected {
t.Errorf("Expected number of health checks to be %d, got %d", expected, i1)
}
}
func TestHealthFailTwice(t *testing.T) {
const expected = 2
i := uint32(0)
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
if r.Question[0].Name == "." {
atomic.AddUint32(&i, 1)
i1 := atomic.LoadUint32(&i)
// Timeout health until we get the second one
if i1 < 2 {
return
}
ret := new(dns.Msg)
ret.SetReply(r)
w.WriteMsg(ret)
}
})
defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */)
f := New()
f.SetProxy(p)
defer f.Close()
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req)
time.Sleep(3 * time.Second)
i1 := atomic.LoadUint32(&i)
if i1 != expected {
t.Errorf("Expected number of health checks to be %d, got %d", expected, i1)
}
}
func TestHealthMaxFails(t *testing.T) {
const expected = 0
i := uint32(0)
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
// timeout
})
defer s.Close()
p := NewProxy(s.Addr, nil /* no TLS */)
f := New()
f.maxfails = 0
f.SetProxy(p)
defer f.Close()
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req)
time.Sleep(1 * time.Second)
i1 := atomic.LoadUint32(&i)
if i1 != expected {
t.Errorf("Expected number of health checks to be %d, got %d", expected, i1)
}
}

View file

@ -1,44 +0,0 @@
package forward
import (
"crypto/tls"
"sync"
"time"
"github.com/miekg/dns"
)
type host struct {
addr string
client *dns.Client
tlsConfig *tls.Config
expire time.Duration
fails uint32
sync.RWMutex
checking bool
}
// newHost returns a new host, the fails are set to 1, i.e.
// the first healthcheck must succeed before we use this host.
func newHost(addr string) *host {
return &host{addr: addr, fails: 1, expire: defaultExpire}
}
// setClient sets and configures the dns.Client in host.
func (h *host) SetClient() {
c := new(dns.Client)
c.Net = "udp"
c.ReadTimeout = 2 * time.Second
c.WriteTimeout = 2 * time.Second
if h.tlsConfig != nil {
c.Net = "tcp-tls"
c.TLSConfig = h.tlsConfig
}
h.client = c
}
const defaultExpire = 10 * time.Second

View file

@ -5,10 +5,6 @@
package forward package forward
import ( import (
"crypto/tls"
"log"
"time"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -32,12 +28,10 @@ func (f *Forward) Forward(state request.Request) (*dns.Msg, error) {
// All upstream proxies are dead, assume healtcheck is complete broken and randomly // All upstream proxies are dead, assume healtcheck is complete broken and randomly
// select an upstream to connect to. // select an upstream to connect to.
proxy = f.list()[0] proxy = f.list()[0]
log.Printf("[WARNING] All upstreams down, picking random one to connect to %s", proxy.host.addr)
} }
ret, err := proxy.connect(context.Background(), state, f.forceTCP, true) ret, err := proxy.connect(context.Background(), state, f.forceTCP, true)
if err != nil { if err != nil {
log.Printf("[WARNING] Failed to connect to %s: %s", proxy.host.addr, err)
if fails < len(f.proxies) { if fails < len(f.proxies) {
continue continue
} }
@ -68,10 +62,11 @@ func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.M
} }
// NewLookup returns a Forward that can be used for plugin that need an upstream to resolve external names. // NewLookup returns a Forward that can be used for plugin that need an upstream to resolve external names.
// Note that the caller must run Close on the forward to stop the health checking goroutines.
func NewLookup(addr []string) *Forward { func NewLookup(addr []string) *Forward {
f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: 2 * time.Second} f := New()
for i := range addr { for i := range addr {
p := NewProxy(addr[i]) p := NewProxy(addr[i], nil)
f.SetProxy(p) f.SetProxy(p)
} }
return f return f

View file

@ -6,6 +6,7 @@ import (
"github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -18,7 +19,7 @@ func TestLookup(t *testing.T) {
}) })
defer s.Close() defer s.Close()
p := NewProxy(s.Addr) p := NewProxy(s.Addr, nil /* no TLS */)
f := New() f := New()
f.SetProxy(p) f.SetProxy(p)
defer f.Close() defer f.Close()

View file

@ -1,6 +1,7 @@
package forward package forward
import ( import (
"crypto/tls"
"net" "net"
"time" "time"
@ -21,8 +22,10 @@ type connErr struct {
// transport hold the persistent cache. // transport hold the persistent cache.
type transport struct { type transport struct {
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
host *host expire time.Duration // After this duration a connection is expired.
addr string
tlsConfig *tls.Config
dial chan string dial chan string
yield chan connErr yield chan connErr
@ -35,10 +38,11 @@ type transport struct {
stop chan bool stop chan bool
} }
func newTransport(h *host) *transport { func newTransport(addr string, tlsConfig *tls.Config) *transport {
t := &transport{ t := &transport{
conns: make(map[string][]*persistConn), conns: make(map[string][]*persistConn),
host: h, expire: defaultExpire,
addr: addr,
dial: make(chan string), dial: make(chan string),
yield: make(chan connErr), yield: make(chan connErr),
ret: make(chan connErr), ret: make(chan connErr),
@ -51,7 +55,7 @@ func newTransport(h *host) *transport {
} }
// len returns the number of connection, used for metrics. Can only be safely // len returns the number of connection, used for metrics. Can only be safely
// used inside connManager() because of races. // used inside connManager() because of data races.
func (t *transport) len() int { func (t *transport) len() int {
l := 0 l := 0
for _, conns := range t.conns { for _, conns := range t.conns {
@ -79,7 +83,7 @@ Wait:
i := 0 i := 0
for i = 0; i < len(t.conns[proto]); i++ { for i = 0; i < len(t.conns[proto]); i++ {
pc := t.conns[proto][i] pc := t.conns[proto][i]
if time.Since(pc.used) < t.host.expire { if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn. // Found one, remove from pool and return this conn.
t.conns[proto] = t.conns[proto][i+1:] t.conns[proto] = t.conns[proto][i+1:]
t.ret <- connErr{pc.c, nil} t.ret <- connErr{pc.c, nil}
@ -91,22 +95,22 @@ Wait:
// Not conns were found. Connect to the upstream to create one. // Not conns were found. Connect to the upstream to create one.
t.conns[proto] = t.conns[proto][i:] t.conns[proto] = t.conns[proto][i:]
SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len())) SocketGauge.WithLabelValues(t.addr).Set(float64(t.len()))
go func() { go func() {
if proto != "tcp-tls" { if proto != "tcp-tls" {
c, err := dns.DialTimeout(proto, t.host.addr, dialTimeout) c, err := dns.DialTimeout(proto, t.addr, dialTimeout)
t.ret <- connErr{c, err} t.ret <- connErr{c, err}
return return
} }
c, err := dns.DialTimeoutWithTLS("tcp", t.host.addr, t.host.tlsConfig, dialTimeout) c, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, dialTimeout)
t.ret <- connErr{c, err} t.ret <- connErr{c, err}
}() }()
case conn := <-t.yield: case conn := <-t.yield:
SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len() + 1)) SocketGauge.WithLabelValues(t.addr).Set(float64(t.len() + 1))
// no proto here, infer from config and conn // no proto here, infer from config and conn
if _, ok := conn.c.Conn.(*net.UDPConn); ok { if _, ok := conn.c.Conn.(*net.UDPConn); ok {
@ -114,7 +118,7 @@ Wait:
continue Wait continue Wait
} }
if t.host.tlsConfig == nil { if t.tlsConfig == nil {
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()}) t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()})
continue Wait continue Wait
} }
@ -134,15 +138,30 @@ Wait:
} }
} }
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *transport) Dial(proto string) (*dns.Conn, error) { func (t *transport) Dial(proto string) (*dns.Conn, error) {
// If tls has been configured; use it.
if t.tlsConfig != nil {
proto = "tcp-tls"
}
t.dial <- proto t.dial <- proto
c := <-t.ret c := <-t.ret
return c.c, c.err return c.c, c.err
} }
// Yield return the connection to transport for reuse.
func (t *transport) Yield(c *dns.Conn) { func (t *transport) Yield(c *dns.Conn) {
t.yield <- connErr{c, nil} t.yield <- connErr{c, nil}
} }
// Stop stops the transports. // Stop stops the transport's connection manager.
func (t *transport) Stop() { t.stop <- true } func (t *transport) Stop() { t.stop <- true }
// SetExpire sets the connection expire time in transport.
func (t *transport) SetExpire(expire time.Duration) { t.expire = expire }
// SetTLSConfig sets the TLS config in transport.
func (t *transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
const defaultExpire = 10 * time.Second

View file

@ -16,8 +16,7 @@ func TestPersistent(t *testing.T) {
}) })
defer s.Close() defer s.Close()
h := newHost(s.Addr) tr := newTransport(s.Addr, nil /* no TLS */)
tr := newTransport(h)
defer tr.Stop() defer tr.Stop()
c1, _ := tr.Dial("udp") c1, _ := tr.Dial("udp")

View file

@ -2,47 +2,60 @@ package forward
import ( import (
"crypto/tls" "crypto/tls"
"sync" "sync/atomic"
"time" "time"
"github.com/coredns/coredns/plugin/pkg/up"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// Proxy defines an upstream host. // Proxy defines an upstream host.
type Proxy struct { type Proxy struct {
host *host addr string
client *dns.Client
// Connection caching
expire time.Duration
transport *transport transport *transport
// copied from Forward. // health checking
hcInterval time.Duration probe *up.Probe
forceTCP bool fails uint32
stop chan bool
sync.RWMutex
} }
// NewProxy returns a new proxy. // NewProxy returns a new proxy.
func NewProxy(addr string) *Proxy { func NewProxy(addr string, tlsConfig *tls.Config) *Proxy {
host := newHost(addr)
p := &Proxy{ p := &Proxy{
host: host, addr: addr,
hcInterval: hcDuration, fails: 0,
stop: make(chan bool), probe: up.New(),
transport: newTransport(host), transport: newTransport(addr, tlsConfig),
} }
p.client = dnsClient(tlsConfig)
return p return p
} }
// SetTLSConfig sets the TLS config in the lower p.host. // dnsClient returns a client used for health checking.
func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.host.tlsConfig = cfg } func dnsClient(tlsConfig *tls.Config) *dns.Client {
c := new(dns.Client)
c.Net = "udp"
// TODO(miek): this should be half of hcDuration?
c.ReadTimeout = 1 * time.Second
c.WriteTimeout = 1 * time.Second
// SetExpire sets the expire duration in the lower p.host. if tlsConfig != nil {
func (p *Proxy) SetExpire(expire time.Duration) { p.host.expire = expire } c.Net = "tcp-tls"
c.TLSConfig = tlsConfig
}
return c
}
func (p *Proxy) close() { p.stop <- true } // SetTLSConfig sets the TLS config in the lower p.transport.
func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.transport.SetTLSConfig(cfg) }
// SetExpire sets the expire duration in the lower p.transport.
func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) }
// Dial connects to the host in p with the configured transport. // Dial connects to the host in p with the configured transport.
func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) } func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) }
@ -50,26 +63,28 @@ func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(
// Yield returns the connection to the pool. // Yield returns the connection to the pool.
func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) } func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) }
// Down returns if this proxy is up or down. // Healthcheck kicks of a round of health checks for this proxy.
func (p *Proxy) Down(maxfails uint32) bool { return p.host.down(maxfails) } func (p *Proxy) Healthcheck() { p.probe.Do(p.Check) }
func (p *Proxy) healthCheck() { // Down returns true if this proxy is down, i.e. has *more* fails than maxfails.
func (p *Proxy) Down(maxfails uint32) bool {
// stop channel if maxfails == 0 {
p.host.SetClient() return false
p.host.Check()
tick := time.NewTicker(p.hcInterval)
for {
select {
case <-tick.C:
p.host.Check()
case <-p.stop:
return
}
} }
fails := atomic.LoadUint32(&p.fails)
return fails > maxfails
} }
// close stops the health checking goroutine.
func (p *Proxy) close() {
p.probe.Stop()
p.transport.Stop()
}
// start starts the proxy's healthchecking.
func (p *Proxy) start(duration time.Duration) { p.probe.Start(duration) }
const ( const (
dialTimeout = 4 * time.Second dialTimeout = 4 * time.Second
timeout = 2 * time.Second timeout = 2 * time.Second

View file

@ -62,25 +62,14 @@ func setup(c *caddy.Controller) error {
// OnStartup starts a goroutines for all proxies. // OnStartup starts a goroutines for all proxies.
func (f *Forward) OnStartup() (err error) { func (f *Forward) OnStartup() (err error) {
if f.hcInterval == 0 {
for _, p := range f.proxies {
p.host.fails = 0
}
return nil
}
for _, p := range f.proxies { for _, p := range f.proxies {
go p.healthCheck() p.start(f.hcInterval)
} }
return nil return nil
} }
// OnShutdown stops all configured proxies. // OnShutdown stops all configured proxies.
func (f *Forward) OnShutdown() error { func (f *Forward) OnShutdown() error {
if f.hcInterval == 0 {
return nil
}
for _, p := range f.proxies { for _, p := range f.proxies {
p.close() p.close()
} }
@ -88,9 +77,7 @@ func (f *Forward) OnShutdown() error {
} }
// Close is a synonym for OnShutdown(). // Close is a synonym for OnShutdown().
func (f *Forward) Close() { func (f *Forward) Close() { f.OnShutdown() }
f.OnShutdown()
}
func parseForward(c *caddy.Controller) (*Forward, error) { func parseForward(c *caddy.Controller) (*Forward, error) {
f := New() f := New()
@ -140,8 +127,8 @@ func parseForward(c *caddy.Controller) (*Forward, error) {
} }
// We can't set tlsConfig here, because we haven't parsed it yet. // We can't set tlsConfig here, because we haven't parsed it yet.
// We set it below at the end of parseBlock. // We set it below at the end of parseBlock, use nil now.
p := NewProxy(h) p := NewProxy(h, nil /* no TLS */)
f.proxies = append(f.proxies, p) f.proxies = append(f.proxies, p)
} }
@ -200,17 +187,11 @@ func parseBlock(c *caddy.Controller, f *Forward) error {
return fmt.Errorf("health_check can't be negative: %d", dur) return fmt.Errorf("health_check can't be negative: %d", dur)
} }
f.hcInterval = dur f.hcInterval = dur
for i := range f.proxies {
f.proxies[i].hcInterval = dur
}
case "force_tcp": case "force_tcp":
if c.NextArg() { if c.NextArg() {
return c.ArgErr() return c.ArgErr()
} }
f.forceTCP = true f.forceTCP = true
for i := range f.proxies {
f.proxies[i].forceTCP = true
}
case "tls": case "tls":
args := c.RemainingArgs() args := c.RemainingArgs()
if len(args) != 3 { if len(args) != 3 {

View file

@ -17,8 +17,8 @@ type Probe struct {
inprogress bool inprogress bool
} }
// Func is used to determine if a target is alive. If so this function must return true. // Func is used to determine if a target is alive. If so this function must return nil.
type Func func(target string) bool type Func func() error
// New returns a pointer to an intialized Probe. // New returns a pointer to an intialized Probe.
func New() *Probe { func New() *Probe {
@ -32,9 +32,9 @@ func (p *Probe) Do(f Func) { p.do <- f }
func (p *Probe) Stop() { p.stop <- true } func (p *Probe) Stop() { p.stop <- true }
// Start will start the probe manager, after which probes can be initialized with Do. // Start will start the probe manager, after which probes can be initialized with Do.
func (p *Probe) Start(target string, interval time.Duration) { go p.start(target, interval) } func (p *Probe) Start(interval time.Duration) { go p.start(interval) }
func (p *Probe) start(target string, interval time.Duration) { func (p *Probe) start(interval time.Duration) {
for { for {
select { select {
case <-p.stop: case <-p.stop:
@ -52,9 +52,10 @@ func (p *Probe) start(target string, interval time.Duration) {
// we return from the goroutine and we can accept another Func to run. // we return from the goroutine and we can accept another Func to run.
go func() { go func() {
for { for {
if ok := f(target); ok { if err := f(); err == nil {
break break
} }
// TODO(miek): little bit of exponential backoff here?
time.Sleep(interval) time.Sleep(interval)
} }
p.Lock() p.Lock()

View file

@ -12,20 +12,20 @@ func TestUp(t *testing.T) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
hits := int32(0) hits := int32(0)
upfunc := func(s string) bool { upfunc := func() error {
atomic.AddInt32(&hits, 1) atomic.AddInt32(&hits, 1)
// Sleep tiny amount so that our other pr.Do() calls hit the lock. // Sleep tiny amount so that our other pr.Do() calls hit the lock.
time.Sleep(3 * time.Millisecond) time.Sleep(3 * time.Millisecond)
wg.Done() wg.Done()
return true return nil
} }
pr.Start("nonexistent", 5*time.Millisecond) pr.Start(5 * time.Millisecond)
defer pr.Stop() defer pr.Stop()
// These functions AddInt32 to the same hits variable, but we only want to wait when // These functions AddInt32 to the same hits variable, but we only want to wait when
// upfunc finishes, as that only calls Done() on the waitgroup. // upfunc finishes, as that only calls Done() on the waitgroup.
upfuncNoWg := func(s string) bool { atomic.AddInt32(&hits, 1); return true } upfuncNoWg := func() error { atomic.AddInt32(&hits, 1); return nil }
wg.Add(1) wg.Add(1)
pr.Do(upfunc) pr.Do(upfunc)
pr.Do(upfuncNoWg) pr.Do(upfuncNoWg)