diff --git a/plugin/forward/proxy.go b/plugin/forward/proxy.go index 8454b296d..c788f98cc 100644 --- a/plugin/forward/proxy.go +++ b/plugin/forward/proxy.go @@ -57,8 +57,11 @@ func dnsClient(tlsConfig *tls.Config) *dns.Client { return c } -// SetTLSConfig sets the TLS config in the lower p.transport. -func (p *Proxy) SetTLSConfig(cfg *tls.Config) { p.transport.SetTLSConfig(cfg) } +// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client. +func (p *Proxy) SetTLSConfig(cfg *tls.Config) { + p.transport.SetTLSConfig(cfg) + p.client = dnsClient(cfg) +} // SetExpire sets the expire duration in the lower p.transport. func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go index e33e274c0..a46b3f1ee 100644 --- a/plugin/forward/proxy_test.go +++ b/plugin/forward/proxy_test.go @@ -9,6 +9,7 @@ import ( "github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/request" + "github.com/mholt/caddy" "github.com/miekg/dns" ) @@ -61,3 +62,59 @@ func TestProxyClose(t *testing.T) { } } } + +func TestProxy(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + c := caddy.NewTestController("dns", "forward . "+s.Addr) + f, err := parseForward(c) + if err != nil { + t.Errorf("Failed to create forwarder: %s", err) + } + f.OnStartup() + defer f.OnShutdown() + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + if _, err := f.ServeDNS(context.TODO(), rec, m); err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + if x := rec.Msg.Answer[0].Header().Name; x != "example.org." { + t.Errorf("Expected %s, got %s", "example.org.", x) + } +} + +func TestProxyTLSFail(t *testing.T) { + // This is an udp/tcp test server, so we shouldn't reach it with TLS. + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + c := caddy.NewTestController("dns", "forward . tls://"+s.Addr) + f, err := parseForward(c) + if err != nil { + t.Errorf("Failed to create forwarder: %s", err) + } + f.OnStartup() + defer f.OnShutdown() + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + if _, err := f.ServeDNS(context.TODO(), rec, m); err == nil { + t.Fatal("Expected *not* to receive reply, but got one") + } +} diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go index d787a59d0..fba2359b9 100644 --- a/plugin/forward/setup_test.go +++ b/plugin/forward/setup_test.go @@ -77,9 +77,16 @@ func TestSetupTLS(t *testing.T) { expectedErr string }{ // positive + {`forward . tls://127.0.0.1 { + tls_servername dns + }`, false, "dns", ""}, {`forward . 127.0.0.1 { - tls_servername dns - }`, false, "dns", ""}, + tls_servername dns + }`, false, "", ""}, + {`forward . 127.0.0.1 { + tls + }`, false, "", ""}, + {`forward . tls://127.0.0.1`, false, "", ""}, } for i, test := range tests { @@ -100,8 +107,12 @@ func TestSetupTLS(t *testing.T) { } } - if !test.shouldErr && test.expectedServerName != f.tlsConfig.ServerName { + if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.tlsConfig.ServerName { t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName) } + + if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].client.TLSConfig.ServerName { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].client.TLSConfig.ServerName) + } } }