diff --git a/Makefile b/Makefile index 245607484..8765bb37c 100644 --- a/Makefile +++ b/Makefile @@ -29,20 +29,17 @@ godeps: (cd $(GOPATH)/src/github.com/prometheus/client_golang 2>/dev/null && git checkout -q master 2>/dev/null || true) (cd $(GOPATH)/src/golang.org/x/net 2>/dev/null && git checkout -q master 2>/dev/null || true) (cd $(GOPATH)/src/golang.org/x/text 2>/dev/null && git checkout -q master 2>/dev/null || true) - (cd $(GOPATH)/src/github.com/coredns/forward 2>/dev/null && git checkout -q master 2>/dev/null || true) go get -u github.com/mholt/caddy go get -u github.com/miekg/dns go get -u github.com/prometheus/client_golang/prometheus/promhttp go get -u github.com/prometheus/client_golang/prometheus go get -u golang.org/x/net/context go get -u golang.org/x/text - -go get -f -u github.com/coredns/forward (cd $(GOPATH)/src/github.com/mholt/caddy && git checkout -q v0.10.10) (cd $(GOPATH)/src/github.com/miekg/dns && git checkout -q v1.0.4) (cd $(GOPATH)/src/github.com/prometheus/client_golang && git checkout -q v0.8.0) (cd $(GOPATH)/src/golang.org/x/net && git checkout -q release-branch.go1.9) (cd $(GOPATH)/src/golang.org/x/text && git checkout -q e19ae1496984b1c655b8044a65c0300a3c878dd3) - (cd $(GOPATH)/src/github.com/coredns/forward && git checkout -q v0.0.2) .PHONY: travis travis: check diff --git a/plugin.cfg b/plugin.cfg index 2b34faa63..60193990d 100644 --- a/plugin.cfg +++ b/plugin.cfg @@ -48,7 +48,7 @@ file:file auto:auto secondary:secondary etcd:etcd -forward:github.com/coredns/forward +forward:forward proxy:proxy erratic:erratic whoami:whoami diff --git a/plugin/forward/README.md b/plugin/forward/README.md new file mode 100644 index 000000000..bbef305db --- /dev/null +++ b/plugin/forward/README.md @@ -0,0 +1,156 @@ +# forward + +## Name + +*forward* facilitates proxying DNS messages to upstream resolvers. + +## Description + +The *forward* plugin is generally faster (~30+%) than *proxy* as it re-uses already opened sockets +to the upstreams. It supports UDP, TCP and DNS-over-TLS and uses inband health checking that is +enabled by default. +When *all* upstreams are down it assumes healtchecking as a mechanism has failed and will try to +connect to a random upstream (which may or may not work). + +## Syntax + +In its most basic form, a simple forwarder uses this syntax: + +~~~ +forward FROM TO... +~~~ + +* **FROM** is the base domain to match for the request to be forwarded. +* **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify + a protocol, `tls://9.9.9.9` or `dns://` for plain DNS. The number of upstreams is limited to 15. + +The health checks are done every *0.5s*. After *two* failed checks the upstream is considered +unhealthy. The health checks use a recursive DNS query (`. IN NS`) to get upstream health. Any +response that is not an error (REFUSED, NOTIMPL, SERVFAIL, etc) is taken as a healthy upstream. The +health check uses the same protocol as specific in the **TO**. On startup each upstream is marked +unhealthy until it passes a health check. A 0 duration will disable any health checks. + +Multiple upstreams are randomized (default policy) on first use. When a healthy proxy returns an +error during the exchange the next upstream in the list is tried. + +Extra knobs are available with an expanded syntax: + +~~~ +forward FROM TO... { + except IGNORED_NAMES... + force_tcp + health_check DURATION + expire DURATION + max_fails INTEGER + tls CERT KEY CA + tls_servername NAME + policy random|round_robin +} +~~~ + +* **FROM** and **TO...** as above. +* **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding. + Requests that match none of these names will be passed through. +* `force_tcp`, use TCP even when the request comes in over UDP. +* `health_checks`, use a different **DURATION** for health checking, the default duration is 0.5s. + A value of 0 disables the health checks completely. +* `max_fails` is the number of subsequent failed health checks that are needed before considering + a backend to be down. If 0, the backend will never be marked as down. Default is 2. +* `expire` **DURATION**, expire (cached) connections after this time, the default is 10s. +* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS; if you leave this out the + system's configuration will be used. +* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9 + needs this to be set to `dns.quad9.net`. +* `policy` specifies the policy to use for selecting upstream servers. The default is `random`. + +The upstream selection is done via random (default policy) selection. If the socket for this client +isn't known *forward* will randomly choose one. If this turns out to be unhealthy, the next one is +tried. If *all* hosts are down, we assume health checking is broken and select a *random* upstream to +try. + +Also note the TLS config is "global" for the whole forwarding proxy if you need a different +`tls-name` for different upstreams you're out of luck. + +## Metrics + +If monitoring is enabled (via the *prometheus* directive) then the following metric are exported: + +* `coredns_forward_request_duration_seconds{to}` - duration per upstream interaction. +* `coredns_forward_request_count_total{to}` - query count per upstream. +* `coredns_forward_response_rcode_total{to, rcode}` - count of RCODEs per upstream. +* `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream. +* `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy, + and we are randomly spraying to a target. +* `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream. + +Where `to` is one of the upstream servers (**TO** from the config), `proto` is the protocol used by +the incoming query ("tcp" or "udp"), and family the transport family ("1" for IPv4, and "2" for +IPv6). + +## Examples + +Proxy all requests within example.org. to a nameserver running on a different port: + +~~~ corefile +example.org { + forward . 127.0.0.1:9005 +} +~~~ + +Load balance all requests between three resolvers, one of which has a IPv6 address. + +~~~ corefile +. { + forward . 10.0.0.10:53 10.0.0.11:1053 [2003::1]:53 +} +~~~ + +Forward everything except requests to `example.org` + +~~~ corefile +. { + forward . 10.0.0.10:1234 { + except example.org + } +} +~~~ + +Proxy everything except `example.org` using the host's `resolv.conf`'s nameservers: + +~~~ corefile +. { + forward . /etc/resolv.conf { + except example.org + } +} +~~~ + +Forward to a IPv6 host: + +~~~ corefile +. { + forward . [::1]:1053 +} +~~~ + +Proxy all requests to 9.9.9.9 using the DNS-over-TLS protocol, and cache every answer for up to 30 +seconds. + +~~~ corefile +. { + forward . tls://9.9.9.9 { + tls_servername dns.quad9.net + health_check 5s + } + cache 30 +} +~~~ + +## Bugs + +The TLS config is global for the whole forwarding proxy if you need a different `tls-name` for +different upstreams you're out of luck. + +## Also See + +[RFC 7858](https://tools.ietf.org/html/rfc7858) for DNS over TLS. diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go new file mode 100644 index 000000000..cdad29ed1 --- /dev/null +++ b/plugin/forward/connect.go @@ -0,0 +1,66 @@ +// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be +// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses +// inband healthchecking. +package forward + +import ( + "strconv" + "time" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) { + start := time.Now() + + proto := state.Proto() + if forceTCP { + proto = "tcp" + } + if p.host.tlsConfig != nil { + proto = "tcp-tls" + } + + conn, err := p.Dial(proto) + if err != nil { + return nil, err + } + + // Set buffer size correctly for this client. + conn.UDPSize = uint16(state.Size()) + if conn.UDPSize < 512 { + conn.UDPSize = 512 + } + + conn.SetWriteDeadline(time.Now().Add(timeout)) + if err := conn.WriteMsg(state.Req); err != nil { + conn.Close() // not giving it back + return nil, err + } + + conn.SetReadDeadline(time.Now().Add(timeout)) + ret, err := conn.ReadMsg() + if err != nil { + conn.Close() // not giving it back + return nil, err + } + + p.Yield(conn) + + if metric { + rc, ok := dns.RcodeToString[ret.Rcode] + if !ok { + rc = strconv.Itoa(ret.Rcode) + } + + RequestCount.WithLabelValues(p.host.addr).Add(1) + RcodeCount.WithLabelValues(rc, p.host.addr).Add(1) + RequestDuration.WithLabelValues(p.host.addr).Observe(time.Since(start).Seconds()) + } + + return ret, nil +} diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go new file mode 100644 index 000000000..35885008e --- /dev/null +++ b/plugin/forward/forward.go @@ -0,0 +1,154 @@ +// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be +// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses +// inband healthchecking. +package forward + +import ( + "crypto/tls" + "errors" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + ot "github.com/opentracing/opentracing-go" + "golang.org/x/net/context" +) + +// Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list +// of proxies each representing one upstream proxy. +type Forward struct { + proxies []*Proxy + p Policy + + from string + ignored []string + + tlsConfig *tls.Config + tlsServerName string + maxfails uint32 + expire time.Duration + + forceTCP bool // also here for testing + hcInterval time.Duration // also here for testing + + Next plugin.Handler +} + +// New returns a new Forward. +func New() *Forward { + f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: hcDuration, p: new(random)} + return f +} + +// SetProxy appends p to the proxy list and starts healthchecking. +func (f *Forward) SetProxy(p *Proxy) { + f.proxies = append(f.proxies, p) + go p.healthCheck() +} + +// Len returns the number of configured proxies. +func (f *Forward) Len() int { return len(f.proxies) } + +// Name implements plugin.Handler. +func (f *Forward) Name() string { return "forward" } + +// ServeDNS implements plugin.Handler. +func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + + state := request.Request{W: w, Req: r} + if !f.match(state) { + return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r) + } + + fails := 0 + var span, child ot.Span + span = ot.SpanFromContext(ctx) + + for _, proxy := range f.list() { + if proxy.Down(f.maxfails) { + fails++ + if fails < len(f.proxies) { + continue + } + // All upstream proxies are dead, assume healtcheck is completely broken and randomly + // select an upstream to connect to. + r := new(random) + proxy = r.List(f.proxies)[0] + + HealthcheckBrokenCount.Add(1) + } + + if span != nil { + child = span.Tracer().StartSpan("connect", ot.ChildOf(span.Context())) + ctx = ot.ContextWithSpan(ctx, child) + } + + ret, err := proxy.connect(ctx, state, f.forceTCP, true) + + if child != nil { + child.Finish() + } + + if err != nil { + if fails < len(f.proxies) { + continue + } + break + } + + ret.Compress = true + // When using force_tcp the upstream can send a message that is too big for + // the udp buffer, hence we need to truncate the message to at least make it + // fit the udp buffer. + ret, _ = state.Scrub(ret) + + w.WriteMsg(ret) + + return 0, nil + } + + return dns.RcodeServerFailure, errNoHealthy +} + +func (f *Forward) match(state request.Request) bool { + from := f.from + + if !plugin.Name(from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) { + return false + } + + return true +} + +func (f *Forward) isAllowedDomain(name string) bool { + if dns.Name(name) == dns.Name(f.from) { + return true + } + + for _, ignore := range f.ignored { + if plugin.Name(ignore).Matches(name) { + return false + } + } + return true +} + +// List returns a set of proxies to be used for this client depending on the policy in f. +func (f *Forward) list() []*Proxy { return f.p.List(f.proxies) } + +var ( + errInvalidDomain = errors.New("invalid domain for proxy") + errNoHealthy = errors.New("no healthy proxies") + errNoForward = errors.New("no forwarder defined") +) + +// policy tells forward what policy for selecting upstream it uses. +type policy int + +const ( + randomPolicy policy = iota + roundRobinPolicy +) diff --git a/plugin/forward/forward_test.go b/plugin/forward/forward_test.go new file mode 100644 index 000000000..d467a0efa --- /dev/null +++ b/plugin/forward/forward_test.go @@ -0,0 +1,42 @@ +package forward + +import ( + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + "github.com/miekg/dns" +) + +func TestForward(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr) + f := New() + f.SetProxy(p) + defer f.Close() + + state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + state.Req.SetQuestion("example.org.", dns.TypeA) + resp, err := f.Forward(state) + if err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + // expect answer section with A record in it + if len(resp.Answer) == 0 { + t.Fatalf("Expected to at least one RR in the answer section, got none: %s", resp) + } + if resp.Answer[0].Header().Rrtype != dns.TypeA { + t.Errorf("Expected RR to A, got: %d", resp.Answer[0].Header().Rrtype) + } + if resp.Answer[0].(*dns.A).A.String() != "127.0.0.1" { + t.Errorf("Expected 127.0.0.1, got: %s", resp.Answer[0].(*dns.A).A.String()) + } +} diff --git a/plugin/forward/health.go b/plugin/forward/health.go new file mode 100644 index 000000000..e277f30a6 --- /dev/null +++ b/plugin/forward/health.go @@ -0,0 +1,67 @@ +package forward + +import ( + "log" + "sync/atomic" + + "github.com/miekg/dns" +) + +// For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty +// replies are considered fails, basically anything else constitutes a healthy upstream. + +func (h *host) Check() { + h.Lock() + + if h.checking { + h.Unlock() + return + } + + h.checking = true + h.Unlock() + + err := h.send() + if err != nil { + log.Printf("[INFO] healtheck of %s failed with %s", h.addr, err) + + HealthcheckFailureCount.WithLabelValues(h.addr).Add(1) + + atomic.AddUint32(&h.fails, 1) + } else { + atomic.StoreUint32(&h.fails, 0) + } + + h.Lock() + h.checking = false + h.Unlock() + + return +} + +func (h *host) send() error { + hcping := new(dns.Msg) + hcping.SetQuestion(".", dns.TypeNS) + hcping.RecursionDesired = false + + m, _, err := h.client.Exchange(hcping, h.addr) + // If we got a header, we're alright, basically only care about I/O errors 'n stuff + if err != nil && m != nil { + // Silly check, something sane came back + if m.Response || m.Opcode == dns.OpcodeQuery { + err = nil + } + } + + return err +} + +// down returns true is this host has more than maxfails fails. +func (h *host) down(maxfails uint32) bool { + if maxfails == 0 { + return false + } + + fails := atomic.LoadUint32(&h.fails) + return fails > maxfails +} diff --git a/plugin/forward/host.go b/plugin/forward/host.go new file mode 100644 index 000000000..48d6c7d6e --- /dev/null +++ b/plugin/forward/host.go @@ -0,0 +1,44 @@ +package forward + +import ( + "crypto/tls" + "sync" + "time" + + "github.com/miekg/dns" +) + +type host struct { + addr string + client *dns.Client + + tlsConfig *tls.Config + expire time.Duration + + fails uint32 + sync.RWMutex + checking bool +} + +// newHost returns a new host, the fails are set to 1, i.e. +// the first healthcheck must succeed before we use this host. +func newHost(addr string) *host { + return &host{addr: addr, fails: 1, expire: defaultExpire} +} + +// setClient sets and configures the dns.Client in host. +func (h *host) SetClient() { + c := new(dns.Client) + c.Net = "udp" + c.ReadTimeout = 2 * time.Second + c.WriteTimeout = 2 * time.Second + + if h.tlsConfig != nil { + c.Net = "tcp-tls" + c.TLSConfig = h.tlsConfig + } + + h.client = c +} + +const defaultExpire = 10 * time.Second diff --git a/plugin/forward/lookup.go b/plugin/forward/lookup.go new file mode 100644 index 000000000..47c4319cf --- /dev/null +++ b/plugin/forward/lookup.go @@ -0,0 +1,78 @@ +// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be +// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses +// inband healthchecking. +package forward + +import ( + "crypto/tls" + "log" + "time" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +// Forward forward the request in state as-is. Unlike Lookup that adds EDNS0 suffix to the message. +// Forward may be called with a nil f, an error is returned in that case. +func (f *Forward) Forward(state request.Request) (*dns.Msg, error) { + if f == nil { + return nil, errNoForward + } + + fails := 0 + for _, proxy := range f.list() { + if proxy.Down(f.maxfails) { + fails++ + if fails < len(f.proxies) { + continue + } + // All upstream proxies are dead, assume healtcheck is complete broken and randomly + // select an upstream to connect to. + proxy = f.list()[0] + log.Printf("[WARNING] All upstreams down, picking random one to connect to %s", proxy.host.addr) + } + + ret, err := proxy.connect(context.Background(), state, f.forceTCP, true) + if err != nil { + log.Printf("[WARNING] Failed to connect to %s: %s", proxy.host.addr, err) + if fails < len(f.proxies) { + continue + } + break + + } + + return ret, nil + } + return nil, errNoHealthy +} + +// Lookup will use name and type to forge a new message and will send that upstream. It will +// set any EDNS0 options correctly so that downstream will be able to process the reply. +// Lookup may be called with a nil f, an error is returned in that case. +func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.Msg, error) { + if f == nil { + return nil, errNoForward + } + + req := new(dns.Msg) + req.SetQuestion(name, typ) + state.SizeAndDo(req) + + state2 := request.Request{W: state.W, Req: req} + + return f.Forward(state2) +} + +// NewLookup returns a Forward that can be used for plugin that need an upstream to resolve external names. +func NewLookup(addr []string) *Forward { + f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: 2 * time.Second} + for i := range addr { + p := NewProxy(addr[i]) + f.SetProxy(p) + } + return f +} diff --git a/plugin/forward/lookup_test.go b/plugin/forward/lookup_test.go new file mode 100644 index 000000000..69c7a1949 --- /dev/null +++ b/plugin/forward/lookup_test.go @@ -0,0 +1,41 @@ +package forward + +import ( + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + "github.com/miekg/dns" +) + +func TestLookup(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr) + f := New() + f.SetProxy(p) + defer f.Close() + + state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + resp, err := f.Lookup(state, "example.org.", dns.TypeA) + if err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + // expect answer section with A record in it + if len(resp.Answer) == 0 { + t.Fatalf("Expected to at least one RR in the answer section, got none: %s", resp) + } + if resp.Answer[0].Header().Rrtype != dns.TypeA { + t.Errorf("Expected RR to A, got: %d", resp.Answer[0].Header().Rrtype) + } + if resp.Answer[0].(*dns.A).A.String() != "127.0.0.1" { + t.Errorf("Expected 127.0.0.1, got: %s", resp.Answer[0].(*dns.A).A.String()) + } +} diff --git a/plugin/forward/metrics.go b/plugin/forward/metrics.go new file mode 100644 index 000000000..1e72454e0 --- /dev/null +++ b/plugin/forward/metrics.go @@ -0,0 +1,52 @@ +package forward + +import ( + "sync" + + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" +) + +// Variables declared for monitoring. +var ( + RequestCount = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "request_count_total", + Help: "Counter of requests made per upstream.", + }, []string{"to"}) + RcodeCount = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "response_rcode_count_total", + Help: "Counter of requests made per upstream.", + }, []string{"rcode", "to"}) + RequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "request_duration_seconds", + Buckets: plugin.TimeBuckets, + Help: "Histogram of the time each request took.", + }, []string{"to"}) + HealthcheckFailureCount = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "healthcheck_failure_count_total", + Help: "Counter of the number of failed healtchecks.", + }, []string{"to"}) + HealthcheckBrokenCount = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "healthcheck_broken_count_total", + Help: "Counter of the number of complete failures of the healtchecks.", + }) + SocketGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "socket_count_total", + Help: "Guage of open sockets per upstream.", + }, []string{"to"}) +) + +var once sync.Once diff --git a/plugin/forward/persistent.go b/plugin/forward/persistent.go new file mode 100644 index 000000000..6a7c4464e --- /dev/null +++ b/plugin/forward/persistent.go @@ -0,0 +1,148 @@ +package forward + +import ( + "net" + "time" + + "github.com/miekg/dns" +) + +// a persistConn hold the dns.Conn and the last used time. +type persistConn struct { + c *dns.Conn + used time.Time +} + +// connErr is used to communicate the connection manager. +type connErr struct { + c *dns.Conn + err error +} + +// transport hold the persistent cache. +type transport struct { + conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. + host *host + + dial chan string + yield chan connErr + ret chan connErr + + // Aid in testing, gets length of cache in data-race safe manner. + lenc chan bool + lencOut chan int + + stop chan bool +} + +func newTransport(h *host) *transport { + t := &transport{ + conns: make(map[string][]*persistConn), + host: h, + dial: make(chan string), + yield: make(chan connErr), + ret: make(chan connErr), + stop: make(chan bool), + lenc: make(chan bool), + lencOut: make(chan int), + } + go t.connManager() + return t +} + +// len returns the number of connection, used for metrics. Can only be safely +// used inside connManager() because of races. +func (t *transport) len() int { + l := 0 + for _, conns := range t.conns { + l += len(conns) + } + return l +} + +// Len returns the number of connections in the cache. +func (t *transport) Len() int { + t.lenc <- true + l := <-t.lencOut + return l +} + +// connManagers manages the persistent connection cache for UDP and TCP. +func (t *transport) connManager() { + +Wait: + for { + select { + case proto := <-t.dial: + // Yes O(n), shouldn't put millions in here. We walk all connection until we find the first + // one that is usuable. + i := 0 + for i = 0; i < len(t.conns[proto]); i++ { + pc := t.conns[proto][i] + if time.Since(pc.used) < t.host.expire { + // Found one, remove from pool and return this conn. + t.conns[proto] = t.conns[proto][i+1:] + t.ret <- connErr{pc.c, nil} + continue Wait + } + // This conn has expired. Close it. + pc.c.Close() + } + + // Not conns were found. Connect to the upstream to create one. + t.conns[proto] = t.conns[proto][i:] + SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len())) + + go func() { + if proto != "tcp-tls" { + c, err := dns.DialTimeout(proto, t.host.addr, dialTimeout) + t.ret <- connErr{c, err} + return + } + + c, err := dns.DialTimeoutWithTLS("tcp", t.host.addr, t.host.tlsConfig, dialTimeout) + t.ret <- connErr{c, err} + }() + + case conn := <-t.yield: + + SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len() + 1)) + + // no proto here, infer from config and conn + if _, ok := conn.c.Conn.(*net.UDPConn); ok { + t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()}) + continue Wait + } + + if t.host.tlsConfig == nil { + t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()}) + continue Wait + } + + t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()}) + + case <-t.stop: + return + + case <-t.lenc: + l := 0 + for _, conns := range t.conns { + l += len(conns) + } + t.lencOut <- l + } + } +} + +func (t *transport) Dial(proto string) (*dns.Conn, error) { + t.dial <- proto + c := <-t.ret + return c.c, c.err +} + +func (t *transport) Yield(c *dns.Conn) { + t.yield <- connErr{c, nil} +} + +// Stop stops the transports. +func (t *transport) Stop() { t.stop <- true } diff --git a/plugin/forward/persistent_test.go b/plugin/forward/persistent_test.go new file mode 100644 index 000000000..5674658e6 --- /dev/null +++ b/plugin/forward/persistent_test.go @@ -0,0 +1,44 @@ +package forward + +import ( + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + + "github.com/miekg/dns" +) + +func TestPersistent(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + h := newHost(s.Addr) + tr := newTransport(h) + defer tr.Stop() + + c1, _ := tr.Dial("udp") + c2, _ := tr.Dial("udp") + c3, _ := tr.Dial("udp") + + tr.Yield(c1) + tr.Yield(c2) + tr.Yield(c3) + + if x := tr.Len(); x != 3 { + t.Errorf("Expected cache size to be 3, got %d", x) + } + + tr.Dial("udp") + if x := tr.Len(); x != 2 { + t.Errorf("Expected cache size to be 2, got %d", x) + } + + tr.Dial("udp") + if x := tr.Len(); x != 1 { + t.Errorf("Expected cache size to be 2, got %d", x) + } +} diff --git a/plugin/forward/policy.go b/plugin/forward/policy.go new file mode 100644 index 000000000..f39a14105 --- /dev/null +++ b/plugin/forward/policy.go @@ -0,0 +1,55 @@ +package forward + +import ( + "math/rand" + "sync/atomic" +) + +// Policy defines a policy we use for selecting upstreams. +type Policy interface { + List([]*Proxy) []*Proxy + String() string +} + +// random is a policy that implements random upstream selection. +type random struct{} + +func (r *random) String() string { return "random" } + +func (r *random) List(p []*Proxy) []*Proxy { + switch len(p) { + case 1: + return p + case 2: + if rand.Int()%2 == 0 { + return []*Proxy{p[1], p[0]} // swap + } + return p + } + + perms := rand.Perm(len(p)) + rnd := make([]*Proxy, len(p)) + + for i, p1 := range perms { + rnd[i] = p[p1] + } + return rnd +} + +// roundRobin is a policy that selects hosts based on round robin ordering. +type roundRobin struct { + robin uint32 +} + +func (r *roundRobin) String() string { return "round_robin" } + +func (r *roundRobin) List(p []*Proxy) []*Proxy { + poolLen := uint32(len(p)) + i := atomic.AddUint32(&r.robin, 1) % poolLen + + robin := []*Proxy{p[i]} + robin = append(robin, p[:i]...) + robin = append(robin, p[i+1:]...) + + return robin +} diff --git a/plugin/forward/protocol.go b/plugin/forward/protocol.go new file mode 100644 index 000000000..338b60116 --- /dev/null +++ b/plugin/forward/protocol.go @@ -0,0 +1,30 @@ +package forward + +// Copied from coredns/core/dnsserver/address.go + +import ( + "strings" +) + +// protocol returns the protocol of the string s. The second string returns s +// with the prefix chopped off. +func protocol(s string) (int, string) { + switch { + case strings.HasPrefix(s, _tls+"://"): + return TLS, s[len(_tls)+3:] + case strings.HasPrefix(s, _dns+"://"): + return DNS, s[len(_dns)+3:] + } + return DNS, s +} + +// Supported protocols. +const ( + DNS = iota + 1 + TLS +) + +const ( + _dns = "dns" + _tls = "tls" +) diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go new file mode 100644 index 000000000..c89490374 --- /dev/null +++ b/plugin/forward/proxy.go @@ -0,0 +1,77 @@ +package forward + +import ( + "crypto/tls" + "sync" + "time" + + "github.com/miekg/dns" +) + +// Proxy defines an upstream host. +type Proxy struct { + host *host + + transport *transport + + // copied from Forward. + hcInterval time.Duration + forceTCP bool + + stop chan bool + + sync.RWMutex +} + +// NewProxy returns a new proxy. +func NewProxy(addr string) *Proxy { + host := newHost(addr) + + p := &Proxy{ + host: host, + hcInterval: hcDuration, + stop: make(chan bool), + transport: newTransport(host), + } + return p +} + +// SetTLSConfig sets the TLS config in the lower p.host. +func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.host.tlsConfig = cfg } + +// SetExpire sets the expire duration in the lower p.host. +func (p *Proxy) SetExpire(expire time.Duration) { p.host.expire = expire } + +func (p *Proxy) close() { p.stop <- true } + +// Dial connects to the host in p with the configured transport. +func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) } + +// Yield returns the connection to the pool. +func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) } + +// Down returns if this proxy is up or down. +func (p *Proxy) Down(maxfails uint32) bool { return p.host.down(maxfails) } + +func (p *Proxy) healthCheck() { + + // stop channel + p.host.SetClient() + + p.host.Check() + tick := time.NewTicker(p.hcInterval) + for { + select { + case <-tick.C: + p.host.Check() + case <-p.stop: + return + } + } +} + +const ( + dialTimeout = 4 * time.Second + timeout = 2 * time.Second + hcDuration = 500 * time.Millisecond +) diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go new file mode 100644 index 000000000..bed20f0c7 --- /dev/null +++ b/plugin/forward/setup.go @@ -0,0 +1,262 @@ +package forward + +import ( + "fmt" + "net" + "strconv" + "time" + + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + pkgtls "github.com/coredns/coredns/plugin/pkg/tls" + + "github.com/mholt/caddy" +) + +func init() { + caddy.RegisterPlugin("forward", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + f, err := parseForward(c) + if err != nil { + return plugin.Error("foward", err) + } + if f.Len() > max { + return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len())) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + f.Next = next + return f + }) + + c.OnStartup(func() error { + once.Do(func() { + m := dnsserver.GetConfig(c).Handler("prometheus") + if m == nil { + return + } + if x, ok := m.(*metrics.Metrics); ok { + x.MustRegister(RequestCount) + x.MustRegister(RcodeCount) + x.MustRegister(RequestDuration) + x.MustRegister(HealthcheckFailureCount) + x.MustRegister(SocketGauge) + } + }) + return f.OnStartup() + }) + + c.OnShutdown(func() error { + return f.OnShutdown() + }) + + return nil +} + +// OnStartup starts a goroutines for all proxies. +func (f *Forward) OnStartup() (err error) { + if f.hcInterval == 0 { + for _, p := range f.proxies { + p.host.fails = 0 + } + return nil + } + + for _, p := range f.proxies { + go p.healthCheck() + } + return nil +} + +// OnShutdown stops all configured proxies. +func (f *Forward) OnShutdown() error { + if f.hcInterval == 0 { + return nil + } + + for _, p := range f.proxies { + p.close() + } + return nil +} + +// Close is a synonym for OnShutdown(). +func (f *Forward) Close() { + f.OnShutdown() +} + +func parseForward(c *caddy.Controller) (*Forward, error) { + f := New() + + protocols := map[int]int{} + + for c.Next() { + if !c.Args(&f.from) { + return f, c.ArgErr() + } + f.from = plugin.Host(f.from).Normalize() + + to := c.RemainingArgs() + if len(to) == 0 { + return f, c.ArgErr() + } + + // A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies. + protocols = make(map[int]int) + for i := range to { + protocols[i], to[i] = protocol(to[i]) + } + + // If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make + // any sense anymore... For now: lets don't care. + toHosts, err := dnsutil.ParseHostPortOrFile(to...) + if err != nil { + return f, err + } + + for i, h := range toHosts { + // Double check the port, if e.g. is 53 and the transport is TLS make it 853. + // This can be somewhat annoying because you *can't* have TLS on port 53 then. + switch protocols[i] { + case TLS: + h1, p, err := net.SplitHostPort(h) + if err != nil { + break + } + + // This is more of a bug in // dnsutil.ParseHostPortOrFile that defaults to + // 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence + // Fix the port number here, back to what the user intended. + if p == "53" { + h = net.JoinHostPort(h1, "853") + } + } + + // We can't set tlsConfig here, because we haven't parsed it yet. + // We set it below at the end of parseBlock. + p := NewProxy(h) + f.proxies = append(f.proxies, p) + } + + for c.NextBlock() { + if err := parseBlock(c, f); err != nil { + return f, err + } + } + } + + if f.tlsServerName != "" { + f.tlsConfig.ServerName = f.tlsServerName + } + for i := range f.proxies { + // Only set this for proxies that need it. + if protocols[i] == TLS { + f.proxies[i].SetTLSConfig(f.tlsConfig) + } + f.proxies[i].SetExpire(f.expire) + } + return f, nil +} + +func parseBlock(c *caddy.Controller, f *Forward) error { + switch c.Val() { + case "except": + ignore := c.RemainingArgs() + if len(ignore) == 0 { + return c.ArgErr() + } + for i := 0; i < len(ignore); i++ { + ignore[i] = plugin.Host(ignore[i]).Normalize() + } + f.ignored = ignore + case "max_fails": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + if n < 0 { + return fmt.Errorf("max_fails can't be negative: %d", n) + } + f.maxfails = uint32(n) + case "health_check": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + if dur < 0 { + return fmt.Errorf("health_check can't be negative: %d", dur) + } + f.hcInterval = dur + for i := range f.proxies { + f.proxies[i].hcInterval = dur + } + case "force_tcp": + if c.NextArg() { + return c.ArgErr() + } + f.forceTCP = true + for i := range f.proxies { + f.proxies[i].forceTCP = true + } + case "tls": + args := c.RemainingArgs() + if len(args) != 3 { + return c.ArgErr() + } + + tlsConfig, err := pkgtls.NewTLSConfig(args[0], args[1], args[2]) + if err != nil { + return err + } + f.tlsConfig = tlsConfig + case "tls_servername": + if !c.NextArg() { + return c.ArgErr() + } + f.tlsServerName = c.Val() + case "expire": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + if dur < 0 { + return fmt.Errorf("expire can't be negative: %s", dur) + } + f.expire = dur + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + switch x := c.Val(); x { + case "random": + f.p = &random{} + case "round_robin": + f.p = &roundRobin{} + default: + return c.Errf("unknown policy '%s'", x) + } + + default: + return c.Errf("unknown property '%s'", c.Val()) + } + + return nil +} + +const max = 15 // Maximum number of upstreams. diff --git a/plugin/forward/setup_policy_test.go b/plugin/forward/setup_policy_test.go new file mode 100644 index 000000000..8c40b9fdd --- /dev/null +++ b/plugin/forward/setup_policy_test.go @@ -0,0 +1,46 @@ +package forward + +import ( + "strings" + "testing" + + "github.com/mholt/caddy" +) + +func TestSetupPolicy(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedPolicy string + expectedErr string + }{ + // positive + {"forward . 127.0.0.1 {\npolicy random\n}\n", false, "random", ""}, + {"forward . 127.0.0.1 {\npolicy round_robin\n}\n", false, "round_robin", ""}, + // negative + {"forward . 127.0.0.1 {\npolicy random2\n}\n", true, "random", "unknown policy"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + f, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && f.p.String() != test.expectedPolicy { + t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedPolicy, f.p.String()) + } + } +} diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go new file mode 100644 index 000000000..f1776222f --- /dev/null +++ b/plugin/forward/setup_test.go @@ -0,0 +1,68 @@ +package forward + +import ( + "reflect" + "strings" + "testing" + + "github.com/mholt/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedFrom string + expectedIgnored []string + expectedFails uint32 + expectedForceTCP bool + expectedErr string + }{ + // positive + {"forward . 127.0.0.1", false, ".", nil, 2, false, ""}, + {"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, false, ""}, + {"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, false, ""}, + {"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, true, ""}, + {"forward . 127.0.0.1:53", false, ".", nil, 2, false, ""}, + {"forward . 127.0.0.1:8080", false, ".", nil, 2, false, ""}, + {"forward . [::1]:53", false, ".", nil, 2, false, ""}, + {"forward . [2003::1]:53", false, ".", nil, 2, false, ""}, + // negative + {"forward . a27.0.0.1", true, "", nil, 0, false, "not an IP"}, + {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, false, "unknown property"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + f, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && f.from != test.expectedFrom { + t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedFrom, f.from) + } + if !test.shouldErr && test.expectedIgnored != nil { + if !reflect.DeepEqual(f.ignored, test.expectedIgnored) { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedIgnored, f.ignored) + } + } + if !test.shouldErr && f.maxfails != test.expectedFails { + t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedFails, f.maxfails) + } + if !test.shouldErr && f.forceTCP != test.expectedForceTCP { + t.Errorf("Test %d: expected: %t, got: %t", i, test.expectedForceTCP, f.forceTCP) + } + } +}