diff --git a/plugin/forward/README.md b/plugin/forward/README.md index d0cf89e88..fb199366d 100644 --- a/plugin/forward/README.md +++ b/plugin/forward/README.md @@ -47,6 +47,7 @@ Extra knobs are available with an expanded syntax: forward FROM TO... { except IGNORED_NAMES... force_tcp + prefer_udp expire DURATION max_fails INTEGER tls CERT KEY CA @@ -60,6 +61,9 @@ forward FROM TO... { * **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding. Requests that match none of these names will be passed through. * `force_tcp`, use TCP even when the request comes in over UDP. +* `prefer_udp`, try first using UDP even when the request comes in over TCP. If response is truncated + (TC flag set in response) then do another attempt over TCP. In case if both `force_tcp` and `prefer_udp` + options specified the `force_tcp` takes precedence. * `max_fails` is the number of subsequent failed health checks that are needed before considering an upstream to be down. If 0, the upstream will never be marked as down (nor health checked). Default is 2. diff --git a/plugin/forward/connect.go b/plugin/forward/connect.go index 3259fda5e..4a0a7141e 100644 --- a/plugin/forward/connect.go +++ b/plugin/forward/connect.go @@ -78,12 +78,17 @@ func (p *Proxy) updateRtt(newRtt time.Duration) { } // Connect selects an upstream, sends the request and waits for a response. -func (p *Proxy) Connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) { +func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options) (*dns.Msg, error) { start := time.Now() - proto := state.Proto() - if forceTCP { + proto := "" + switch { + case opts.forceTCP: // TCP flag has precedence over UDP flag proto = "tcp" + case opts.preferUDP: + proto = "udp" + default: + proto = state.Proto() } conn, cached, err := p.Dial(proto) @@ -122,17 +127,15 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, forceTCP, me p.Yield(conn) - if metric { - rc, ok := dns.RcodeToString[ret.Rcode] - if !ok { - rc = strconv.Itoa(ret.Rcode) - } - - RequestCount.WithLabelValues(p.addr).Add(1) - RcodeCount.WithLabelValues(rc, p.addr).Add(1) - RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds()) + rc, ok := dns.RcodeToString[ret.Rcode] + if !ok { + rc = strconv.Itoa(ret.Rcode) } + RequestCount.WithLabelValues(p.addr).Add(1) + RcodeCount.WithLabelValues(rc, p.addr).Add(1) + RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds()) + return ret, nil } diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go index e901572f9..861ff61a8 100644 --- a/plugin/forward/forward.go +++ b/plugin/forward/forward.go @@ -33,7 +33,7 @@ type Forward struct { maxfails uint32 expire time.Duration - forceTCP bool // also here for testing + opts options // also here for testing Next plugin.Handler } @@ -103,9 +103,18 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg ret *dns.Msg err error ) + opts := f.opts for { - ret, err = proxy.Connect(ctx, state, f.forceTCP, true) - if err != nil && err == ErrCachedClosed { // Remote side closed conn, can only happen with TCP. + ret, err = proxy.Connect(ctx, state, opts) + if err == nil { + break + } + if err == ErrCachedClosed { // Remote side closed conn, can only happen with TCP. + continue + } + // Retry with TCP if truncated and prefer_udp configured + if err == dns.ErrTruncated && !opts.forceTCP && f.opts.preferUDP { + opts.forceTCP = true continue } break @@ -183,7 +192,10 @@ func (f *Forward) isAllowedDomain(name string) bool { func (f *Forward) From() string { return f.from } // ForceTCP returns if TCP is forced to be used even when the request comes in over UDP. -func (f *Forward) ForceTCP() bool { return f.forceTCP } +func (f *Forward) ForceTCP() bool { return f.opts.forceTCP } + +// PreferUDP returns if UDP is preferred to be used even when the request comes in over TCP. +func (f *Forward) PreferUDP() bool { return f.opts.preferUDP } // List returns a set of proxies to be used for this client depending on the policy in f. func (f *Forward) List() []*Proxy { return f.p.List(f.proxies) } @@ -206,4 +218,9 @@ const ( sequentialPolicy ) +type options struct { + forceTCP bool + preferUDP bool +} + const defaultTimeout = 5 * time.Second diff --git a/plugin/forward/lookup.go b/plugin/forward/lookup.go index 65ee593f0..96eceab84 100644 --- a/plugin/forward/lookup.go +++ b/plugin/forward/lookup.go @@ -32,7 +32,7 @@ func (f *Forward) Forward(state request.Request) (*dns.Msg, error) { proxy = f.List()[0] } - ret, err := proxy.Connect(context.Background(), state, f.forceTCP, true) + ret, err := proxy.Connect(context.Background(), state, f.opts) ret, err = truncated(state, ret, err) upstreamErr = err diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go index 01e8ef6fd..d68d1def2 100644 --- a/plugin/forward/proxy_test.go +++ b/plugin/forward/proxy_test.go @@ -29,10 +29,10 @@ func TestProxyClose(t *testing.T) { p := NewProxy(s.Addr, nil) p.start(hcInterval) - go func() { p.Connect(ctx, state, false, false) }() - go func() { p.Connect(ctx, state, true, false) }() - go func() { p.Connect(ctx, state, false, false) }() - go func() { p.Connect(ctx, state, true, false) }() + go func() { p.Connect(ctx, state, options{}) }() + go func() { p.Connect(ctx, state, options{forceTCP: true}) }() + go func() { p.Connect(ctx, state, options{}) }() + go func() { p.Connect(ctx, state, options{forceTCP: true}) }() p.close() } @@ -93,3 +93,30 @@ func TestProxyTLSFail(t *testing.T) { t.Fatal("Expected *not* to receive reply, but got one") } } + +func TestProtocolSelection(t *testing.T) { + p := NewProxy("bad_address", nil) + + stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} + ctx := context.TODO() + + go func() { + p.Connect(ctx, stateUDP, options{}) + p.Connect(ctx, stateUDP, options{forceTCP: true}) + p.Connect(ctx, stateUDP, options{preferUDP: true}) + p.Connect(ctx, stateUDP, options{preferUDP: true, forceTCP: true}) + p.Connect(ctx, stateTCP, options{}) + p.Connect(ctx, stateTCP, options{forceTCP: true}) + p.Connect(ctx, stateTCP, options{preferUDP: true}) + p.Connect(ctx, stateTCP, options{preferUDP: true, forceTCP: true}) + }() + + for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} { + proto := <-p.transport.dial + p.transport.ret <- nil + if proto != exp { + t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto) + } + } +} diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go index 7afafc8a7..152ba36c5 100644 --- a/plugin/forward/setup.go +++ b/plugin/forward/setup.go @@ -187,7 +187,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error { if c.NextArg() { return c.ArgErr() } - f.forceTCP = true + f.opts.forceTCP = true + case "prefer_udp": + if c.NextArg() { + return c.ArgErr() + } + f.opts.preferUDP = true case "tls": args := c.RemainingArgs() if len(args) > 3 { diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index fba2359b9..a8140d410 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -10,28 +10,30 @@ import ( func TestSetup(t *testing.T) { tests := []struct { - input string - shouldErr bool - expectedFrom string - expectedIgnored []string - expectedFails uint32 - expectedForceTCP bool - expectedErr string + input string + shouldErr bool + expectedFrom string + expectedIgnored []string + expectedFails uint32 + expectedOpts options + expectedErr string }{ // positive - {"forward . 127.0.0.1", false, ".", nil, 2, false, ""}, - {"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, false, ""}, - {"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, false, ""}, - {"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, true, ""}, - {"forward . 127.0.0.1:53", false, ".", nil, 2, false, ""}, - {"forward . 127.0.0.1:8080", false, ".", nil, 2, false, ""}, - {"forward . [::1]:53", false, ".", nil, 2, false, ""}, - {"forward . [2003::1]:53", false, ".", nil, 2, false, ""}, + {"forward . 127.0.0.1", false, ".", nil, 2, options{}, ""}, + {"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, options{}, ""}, + {"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, options{}, ""}, + {"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, options{forceTCP: true}, ""}, + {"forward . 127.0.0.1 {\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true}, ""}, + {"forward . 127.0.0.1 {\nforce_tcp\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true, forceTCP: true}, ""}, + {"forward . 127.0.0.1:53", false, ".", nil, 2, options{}, ""}, + {"forward . 127.0.0.1:8080", false, ".", nil, 2, options{}, ""}, + {"forward . [::1]:53", false, ".", nil, 2, options{}, ""}, + {"forward . [2003::1]:53", false, ".", nil, 2, options{}, ""}, // negative - {"forward . a27.0.0.1", true, "", nil, 0, false, "not an IP"}, - {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, false, "unknown property"}, + {"forward . a27.0.0.1", true, "", nil, 0, options{}, "not an IP"}, + {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, options{}, "unknown property"}, {`forward . ::1 - forward com ::2`, true, "", nil, 0, false, "plugin"}, + forward com ::2`, true, "", nil, 0, options{}, "plugin"}, } for i, test := range tests { @@ -63,8 +65,8 @@ func TestSetup(t *testing.T) { if !test.shouldErr && f.maxfails != test.expectedFails { t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedFails, f.maxfails) } - if !test.shouldErr && f.forceTCP != test.expectedForceTCP { - t.Errorf("Test %d: expected: %t, got: %t", i, test.expectedForceTCP, f.forceTCP) + if !test.shouldErr && f.opts != test.expectedOpts { + t.Errorf("Test %d: expected: %v, got: %v", i, test.expectedOpts, f.opts) } } } diff --git a/plugin/test/responsewriter.go b/plugin/test/responsewriter.go index 4db5728e4..32796249b 100644 --- a/plugin/test/responsewriter.go +++ b/plugin/test/responsewriter.go @@ -9,12 +9,17 @@ import ( // ResponseWriter is useful for writing tests. It uses some fixed values for the client. The // remote will always be 10.240.0.1 and port 40212. The local address is always 127.0.0.1 and // port 53. -type ResponseWriter struct{} +type ResponseWriter struct { + TCP bool +} // LocalAddr returns the local address, always 127.0.0.1:53 (UDP). func (t *ResponseWriter) LocalAddr() net.Addr { ip := net.ParseIP("127.0.0.1") port := 53 + if t.TCP { + return &net.TCPAddr{IP: ip, Port: port, Zone: ""} + } return &net.UDPAddr{IP: ip, Port: port, Zone: ""} } @@ -22,6 +27,9 @@ func (t *ResponseWriter) LocalAddr() net.Addr { func (t *ResponseWriter) RemoteAddr() net.Addr { ip := net.ParseIP("10.240.0.1") port := 40212 + if t.TCP { + return &net.TCPAddr{IP: ip, Port: port, Zone: ""} + } return &net.UDPAddr{IP: ip, Port: port, Zone: ""} } @@ -52,10 +60,16 @@ type ResponseWriter6 struct { // LocalAddr returns the local address, always ::1, port 53 (UDP). func (t *ResponseWriter6) LocalAddr() net.Addr { + if t.TCP { + return &net.TCPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""} + } return &net.UDPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""} } // RemoteAddr returns the remote address, always fe80::42:ff:feca:4c65 port 40212 (UDP). func (t *ResponseWriter6) RemoteAddr() net.Addr { + if t.TCP { + return &net.TCPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""} + } return &net.UDPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""} }