plugin/forward: add prefer_udp option (#1944)
* plugin/forward: add prefer_udp option * updated according to code review - fixed linter warning - removed metric parameter in Proxy.Connect()
This commit is contained in:
parent
7c41f2ce9f
commit
bc50901234
8 changed files with 115 additions and 43 deletions
|
@ -47,6 +47,7 @@ Extra knobs are available with an expanded syntax:
|
|||
forward FROM TO... {
|
||||
except IGNORED_NAMES...
|
||||
force_tcp
|
||||
prefer_udp
|
||||
expire DURATION
|
||||
max_fails INTEGER
|
||||
tls CERT KEY CA
|
||||
|
@ -60,6 +61,9 @@ forward FROM TO... {
|
|||
* **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding.
|
||||
Requests that match none of these names will be passed through.
|
||||
* `force_tcp`, use TCP even when the request comes in over UDP.
|
||||
* `prefer_udp`, try first using UDP even when the request comes in over TCP. If response is truncated
|
||||
(TC flag set in response) then do another attempt over TCP. In case if both `force_tcp` and `prefer_udp`
|
||||
options specified the `force_tcp` takes precedence.
|
||||
* `max_fails` is the number of subsequent failed health checks that are needed before considering
|
||||
an upstream to be down. If 0, the upstream will never be marked as down (nor health checked).
|
||||
Default is 2.
|
||||
|
|
|
@ -78,12 +78,17 @@ func (p *Proxy) updateRtt(newRtt time.Duration) {
|
|||
}
|
||||
|
||||
// Connect selects an upstream, sends the request and waits for a response.
|
||||
func (p *Proxy) Connect(ctx context.Context, state request.Request, forceTCP, metric bool) (*dns.Msg, error) {
|
||||
func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options) (*dns.Msg, error) {
|
||||
start := time.Now()
|
||||
|
||||
proto := state.Proto()
|
||||
if forceTCP {
|
||||
proto := ""
|
||||
switch {
|
||||
case opts.forceTCP: // TCP flag has precedence over UDP flag
|
||||
proto = "tcp"
|
||||
case opts.preferUDP:
|
||||
proto = "udp"
|
||||
default:
|
||||
proto = state.Proto()
|
||||
}
|
||||
|
||||
conn, cached, err := p.Dial(proto)
|
||||
|
@ -122,17 +127,15 @@ func (p *Proxy) Connect(ctx context.Context, state request.Request, forceTCP, me
|
|||
|
||||
p.Yield(conn)
|
||||
|
||||
if metric {
|
||||
rc, ok := dns.RcodeToString[ret.Rcode]
|
||||
if !ok {
|
||||
rc = strconv.Itoa(ret.Rcode)
|
||||
}
|
||||
|
||||
RequestCount.WithLabelValues(p.addr).Add(1)
|
||||
RcodeCount.WithLabelValues(rc, p.addr).Add(1)
|
||||
RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds())
|
||||
rc, ok := dns.RcodeToString[ret.Rcode]
|
||||
if !ok {
|
||||
rc = strconv.Itoa(ret.Rcode)
|
||||
}
|
||||
|
||||
RequestCount.WithLabelValues(p.addr).Add(1)
|
||||
RcodeCount.WithLabelValues(rc, p.addr).Add(1)
|
||||
RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds())
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ type Forward struct {
|
|||
maxfails uint32
|
||||
expire time.Duration
|
||||
|
||||
forceTCP bool // also here for testing
|
||||
opts options // also here for testing
|
||||
|
||||
Next plugin.Handler
|
||||
}
|
||||
|
@ -103,9 +103,18 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
|
|||
ret *dns.Msg
|
||||
err error
|
||||
)
|
||||
opts := f.opts
|
||||
for {
|
||||
ret, err = proxy.Connect(ctx, state, f.forceTCP, true)
|
||||
if err != nil && err == ErrCachedClosed { // Remote side closed conn, can only happen with TCP.
|
||||
ret, err = proxy.Connect(ctx, state, opts)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if err == ErrCachedClosed { // Remote side closed conn, can only happen with TCP.
|
||||
continue
|
||||
}
|
||||
// Retry with TCP if truncated and prefer_udp configured
|
||||
if err == dns.ErrTruncated && !opts.forceTCP && f.opts.preferUDP {
|
||||
opts.forceTCP = true
|
||||
continue
|
||||
}
|
||||
break
|
||||
|
@ -183,7 +192,10 @@ func (f *Forward) isAllowedDomain(name string) bool {
|
|||
func (f *Forward) From() string { return f.from }
|
||||
|
||||
// ForceTCP returns if TCP is forced to be used even when the request comes in over UDP.
|
||||
func (f *Forward) ForceTCP() bool { return f.forceTCP }
|
||||
func (f *Forward) ForceTCP() bool { return f.opts.forceTCP }
|
||||
|
||||
// PreferUDP returns if UDP is preferred to be used even when the request comes in over TCP.
|
||||
func (f *Forward) PreferUDP() bool { return f.opts.preferUDP }
|
||||
|
||||
// List returns a set of proxies to be used for this client depending on the policy in f.
|
||||
func (f *Forward) List() []*Proxy { return f.p.List(f.proxies) }
|
||||
|
@ -206,4 +218,9 @@ const (
|
|||
sequentialPolicy
|
||||
)
|
||||
|
||||
type options struct {
|
||||
forceTCP bool
|
||||
preferUDP bool
|
||||
}
|
||||
|
||||
const defaultTimeout = 5 * time.Second
|
||||
|
|
|
@ -32,7 +32,7 @@ func (f *Forward) Forward(state request.Request) (*dns.Msg, error) {
|
|||
proxy = f.List()[0]
|
||||
}
|
||||
|
||||
ret, err := proxy.Connect(context.Background(), state, f.forceTCP, true)
|
||||
ret, err := proxy.Connect(context.Background(), state, f.opts)
|
||||
|
||||
ret, err = truncated(state, ret, err)
|
||||
upstreamErr = err
|
||||
|
|
|
@ -29,10 +29,10 @@ func TestProxyClose(t *testing.T) {
|
|||
p := NewProxy(s.Addr, nil)
|
||||
p.start(hcInterval)
|
||||
|
||||
go func() { p.Connect(ctx, state, false, false) }()
|
||||
go func() { p.Connect(ctx, state, true, false) }()
|
||||
go func() { p.Connect(ctx, state, false, false) }()
|
||||
go func() { p.Connect(ctx, state, true, false) }()
|
||||
go func() { p.Connect(ctx, state, options{}) }()
|
||||
go func() { p.Connect(ctx, state, options{forceTCP: true}) }()
|
||||
go func() { p.Connect(ctx, state, options{}) }()
|
||||
go func() { p.Connect(ctx, state, options{forceTCP: true}) }()
|
||||
|
||||
p.close()
|
||||
}
|
||||
|
@ -93,3 +93,30 @@ func TestProxyTLSFail(t *testing.T) {
|
|||
t.Fatal("Expected *not* to receive reply, but got one")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtocolSelection(t *testing.T) {
|
||||
p := NewProxy("bad_address", nil)
|
||||
|
||||
stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)}
|
||||
stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)}
|
||||
ctx := context.TODO()
|
||||
|
||||
go func() {
|
||||
p.Connect(ctx, stateUDP, options{})
|
||||
p.Connect(ctx, stateUDP, options{forceTCP: true})
|
||||
p.Connect(ctx, stateUDP, options{preferUDP: true})
|
||||
p.Connect(ctx, stateUDP, options{preferUDP: true, forceTCP: true})
|
||||
p.Connect(ctx, stateTCP, options{})
|
||||
p.Connect(ctx, stateTCP, options{forceTCP: true})
|
||||
p.Connect(ctx, stateTCP, options{preferUDP: true})
|
||||
p.Connect(ctx, stateTCP, options{preferUDP: true, forceTCP: true})
|
||||
}()
|
||||
|
||||
for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} {
|
||||
proto := <-p.transport.dial
|
||||
p.transport.ret <- nil
|
||||
if proto != exp {
|
||||
t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -187,7 +187,12 @@ func parseBlock(c *caddy.Controller, f *Forward) error {
|
|||
if c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
f.forceTCP = true
|
||||
f.opts.forceTCP = true
|
||||
case "prefer_udp":
|
||||
if c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
f.opts.preferUDP = true
|
||||
case "tls":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) > 3 {
|
||||
|
|
|
@ -10,28 +10,30 @@ import (
|
|||
|
||||
func TestSetup(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedFrom string
|
||||
expectedIgnored []string
|
||||
expectedFails uint32
|
||||
expectedForceTCP bool
|
||||
expectedErr string
|
||||
input string
|
||||
shouldErr bool
|
||||
expectedFrom string
|
||||
expectedIgnored []string
|
||||
expectedFails uint32
|
||||
expectedOpts options
|
||||
expectedErr string
|
||||
}{
|
||||
// positive
|
||||
{"forward . 127.0.0.1", false, ".", nil, 2, false, ""},
|
||||
{"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, false, ""},
|
||||
{"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, false, ""},
|
||||
{"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, true, ""},
|
||||
{"forward . 127.0.0.1:53", false, ".", nil, 2, false, ""},
|
||||
{"forward . 127.0.0.1:8080", false, ".", nil, 2, false, ""},
|
||||
{"forward . [::1]:53", false, ".", nil, 2, false, ""},
|
||||
{"forward . [2003::1]:53", false, ".", nil, 2, false, ""},
|
||||
{"forward . 127.0.0.1", false, ".", nil, 2, options{}, ""},
|
||||
{"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, options{}, ""},
|
||||
{"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, options{}, ""},
|
||||
{"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, options{forceTCP: true}, ""},
|
||||
{"forward . 127.0.0.1 {\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true}, ""},
|
||||
{"forward . 127.0.0.1 {\nforce_tcp\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true, forceTCP: true}, ""},
|
||||
{"forward . 127.0.0.1:53", false, ".", nil, 2, options{}, ""},
|
||||
{"forward . 127.0.0.1:8080", false, ".", nil, 2, options{}, ""},
|
||||
{"forward . [::1]:53", false, ".", nil, 2, options{}, ""},
|
||||
{"forward . [2003::1]:53", false, ".", nil, 2, options{}, ""},
|
||||
// negative
|
||||
{"forward . a27.0.0.1", true, "", nil, 0, false, "not an IP"},
|
||||
{"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, false, "unknown property"},
|
||||
{"forward . a27.0.0.1", true, "", nil, 0, options{}, "not an IP"},
|
||||
{"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, options{}, "unknown property"},
|
||||
{`forward . ::1
|
||||
forward com ::2`, true, "", nil, 0, false, "plugin"},
|
||||
forward com ::2`, true, "", nil, 0, options{}, "plugin"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
|
@ -63,8 +65,8 @@ func TestSetup(t *testing.T) {
|
|||
if !test.shouldErr && f.maxfails != test.expectedFails {
|
||||
t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedFails, f.maxfails)
|
||||
}
|
||||
if !test.shouldErr && f.forceTCP != test.expectedForceTCP {
|
||||
t.Errorf("Test %d: expected: %t, got: %t", i, test.expectedForceTCP, f.forceTCP)
|
||||
if !test.shouldErr && f.opts != test.expectedOpts {
|
||||
t.Errorf("Test %d: expected: %v, got: %v", i, test.expectedOpts, f.opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,12 +9,17 @@ import (
|
|||
// ResponseWriter is useful for writing tests. It uses some fixed values for the client. The
|
||||
// remote will always be 10.240.0.1 and port 40212. The local address is always 127.0.0.1 and
|
||||
// port 53.
|
||||
type ResponseWriter struct{}
|
||||
type ResponseWriter struct {
|
||||
TCP bool
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address, always 127.0.0.1:53 (UDP).
|
||||
func (t *ResponseWriter) LocalAddr() net.Addr {
|
||||
ip := net.ParseIP("127.0.0.1")
|
||||
port := 53
|
||||
if t.TCP {
|
||||
return &net.TCPAddr{IP: ip, Port: port, Zone: ""}
|
||||
}
|
||||
return &net.UDPAddr{IP: ip, Port: port, Zone: ""}
|
||||
}
|
||||
|
||||
|
@ -22,6 +27,9 @@ func (t *ResponseWriter) LocalAddr() net.Addr {
|
|||
func (t *ResponseWriter) RemoteAddr() net.Addr {
|
||||
ip := net.ParseIP("10.240.0.1")
|
||||
port := 40212
|
||||
if t.TCP {
|
||||
return &net.TCPAddr{IP: ip, Port: port, Zone: ""}
|
||||
}
|
||||
return &net.UDPAddr{IP: ip, Port: port, Zone: ""}
|
||||
}
|
||||
|
||||
|
@ -52,10 +60,16 @@ type ResponseWriter6 struct {
|
|||
|
||||
// LocalAddr returns the local address, always ::1, port 53 (UDP).
|
||||
func (t *ResponseWriter6) LocalAddr() net.Addr {
|
||||
if t.TCP {
|
||||
return &net.TCPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""}
|
||||
}
|
||||
return &net.UDPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""}
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address, always fe80::42:ff:feca:4c65 port 40212 (UDP).
|
||||
func (t *ResponseWriter6) RemoteAddr() net.Addr {
|
||||
if t.TCP {
|
||||
return &net.TCPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""}
|
||||
}
|
||||
return &net.UDPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue