plugin/forward: add it (#1447)

* plugin/forward: add it

This moves coredns/forward into CoreDNS. Fixes as a few bugs, adds a
policy option and more tests to the plugin.

Update the documentation, test IPv6 address and add persistent tests.

* Always use random policy when spraying

* include scrub fix here as well

* use correct var name

* Code review

* go vet

* Move logging to metrcs

* Small readme updates

* Fix readme
This commit is contained in:
Miek Gieben 2018-02-05 22:00:47 +00:00 committed by GitHub
parent fb1cafe5fa
commit 5b844b5017
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 1431 additions and 4 deletions

View file

@ -29,20 +29,17 @@ godeps:
(cd $(GOPATH)/src/github.com/prometheus/client_golang 2>/dev/null && git checkout -q master 2>/dev/null || true) (cd $(GOPATH)/src/github.com/prometheus/client_golang 2>/dev/null && git checkout -q master 2>/dev/null || true)
(cd $(GOPATH)/src/golang.org/x/net 2>/dev/null && git checkout -q master 2>/dev/null || true) (cd $(GOPATH)/src/golang.org/x/net 2>/dev/null && git checkout -q master 2>/dev/null || true)
(cd $(GOPATH)/src/golang.org/x/text 2>/dev/null && git checkout -q master 2>/dev/null || true) (cd $(GOPATH)/src/golang.org/x/text 2>/dev/null && git checkout -q master 2>/dev/null || true)
(cd $(GOPATH)/src/github.com/coredns/forward 2>/dev/null && git checkout -q master 2>/dev/null || true)
go get -u github.com/mholt/caddy go get -u github.com/mholt/caddy
go get -u github.com/miekg/dns go get -u github.com/miekg/dns
go get -u github.com/prometheus/client_golang/prometheus/promhttp go get -u github.com/prometheus/client_golang/prometheus/promhttp
go get -u github.com/prometheus/client_golang/prometheus go get -u github.com/prometheus/client_golang/prometheus
go get -u golang.org/x/net/context go get -u golang.org/x/net/context
go get -u golang.org/x/text go get -u golang.org/x/text
-go get -f -u github.com/coredns/forward
(cd $(GOPATH)/src/github.com/mholt/caddy && git checkout -q v0.10.10) (cd $(GOPATH)/src/github.com/mholt/caddy && git checkout -q v0.10.10)
(cd $(GOPATH)/src/github.com/miekg/dns && git checkout -q v1.0.4) (cd $(GOPATH)/src/github.com/miekg/dns && git checkout -q v1.0.4)
(cd $(GOPATH)/src/github.com/prometheus/client_golang && git checkout -q v0.8.0) (cd $(GOPATH)/src/github.com/prometheus/client_golang && git checkout -q v0.8.0)
(cd $(GOPATH)/src/golang.org/x/net && git checkout -q release-branch.go1.9) (cd $(GOPATH)/src/golang.org/x/net && git checkout -q release-branch.go1.9)
(cd $(GOPATH)/src/golang.org/x/text && git checkout -q e19ae1496984b1c655b8044a65c0300a3c878dd3) (cd $(GOPATH)/src/golang.org/x/text && git checkout -q e19ae1496984b1c655b8044a65c0300a3c878dd3)
(cd $(GOPATH)/src/github.com/coredns/forward && git checkout -q v0.0.2)
.PHONY: travis .PHONY: travis
travis: check travis: check

View file

@ -48,7 +48,7 @@ file:file
auto:auto auto:auto
secondary:secondary secondary:secondary
etcd:etcd etcd:etcd
forward:github.com/coredns/forward forward:forward
proxy:proxy proxy:proxy
erratic:erratic erratic:erratic
whoami:whoami whoami:whoami

156
plugin/forward/README.md Normal file
View file

@ -0,0 +1,156 @@
# forward
## Name
*forward* facilitates proxying DNS messages to upstream resolvers.
## Description
The *forward* plugin is generally faster (~30+%) than *proxy* as it re-uses already opened sockets
to the upstreams. It supports UDP, TCP and DNS-over-TLS and uses inband health checking that is
enabled by default.
When *all* upstreams are down it assumes healtchecking as a mechanism has failed and will try to
connect to a random upstream (which may or may not work).
## Syntax
In its most basic form, a simple forwarder uses this syntax:
~~~
forward FROM TO...
~~~
* **FROM** is the base domain to match for the request to be forwarded.
* **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify
a protocol, `tls://9.9.9.9` or `dns://` for plain DNS. The number of upstreams is limited to 15.
The health checks are done every *0.5s*. After *two* failed checks the upstream is considered
unhealthy. The health checks use a recursive DNS query (`. IN NS`) to get upstream health. Any
response that is not an error (REFUSED, NOTIMPL, SERVFAIL, etc) is taken as a healthy upstream. The
health check uses the same protocol as specific in the **TO**. On startup each upstream is marked
unhealthy until it passes a health check. A 0 duration will disable any health checks.
Multiple upstreams are randomized (default policy) on first use. When a healthy proxy returns an
error during the exchange the next upstream in the list is tried.
Extra knobs are available with an expanded syntax:
~~~
forward FROM TO... {
except IGNORED_NAMES...
force_tcp
health_check DURATION
expire DURATION
max_fails INTEGER
tls CERT KEY CA
tls_servername NAME
policy random|round_robin
}
~~~
* **FROM** and **TO...** as above.
* **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding.
Requests that match none of these names will be passed through.
* `force_tcp`, use TCP even when the request comes in over UDP.
* `health_checks`, use a different **DURATION** for health checking, the default duration is 0.5s.
A value of 0 disables the health checks completely.
* `max_fails` is the number of subsequent failed health checks that are needed before considering
a backend to be down. If 0, the backend will never be marked as down. Default is 2.
* `expire` **DURATION**, expire (cached) connections after this time, the default is 10s.
* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS; if you leave this out the
system's configuration will be used.
* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9
needs this to be set to `dns.quad9.net`.
* `policy` specifies the policy to use for selecting upstream servers. The default is `random`.
The upstream selection is done via random (default policy) selection. If the socket for this client
isn't known *forward* will randomly choose one. If this turns out to be unhealthy, the next one is
tried. If *all* hosts are down, we assume health checking is broken and select a *random* upstream to
try.
Also note the TLS config is "global" for the whole forwarding proxy if you need a different
`tls-name` for different upstreams you're out of luck.
## Metrics
If monitoring is enabled (via the *prometheus* directive) then the following metric are exported:
* `coredns_forward_request_duration_seconds{to}` - duration per upstream interaction.
* `coredns_forward_request_count_total{to}` - query count per upstream.
* `coredns_forward_response_rcode_total{to, rcode}` - count of RCODEs per upstream.
* `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream.
* `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy,
and we are randomly spraying to a target.
* `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream.
Where `to` is one of the upstream servers (**TO** from the config), `proto` is the protocol used by
the incoming query ("tcp" or "udp"), and family the transport family ("1" for IPv4, and "2" for
IPv6).
## Examples
Proxy all requests within example.org. to a nameserver running on a different port:
~~~ corefile
example.org {
forward . 127.0.0.1:9005
}
~~~
Load balance all requests between three resolvers, one of which has a IPv6 address.
~~~ corefile
. {
forward . 10.0.0.10:53 10.0.0.11:1053 [2003::1]:53
}
~~~
Forward everything except requests to `example.org`
~~~ corefile
. {
forward . 10.0.0.10:1234 {
except example.org
}
}
~~~
Proxy everything except `example.org` using the host's `resolv.conf`'s nameservers:
~~~ corefile
. {
forward . /etc/resolv.conf {
except example.org
}
}
~~~
Forward to a IPv6 host:
~~~ corefile
. {
forward . [::1]:1053
}
~~~
Proxy all requests to 9.9.9.9 using the DNS-over-TLS protocol, and cache every answer for up to 30
seconds.
~~~ corefile
. {
forward . tls://9.9.9.9 {
tls_servername dns.quad9.net
health_check 5s
}
cache 30
}
~~~
## Bugs
The TLS config is global for the whole forwarding proxy if you need a different `tls-name` for
different upstreams you're out of luck.
## Also See
[RFC 7858](https://tools.ietf.org/html/rfc7858) for DNS over TLS.

66
plugin/forward/connect.go Normal file
View file

@ -0,0 +1,66 @@
// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
// inband healthchecking.
package forward
import (
"strconv"
"time"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
start := time.Now()
proto := state.Proto()
if forceTCP {
proto = "tcp"
}
if p.host.tlsConfig != nil {
proto = "tcp-tls"
}
conn, err := p.Dial(proto)
if err != nil {
return nil, err
}
// Set buffer size correctly for this client.
conn.UDPSize = uint16(state.Size())
if conn.UDPSize < 512 {
conn.UDPSize = 512
}
conn.SetWriteDeadline(time.Now().Add(timeout))
if err := conn.WriteMsg(state.Req); err != nil {
conn.Close() // not giving it back
return nil, err
}
conn.SetReadDeadline(time.Now().Add(timeout))
ret, err := conn.ReadMsg()
if err != nil {
conn.Close() // not giving it back
return nil, err
}
p.Yield(conn)
if metric {
rc, ok := dns.RcodeToString[ret.Rcode]
if !ok {
rc = strconv.Itoa(ret.Rcode)
}
RequestCount.WithLabelValues(p.host.addr).Add(1)
RcodeCount.WithLabelValues(rc, p.host.addr).Add(1)
RequestDuration.WithLabelValues(p.host.addr).Observe(time.Since(start).Seconds())
}
return ret, nil
}

154
plugin/forward/forward.go Normal file
View file

@ -0,0 +1,154 @@
// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
// inband healthchecking.
package forward
import (
"crypto/tls"
"errors"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
ot "github.com/opentracing/opentracing-go"
"golang.org/x/net/context"
)
// Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list
// of proxies each representing one upstream proxy.
type Forward struct {
proxies []*Proxy
p Policy
from string
ignored []string
tlsConfig *tls.Config
tlsServerName string
maxfails uint32
expire time.Duration
forceTCP bool // also here for testing
hcInterval time.Duration // also here for testing
Next plugin.Handler
}
// New returns a new Forward.
func New() *Forward {
f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: hcDuration, p: new(random)}
return f
}
// SetProxy appends p to the proxy list and starts healthchecking.
func (f *Forward) SetProxy(p *Proxy) {
f.proxies = append(f.proxies, p)
go p.healthCheck()
}
// Len returns the number of configured proxies.
func (f *Forward) Len() int { return len(f.proxies) }
// Name implements plugin.Handler.
func (f *Forward) Name() string { return "forward" }
// ServeDNS implements plugin.Handler.
func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
if !f.match(state) {
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
}
fails := 0
var span, child ot.Span
span = ot.SpanFromContext(ctx)
for _, proxy := range f.list() {
if proxy.Down(f.maxfails) {
fails++
if fails < len(f.proxies) {
continue
}
// All upstream proxies are dead, assume healtcheck is completely broken and randomly
// select an upstream to connect to.
r := new(random)
proxy = r.List(f.proxies)[0]
HealthcheckBrokenCount.Add(1)
}
if span != nil {
child = span.Tracer().StartSpan("connect", ot.ChildOf(span.Context()))
ctx = ot.ContextWithSpan(ctx, child)
}
ret, err := proxy.connect(ctx, state, f.forceTCP, true)
if child != nil {
child.Finish()
}
if err != nil {
if fails < len(f.proxies) {
continue
}
break
}
ret.Compress = true
// When using force_tcp the upstream can send a message that is too big for
// the udp buffer, hence we need to truncate the message to at least make it
// fit the udp buffer.
ret, _ = state.Scrub(ret)
w.WriteMsg(ret)
return 0, nil
}
return dns.RcodeServerFailure, errNoHealthy
}
func (f *Forward) match(state request.Request) bool {
from := f.from
if !plugin.Name(from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) {
return false
}
return true
}
func (f *Forward) isAllowedDomain(name string) bool {
if dns.Name(name) == dns.Name(f.from) {
return true
}
for _, ignore := range f.ignored {
if plugin.Name(ignore).Matches(name) {
return false
}
}
return true
}
// List returns a set of proxies to be used for this client depending on the policy in f.
func (f *Forward) list() []*Proxy { return f.p.List(f.proxies) }
var (
errInvalidDomain = errors.New("invalid domain for proxy")
errNoHealthy = errors.New("no healthy proxies")
errNoForward = errors.New("no forwarder defined")
)
// policy tells forward what policy for selecting upstream it uses.
type policy int
const (
randomPolicy policy = iota
roundRobinPolicy
)

View file

@ -0,0 +1,42 @@
package forward
import (
"testing"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
func TestForward(t *testing.T) {
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
w.WriteMsg(ret)
})
defer s.Close()
p := NewProxy(s.Addr)
f := New()
f.SetProxy(p)
defer f.Close()
state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
state.Req.SetQuestion("example.org.", dns.TypeA)
resp, err := f.Forward(state)
if err != nil {
t.Fatal("Expected to receive reply, but didn't")
}
// expect answer section with A record in it
if len(resp.Answer) == 0 {
t.Fatalf("Expected to at least one RR in the answer section, got none: %s", resp)
}
if resp.Answer[0].Header().Rrtype != dns.TypeA {
t.Errorf("Expected RR to A, got: %d", resp.Answer[0].Header().Rrtype)
}
if resp.Answer[0].(*dns.A).A.String() != "127.0.0.1" {
t.Errorf("Expected 127.0.0.1, got: %s", resp.Answer[0].(*dns.A).A.String())
}
}

67
plugin/forward/health.go Normal file
View file

@ -0,0 +1,67 @@
package forward
import (
"log"
"sync/atomic"
"github.com/miekg/dns"
)
// For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty
// replies are considered fails, basically anything else constitutes a healthy upstream.
func (h *host) Check() {
h.Lock()
if h.checking {
h.Unlock()
return
}
h.checking = true
h.Unlock()
err := h.send()
if err != nil {
log.Printf("[INFO] healtheck of %s failed with %s", h.addr, err)
HealthcheckFailureCount.WithLabelValues(h.addr).Add(1)
atomic.AddUint32(&h.fails, 1)
} else {
atomic.StoreUint32(&h.fails, 0)
}
h.Lock()
h.checking = false
h.Unlock()
return
}
func (h *host) send() error {
hcping := new(dns.Msg)
hcping.SetQuestion(".", dns.TypeNS)
hcping.RecursionDesired = false
m, _, err := h.client.Exchange(hcping, h.addr)
// If we got a header, we're alright, basically only care about I/O errors 'n stuff
if err != nil && m != nil {
// Silly check, something sane came back
if m.Response || m.Opcode == dns.OpcodeQuery {
err = nil
}
}
return err
}
// down returns true is this host has more than maxfails fails.
func (h *host) down(maxfails uint32) bool {
if maxfails == 0 {
return false
}
fails := atomic.LoadUint32(&h.fails)
return fails > maxfails
}

44
plugin/forward/host.go Normal file
View file

@ -0,0 +1,44 @@
package forward
import (
"crypto/tls"
"sync"
"time"
"github.com/miekg/dns"
)
type host struct {
addr string
client *dns.Client
tlsConfig *tls.Config
expire time.Duration
fails uint32
sync.RWMutex
checking bool
}
// newHost returns a new host, the fails are set to 1, i.e.
// the first healthcheck must succeed before we use this host.
func newHost(addr string) *host {
return &host{addr: addr, fails: 1, expire: defaultExpire}
}
// setClient sets and configures the dns.Client in host.
func (h *host) SetClient() {
c := new(dns.Client)
c.Net = "udp"
c.ReadTimeout = 2 * time.Second
c.WriteTimeout = 2 * time.Second
if h.tlsConfig != nil {
c.Net = "tcp-tls"
c.TLSConfig = h.tlsConfig
}
h.client = c
}
const defaultExpire = 10 * time.Second

78
plugin/forward/lookup.go Normal file
View file

@ -0,0 +1,78 @@
// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
// inband healthchecking.
package forward
import (
"crypto/tls"
"log"
"time"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// Forward forward the request in state as-is. Unlike Lookup that adds EDNS0 suffix to the message.
// Forward may be called with a nil f, an error is returned in that case.
func (f *Forward) Forward(state request.Request) (*dns.Msg, error) {
if f == nil {
return nil, errNoForward
}
fails := 0
for _, proxy := range f.list() {
if proxy.Down(f.maxfails) {
fails++
if fails < len(f.proxies) {
continue
}
// All upstream proxies are dead, assume healtcheck is complete broken and randomly
// select an upstream to connect to.
proxy = f.list()[0]
log.Printf("[WARNING] All upstreams down, picking random one to connect to %s", proxy.host.addr)
}
ret, err := proxy.connect(context.Background(), state, f.forceTCP, true)
if err != nil {
log.Printf("[WARNING] Failed to connect to %s: %s", proxy.host.addr, err)
if fails < len(f.proxies) {
continue
}
break
}
return ret, nil
}
return nil, errNoHealthy
}
// Lookup will use name and type to forge a new message and will send that upstream. It will
// set any EDNS0 options correctly so that downstream will be able to process the reply.
// Lookup may be called with a nil f, an error is returned in that case.
func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.Msg, error) {
if f == nil {
return nil, errNoForward
}
req := new(dns.Msg)
req.SetQuestion(name, typ)
state.SizeAndDo(req)
state2 := request.Request{W: state.W, Req: req}
return f.Forward(state2)
}
// NewLookup returns a Forward that can be used for plugin that need an upstream to resolve external names.
func NewLookup(addr []string) *Forward {
f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: 2 * time.Second}
for i := range addr {
p := NewProxy(addr[i])
f.SetProxy(p)
}
return f
}

View file

@ -0,0 +1,41 @@
package forward
import (
"testing"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
func TestLookup(t *testing.T) {
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
w.WriteMsg(ret)
})
defer s.Close()
p := NewProxy(s.Addr)
f := New()
f.SetProxy(p)
defer f.Close()
state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
resp, err := f.Lookup(state, "example.org.", dns.TypeA)
if err != nil {
t.Fatal("Expected to receive reply, but didn't")
}
// expect answer section with A record in it
if len(resp.Answer) == 0 {
t.Fatalf("Expected to at least one RR in the answer section, got none: %s", resp)
}
if resp.Answer[0].Header().Rrtype != dns.TypeA {
t.Errorf("Expected RR to A, got: %d", resp.Answer[0].Header().Rrtype)
}
if resp.Answer[0].(*dns.A).A.String() != "127.0.0.1" {
t.Errorf("Expected 127.0.0.1, got: %s", resp.Answer[0].(*dns.A).A.String())
}
}

52
plugin/forward/metrics.go Normal file
View file

@ -0,0 +1,52 @@
package forward
import (
"sync"
"github.com/coredns/coredns/plugin"
"github.com/prometheus/client_golang/prometheus"
)
// Variables declared for monitoring.
var (
RequestCount = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "forward",
Name: "request_count_total",
Help: "Counter of requests made per upstream.",
}, []string{"to"})
RcodeCount = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "forward",
Name: "response_rcode_count_total",
Help: "Counter of requests made per upstream.",
}, []string{"rcode", "to"})
RequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: plugin.Namespace,
Subsystem: "forward",
Name: "request_duration_seconds",
Buckets: plugin.TimeBuckets,
Help: "Histogram of the time each request took.",
}, []string{"to"})
HealthcheckFailureCount = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "forward",
Name: "healthcheck_failure_count_total",
Help: "Counter of the number of failed healtchecks.",
}, []string{"to"})
HealthcheckBrokenCount = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "forward",
Name: "healthcheck_broken_count_total",
Help: "Counter of the number of complete failures of the healtchecks.",
})
SocketGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: plugin.Namespace,
Subsystem: "forward",
Name: "socket_count_total",
Help: "Guage of open sockets per upstream.",
}, []string{"to"})
)
var once sync.Once

View file

@ -0,0 +1,148 @@
package forward
import (
"net"
"time"
"github.com/miekg/dns"
)
// a persistConn hold the dns.Conn and the last used time.
type persistConn struct {
c *dns.Conn
used time.Time
}
// connErr is used to communicate the connection manager.
type connErr struct {
c *dns.Conn
err error
}
// transport hold the persistent cache.
type transport struct {
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
host *host
dial chan string
yield chan connErr
ret chan connErr
// Aid in testing, gets length of cache in data-race safe manner.
lenc chan bool
lencOut chan int
stop chan bool
}
func newTransport(h *host) *transport {
t := &transport{
conns: make(map[string][]*persistConn),
host: h,
dial: make(chan string),
yield: make(chan connErr),
ret: make(chan connErr),
stop: make(chan bool),
lenc: make(chan bool),
lencOut: make(chan int),
}
go t.connManager()
return t
}
// len returns the number of connection, used for metrics. Can only be safely
// used inside connManager() because of races.
func (t *transport) len() int {
l := 0
for _, conns := range t.conns {
l += len(conns)
}
return l
}
// Len returns the number of connections in the cache.
func (t *transport) Len() int {
t.lenc <- true
l := <-t.lencOut
return l
}
// connManagers manages the persistent connection cache for UDP and TCP.
func (t *transport) connManager() {
Wait:
for {
select {
case proto := <-t.dial:
// Yes O(n), shouldn't put millions in here. We walk all connection until we find the first
// one that is usuable.
i := 0
for i = 0; i < len(t.conns[proto]); i++ {
pc := t.conns[proto][i]
if time.Since(pc.used) < t.host.expire {
// Found one, remove from pool and return this conn.
t.conns[proto] = t.conns[proto][i+1:]
t.ret <- connErr{pc.c, nil}
continue Wait
}
// This conn has expired. Close it.
pc.c.Close()
}
// Not conns were found. Connect to the upstream to create one.
t.conns[proto] = t.conns[proto][i:]
SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len()))
go func() {
if proto != "tcp-tls" {
c, err := dns.DialTimeout(proto, t.host.addr, dialTimeout)
t.ret <- connErr{c, err}
return
}
c, err := dns.DialTimeoutWithTLS("tcp", t.host.addr, t.host.tlsConfig, dialTimeout)
t.ret <- connErr{c, err}
}()
case conn := <-t.yield:
SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len() + 1))
// no proto here, infer from config and conn
if _, ok := conn.c.Conn.(*net.UDPConn); ok {
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()})
continue Wait
}
if t.host.tlsConfig == nil {
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()})
continue Wait
}
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()})
case <-t.stop:
return
case <-t.lenc:
l := 0
for _, conns := range t.conns {
l += len(conns)
}
t.lencOut <- l
}
}
}
func (t *transport) Dial(proto string) (*dns.Conn, error) {
t.dial <- proto
c := <-t.ret
return c.c, c.err
}
func (t *transport) Yield(c *dns.Conn) {
t.yield <- connErr{c, nil}
}
// Stop stops the transports.
func (t *transport) Stop() { t.stop <- true }

View file

@ -0,0 +1,44 @@
package forward
import (
"testing"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/miekg/dns"
)
func TestPersistent(t *testing.T) {
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
w.WriteMsg(ret)
})
defer s.Close()
h := newHost(s.Addr)
tr := newTransport(h)
defer tr.Stop()
c1, _ := tr.Dial("udp")
c2, _ := tr.Dial("udp")
c3, _ := tr.Dial("udp")
tr.Yield(c1)
tr.Yield(c2)
tr.Yield(c3)
if x := tr.Len(); x != 3 {
t.Errorf("Expected cache size to be 3, got %d", x)
}
tr.Dial("udp")
if x := tr.Len(); x != 2 {
t.Errorf("Expected cache size to be 2, got %d", x)
}
tr.Dial("udp")
if x := tr.Len(); x != 1 {
t.Errorf("Expected cache size to be 2, got %d", x)
}
}

55
plugin/forward/policy.go Normal file
View file

@ -0,0 +1,55 @@
package forward
import (
"math/rand"
"sync/atomic"
)
// Policy defines a policy we use for selecting upstreams.
type Policy interface {
List([]*Proxy) []*Proxy
String() string
}
// random is a policy that implements random upstream selection.
type random struct{}
func (r *random) String() string { return "random" }
func (r *random) List(p []*Proxy) []*Proxy {
switch len(p) {
case 1:
return p
case 2:
if rand.Int()%2 == 0 {
return []*Proxy{p[1], p[0]} // swap
}
return p
}
perms := rand.Perm(len(p))
rnd := make([]*Proxy, len(p))
for i, p1 := range perms {
rnd[i] = p[p1]
}
return rnd
}
// roundRobin is a policy that selects hosts based on round robin ordering.
type roundRobin struct {
robin uint32
}
func (r *roundRobin) String() string { return "round_robin" }
func (r *roundRobin) List(p []*Proxy) []*Proxy {
poolLen := uint32(len(p))
i := atomic.AddUint32(&r.robin, 1) % poolLen
robin := []*Proxy{p[i]}
robin = append(robin, p[:i]...)
robin = append(robin, p[i+1:]...)
return robin
}

View file

@ -0,0 +1,30 @@
package forward
// Copied from coredns/core/dnsserver/address.go
import (
"strings"
)
// protocol returns the protocol of the string s. The second string returns s
// with the prefix chopped off.
func protocol(s string) (int, string) {
switch {
case strings.HasPrefix(s, _tls+"://"):
return TLS, s[len(_tls)+3:]
case strings.HasPrefix(s, _dns+"://"):
return DNS, s[len(_dns)+3:]
}
return DNS, s
}
// Supported protocols.
const (
DNS = iota + 1
TLS
)
const (
_dns = "dns"
_tls = "tls"
)

77
plugin/forward/proxy.go Normal file
View file

@ -0,0 +1,77 @@
package forward
import (
"crypto/tls"
"sync"
"time"
"github.com/miekg/dns"
)
// Proxy defines an upstream host.
type Proxy struct {
host *host
transport *transport
// copied from Forward.
hcInterval time.Duration
forceTCP bool
stop chan bool
sync.RWMutex
}
// NewProxy returns a new proxy.
func NewProxy(addr string) *Proxy {
host := newHost(addr)
p := &Proxy{
host: host,
hcInterval: hcDuration,
stop: make(chan bool),
transport: newTransport(host),
}
return p
}
// SetTLSConfig sets the TLS config in the lower p.host.
func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.host.tlsConfig = cfg }
// SetExpire sets the expire duration in the lower p.host.
func (p *Proxy) SetExpire(expire time.Duration) { p.host.expire = expire }
func (p *Proxy) close() { p.stop <- true }
// Dial connects to the host in p with the configured transport.
func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) }
// Yield returns the connection to the pool.
func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) }
// Down returns if this proxy is up or down.
func (p *Proxy) Down(maxfails uint32) bool { return p.host.down(maxfails) }
func (p *Proxy) healthCheck() {
// stop channel
p.host.SetClient()
p.host.Check()
tick := time.NewTicker(p.hcInterval)
for {
select {
case <-tick.C:
p.host.Check()
case <-p.stop:
return
}
}
}
const (
dialTimeout = 4 * time.Second
timeout = 2 * time.Second
hcDuration = 500 * time.Millisecond
)

262
plugin/forward/setup.go Normal file
View file

@ -0,0 +1,262 @@
package forward
import (
"fmt"
"net"
"strconv"
"time"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
"github.com/mholt/caddy"
)
func init() {
caddy.RegisterPlugin("forward", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
func setup(c *caddy.Controller) error {
f, err := parseForward(c)
if err != nil {
return plugin.Error("foward", err)
}
if f.Len() > max {
return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len()))
}
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
f.Next = next
return f
})
c.OnStartup(func() error {
once.Do(func() {
m := dnsserver.GetConfig(c).Handler("prometheus")
if m == nil {
return
}
if x, ok := m.(*metrics.Metrics); ok {
x.MustRegister(RequestCount)
x.MustRegister(RcodeCount)
x.MustRegister(RequestDuration)
x.MustRegister(HealthcheckFailureCount)
x.MustRegister(SocketGauge)
}
})
return f.OnStartup()
})
c.OnShutdown(func() error {
return f.OnShutdown()
})
return nil
}
// OnStartup starts a goroutines for all proxies.
func (f *Forward) OnStartup() (err error) {
if f.hcInterval == 0 {
for _, p := range f.proxies {
p.host.fails = 0
}
return nil
}
for _, p := range f.proxies {
go p.healthCheck()
}
return nil
}
// OnShutdown stops all configured proxies.
func (f *Forward) OnShutdown() error {
if f.hcInterval == 0 {
return nil
}
for _, p := range f.proxies {
p.close()
}
return nil
}
// Close is a synonym for OnShutdown().
func (f *Forward) Close() {
f.OnShutdown()
}
func parseForward(c *caddy.Controller) (*Forward, error) {
f := New()
protocols := map[int]int{}
for c.Next() {
if !c.Args(&f.from) {
return f, c.ArgErr()
}
f.from = plugin.Host(f.from).Normalize()
to := c.RemainingArgs()
if len(to) == 0 {
return f, c.ArgErr()
}
// A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies.
protocols = make(map[int]int)
for i := range to {
protocols[i], to[i] = protocol(to[i])
}
// If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make
// any sense anymore... For now: lets don't care.
toHosts, err := dnsutil.ParseHostPortOrFile(to...)
if err != nil {
return f, err
}
for i, h := range toHosts {
// Double check the port, if e.g. is 53 and the transport is TLS make it 853.
// This can be somewhat annoying because you *can't* have TLS on port 53 then.
switch protocols[i] {
case TLS:
h1, p, err := net.SplitHostPort(h)
if err != nil {
break
}
// This is more of a bug in // dnsutil.ParseHostPortOrFile that defaults to
// 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence
// Fix the port number here, back to what the user intended.
if p == "53" {
h = net.JoinHostPort(h1, "853")
}
}
// We can't set tlsConfig here, because we haven't parsed it yet.
// We set it below at the end of parseBlock.
p := NewProxy(h)
f.proxies = append(f.proxies, p)
}
for c.NextBlock() {
if err := parseBlock(c, f); err != nil {
return f, err
}
}
}
if f.tlsServerName != "" {
f.tlsConfig.ServerName = f.tlsServerName
}
for i := range f.proxies {
// Only set this for proxies that need it.
if protocols[i] == TLS {
f.proxies[i].SetTLSConfig(f.tlsConfig)
}
f.proxies[i].SetExpire(f.expire)
}
return f, nil
}
func parseBlock(c *caddy.Controller, f *Forward) error {
switch c.Val() {
case "except":
ignore := c.RemainingArgs()
if len(ignore) == 0 {
return c.ArgErr()
}
for i := 0; i < len(ignore); i++ {
ignore[i] = plugin.Host(ignore[i]).Normalize()
}
f.ignored = ignore
case "max_fails":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.Atoi(c.Val())
if err != nil {
return err
}
if n < 0 {
return fmt.Errorf("max_fails can't be negative: %d", n)
}
f.maxfails = uint32(n)
case "health_check":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
if dur < 0 {
return fmt.Errorf("health_check can't be negative: %d", dur)
}
f.hcInterval = dur
for i := range f.proxies {
f.proxies[i].hcInterval = dur
}
case "force_tcp":
if c.NextArg() {
return c.ArgErr()
}
f.forceTCP = true
for i := range f.proxies {
f.proxies[i].forceTCP = true
}
case "tls":
args := c.RemainingArgs()
if len(args) != 3 {
return c.ArgErr()
}
tlsConfig, err := pkgtls.NewTLSConfig(args[0], args[1], args[2])
if err != nil {
return err
}
f.tlsConfig = tlsConfig
case "tls_servername":
if !c.NextArg() {
return c.ArgErr()
}
f.tlsServerName = c.Val()
case "expire":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
if dur < 0 {
return fmt.Errorf("expire can't be negative: %s", dur)
}
f.expire = dur
case "policy":
if !c.NextArg() {
return c.ArgErr()
}
switch x := c.Val(); x {
case "random":
f.p = &random{}
case "round_robin":
f.p = &roundRobin{}
default:
return c.Errf("unknown policy '%s'", x)
}
default:
return c.Errf("unknown property '%s'", c.Val())
}
return nil
}
const max = 15 // Maximum number of upstreams.

View file

@ -0,0 +1,46 @@
package forward
import (
"strings"
"testing"
"github.com/mholt/caddy"
)
func TestSetupPolicy(t *testing.T) {
tests := []struct {
input string
shouldErr bool
expectedPolicy string
expectedErr string
}{
// positive
{"forward . 127.0.0.1 {\npolicy random\n}\n", false, "random", ""},
{"forward . 127.0.0.1 {\npolicy round_robin\n}\n", false, "round_robin", ""},
// negative
{"forward . 127.0.0.1 {\npolicy random2\n}\n", true, "random", "unknown policy"},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
f, err := parseForward(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input)
}
if err != nil {
if !test.shouldErr {
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
}
if !strings.Contains(err.Error(), test.expectedErr) {
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
}
}
if !test.shouldErr && f.p.String() != test.expectedPolicy {
t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedPolicy, f.p.String())
}
}
}

View file

@ -0,0 +1,68 @@
package forward
import (
"reflect"
"strings"
"testing"
"github.com/mholt/caddy"
)
func TestSetup(t *testing.T) {
tests := []struct {
input string
shouldErr bool
expectedFrom string
expectedIgnored []string
expectedFails uint32
expectedForceTCP bool
expectedErr string
}{
// positive
{"forward . 127.0.0.1", false, ".", nil, 2, false, ""},
{"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, false, ""},
{"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, false, ""},
{"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, true, ""},
{"forward . 127.0.0.1:53", false, ".", nil, 2, false, ""},
{"forward . 127.0.0.1:8080", false, ".", nil, 2, false, ""},
{"forward . [::1]:53", false, ".", nil, 2, false, ""},
{"forward . [2003::1]:53", false, ".", nil, 2, false, ""},
// negative
{"forward . a27.0.0.1", true, "", nil, 0, false, "not an IP"},
{"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, false, "unknown property"},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
f, err := parseForward(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input)
}
if err != nil {
if !test.shouldErr {
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
}
if !strings.Contains(err.Error(), test.expectedErr) {
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
}
}
if !test.shouldErr && f.from != test.expectedFrom {
t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedFrom, f.from)
}
if !test.shouldErr && test.expectedIgnored != nil {
if !reflect.DeepEqual(f.ignored, test.expectedIgnored) {
t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedIgnored, f.ignored)
}
}
if !test.shouldErr && f.maxfails != test.expectedFails {
t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedFails, f.maxfails)
}
if !test.shouldErr && f.forceTCP != test.expectedForceTCP {
t.Errorf("Test %d: expected: %t, got: %t", i, test.expectedForceTCP, f.forceTCP)
}
}
}