From c74c212bdfbae4ce66a711d6dc17f22a4cd7be5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Benkovsk=C3=BD?= Date: Fri, 2 Jun 2023 15:33:34 +0200 Subject: [PATCH] prevent panics when using DoHWriter (#6120) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ondřej Benkovský --- core/dnsserver/https.go | 51 ++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/core/dnsserver/https.go b/core/dnsserver/https.go index 382e06efe..015c52ec5 100644 --- a/core/dnsserver/https.go +++ b/core/dnsserver/https.go @@ -4,13 +4,11 @@ import ( "net" "net/http" - "github.com/coredns/coredns/plugin/pkg/nonwriter" + "github.com/miekg/dns" ) -// DoHWriter is a nonwriter.Writer that adds more specific LocalAddr and RemoteAddr methods. +// DoHWriter is a dns.ResponseWriter that adds more specific LocalAddr and RemoteAddr methods. type DoHWriter struct { - nonwriter.Writer - // raddr is the remote's address. This can be optionally set. raddr net.Addr // laddr is our address. This can be optionally set. @@ -18,13 +16,50 @@ type DoHWriter struct { // request is the HTTP request we're currently handling. request *http.Request + + // Msg is a response to be written to the client. + Msg *dns.Msg +} + +// WriteMsg stores the message to be written to the client. +func (d *DoHWriter) WriteMsg(m *dns.Msg) error { + d.Msg = m + return nil +} + +// Write stores the message to be written to the client. +func (d *DoHWriter) Write(b []byte) (int, error) { + d.Msg = new(dns.Msg) + return len(b), d.Msg.Unpack(b) } // RemoteAddr returns the remote address. -func (d *DoHWriter) RemoteAddr() net.Addr { return d.raddr } +func (d *DoHWriter) RemoteAddr() net.Addr { + return d.raddr +} // LocalAddr returns the local address. -func (d *DoHWriter) LocalAddr() net.Addr { return d.laddr } +func (d *DoHWriter) LocalAddr() net.Addr { + return d.laddr +} -// Request returns the HTTP request -func (d *DoHWriter) Request() *http.Request { return d.request } +// Request returns the HTTP request. +func (d *DoHWriter) Request() *http.Request { + return d.request +} + +// Close no-op implementation. +func (d *DoHWriter) Close() error { + return nil +} + +// TsigStatus no-op implementation. +func (d *DoHWriter) TsigStatus() error { + return nil +} + +// TsigTimersOnly no-op implementation. +func (d *DoHWriter) TsigTimersOnly(_ bool) {} + +// Hijack no-op implementation. +func (d *DoHWriter) Hijack() {}