plugin/forward: add it (#1447)
* plugin/forward: add it This moves coredns/forward into CoreDNS. Fixes as a few bugs, adds a policy option and more tests to the plugin. Update the documentation, test IPv6 address and add persistent tests. * Always use random policy when spraying * include scrub fix here as well * use correct var name * Code review * go vet * Move logging to metrcs * Small readme updates * Fix readme
This commit is contained in:
parent
fb1cafe5fa
commit
5b844b5017
19 changed files with 1431 additions and 4 deletions
3
Makefile
3
Makefile
|
@ -29,20 +29,17 @@ godeps:
|
|||
(cd $(GOPATH)/src/github.com/prometheus/client_golang 2>/dev/null && git checkout -q master 2>/dev/null || true)
|
||||
(cd $(GOPATH)/src/golang.org/x/net 2>/dev/null && git checkout -q master 2>/dev/null || true)
|
||||
(cd $(GOPATH)/src/golang.org/x/text 2>/dev/null && git checkout -q master 2>/dev/null || true)
|
||||
(cd $(GOPATH)/src/github.com/coredns/forward 2>/dev/null && git checkout -q master 2>/dev/null || true)
|
||||
go get -u github.com/mholt/caddy
|
||||
go get -u github.com/miekg/dns
|
||||
go get -u github.com/prometheus/client_golang/prometheus/promhttp
|
||||
go get -u github.com/prometheus/client_golang/prometheus
|
||||
go get -u golang.org/x/net/context
|
||||
go get -u golang.org/x/text
|
||||
-go get -f -u github.com/coredns/forward
|
||||
(cd $(GOPATH)/src/github.com/mholt/caddy && git checkout -q v0.10.10)
|
||||
(cd $(GOPATH)/src/github.com/miekg/dns && git checkout -q v1.0.4)
|
||||
(cd $(GOPATH)/src/github.com/prometheus/client_golang && git checkout -q v0.8.0)
|
||||
(cd $(GOPATH)/src/golang.org/x/net && git checkout -q release-branch.go1.9)
|
||||
(cd $(GOPATH)/src/golang.org/x/text && git checkout -q e19ae1496984b1c655b8044a65c0300a3c878dd3)
|
||||
(cd $(GOPATH)/src/github.com/coredns/forward && git checkout -q v0.0.2)
|
||||
|
||||
.PHONY: travis
|
||||
travis: check
|
||||
|
|
|
@ -48,7 +48,7 @@ file:file
|
|||
auto:auto
|
||||
secondary:secondary
|
||||
etcd:etcd
|
||||
forward:github.com/coredns/forward
|
||||
forward:forward
|
||||
proxy:proxy
|
||||
erratic:erratic
|
||||
whoami:whoami
|
||||
|
|
156
plugin/forward/README.md
Normal file
156
plugin/forward/README.md
Normal file
|
@ -0,0 +1,156 @@
|
|||
# forward
|
||||
|
||||
## Name
|
||||
|
||||
*forward* facilitates proxying DNS messages to upstream resolvers.
|
||||
|
||||
## Description
|
||||
|
||||
The *forward* plugin is generally faster (~30+%) than *proxy* as it re-uses already opened sockets
|
||||
to the upstreams. It supports UDP, TCP and DNS-over-TLS and uses inband health checking that is
|
||||
enabled by default.
|
||||
When *all* upstreams are down it assumes healtchecking as a mechanism has failed and will try to
|
||||
connect to a random upstream (which may or may not work).
|
||||
|
||||
## Syntax
|
||||
|
||||
In its most basic form, a simple forwarder uses this syntax:
|
||||
|
||||
~~~
|
||||
forward FROM TO...
|
||||
~~~
|
||||
|
||||
* **FROM** is the base domain to match for the request to be forwarded.
|
||||
* **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify
|
||||
a protocol, `tls://9.9.9.9` or `dns://` for plain DNS. The number of upstreams is limited to 15.
|
||||
|
||||
The health checks are done every *0.5s*. After *two* failed checks the upstream is considered
|
||||
unhealthy. The health checks use a recursive DNS query (`. IN NS`) to get upstream health. Any
|
||||
response that is not an error (REFUSED, NOTIMPL, SERVFAIL, etc) is taken as a healthy upstream. The
|
||||
health check uses the same protocol as specific in the **TO**. On startup each upstream is marked
|
||||
unhealthy until it passes a health check. A 0 duration will disable any health checks.
|
||||
|
||||
Multiple upstreams are randomized (default policy) on first use. When a healthy proxy returns an
|
||||
error during the exchange the next upstream in the list is tried.
|
||||
|
||||
Extra knobs are available with an expanded syntax:
|
||||
|
||||
~~~
|
||||
forward FROM TO... {
|
||||
except IGNORED_NAMES...
|
||||
force_tcp
|
||||
health_check DURATION
|
||||
expire DURATION
|
||||
max_fails INTEGER
|
||||
tls CERT KEY CA
|
||||
tls_servername NAME
|
||||
policy random|round_robin
|
||||
}
|
||||
~~~
|
||||
|
||||
* **FROM** and **TO...** as above.
|
||||
* **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding.
|
||||
Requests that match none of these names will be passed through.
|
||||
* `force_tcp`, use TCP even when the request comes in over UDP.
|
||||
* `health_checks`, use a different **DURATION** for health checking, the default duration is 0.5s.
|
||||
A value of 0 disables the health checks completely.
|
||||
* `max_fails` is the number of subsequent failed health checks that are needed before considering
|
||||
a backend to be down. If 0, the backend will never be marked as down. Default is 2.
|
||||
* `expire` **DURATION**, expire (cached) connections after this time, the default is 10s.
|
||||
* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS; if you leave this out the
|
||||
system's configuration will be used.
|
||||
* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9
|
||||
needs this to be set to `dns.quad9.net`.
|
||||
* `policy` specifies the policy to use for selecting upstream servers. The default is `random`.
|
||||
|
||||
The upstream selection is done via random (default policy) selection. If the socket for this client
|
||||
isn't known *forward* will randomly choose one. If this turns out to be unhealthy, the next one is
|
||||
tried. If *all* hosts are down, we assume health checking is broken and select a *random* upstream to
|
||||
try.
|
||||
|
||||
Also note the TLS config is "global" for the whole forwarding proxy if you need a different
|
||||
`tls-name` for different upstreams you're out of luck.
|
||||
|
||||
## Metrics
|
||||
|
||||
If monitoring is enabled (via the *prometheus* directive) then the following metric are exported:
|
||||
|
||||
* `coredns_forward_request_duration_seconds{to}` - duration per upstream interaction.
|
||||
* `coredns_forward_request_count_total{to}` - query count per upstream.
|
||||
* `coredns_forward_response_rcode_total{to, rcode}` - count of RCODEs per upstream.
|
||||
* `coredns_forward_healthcheck_failure_count_total{to}` - number of failed health checks per upstream.
|
||||
* `coredns_forward_healthcheck_broken_count_total{}` - counter of when all upstreams are unhealthy,
|
||||
and we are randomly spraying to a target.
|
||||
* `coredns_forward_socket_count_total{to}` - number of cached sockets per upstream.
|
||||
|
||||
Where `to` is one of the upstream servers (**TO** from the config), `proto` is the protocol used by
|
||||
the incoming query ("tcp" or "udp"), and family the transport family ("1" for IPv4, and "2" for
|
||||
IPv6).
|
||||
|
||||
## Examples
|
||||
|
||||
Proxy all requests within example.org. to a nameserver running on a different port:
|
||||
|
||||
~~~ corefile
|
||||
example.org {
|
||||
forward . 127.0.0.1:9005
|
||||
}
|
||||
~~~
|
||||
|
||||
Load balance all requests between three resolvers, one of which has a IPv6 address.
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
forward . 10.0.0.10:53 10.0.0.11:1053 [2003::1]:53
|
||||
}
|
||||
~~~
|
||||
|
||||
Forward everything except requests to `example.org`
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
forward . 10.0.0.10:1234 {
|
||||
except example.org
|
||||
}
|
||||
}
|
||||
~~~
|
||||
|
||||
Proxy everything except `example.org` using the host's `resolv.conf`'s nameservers:
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
forward . /etc/resolv.conf {
|
||||
except example.org
|
||||
}
|
||||
}
|
||||
~~~
|
||||
|
||||
Forward to a IPv6 host:
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
forward . [::1]:1053
|
||||
}
|
||||
~~~
|
||||
|
||||
Proxy all requests to 9.9.9.9 using the DNS-over-TLS protocol, and cache every answer for up to 30
|
||||
seconds.
|
||||
|
||||
~~~ corefile
|
||||
. {
|
||||
forward . tls://9.9.9.9 {
|
||||
tls_servername dns.quad9.net
|
||||
health_check 5s
|
||||
}
|
||||
cache 30
|
||||
}
|
||||
~~~
|
||||
|
||||
## Bugs
|
||||
|
||||
The TLS config is global for the whole forwarding proxy if you need a different `tls-name` for
|
||||
different upstreams you're out of luck.
|
||||
|
||||
## Also See
|
||||
|
||||
[RFC 7858](https://tools.ietf.org/html/rfc7858) for DNS over TLS.
|
66
plugin/forward/connect.go
Normal file
66
plugin/forward/connect.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
|
||||
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
|
||||
// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
|
||||
// inband healthchecking.
|
||||
package forward
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func (p *Proxy) connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
|
||||
start := time.Now()
|
||||
|
||||
proto := state.Proto()
|
||||
if forceTCP {
|
||||
proto = "tcp"
|
||||
}
|
||||
if p.host.tlsConfig != nil {
|
||||
proto = "tcp-tls"
|
||||
}
|
||||
|
||||
conn, err := p.Dial(proto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set buffer size correctly for this client.
|
||||
conn.UDPSize = uint16(state.Size())
|
||||
if conn.UDPSize < 512 {
|
||||
conn.UDPSize = 512
|
||||
}
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
if err := conn.WriteMsg(state.Req); err != nil {
|
||||
conn.Close() // not giving it back
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
ret, err := conn.ReadMsg()
|
||||
if err != nil {
|
||||
conn.Close() // not giving it back
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.Yield(conn)
|
||||
|
||||
if metric {
|
||||
rc, ok := dns.RcodeToString[ret.Rcode]
|
||||
if !ok {
|
||||
rc = strconv.Itoa(ret.Rcode)
|
||||
}
|
||||
|
||||
RequestCount.WithLabelValues(p.host.addr).Add(1)
|
||||
RcodeCount.WithLabelValues(rc, p.host.addr).Add(1)
|
||||
RequestDuration.WithLabelValues(p.host.addr).Observe(time.Since(start).Seconds())
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
154
plugin/forward/forward.go
Normal file
154
plugin/forward/forward.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
|
||||
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
|
||||
// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
|
||||
// inband healthchecking.
|
||||
package forward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
ot "github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list
|
||||
// of proxies each representing one upstream proxy.
|
||||
type Forward struct {
|
||||
proxies []*Proxy
|
||||
p Policy
|
||||
|
||||
from string
|
||||
ignored []string
|
||||
|
||||
tlsConfig *tls.Config
|
||||
tlsServerName string
|
||||
maxfails uint32
|
||||
expire time.Duration
|
||||
|
||||
forceTCP bool // also here for testing
|
||||
hcInterval time.Duration // also here for testing
|
||||
|
||||
Next plugin.Handler
|
||||
}
|
||||
|
||||
// New returns a new Forward.
|
||||
func New() *Forward {
|
||||
f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: hcDuration, p: new(random)}
|
||||
return f
|
||||
}
|
||||
|
||||
// SetProxy appends p to the proxy list and starts healthchecking.
|
||||
func (f *Forward) SetProxy(p *Proxy) {
|
||||
f.proxies = append(f.proxies, p)
|
||||
go p.healthCheck()
|
||||
}
|
||||
|
||||
// Len returns the number of configured proxies.
|
||||
func (f *Forward) Len() int { return len(f.proxies) }
|
||||
|
||||
// Name implements plugin.Handler.
|
||||
func (f *Forward) Name() string { return "forward" }
|
||||
|
||||
// ServeDNS implements plugin.Handler.
|
||||
func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
|
||||
state := request.Request{W: w, Req: r}
|
||||
if !f.match(state) {
|
||||
return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r)
|
||||
}
|
||||
|
||||
fails := 0
|
||||
var span, child ot.Span
|
||||
span = ot.SpanFromContext(ctx)
|
||||
|
||||
for _, proxy := range f.list() {
|
||||
if proxy.Down(f.maxfails) {
|
||||
fails++
|
||||
if fails < len(f.proxies) {
|
||||
continue
|
||||
}
|
||||
// All upstream proxies are dead, assume healtcheck is completely broken and randomly
|
||||
// select an upstream to connect to.
|
||||
r := new(random)
|
||||
proxy = r.List(f.proxies)[0]
|
||||
|
||||
HealthcheckBrokenCount.Add(1)
|
||||
}
|
||||
|
||||
if span != nil {
|
||||
child = span.Tracer().StartSpan("connect", ot.ChildOf(span.Context()))
|
||||
ctx = ot.ContextWithSpan(ctx, child)
|
||||
}
|
||||
|
||||
ret, err := proxy.connect(ctx, state, f.forceTCP, true)
|
||||
|
||||
if child != nil {
|
||||
child.Finish()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if fails < len(f.proxies) {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
ret.Compress = true
|
||||
// When using force_tcp the upstream can send a message that is too big for
|
||||
// the udp buffer, hence we need to truncate the message to at least make it
|
||||
// fit the udp buffer.
|
||||
ret, _ = state.Scrub(ret)
|
||||
|
||||
w.WriteMsg(ret)
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return dns.RcodeServerFailure, errNoHealthy
|
||||
}
|
||||
|
||||
func (f *Forward) match(state request.Request) bool {
|
||||
from := f.from
|
||||
|
||||
if !plugin.Name(from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forward) isAllowedDomain(name string) bool {
|
||||
if dns.Name(name) == dns.Name(f.from) {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, ignore := range f.ignored {
|
||||
if plugin.Name(ignore).Matches(name) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// List returns a set of proxies to be used for this client depending on the policy in f.
|
||||
func (f *Forward) list() []*Proxy { return f.p.List(f.proxies) }
|
||||
|
||||
var (
|
||||
errInvalidDomain = errors.New("invalid domain for proxy")
|
||||
errNoHealthy = errors.New("no healthy proxies")
|
||||
errNoForward = errors.New("no forwarder defined")
|
||||
)
|
||||
|
||||
// policy tells forward what policy for selecting upstream it uses.
|
||||
type policy int
|
||||
|
||||
const (
|
||||
randomPolicy policy = iota
|
||||
roundRobinPolicy
|
||||
)
|
42
plugin/forward/forward_test.go
Normal file
42
plugin/forward/forward_test.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/coredns/coredns/request"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestForward(t *testing.T) {
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr)
|
||||
f := New()
|
||||
f.SetProxy(p)
|
||||
defer f.Close()
|
||||
|
||||
state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
|
||||
state.Req.SetQuestion("example.org.", dns.TypeA)
|
||||
resp, err := f.Forward(state)
|
||||
if err != nil {
|
||||
t.Fatal("Expected to receive reply, but didn't")
|
||||
}
|
||||
// expect answer section with A record in it
|
||||
if len(resp.Answer) == 0 {
|
||||
t.Fatalf("Expected to at least one RR in the answer section, got none: %s", resp)
|
||||
}
|
||||
if resp.Answer[0].Header().Rrtype != dns.TypeA {
|
||||
t.Errorf("Expected RR to A, got: %d", resp.Answer[0].Header().Rrtype)
|
||||
}
|
||||
if resp.Answer[0].(*dns.A).A.String() != "127.0.0.1" {
|
||||
t.Errorf("Expected 127.0.0.1, got: %s", resp.Answer[0].(*dns.A).A.String())
|
||||
}
|
||||
}
|
67
plugin/forward/health.go
Normal file
67
plugin/forward/health.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// For HC we send to . IN NS +norec message to the upstream. Dial timeouts and empty
|
||||
// replies are considered fails, basically anything else constitutes a healthy upstream.
|
||||
|
||||
func (h *host) Check() {
|
||||
h.Lock()
|
||||
|
||||
if h.checking {
|
||||
h.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
h.checking = true
|
||||
h.Unlock()
|
||||
|
||||
err := h.send()
|
||||
if err != nil {
|
||||
log.Printf("[INFO] healtheck of %s failed with %s", h.addr, err)
|
||||
|
||||
HealthcheckFailureCount.WithLabelValues(h.addr).Add(1)
|
||||
|
||||
atomic.AddUint32(&h.fails, 1)
|
||||
} else {
|
||||
atomic.StoreUint32(&h.fails, 0)
|
||||
}
|
||||
|
||||
h.Lock()
|
||||
h.checking = false
|
||||
h.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (h *host) send() error {
|
||||
hcping := new(dns.Msg)
|
||||
hcping.SetQuestion(".", dns.TypeNS)
|
||||
hcping.RecursionDesired = false
|
||||
|
||||
m, _, err := h.client.Exchange(hcping, h.addr)
|
||||
// If we got a header, we're alright, basically only care about I/O errors 'n stuff
|
||||
if err != nil && m != nil {
|
||||
// Silly check, something sane came back
|
||||
if m.Response || m.Opcode == dns.OpcodeQuery {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// down returns true is this host has more than maxfails fails.
|
||||
func (h *host) down(maxfails uint32) bool {
|
||||
if maxfails == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fails := atomic.LoadUint32(&h.fails)
|
||||
return fails > maxfails
|
||||
}
|
44
plugin/forward/host.go
Normal file
44
plugin/forward/host.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type host struct {
|
||||
addr string
|
||||
client *dns.Client
|
||||
|
||||
tlsConfig *tls.Config
|
||||
expire time.Duration
|
||||
|
||||
fails uint32
|
||||
sync.RWMutex
|
||||
checking bool
|
||||
}
|
||||
|
||||
// newHost returns a new host, the fails are set to 1, i.e.
|
||||
// the first healthcheck must succeed before we use this host.
|
||||
func newHost(addr string) *host {
|
||||
return &host{addr: addr, fails: 1, expire: defaultExpire}
|
||||
}
|
||||
|
||||
// setClient sets and configures the dns.Client in host.
|
||||
func (h *host) SetClient() {
|
||||
c := new(dns.Client)
|
||||
c.Net = "udp"
|
||||
c.ReadTimeout = 2 * time.Second
|
||||
c.WriteTimeout = 2 * time.Second
|
||||
|
||||
if h.tlsConfig != nil {
|
||||
c.Net = "tcp-tls"
|
||||
c.TLSConfig = h.tlsConfig
|
||||
}
|
||||
|
||||
h.client = c
|
||||
}
|
||||
|
||||
const defaultExpire = 10 * time.Second
|
78
plugin/forward/lookup.go
Normal file
78
plugin/forward/lookup.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same
|
||||
// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be
|
||||
// 50% faster than just openening a new connection for every client. It works with UDP and TCP and uses
|
||||
// inband healthchecking.
|
||||
package forward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/request"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Forward forward the request in state as-is. Unlike Lookup that adds EDNS0 suffix to the message.
|
||||
// Forward may be called with a nil f, an error is returned in that case.
|
||||
func (f *Forward) Forward(state request.Request) (*dns.Msg, error) {
|
||||
if f == nil {
|
||||
return nil, errNoForward
|
||||
}
|
||||
|
||||
fails := 0
|
||||
for _, proxy := range f.list() {
|
||||
if proxy.Down(f.maxfails) {
|
||||
fails++
|
||||
if fails < len(f.proxies) {
|
||||
continue
|
||||
}
|
||||
// All upstream proxies are dead, assume healtcheck is complete broken and randomly
|
||||
// select an upstream to connect to.
|
||||
proxy = f.list()[0]
|
||||
log.Printf("[WARNING] All upstreams down, picking random one to connect to %s", proxy.host.addr)
|
||||
}
|
||||
|
||||
ret, err := proxy.connect(context.Background(), state, f.forceTCP, true)
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] Failed to connect to %s: %s", proxy.host.addr, err)
|
||||
if fails < len(f.proxies) {
|
||||
continue
|
||||
}
|
||||
break
|
||||
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
return nil, errNoHealthy
|
||||
}
|
||||
|
||||
// Lookup will use name and type to forge a new message and will send that upstream. It will
|
||||
// set any EDNS0 options correctly so that downstream will be able to process the reply.
|
||||
// Lookup may be called with a nil f, an error is returned in that case.
|
||||
func (f *Forward) Lookup(state request.Request, name string, typ uint16) (*dns.Msg, error) {
|
||||
if f == nil {
|
||||
return nil, errNoForward
|
||||
}
|
||||
|
||||
req := new(dns.Msg)
|
||||
req.SetQuestion(name, typ)
|
||||
state.SizeAndDo(req)
|
||||
|
||||
state2 := request.Request{W: state.W, Req: req}
|
||||
|
||||
return f.Forward(state2)
|
||||
}
|
||||
|
||||
// NewLookup returns a Forward that can be used for plugin that need an upstream to resolve external names.
|
||||
func NewLookup(addr []string) *Forward {
|
||||
f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, hcInterval: 2 * time.Second}
|
||||
for i := range addr {
|
||||
p := NewProxy(addr[i])
|
||||
f.SetProxy(p)
|
||||
}
|
||||
return f
|
||||
}
|
41
plugin/forward/lookup_test.go
Normal file
41
plugin/forward/lookup_test.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
"github.com/coredns/coredns/plugin/test"
|
||||
"github.com/coredns/coredns/request"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestLookup(t *testing.T) {
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1"))
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
p := NewProxy(s.Addr)
|
||||
f := New()
|
||||
f.SetProxy(p)
|
||||
defer f.Close()
|
||||
|
||||
state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
|
||||
resp, err := f.Lookup(state, "example.org.", dns.TypeA)
|
||||
if err != nil {
|
||||
t.Fatal("Expected to receive reply, but didn't")
|
||||
}
|
||||
// expect answer section with A record in it
|
||||
if len(resp.Answer) == 0 {
|
||||
t.Fatalf("Expected to at least one RR in the answer section, got none: %s", resp)
|
||||
}
|
||||
if resp.Answer[0].Header().Rrtype != dns.TypeA {
|
||||
t.Errorf("Expected RR to A, got: %d", resp.Answer[0].Header().Rrtype)
|
||||
}
|
||||
if resp.Answer[0].(*dns.A).A.String() != "127.0.0.1" {
|
||||
t.Errorf("Expected 127.0.0.1, got: %s", resp.Answer[0].(*dns.A).A.String())
|
||||
}
|
||||
}
|
52
plugin/forward/metrics.go
Normal file
52
plugin/forward/metrics.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// Variables declared for monitoring.
|
||||
var (
|
||||
RequestCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "forward",
|
||||
Name: "request_count_total",
|
||||
Help: "Counter of requests made per upstream.",
|
||||
}, []string{"to"})
|
||||
RcodeCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "forward",
|
||||
Name: "response_rcode_count_total",
|
||||
Help: "Counter of requests made per upstream.",
|
||||
}, []string{"rcode", "to"})
|
||||
RequestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "forward",
|
||||
Name: "request_duration_seconds",
|
||||
Buckets: plugin.TimeBuckets,
|
||||
Help: "Histogram of the time each request took.",
|
||||
}, []string{"to"})
|
||||
HealthcheckFailureCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "forward",
|
||||
Name: "healthcheck_failure_count_total",
|
||||
Help: "Counter of the number of failed healtchecks.",
|
||||
}, []string{"to"})
|
||||
HealthcheckBrokenCount = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "forward",
|
||||
Name: "healthcheck_broken_count_total",
|
||||
Help: "Counter of the number of complete failures of the healtchecks.",
|
||||
})
|
||||
SocketGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: plugin.Namespace,
|
||||
Subsystem: "forward",
|
||||
Name: "socket_count_total",
|
||||
Help: "Guage of open sockets per upstream.",
|
||||
}, []string{"to"})
|
||||
)
|
||||
|
||||
var once sync.Once
|
148
plugin/forward/persistent.go
Normal file
148
plugin/forward/persistent.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// a persistConn hold the dns.Conn and the last used time.
|
||||
type persistConn struct {
|
||||
c *dns.Conn
|
||||
used time.Time
|
||||
}
|
||||
|
||||
// connErr is used to communicate the connection manager.
|
||||
type connErr struct {
|
||||
c *dns.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
// transport hold the persistent cache.
|
||||
type transport struct {
|
||||
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||
host *host
|
||||
|
||||
dial chan string
|
||||
yield chan connErr
|
||||
ret chan connErr
|
||||
|
||||
// Aid in testing, gets length of cache in data-race safe manner.
|
||||
lenc chan bool
|
||||
lencOut chan int
|
||||
|
||||
stop chan bool
|
||||
}
|
||||
|
||||
func newTransport(h *host) *transport {
|
||||
t := &transport{
|
||||
conns: make(map[string][]*persistConn),
|
||||
host: h,
|
||||
dial: make(chan string),
|
||||
yield: make(chan connErr),
|
||||
ret: make(chan connErr),
|
||||
stop: make(chan bool),
|
||||
lenc: make(chan bool),
|
||||
lencOut: make(chan int),
|
||||
}
|
||||
go t.connManager()
|
||||
return t
|
||||
}
|
||||
|
||||
// len returns the number of connection, used for metrics. Can only be safely
|
||||
// used inside connManager() because of races.
|
||||
func (t *transport) len() int {
|
||||
l := 0
|
||||
for _, conns := range t.conns {
|
||||
l += len(conns)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// Len returns the number of connections in the cache.
|
||||
func (t *transport) Len() int {
|
||||
t.lenc <- true
|
||||
l := <-t.lencOut
|
||||
return l
|
||||
}
|
||||
|
||||
// connManagers manages the persistent connection cache for UDP and TCP.
|
||||
func (t *transport) connManager() {
|
||||
|
||||
Wait:
|
||||
for {
|
||||
select {
|
||||
case proto := <-t.dial:
|
||||
// Yes O(n), shouldn't put millions in here. We walk all connection until we find the first
|
||||
// one that is usuable.
|
||||
i := 0
|
||||
for i = 0; i < len(t.conns[proto]); i++ {
|
||||
pc := t.conns[proto][i]
|
||||
if time.Since(pc.used) < t.host.expire {
|
||||
// Found one, remove from pool and return this conn.
|
||||
t.conns[proto] = t.conns[proto][i+1:]
|
||||
t.ret <- connErr{pc.c, nil}
|
||||
continue Wait
|
||||
}
|
||||
// This conn has expired. Close it.
|
||||
pc.c.Close()
|
||||
}
|
||||
|
||||
// Not conns were found. Connect to the upstream to create one.
|
||||
t.conns[proto] = t.conns[proto][i:]
|
||||
SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len()))
|
||||
|
||||
go func() {
|
||||
if proto != "tcp-tls" {
|
||||
c, err := dns.DialTimeout(proto, t.host.addr, dialTimeout)
|
||||
t.ret <- connErr{c, err}
|
||||
return
|
||||
}
|
||||
|
||||
c, err := dns.DialTimeoutWithTLS("tcp", t.host.addr, t.host.tlsConfig, dialTimeout)
|
||||
t.ret <- connErr{c, err}
|
||||
}()
|
||||
|
||||
case conn := <-t.yield:
|
||||
|
||||
SocketGauge.WithLabelValues(t.host.addr).Set(float64(t.len() + 1))
|
||||
|
||||
// no proto here, infer from config and conn
|
||||
if _, ok := conn.c.Conn.(*net.UDPConn); ok {
|
||||
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn.c, time.Now()})
|
||||
continue Wait
|
||||
}
|
||||
|
||||
if t.host.tlsConfig == nil {
|
||||
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn.c, time.Now()})
|
||||
continue Wait
|
||||
}
|
||||
|
||||
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn.c, time.Now()})
|
||||
|
||||
case <-t.stop:
|
||||
return
|
||||
|
||||
case <-t.lenc:
|
||||
l := 0
|
||||
for _, conns := range t.conns {
|
||||
l += len(conns)
|
||||
}
|
||||
t.lencOut <- l
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *transport) Dial(proto string) (*dns.Conn, error) {
|
||||
t.dial <- proto
|
||||
c := <-t.ret
|
||||
return c.c, c.err
|
||||
}
|
||||
|
||||
func (t *transport) Yield(c *dns.Conn) {
|
||||
t.yield <- connErr{c, nil}
|
||||
}
|
||||
|
||||
// Stop stops the transports.
|
||||
func (t *transport) Stop() { t.stop <- true }
|
44
plugin/forward/persistent_test.go
Normal file
44
plugin/forward/persistent_test.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestPersistent(t *testing.T) {
|
||||
s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ret := new(dns.Msg)
|
||||
ret.SetReply(r)
|
||||
w.WriteMsg(ret)
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
h := newHost(s.Addr)
|
||||
tr := newTransport(h)
|
||||
defer tr.Stop()
|
||||
|
||||
c1, _ := tr.Dial("udp")
|
||||
c2, _ := tr.Dial("udp")
|
||||
c3, _ := tr.Dial("udp")
|
||||
|
||||
tr.Yield(c1)
|
||||
tr.Yield(c2)
|
||||
tr.Yield(c3)
|
||||
|
||||
if x := tr.Len(); x != 3 {
|
||||
t.Errorf("Expected cache size to be 3, got %d", x)
|
||||
}
|
||||
|
||||
tr.Dial("udp")
|
||||
if x := tr.Len(); x != 2 {
|
||||
t.Errorf("Expected cache size to be 2, got %d", x)
|
||||
}
|
||||
|
||||
tr.Dial("udp")
|
||||
if x := tr.Len(); x != 1 {
|
||||
t.Errorf("Expected cache size to be 2, got %d", x)
|
||||
}
|
||||
}
|
55
plugin/forward/policy.go
Normal file
55
plugin/forward/policy.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Policy defines a policy we use for selecting upstreams.
|
||||
type Policy interface {
|
||||
List([]*Proxy) []*Proxy
|
||||
String() string
|
||||
}
|
||||
|
||||
// random is a policy that implements random upstream selection.
|
||||
type random struct{}
|
||||
|
||||
func (r *random) String() string { return "random" }
|
||||
|
||||
func (r *random) List(p []*Proxy) []*Proxy {
|
||||
switch len(p) {
|
||||
case 1:
|
||||
return p
|
||||
case 2:
|
||||
if rand.Int()%2 == 0 {
|
||||
return []*Proxy{p[1], p[0]} // swap
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
perms := rand.Perm(len(p))
|
||||
rnd := make([]*Proxy, len(p))
|
||||
|
||||
for i, p1 := range perms {
|
||||
rnd[i] = p[p1]
|
||||
}
|
||||
return rnd
|
||||
}
|
||||
|
||||
// roundRobin is a policy that selects hosts based on round robin ordering.
|
||||
type roundRobin struct {
|
||||
robin uint32
|
||||
}
|
||||
|
||||
func (r *roundRobin) String() string { return "round_robin" }
|
||||
|
||||
func (r *roundRobin) List(p []*Proxy) []*Proxy {
|
||||
poolLen := uint32(len(p))
|
||||
i := atomic.AddUint32(&r.robin, 1) % poolLen
|
||||
|
||||
robin := []*Proxy{p[i]}
|
||||
robin = append(robin, p[:i]...)
|
||||
robin = append(robin, p[i+1:]...)
|
||||
|
||||
return robin
|
||||
}
|
30
plugin/forward/protocol.go
Normal file
30
plugin/forward/protocol.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package forward
|
||||
|
||||
// Copied from coredns/core/dnsserver/address.go
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// protocol returns the protocol of the string s. The second string returns s
|
||||
// with the prefix chopped off.
|
||||
func protocol(s string) (int, string) {
|
||||
switch {
|
||||
case strings.HasPrefix(s, _tls+"://"):
|
||||
return TLS, s[len(_tls)+3:]
|
||||
case strings.HasPrefix(s, _dns+"://"):
|
||||
return DNS, s[len(_dns)+3:]
|
||||
}
|
||||
return DNS, s
|
||||
}
|
||||
|
||||
// Supported protocols.
|
||||
const (
|
||||
DNS = iota + 1
|
||||
TLS
|
||||
)
|
||||
|
||||
const (
|
||||
_dns = "dns"
|
||||
_tls = "tls"
|
||||
)
|
77
plugin/forward/proxy.go
Normal file
77
plugin/forward/proxy.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Proxy defines an upstream host.
|
||||
type Proxy struct {
|
||||
host *host
|
||||
|
||||
transport *transport
|
||||
|
||||
// copied from Forward.
|
||||
hcInterval time.Duration
|
||||
forceTCP bool
|
||||
|
||||
stop chan bool
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// NewProxy returns a new proxy.
|
||||
func NewProxy(addr string) *Proxy {
|
||||
host := newHost(addr)
|
||||
|
||||
p := &Proxy{
|
||||
host: host,
|
||||
hcInterval: hcDuration,
|
||||
stop: make(chan bool),
|
||||
transport: newTransport(host),
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// SetTLSConfig sets the TLS config in the lower p.host.
|
||||
func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.host.tlsConfig = cfg }
|
||||
|
||||
// SetExpire sets the expire duration in the lower p.host.
|
||||
func (p *Proxy) SetExpire(expire time.Duration) { p.host.expire = expire }
|
||||
|
||||
func (p *Proxy) close() { p.stop <- true }
|
||||
|
||||
// Dial connects to the host in p with the configured transport.
|
||||
func (p *Proxy) Dial(proto string) (*dns.Conn, error) { return p.transport.Dial(proto) }
|
||||
|
||||
// Yield returns the connection to the pool.
|
||||
func (p *Proxy) Yield(c *dns.Conn) { p.transport.Yield(c) }
|
||||
|
||||
// Down returns if this proxy is up or down.
|
||||
func (p *Proxy) Down(maxfails uint32) bool { return p.host.down(maxfails) }
|
||||
|
||||
func (p *Proxy) healthCheck() {
|
||||
|
||||
// stop channel
|
||||
p.host.SetClient()
|
||||
|
||||
p.host.Check()
|
||||
tick := time.NewTicker(p.hcInterval)
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
p.host.Check()
|
||||
case <-p.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
dialTimeout = 4 * time.Second
|
||||
timeout = 2 * time.Second
|
||||
hcDuration = 500 * time.Millisecond
|
||||
)
|
262
plugin/forward/setup.go
Normal file
262
plugin/forward/setup.go
Normal file
|
@ -0,0 +1,262 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/coredns/coredns/plugin/metrics"
|
||||
"github.com/coredns/coredns/plugin/pkg/dnsutil"
|
||||
pkgtls "github.com/coredns/coredns/plugin/pkg/tls"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("forward", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
func setup(c *caddy.Controller) error {
|
||||
f, err := parseForward(c)
|
||||
if err != nil {
|
||||
return plugin.Error("foward", err)
|
||||
}
|
||||
if f.Len() > max {
|
||||
return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len()))
|
||||
}
|
||||
|
||||
dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
f.Next = next
|
||||
return f
|
||||
})
|
||||
|
||||
c.OnStartup(func() error {
|
||||
once.Do(func() {
|
||||
m := dnsserver.GetConfig(c).Handler("prometheus")
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if x, ok := m.(*metrics.Metrics); ok {
|
||||
x.MustRegister(RequestCount)
|
||||
x.MustRegister(RcodeCount)
|
||||
x.MustRegister(RequestDuration)
|
||||
x.MustRegister(HealthcheckFailureCount)
|
||||
x.MustRegister(SocketGauge)
|
||||
}
|
||||
})
|
||||
return f.OnStartup()
|
||||
})
|
||||
|
||||
c.OnShutdown(func() error {
|
||||
return f.OnShutdown()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnStartup starts a goroutines for all proxies.
|
||||
func (f *Forward) OnStartup() (err error) {
|
||||
if f.hcInterval == 0 {
|
||||
for _, p := range f.proxies {
|
||||
p.host.fails = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, p := range f.proxies {
|
||||
go p.healthCheck()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnShutdown stops all configured proxies.
|
||||
func (f *Forward) OnShutdown() error {
|
||||
if f.hcInterval == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, p := range f.proxies {
|
||||
p.close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a synonym for OnShutdown().
|
||||
func (f *Forward) Close() {
|
||||
f.OnShutdown()
|
||||
}
|
||||
|
||||
func parseForward(c *caddy.Controller) (*Forward, error) {
|
||||
f := New()
|
||||
|
||||
protocols := map[int]int{}
|
||||
|
||||
for c.Next() {
|
||||
if !c.Args(&f.from) {
|
||||
return f, c.ArgErr()
|
||||
}
|
||||
f.from = plugin.Host(f.from).Normalize()
|
||||
|
||||
to := c.RemainingArgs()
|
||||
if len(to) == 0 {
|
||||
return f, c.ArgErr()
|
||||
}
|
||||
|
||||
// A bit fiddly, but first check if we've got protocols and if so add them back in when we create the proxies.
|
||||
protocols = make(map[int]int)
|
||||
for i := range to {
|
||||
protocols[i], to[i] = protocol(to[i])
|
||||
}
|
||||
|
||||
// If parseHostPortOrFile expands a file with a lot of nameserver our accounting in protocols doesn't make
|
||||
// any sense anymore... For now: lets don't care.
|
||||
toHosts, err := dnsutil.ParseHostPortOrFile(to...)
|
||||
if err != nil {
|
||||
return f, err
|
||||
}
|
||||
|
||||
for i, h := range toHosts {
|
||||
// Double check the port, if e.g. is 53 and the transport is TLS make it 853.
|
||||
// This can be somewhat annoying because you *can't* have TLS on port 53 then.
|
||||
switch protocols[i] {
|
||||
case TLS:
|
||||
h1, p, err := net.SplitHostPort(h)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// This is more of a bug in // dnsutil.ParseHostPortOrFile that defaults to
|
||||
// 53 because it doesn't know about the tls:// // and friends (that should be fixed). Hence
|
||||
// Fix the port number here, back to what the user intended.
|
||||
if p == "53" {
|
||||
h = net.JoinHostPort(h1, "853")
|
||||
}
|
||||
}
|
||||
|
||||
// We can't set tlsConfig here, because we haven't parsed it yet.
|
||||
// We set it below at the end of parseBlock.
|
||||
p := NewProxy(h)
|
||||
f.proxies = append(f.proxies, p)
|
||||
}
|
||||
|
||||
for c.NextBlock() {
|
||||
if err := parseBlock(c, f); err != nil {
|
||||
return f, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if f.tlsServerName != "" {
|
||||
f.tlsConfig.ServerName = f.tlsServerName
|
||||
}
|
||||
for i := range f.proxies {
|
||||
// Only set this for proxies that need it.
|
||||
if protocols[i] == TLS {
|
||||
f.proxies[i].SetTLSConfig(f.tlsConfig)
|
||||
}
|
||||
f.proxies[i].SetExpire(f.expire)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func parseBlock(c *caddy.Controller, f *Forward) error {
|
||||
switch c.Val() {
|
||||
case "except":
|
||||
ignore := c.RemainingArgs()
|
||||
if len(ignore) == 0 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
for i := 0; i < len(ignore); i++ {
|
||||
ignore[i] = plugin.Host(ignore[i]).Normalize()
|
||||
}
|
||||
f.ignored = ignore
|
||||
case "max_fails":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
n, err := strconv.Atoi(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n < 0 {
|
||||
return fmt.Errorf("max_fails can't be negative: %d", n)
|
||||
}
|
||||
f.maxfails = uint32(n)
|
||||
case "health_check":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dur < 0 {
|
||||
return fmt.Errorf("health_check can't be negative: %d", dur)
|
||||
}
|
||||
f.hcInterval = dur
|
||||
for i := range f.proxies {
|
||||
f.proxies[i].hcInterval = dur
|
||||
}
|
||||
case "force_tcp":
|
||||
if c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
f.forceTCP = true
|
||||
for i := range f.proxies {
|
||||
f.proxies[i].forceTCP = true
|
||||
}
|
||||
case "tls":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 3 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
tlsConfig, err := pkgtls.NewTLSConfig(args[0], args[1], args[2])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.tlsConfig = tlsConfig
|
||||
case "tls_servername":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
f.tlsServerName = c.Val()
|
||||
case "expire":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dur < 0 {
|
||||
return fmt.Errorf("expire can't be negative: %s", dur)
|
||||
}
|
||||
f.expire = dur
|
||||
case "policy":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
switch x := c.Val(); x {
|
||||
case "random":
|
||||
f.p = &random{}
|
||||
case "round_robin":
|
||||
f.p = &roundRobin{}
|
||||
default:
|
||||
return c.Errf("unknown policy '%s'", x)
|
||||
}
|
||||
|
||||
default:
|
||||
return c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const max = 15 // Maximum number of upstreams.
|
46
plugin/forward/setup_policy_test.go
Normal file
46
plugin/forward/setup_policy_test.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestSetupPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedPolicy string
|
||||
expectedErr string
|
||||
}{
|
||||
// positive
|
||||
{"forward . 127.0.0.1 {\npolicy random\n}\n", false, "random", ""},
|
||||
{"forward . 127.0.0.1 {\npolicy round_robin\n}\n", false, "round_robin", ""},
|
||||
// negative
|
||||
{"forward . 127.0.0.1 {\npolicy random2\n}\n", true, "random", "unknown policy"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
c := caddy.NewTestController("dns", test.input)
|
||||
f, err := parseForward(c)
|
||||
|
||||
if test.shouldErr && err == nil {
|
||||
t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !test.shouldErr {
|
||||
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), test.expectedErr) {
|
||||
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
|
||||
}
|
||||
}
|
||||
|
||||
if !test.shouldErr && f.p.String() != test.expectedPolicy {
|
||||
t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedPolicy, f.p.String())
|
||||
}
|
||||
}
|
||||
}
|
68
plugin/forward/setup_test.go
Normal file
68
plugin/forward/setup_test.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package forward
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedFrom string
|
||||
expectedIgnored []string
|
||||
expectedFails uint32
|
||||
expectedForceTCP bool
|
||||
expectedErr string
|
||||
}{
|
||||
// positive
|
||||
{"forward . 127.0.0.1", false, ".", nil, 2, false, ""},
|
||||
{"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, false, ""},
|
||||
{"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, false, ""},
|
||||
{"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, true, ""},
|
||||
{"forward . 127.0.0.1:53", false, ".", nil, 2, false, ""},
|
||||
{"forward . 127.0.0.1:8080", false, ".", nil, 2, false, ""},
|
||||
{"forward . [::1]:53", false, ".", nil, 2, false, ""},
|
||||
{"forward . [2003::1]:53", false, ".", nil, 2, false, ""},
|
||||
// negative
|
||||
{"forward . a27.0.0.1", true, "", nil, 0, false, "not an IP"},
|
||||
{"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, false, "unknown property"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
c := caddy.NewTestController("dns", test.input)
|
||||
f, err := parseForward(c)
|
||||
|
||||
if test.shouldErr && err == nil {
|
||||
t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !test.shouldErr {
|
||||
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), test.expectedErr) {
|
||||
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
|
||||
}
|
||||
}
|
||||
|
||||
if !test.shouldErr && f.from != test.expectedFrom {
|
||||
t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedFrom, f.from)
|
||||
}
|
||||
if !test.shouldErr && test.expectedIgnored != nil {
|
||||
if !reflect.DeepEqual(f.ignored, test.expectedIgnored) {
|
||||
t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedIgnored, f.ignored)
|
||||
}
|
||||
}
|
||||
if !test.shouldErr && f.maxfails != test.expectedFails {
|
||||
t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedFails, f.maxfails)
|
||||
}
|
||||
if !test.shouldErr && f.forceTCP != test.expectedForceTCP {
|
||||
t.Errorf("Test %d: expected: %t, got: %t", i, test.expectedForceTCP, f.forceTCP)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue