diff --git a/middleware/proxy/README.md b/middleware/proxy/README.md index 2dee30052..f2de2b48f 100644 --- a/middleware/proxy/README.md +++ b/middleware/proxy/README.md @@ -26,7 +26,7 @@ proxy FROM TO... { health_check PATH:PORT [DURATION] except IGNORED_NAMES... spray - protocol [dns|https_google [bootstrap ADDRESS...]|grpc [insecure|CA-PEM|KEY-PEM CERT-PEM|KEY-PEM CERT-PEM CA-PEM]] + protocol [dns [force_tcp]|https_google [bootstrap ADDRESS...]|grpc [insecure|CA-PEM|KEY-PEM CERT-PEM|KEY-PEM CERT-PEM CA-PEM]] } ~~~ @@ -71,7 +71,8 @@ Currently `protocol` supports `dns` (i.e., standard DNS over UDP/TCP) and `https payload over HTTPS). Note that with `https_google` the entire transport is encrypted. Only *you* and *Google* can see your DNS activity. -* `dns`: no options can be given at the moment. +* `dns`: uses the standard DNS exchange. You can pass `force_tcp` to make sure that the proxied connection is performed + over TCP, regardless of the inbound request's protocol. * `https_google`: bootstrap **ADDRESS...** is used to (re-)resolve `dns.google.com` to an address to connect to. This happens every 300s. If not specified the default is used: 8.8.8.8:53/8.8.4.4:53. Note that **TO** is *ignored* when `https_google` is used, as its upstream is defined as diff --git a/middleware/proxy/dns.go b/middleware/proxy/dns.go index 7fa975733..78a8c3bfe 100644 --- a/middleware/proxy/dns.go +++ b/middleware/proxy/dns.go @@ -14,10 +14,19 @@ import ( type dnsEx struct { Timeout time.Duration group *singleflight.Group + Options +} + +type Options struct { + ForceTCP bool // If true use TCP for upstream no matter what } func newDNSEx() *dnsEx { - return &dnsEx{group: new(singleflight.Group), Timeout: defaultTimeout * time.Second} + return newDNSExWithOption(Options{}) +} + +func newDNSExWithOption(opt Options) *dnsEx { + return &dnsEx{group: new(singleflight.Group), Timeout: defaultTimeout * time.Second, Options: opt} } func (d *dnsEx) Protocol() string { return "dns" } @@ -26,7 +35,11 @@ func (d *dnsEx) OnStartup(p *Proxy) error { return nil } // Exchange implements the Exchanger interface. func (d *dnsEx) Exchange(ctx context.Context, addr string, state request.Request) (*dns.Msg, error) { - co, err := net.DialTimeout(state.Proto(), addr, d.Timeout) + proto := state.Proto() + if d.Options.ForceTCP { + proto = "tcp" + } + co, err := net.DialTimeout(proto, addr, d.Timeout) if err != nil { return nil, err } @@ -43,7 +56,8 @@ func (d *dnsEx) Exchange(ctx context.Context, addr string, state request.Request if err != nil { return nil, err } - + // Make sure it fits in the DNS response. + reply, _ = state.Scrub(reply) reply.Compress = true reply.Id = state.Req.Id diff --git a/middleware/proxy/lookup.go b/middleware/proxy/lookup.go index af94f25e2..e97741fb5 100644 --- a/middleware/proxy/lookup.go +++ b/middleware/proxy/lookup.go @@ -14,6 +14,11 @@ import ( // NewLookup create a new proxy with the hosts in host and a Random policy. func NewLookup(hosts []string) Proxy { + return NewLookupWithOption(hosts, Options{}) +} + +// NewLookupWithForcedProto process creates a simple round robin forward with potentially forced proto for upstream. +func NewLookupWithOption(hosts []string, opts Options) Proxy { p := Proxy{Next: nil} upstream := &staticUpstream{ @@ -23,7 +28,7 @@ func NewLookup(hosts []string) Proxy { Spray: nil, FailTimeout: 10 * time.Second, MaxFails: 3, // TODO(miek): disable error checking for simple lookups? - ex: newDNSEx(), + ex: newDNSExWithOption(opts), } for i, host := range hosts { diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index fdc04cbc4..59e1a534f 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -190,7 +190,16 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { } switch encArgs[0] { case "dns": - u.ex = newDNSEx() + if len(encArgs) > 1 { + if encArgs[1] == "force_tcp" { + opts := Options{ForceTCP: true} + u.ex = newDNSExWithOption(opts) + } else { + return fmt.Errorf("only force_tcp allowed as parameter to dns") + } + } else { + u.ex = newDNSEx() + } case "https_google": boot := []string{"8.8.8.8:53", "8.8.4.4:53"} if len(encArgs) > 2 && encArgs[1] == "bootstrap" { diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go index fc38f1fae..88fba69c1 100644 --- a/middleware/proxy/upstream_test.go +++ b/middleware/proxy/upstream_test.go @@ -198,6 +198,13 @@ proxy . 8.8.8.8:53 { }, { ` +proxy . 8.8.8.8:53 { + protocol dns force_tcp +}`, + false, + }, + { + ` proxy . 8.8.8.8:53 { protocol grpc a b c d }`, @@ -262,6 +269,13 @@ proxy . 8.8.8.8:53 { ` proxy . 8.8.8.8:53 { health_check +}`, + true, + }, + { + ` +proxy . 8.8.8.8:53 { + protocol dns force }`, true, }, diff --git a/test/proxy_test.go b/test/proxy_test.go index cb9f1b298..7e85a0c87 100644 --- a/test/proxy_test.go +++ b/test/proxy_test.go @@ -56,6 +56,50 @@ func TestLookupProxy(t *testing.T) { } } +func TestLookupDnsWithForcedTcp(t *testing.T) { + t.Parallel() + name, rm, err := test.TempFile(".", exampleOrg) + if err != nil { + t.Fatalf("failed to create zone: %s", err) + } + defer rm() + + corefile := `example.org:0 { + file ` + name + ` +} +` + + i, err := CoreDNSServer(corefile) + if err != nil { + t.Fatalf("Could not get CoreDNS serving instance: %s", err) + } + + _, tcp := CoreDNSServerPorts(i, 0) + if tcp == "" { + t.Fatalf("Could not get TCP listening port") + } + defer i.Stop() + + log.SetOutput(ioutil.Discard) + + p := proxy.NewLookupWithOption([]string{tcp}, proxy.Options{ForceTCP: true}) + state := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + resp, err := p.Lookup(state, "example.org.", dns.TypeA) + if err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + // expect answer section with A record in it + if len(resp.Answer) == 0 { + t.Fatalf("Expected to at least one RR in the answer section, got none: %s", resp) + } + if resp.Answer[0].Header().Rrtype != dns.TypeA { + t.Errorf("Expected RR to A, got: %d", resp.Answer[0].Header().Rrtype) + } + if resp.Answer[0].(*dns.A).A.String() != "127.0.0.1" { + t.Errorf("Expected 127.0.0.1, got: %s", resp.Answer[0].(*dns.A).A.String()) + } +} + func BenchmarkLookupProxy(b *testing.B) { t := new(testing.T) name, rm, err := test.TempFile(".", exampleOrg)