diff --git a/plugin/forward/dnstap.go b/plugin/forward/dnstap.go index edbee8715..e9962d268 100644 --- a/plugin/forward/dnstap.go +++ b/plugin/forward/dnstap.go @@ -6,6 +6,7 @@ import ( "time" "github.com/coredns/coredns/plugin/dnstap/msg" + "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/coredns/coredns/request" tap "github.com/dnstap/golang-dnstap" @@ -13,7 +14,7 @@ import ( ) // toDnstap will send the forward and received message to the dnstap plugin. -func toDnstap(f *Forward, host string, state request.Request, opts options, reply *dns.Msg, start time.Time) { +func toDnstap(f *Forward, host string, state request.Request, opts proxy.Options, reply *dns.Msg, start time.Time) { h, p, _ := net.SplitHostPort(host) // this is preparsed and can't err here port, _ := strconv.ParseUint(p, 10, 32) // same here ip := net.ParseIP(h) @@ -21,9 +22,9 @@ func toDnstap(f *Forward, host string, state request.Request, opts options, repl var ta net.Addr = &net.UDPAddr{IP: ip, Port: int(port)} t := state.Proto() switch { - case opts.forceTCP: + case opts.ForceTCP: t = "tcp" - case opts.preferUDP: + case opts.PreferUDP: t = "udp" } diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index 223d7e398..927a6e21f 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -16,6 +16,7 @@ import ( "github.com/coredns/coredns/plugin/dnstap" "github.com/coredns/coredns/plugin/metadata" clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/coredns/coredns/request" "github.com/miekg/dns" @@ -25,12 +26,17 @@ import ( var log = clog.NewWithPlugin("forward") +const ( + defaultExpire = 10 * time.Second + hcInterval = 500 * time.Millisecond +) + // Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list // of proxies each representing one upstream proxy. type Forward struct { concurrent int64 // atomic counters need to be first in struct for proper alignment - proxies []*Proxy + proxies []*proxy.Proxy p Policy hcInterval time.Duration @@ -43,7 +49,7 @@ type Forward struct { expire time.Duration maxConcurrent int64 - opts options // also here for testing + opts proxy.Options // also here for testing // ErrLimitExceeded indicates that a query was rejected because the number of concurrent queries has exceeded // the maximum allowed (maxConcurrent) @@ -56,14 +62,14 @@ type Forward struct { // New returns a new Forward. func New() *Forward { - f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcInterval, opts: options{forceTCP: false, preferUDP: false, hcRecursionDesired: true, hcDomain: "."}} + f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcInterval, opts: proxy.Options{ForceTCP: false, PreferUDP: false, HCRecursionDesired: true, HCDomain: "."}} return f } // SetProxy appends p to the proxy list and starts healthchecking. -func (f *Forward) SetProxy(p *Proxy) { +func (f *Forward) SetProxy(p *proxy.Proxy) { f.proxies = append(f.proxies, p) - p.start(f.hcInterval) + p.Start(f.hcInterval) } // SetTapPlugin appends one or more dnstap plugins to the tap plugin list. @@ -128,12 +134,12 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg if span != nil { child = span.Tracer().StartSpan("connect", ot.ChildOf(span.Context())) - otext.PeerAddress.Set(child, proxy.addr) + otext.PeerAddress.Set(child, proxy.Addr()) ctx = ot.ContextWithSpan(ctx, child) } metadata.SetValueFunc(ctx, "forward/upstream", func() string { - return proxy.addr + return proxy.Addr() }) var ( @@ -141,14 +147,15 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg err error ) opts := f.opts + for { ret, err = proxy.Connect(ctx, state, opts) if err == ErrCachedClosed { // Remote side closed conn, can only happen with TCP. continue } // Retry with TCP if truncated and prefer_udp configured. - if ret != nil && ret.Truncated && !opts.forceTCP && opts.preferUDP { - opts.forceTCP = true + if ret != nil && ret.Truncated && !opts.ForceTCP && opts.PreferUDP { + opts.ForceTCP = true continue } break @@ -159,7 +166,7 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg } if len(f.tapPlugins) != 0 { - toDnstap(f, proxy.addr, state, opts, ret, start) + toDnstap(f, proxy.Addr(), state, opts, ret, start) } upstreamErr = err @@ -219,13 +226,13 @@ func (f *Forward) isAllowedDomain(name string) bool { } // ForceTCP returns if TCP is forced to be used even when the request comes in over UDP. -func (f *Forward) ForceTCP() bool { return f.opts.forceTCP } +func (f *Forward) ForceTCP() bool { return f.opts.ForceTCP } // PreferUDP returns if UDP is preferred to be used even when the request comes in over TCP. -func (f *Forward) PreferUDP() bool { return f.opts.preferUDP } +func (f *Forward) PreferUDP() bool { return f.opts.PreferUDP } // List returns a set of proxies to be used for this client depending on the policy in f. -func (f *Forward) List() []*Proxy { return f.p.List(f.proxies) } +func (f *Forward) List() []*proxy.Proxy { return f.p.List(f.proxies) } var ( // ErrNoHealthy means no healthy proxies left. @@ -236,12 +243,16 @@ var ( ErrCachedClosed = errors.New("cached connection was closed by peer") ) -// options holds various options that can be set. -type options struct { - forceTCP bool - preferUDP bool - hcRecursionDesired bool - hcDomain string +// Options holds various Options that can be set. +type Options struct { + // ForceTCP use TCP protocol for upstream DNS request. Has precedence over PreferUDP flag + ForceTCP bool + // PreferUDP use UDP protocol for upstream DNS request. + PreferUDP bool + // HCRecursionDesired sets recursion desired flag for Proxy healthcheck requests + HCRecursionDesired bool + // HCDomain sets domain for Proxy healthcheck requests + HCDomain string } var defaultTimeout = 5 * time.Second diff --git a/plugin/forward/forward_test.go b/plugin/forward/forward_test.go index b50f4ff22..9ea859826 100644 --- a/plugin/forward/forward_test.go +++ b/plugin/forward/forward_test.go @@ -8,23 +8,33 @@ import ( "github.com/coredns/caddy/caddyfile" "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin/dnstap" + "github.com/coredns/coredns/plugin/pkg/proxy" + "github.com/coredns/coredns/plugin/pkg/transport" ) func TestList(t *testing.T) { f := Forward{ - proxies: []*Proxy{{addr: "1.1.1.1:53"}, {addr: "2.2.2.2:53"}, {addr: "3.3.3.3:53"}}, - p: &roundRobin{}, + proxies: []*proxy.Proxy{ + proxy.NewProxy("1.1.1.1:53", transport.DNS), + proxy.NewProxy("2.2.2.2:53", transport.DNS), + proxy.NewProxy("3.3.3.3:53", transport.DNS), + }, + p: &roundRobin{}, } - expect := []*Proxy{{addr: "2.2.2.2:53"}, {addr: "1.1.1.1:53"}, {addr: "3.3.3.3:53"}} + expect := []*proxy.Proxy{ + proxy.NewProxy("2.2.2.2:53", transport.DNS), + proxy.NewProxy("1.1.1.1:53", transport.DNS), + proxy.NewProxy("3.3.3.3:53", transport.DNS), + } got := f.List() if len(got) != len(expect) { t.Fatalf("Expected: %v results, got: %v", len(expect), len(got)) } for i, p := range got { - if p.addr != expect[i].addr { - t.Fatalf("Expected proxy %v to be '%v', got: '%v'", i, expect[i].addr, p.addr) + if p.Addr() != expect[i].Addr() { + t.Fatalf("Expected proxy %v to be '%v', got: '%v'", i, expect[i].Addr(), p.Addr()) } } } diff --git a/plugin/forward/health_test.go b/plugin/forward/health_test.go index 9917b3a37..7cb928d22 100644 --- a/plugin/forward/health_test.go +++ b/plugin/forward/health_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/plugin/test" @@ -14,9 +15,6 @@ import ( ) func TestHealth(t *testing.T) { - hcReadTimeout = 10 * time.Millisecond - hcWriteTimeout = 10 * time.Millisecond - readTimeout = 10 * time.Millisecond defaultTimeout = 10 * time.Millisecond i := uint32(0) @@ -35,7 +33,9 @@ func TestHealth(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, transport.DNS) + p := proxy.NewProxy(s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) f := New() f.SetProxy(p) defer f.OnShutdown() @@ -53,9 +53,6 @@ func TestHealth(t *testing.T) { } func TestHealthTCP(t *testing.T) { - hcReadTimeout = 10 * time.Millisecond - hcWriteTimeout = 10 * time.Millisecond - readTimeout = 10 * time.Millisecond defaultTimeout = 10 * time.Millisecond i := uint32(0) @@ -74,8 +71,10 @@ func TestHealthTCP(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, transport.DNS) - p.health.SetTCPTransport() + p := proxy.NewProxy(s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetTCPTransport() f := New() f.SetProxy(p) defer f.OnShutdown() @@ -93,10 +92,7 @@ func TestHealthTCP(t *testing.T) { } func TestHealthNoRecursion(t *testing.T) { - hcReadTimeout = 10 * time.Millisecond - readTimeout = 10 * time.Millisecond defaultTimeout = 10 * time.Millisecond - hcWriteTimeout = 10 * time.Millisecond i := uint32(0) q := uint32(0) @@ -114,8 +110,10 @@ func TestHealthNoRecursion(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, transport.DNS) - p.health.SetRecursionDesired(false) + p := proxy.NewProxy(s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetRecursionDesired(false) f := New() f.SetProxy(p) defer f.OnShutdown() @@ -133,9 +131,6 @@ func TestHealthNoRecursion(t *testing.T) { } func TestHealthTimeout(t *testing.T) { - hcReadTimeout = 10 * time.Millisecond - hcWriteTimeout = 10 * time.Millisecond - readTimeout = 10 * time.Millisecond defaultTimeout = 10 * time.Millisecond i := uint32(0) @@ -159,7 +154,9 @@ func TestHealthTimeout(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, transport.DNS) + p := proxy.NewProxy(s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) f := New() f.SetProxy(p) defer f.OnShutdown() @@ -177,19 +174,20 @@ func TestHealthTimeout(t *testing.T) { } func TestHealthMaxFails(t *testing.T) { - hcReadTimeout = 10 * time.Millisecond - hcWriteTimeout = 10 * time.Millisecond - readTimeout = 10 * time.Millisecond defaultTimeout = 10 * time.Millisecond - hcInterval = 10 * time.Millisecond + //,hcInterval = 10 * time.Millisecond s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { // timeout }) defer s.Close() - p := NewProxy(s.Addr, transport.DNS) + p := proxy.NewProxy(s.Addr, transport.DNS) + p.SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) f := New() + f.hcInterval = 10 * time.Millisecond f.maxfails = 2 f.SetProxy(p) defer f.OnShutdown() @@ -200,18 +198,14 @@ func TestHealthMaxFails(t *testing.T) { f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) time.Sleep(100 * time.Millisecond) - fails := atomic.LoadUint32(&p.fails) + fails := p.Fails() if !p.Down(f.maxfails) { t.Errorf("Expected Proxy fails to be greater than %d, got %d", f.maxfails, fails) } } func TestHealthNoMaxFails(t *testing.T) { - hcReadTimeout = 10 * time.Millisecond - hcWriteTimeout = 10 * time.Millisecond - readTimeout = 10 * time.Millisecond defaultTimeout = 10 * time.Millisecond - hcInterval = 10 * time.Millisecond i := uint32(0) s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { @@ -225,7 +219,9 @@ func TestHealthNoMaxFails(t *testing.T) { }) defer s.Close() - p := NewProxy(s.Addr, transport.DNS) + p := proxy.NewProxy(s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) f := New() f.maxfails = 0 f.SetProxy(p) @@ -244,10 +240,8 @@ func TestHealthNoMaxFails(t *testing.T) { } func TestHealthDomain(t *testing.T) { - hcReadTimeout = 10 * time.Millisecond - readTimeout = 10 * time.Millisecond defaultTimeout = 10 * time.Millisecond - hcWriteTimeout = 10 * time.Millisecond + hcDomain := "example.org." i := uint32(0) q := uint32(0) @@ -264,8 +258,10 @@ func TestHealthDomain(t *testing.T) { w.WriteMsg(ret) }) defer s.Close() - p := NewProxy(s.Addr, transport.DNS) - p.health.SetDomain(hcDomain) + p := proxy.NewProxy(s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetDomain(hcDomain) f := New() f.SetProxy(p) defer f.OnShutdown() diff --git a/plugin/forward/metrics.go b/plugin/forward/metrics.go index f1f0c48d6..da0905525 100644 --- a/plugin/forward/metrics.go +++ b/plugin/forward/metrics.go @@ -9,31 +9,6 @@ import ( // Variables declared for monitoring. var ( - RequestCount = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "forward", - Name: "requests_total", - Help: "Counter of requests made per upstream.", - }, []string{"to"}) - RcodeCount = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "forward", - Name: "responses_total", - Help: "Counter of responses received per upstream.", - }, []string{"rcode", "to"}) - RequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ - Namespace: plugin.Namespace, - Subsystem: "forward", - Name: "request_duration_seconds", - Buckets: plugin.TimeBuckets, - Help: "Histogram of the time each request took.", - }, []string{"to", "rcode"}) - HealthcheckFailureCount = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "forward", - Name: "healthcheck_failures_total", - Help: "Counter of the number of failed healthchecks.", - }, []string{"to"}) HealthcheckBrokenCount = promauto.NewCounter(prometheus.CounterOpts{ Namespace: plugin.Namespace, Subsystem: "forward", @@ -46,16 +21,4 @@ var ( Name: "max_concurrent_rejects_total", Help: "Counter of the number of queries rejected because the concurrent queries were at maximum.", }) - ConnCacheHitsCount = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "forward", - Name: "conn_cache_hits_total", - Help: "Counter of connection cache hits per upstream and protocol.", - }, []string{"to", "proto"}) - ConnCacheMissesCount = promauto.NewCounterVec(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "forward", - Name: "conn_cache_misses_total", - Help: "Counter of connection cache misses per upstream and protocol.", - }, []string{"to", "proto"}) ) diff --git a/plugin/forward/policy.go b/plugin/forward/policy.go index e81e4ab91..7bd1f316a 100644 --- a/plugin/forward/policy.go +++ b/plugin/forward/policy.go @@ -4,12 +4,13 @@ import ( "sync/atomic" "time" + "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/coredns/coredns/plugin/pkg/rand" ) // Policy defines a policy we use for selecting upstreams. type Policy interface { - List([]*Proxy) []*Proxy + List([]*proxy.Proxy) []*proxy.Proxy String() string } @@ -18,19 +19,19 @@ type random struct{} func (r *random) String() string { return "random" } -func (r *random) List(p []*Proxy) []*Proxy { +func (r *random) List(p []*proxy.Proxy) []*proxy.Proxy { switch len(p) { case 1: return p case 2: if rn.Int()%2 == 0 { - return []*Proxy{p[1], p[0]} // swap + return []*proxy.Proxy{p[1], p[0]} // swap } return p } perms := rn.Perm(len(p)) - rnd := make([]*Proxy, len(p)) + rnd := make([]*proxy.Proxy, len(p)) for i, p1 := range perms { rnd[i] = p[p1] @@ -45,11 +46,11 @@ type roundRobin struct { func (r *roundRobin) String() string { return "round_robin" } -func (r *roundRobin) List(p []*Proxy) []*Proxy { +func (r *roundRobin) List(p []*proxy.Proxy) []*proxy.Proxy { poolLen := uint32(len(p)) i := atomic.AddUint32(&r.robin, 1) % poolLen - robin := []*Proxy{p[i]} + robin := []*proxy.Proxy{p[i]} robin = append(robin, p[:i]...) robin = append(robin, p[i+1:]...) @@ -61,7 +62,7 @@ type sequential struct{} func (r *sequential) String() string { return "sequential" } -func (r *sequential) List(p []*Proxy) []*Proxy { +func (r *sequential) List(p []*proxy.Proxy) []*proxy.Proxy { return p } diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go index 74a0b5c4b..daf5f964c 100644 --- a/plugin/forward/proxy_test.go +++ b/plugin/forward/proxy_test.go @@ -6,9 +6,7 @@ import ( "github.com/coredns/caddy" "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/plugin/pkg/transport" "github.com/coredns/coredns/plugin/test" - "github.com/coredns/coredns/request" "github.com/miekg/dns" ) @@ -70,30 +68,3 @@ func TestProxyTLSFail(t *testing.T) { t.Fatal("Expected *not* to receive reply, but got one") } } - -func TestProtocolSelection(t *testing.T) { - p := NewProxy("bad_address", transport.DNS) - - stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} - stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} - ctx := context.TODO() - - go func() { - p.Connect(ctx, stateUDP, options{}) - p.Connect(ctx, stateUDP, options{forceTCP: true}) - p.Connect(ctx, stateUDP, options{preferUDP: true}) - p.Connect(ctx, stateUDP, options{preferUDP: true, forceTCP: true}) - p.Connect(ctx, stateTCP, options{}) - p.Connect(ctx, stateTCP, options{forceTCP: true}) - p.Connect(ctx, stateTCP, options{preferUDP: true}) - p.Connect(ctx, stateTCP, options{preferUDP: true, forceTCP: true}) - }() - - for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} { - proto := <-p.transport.dial - p.transport.ret <- nil - if proto != exp { - t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto) - } - } -} diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 7ca24df4d..6de0c870f 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -12,6 +12,7 @@ import ( "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/dnstap" "github.com/coredns/coredns/plugin/pkg/parse" + "github.com/coredns/coredns/plugin/pkg/proxy" pkgtls "github.com/coredns/coredns/plugin/pkg/tls" "github.com/coredns/coredns/plugin/pkg/transport" @@ -67,7 +68,7 @@ func setup(c *caddy.Controller) error { // OnStartup starts a goroutines for all proxies. func (f *Forward) OnStartup() (err error) { for _, p := range f.proxies { - p.start(f.hcInterval) + p.Start(f.hcInterval) } return nil } @@ -75,7 +76,7 @@ func (f *Forward) OnStartup() (err error) { // OnShutdown stops all configured proxies. func (f *Forward) OnShutdown() error { for _, p := range f.proxies { - p.stop() + p.Stop() } return nil } @@ -127,7 +128,7 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { if !allowedTrans[trans] { return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) } - p := NewProxy(h, trans) + p := proxy.NewProxy(h, trans) f.proxies = append(f.proxies, p) transports[i] = trans } @@ -152,12 +153,12 @@ func parseStanza(c *caddy.Controller) (*Forward, error) { f.proxies[i].SetTLSConfig(f.tlsConfig) } f.proxies[i].SetExpire(f.expire) - f.proxies[i].health.SetRecursionDesired(f.opts.hcRecursionDesired) + f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired) // when TLS is used, checks are set to tcp-tls - if f.opts.forceTCP && transports[i] != transport.TLS { - f.proxies[i].health.SetTCPTransport() + if f.opts.ForceTCP && transports[i] != transport.TLS { + f.proxies[i].GetHealthchecker().SetTCPTransport() } - f.proxies[i].health.SetDomain(f.opts.hcDomain) + f.proxies[i].GetHealthchecker().SetDomain(f.opts.HCDomain) } return f, nil @@ -194,12 +195,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { return fmt.Errorf("health_check can't be negative: %d", dur) } f.hcInterval = dur - f.opts.hcDomain = "." + f.opts.HCDomain = "." for c.NextArg() { switch hcOpts := c.Val(); hcOpts { case "no_rec": - f.opts.hcRecursionDesired = false + f.opts.HCRecursionDesired = false case "domain": if !c.NextArg() { return c.ArgErr() @@ -208,7 +209,7 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if _, ok := dns.IsDomainName(hcDomain); !ok { return fmt.Errorf("health_check: invalid domain name %s", hcDomain) } - f.opts.hcDomain = plugin.Name(hcDomain).Normalize() + f.opts.HCDomain = plugin.Name(hcDomain).Normalize() default: return fmt.Errorf("health_check: unknown option %s", hcOpts) } @@ -218,12 +219,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if c.NextArg() { return c.ArgErr() } - f.opts.forceTCP = true + f.opts.ForceTCP = true case "prefer_udp": if c.NextArg() { return c.ArgErr() } - f.opts.preferUDP = true + f.opts.PreferUDP = true case "tls": args := c.RemainingArgs() if len(args) > 3 { diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index 4b1743098..cf046b486 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -8,6 +8,7 @@ import ( "github.com/coredns/caddy" "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/miekg/dns" ) @@ -19,31 +20,31 @@ func TestSetup(t *testing.T) { expectedFrom string expectedIgnored []string expectedFails uint32 - expectedOpts options + expectedOpts proxy.Options expectedErr string }{ // positive - {"forward . 127.0.0.1", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example.org\n}\n", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "example.org."}, ""}, - {"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, options{hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, options{forceTCP: true, hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . 127.0.0.1 {\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true, hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . 127.0.0.1 {\nforce_tcp\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true, forceTCP: true, hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . 127.0.0.1:53", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . 127.0.0.1:8080", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . [::1]:53", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . [2003::1]:53", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward . 127.0.0.1 \n", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, - {"forward 10.9.3.0/18 127.0.0.1", false, "0.9.10.in-addr.arpa.", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . 127.0.0.1", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example.org\n}\n", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "example.org."}, ""}, + {"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, proxy.Options{ForceTCP: true, HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nprefer_udp\n}\n", false, ".", nil, 2, proxy.Options{PreferUDP: true, HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nforce_tcp\nprefer_udp\n}\n", false, ".", nil, 2, proxy.Options{PreferUDP: true, ForceTCP: true, HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1:53", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1:8080", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . [::1]:53", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . [2003::1]:53", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 \n", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward 10.9.3.0/18 127.0.0.1", false, "0.9.10.in-addr.arpa.", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, {`forward . ::1 - forward com ::2`, false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, "plugin"}, + forward com ::2`, false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "plugin"}, // negative - {"forward . a27.0.0.1", true, "", nil, 0, options{hcRecursionDesired: true, hcDomain: "."}, "not an IP"}, - {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, options{hcRecursionDesired: true, hcDomain: "."}, "unknown property"}, - {"forward . 127.0.0.1 {\nhealth_check 0.5s domain\n}\n", true, "", nil, 0, options{hcRecursionDesired: true, hcDomain: "."}, "Wrong argument count or unexpected line ending after 'domain'"}, - {"forward . https://127.0.0.1 \n", true, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, "'https' is not supported as a destination protocol in forward: https://127.0.0.1"}, - {"forward xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx 127.0.0.1 \n", true, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, "unable to normalize 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'"}, + {"forward . a27.0.0.1", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "not an IP"}, + {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "unknown property"}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain\n}\n", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "Wrong argument count or unexpected line ending after 'domain'"}, + {"forward . https://127.0.0.1 \n", true, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "'https' is not supported as a destination protocol in forward: https://127.0.0.1"}, + {"forward xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx 127.0.0.1 \n", true, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "unable to normalize 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'"}, } for i, test := range tests { @@ -127,8 +128,8 @@ func TestSetupTLS(t *testing.T) { t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName) } - if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName { - t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName) + if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName) } } } @@ -179,14 +180,14 @@ nameserver 10.10.255.253`), 0666); err != nil { f := fs[0] for j, n := range test.expectedNames { - addr := f.proxies[j].addr + addr := f.proxies[j].Addr() if n != addr { t.Errorf("Test %d, expected %q, got %q", j, n, addr) } } for _, p := range f.proxies { - p.health.Check(p) // this should almost always err, we don't care it shouldn't crash + p.Healthcheck() // this should almost always err, we don't care it shouldn't crash } } } @@ -279,9 +280,9 @@ func TestSetupHealthCheck(t *testing.T) { } f := fs[0] - if f.opts.hcRecursionDesired != test.expectedRecVal || f.proxies[0].health.GetRecursionDesired() != test.expectedRecVal || - f.opts.hcDomain != test.expectedDomain || f.proxies[0].health.GetDomain() != test.expectedDomain || !dns.IsFqdn(f.proxies[0].health.GetDomain()) { - t.Errorf("Test %d: expectedRec: %v, got: %v. expectedDomain: %s, got: %s. ", i, test.expectedRecVal, f.opts.hcRecursionDesired, test.expectedDomain, f.opts.hcDomain) + if f.opts.HCRecursionDesired != test.expectedRecVal || f.proxies[0].GetHealthchecker().GetRecursionDesired() != test.expectedRecVal || + f.opts.HCDomain != test.expectedDomain || f.proxies[0].GetHealthchecker().GetDomain() != test.expectedDomain || !dns.IsFqdn(f.proxies[0].GetHealthchecker().GetDomain()) { + t.Errorf("Test %d: expectedRec: %v, got: %v. expectedDomain: %s, got: %s. ", i, test.expectedRecVal, f.opts.HCRecursionDesired, test.expectedDomain, f.opts.HCDomain) } } } diff --git a/plugin/forward/connect.go b/plugin/pkg/proxy/connect.go similarity index 92% rename from plugin/forward/connect.go rename to plugin/pkg/proxy/connect.go index 3d53044e5..29274d92d 100644 --- a/plugin/forward/connect.go +++ b/plugin/pkg/proxy/connect.go @@ -1,8 +1,8 @@ -// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same // client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be // 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses // inband healthchecking. -package forward +package proxy import ( "context" @@ -72,14 +72,14 @@ func (t *Transport) Dial(proto string) (*persistConn, bool, error) { } // Connect selects an upstream, sends the request and waits for a response. -func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options) (*dns.Msg, error) { +func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) { start := time.Now() proto := "" switch { - case opts.forceTCP: // TCP flag has precedence over UDP flag + case opts.ForceTCP: // TCP flag has precedence over UDP flag proto = "tcp" - case opts.preferUDP: + case opts.PreferUDP: proto = "udp" default: proto = state.Proto() @@ -113,7 +113,7 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options } var ret *dns.Msg - pc.c.SetReadDeadline(time.Now().Add(readTimeout)) + pc.c.SetReadDeadline(time.Now().Add(p.readTimeout)) for { ret, err = pc.c.ReadMsg() if err != nil { diff --git a/plugin/pkg/proxy/errors.go b/plugin/pkg/proxy/errors.go new file mode 100644 index 000000000..461236423 --- /dev/null +++ b/plugin/pkg/proxy/errors.go @@ -0,0 +1,26 @@ +package proxy + +import ( + "errors" +) + +var ( + // ErrNoHealthy means no healthy proxies left. + ErrNoHealthy = errors.New("no healthy proxies") + // ErrNoForward means no forwarder defined. + ErrNoForward = errors.New("no forwarder defined") + // ErrCachedClosed means cached connection was closed by peer. + ErrCachedClosed = errors.New("cached connection was closed by peer") +) + +// Options holds various Options that can be set. +type Options struct { + // ForceTCP use TCP protocol for upstream DNS request. Has precedence over PreferUDP flag + ForceTCP bool + // PreferUDP use UDP protocol for upstream DNS request. + PreferUDP bool + // HCRecursionDesired sets recursion desired flag for Proxy healthcheck requests + HCRecursionDesired bool + // HCDomain sets domain for Proxy healthcheck requests + HCDomain string +} diff --git a/plugin/forward/health.go b/plugin/pkg/proxy/health.go similarity index 71% rename from plugin/forward/health.go rename to plugin/pkg/proxy/health.go index ec0b48143..e87104a13 100644 --- a/plugin/forward/health.go +++ b/plugin/pkg/proxy/health.go @@ -1,10 +1,11 @@ -package forward +package proxy import ( "crypto/tls" "sync/atomic" "time" + "github.com/coredns/coredns/plugin/pkg/log" "github.com/coredns/coredns/plugin/pkg/transport" "github.com/miekg/dns" @@ -14,11 +15,16 @@ import ( type HealthChecker interface { Check(*Proxy) error SetTLSConfig(*tls.Config) + GetTLSConfig() *tls.Config SetRecursionDesired(bool) GetRecursionDesired() bool SetDomain(domain string) GetDomain() string SetTCPTransport() + GetReadTimeout() time.Duration + SetReadTimeout(time.Duration) + GetWriteTimeout() time.Duration + SetWriteTimeout(time.Duration) } // dnsHc is a health checker for a DNS endpoint (DNS, and DoT). @@ -28,21 +34,20 @@ type dnsHc struct { domain string } -var ( - hcReadTimeout = 1 * time.Second - hcWriteTimeout = 1 * time.Second -) - // NewHealthChecker returns a new HealthChecker based on transport. func NewHealthChecker(trans string, recursionDesired bool, domain string) HealthChecker { switch trans { case transport.DNS, transport.TLS: c := new(dns.Client) c.Net = "udp" - c.ReadTimeout = hcReadTimeout - c.WriteTimeout = hcWriteTimeout + c.ReadTimeout = 1 * time.Second + c.WriteTimeout = 1 * time.Second - return &dnsHc{c: c, recursionDesired: recursionDesired, domain: domain} + return &dnsHc{ + c: c, + recursionDesired: recursionDesired, + domain: domain, + } } log.Warningf("No healthchecker for transport %q", trans) @@ -54,6 +59,10 @@ func (h *dnsHc) SetTLSConfig(cfg *tls.Config) { h.c.TLSConfig = cfg } +func (h *dnsHc) GetTLSConfig() *tls.Config { + return h.c.TLSConfig +} + func (h *dnsHc) SetRecursionDesired(recursionDesired bool) { h.recursionDesired = recursionDesired } @@ -72,7 +81,23 @@ func (h *dnsHc) SetTCPTransport() { h.c.Net = "tcp" } -// For HC we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty +func (h *dnsHc) GetReadTimeout() time.Duration { + return h.c.ReadTimeout +} + +func (h *dnsHc) SetReadTimeout(t time.Duration) { + h.c.ReadTimeout = t +} + +func (h *dnsHc) GetWriteTimeout() time.Duration { + return h.c.WriteTimeout +} + +func (h *dnsHc) SetWriteTimeout(t time.Duration) { + h.c.WriteTimeout = t +} + +// For HC, we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty // replies are considered fails, basically anything else constitutes a healthy upstream. // Check is used as the up.Func in the up.Probe. diff --git a/plugin/pkg/proxy/health_test.go b/plugin/pkg/proxy/health_test.go new file mode 100644 index 000000000..c1b5270ad --- /dev/null +++ b/plugin/pkg/proxy/health_test.go @@ -0,0 +1,153 @@ +package proxy + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +func TestHealth(t *testing.T) { + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, true, "") + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthTCP(t *testing.T) { + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, true, "") + hc.SetTCPTransport() + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthNoRecursion(t *testing.T) { + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." && r.RecursionDesired == false { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, false, "") + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==false to be %d, got %d", 1, i1) + } +} + +func TestHealthTimeout(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + // timeout + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, false, "") + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err == nil { + t.Errorf("expected error") + } +} + +func TestHealthDomain(t *testing.T) { + hcDomain := "example.org." + + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == hcDomain && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker(transport.DNS, true, hcDomain) + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(12 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1) + } +} diff --git a/plugin/pkg/proxy/metrics.go b/plugin/pkg/proxy/metrics.go new file mode 100644 index 000000000..148bc6edd --- /dev/null +++ b/plugin/pkg/proxy/metrics.go @@ -0,0 +1,49 @@ +package proxy + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Variables declared for monitoring. +var ( + RequestCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "requests_total", + Help: "Counter of requests made per upstream.", + }, []string{"to"}) + RcodeCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "responses_total", + Help: "Counter of responses received per upstream.", + }, []string{"rcode", "to"}) + RequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "request_duration_seconds", + Buckets: plugin.TimeBuckets, + Help: "Histogram of the time each request took.", + }, []string{"to", "rcode"}) + HealthcheckFailureCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "healthcheck_failures_total", + Help: "Counter of the number of failed healthchecks.", + }, []string{"to"}) + ConnCacheHitsCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "conn_cache_hits_total", + Help: "Counter of connection cache hits per upstream and protocol.", + }, []string{"to", "proto"}) + ConnCacheMissesCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "conn_cache_misses_total", + Help: "Counter of connection cache misses per upstream and protocol.", + }, []string{"to", "proto"}) +) diff --git a/plugin/forward/persistent.go b/plugin/pkg/proxy/persistent.go similarity index 94% rename from plugin/forward/persistent.go rename to plugin/pkg/proxy/persistent.go index c53dea82f..0908ce96c 100644 --- a/plugin/forward/persistent.go +++ b/plugin/pkg/proxy/persistent.go @@ -1,4 +1,4 @@ -package forward +package proxy import ( "crypto/tls" @@ -154,9 +154,3 @@ const ( minDialTimeout = 1 * time.Second maxDialTimeout = 30 * time.Second ) - -// Make a var for minimizing this value in tests. -var ( - // Some resolves might take quite a while, usually (cached) responses are fast. Set to 2s to give us some time to retry a different upstream. - readTimeout = 2 * time.Second -) diff --git a/plugin/forward/persistent_test.go b/plugin/pkg/proxy/persistent_test.go similarity index 99% rename from plugin/forward/persistent_test.go rename to plugin/pkg/proxy/persistent_test.go index 633696ac0..c78bd7f1f 100644 --- a/plugin/forward/persistent_test.go +++ b/plugin/pkg/proxy/persistent_test.go @@ -1,4 +1,4 @@ -package forward +package proxy import ( "testing" diff --git a/plugin/forward/proxy.go b/plugin/pkg/proxy/proxy.go similarity index 67% rename from plugin/forward/proxy.go rename to plugin/pkg/proxy/proxy.go index 6a4b5693e..be521fe05 100644 --- a/plugin/forward/proxy.go +++ b/plugin/pkg/proxy/proxy.go @@ -1,4 +1,4 @@ -package forward +package proxy import ( "crypto/tls" @@ -6,6 +6,7 @@ import ( "sync/atomic" "time" + "github.com/coredns/coredns/plugin/pkg/log" "github.com/coredns/coredns/plugin/pkg/up" ) @@ -16,6 +17,8 @@ type Proxy struct { transport *Transport + readTimeout time.Duration + // health checking probe *up.Probe health HealthChecker @@ -24,16 +27,19 @@ type Proxy struct { // NewProxy returns a new proxy. func NewProxy(addr, trans string) *Proxy { p := &Proxy{ - addr: addr, - fails: 0, - probe: up.New(), - transport: newTransport(addr), + addr: addr, + fails: 0, + probe: up.New(), + readTimeout: 2 * time.Second, + transport: newTransport(addr), } p.health = NewHealthChecker(trans, true, ".") runtime.SetFinalizer(p, (*Proxy).finalizer) return p } +func (p *Proxy) Addr() string { return p.addr } + // SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client. func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.transport.SetTLSConfig(cfg) @@ -43,6 +49,14 @@ func (p *Proxy) SetTLSConfig(cfg *tls.Config) { // SetExpire sets the expire duration in the lower p.transport. func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } +func (p *Proxy) GetHealthchecker() HealthChecker { + return p.health +} + +func (p *Proxy) Fails() uint32 { + return atomic.LoadUint32(&p.fails) +} + // Healthcheck kicks of a round of health checks for this proxy. func (p *Proxy) Healthcheck() { if p.health == nil { @@ -65,18 +79,20 @@ func (p *Proxy) Down(maxfails uint32) bool { return fails > maxfails } -// close stops the health checking goroutine. -func (p *Proxy) stop() { p.probe.Stop() } +// Stop close stops the health checking goroutine. +func (p *Proxy) Stop() { p.probe.Stop() } func (p *Proxy) finalizer() { p.transport.Stop() } -// start starts the proxy's healthchecking. -func (p *Proxy) start(duration time.Duration) { +// Start starts the proxy's healthchecking. +func (p *Proxy) Start(duration time.Duration) { p.probe.Start(duration) p.transport.Start() } +func (p *Proxy) SetReadTimeout(duration time.Duration) { + p.readTimeout = duration +} + const ( maxTimeout = 2 * time.Second ) - -var hcInterval = 500 * time.Millisecond diff --git a/plugin/pkg/proxy/proxy_test.go b/plugin/pkg/proxy/proxy_test.go new file mode 100644 index 000000000..274e9679d --- /dev/null +++ b/plugin/pkg/proxy/proxy_test.go @@ -0,0 +1,99 @@ +package proxy + +import ( + "context" + "crypto/tls" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestProxy(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + p.Start(5 * time.Second) + m := new(dns.Msg) + + m.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + req := request.Request{Req: m, W: rec} + + resp, err := p.Connect(context.Background(), req, Options{PreferUDP: true}) + if err != nil { + t.Errorf("Failed to connect to testdnsserver: %s", err) + } + + if x := resp.Answer[0].Header().Name; x != "example.org." { + t.Errorf("Expected %s, got %s", "example.org.", x) + } +} + +func TestProxyTLSFail(t *testing.T) { + // This is an udp/tcp test server, so we shouldn't reach it with TLS. + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.TLS) + p.readTimeout = 10 * time.Millisecond + p.SetTLSConfig(&tls.Config{}) + p.Start(5 * time.Second) + m := new(dns.Msg) + + m.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + req := request.Request{Req: m, W: rec} + + _, err := p.Connect(context.Background(), req, Options{}) + if err == nil { + t.Fatal("Expected *not* to receive reply, but got one") + } +} + +func TestProtocolSelection(t *testing.T) { + p := NewProxy("bad_address", transport.DNS) + p.readTimeout = 10 * time.Millisecond + + stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} + ctx := context.TODO() + + go func() { + p.Connect(ctx, stateUDP, Options{}) + p.Connect(ctx, stateUDP, Options{ForceTCP: true}) + p.Connect(ctx, stateUDP, Options{PreferUDP: true}) + p.Connect(ctx, stateUDP, Options{PreferUDP: true, ForceTCP: true}) + p.Connect(ctx, stateTCP, Options{}) + p.Connect(ctx, stateTCP, Options{ForceTCP: true}) + p.Connect(ctx, stateTCP, Options{PreferUDP: true}) + p.Connect(ctx, stateTCP, Options{PreferUDP: true, ForceTCP: true}) + }() + + for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} { + proto := <-p.transport.dial + p.transport.ret <- nil + if proto != exp { + t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto) + } + } +} diff --git a/plugin/forward/type.go b/plugin/pkg/proxy/type.go similarity index 94% rename from plugin/forward/type.go rename to plugin/pkg/proxy/type.go index 9de842fbe..10f3a4639 100644 --- a/plugin/forward/type.go +++ b/plugin/pkg/proxy/type.go @@ -1,6 +1,8 @@ -package forward +package proxy -import "net" +import ( + "net" +) type transportType int