diff --git a/core/dnsserver/https.go b/core/dnsserver/https.go index 915d366ca..532124575 100644 --- a/core/dnsserver/https.go +++ b/core/dnsserver/https.go @@ -1,62 +1,11 @@ package dnsserver import ( - "encoding/base64" - "fmt" - "io/ioutil" "net" - "net/http" "github.com/coredns/coredns/plugin/pkg/nonwriter" - "github.com/miekg/dns" ) -// mimeTypeDOH is the DoH mimetype that should be used. -const mimeTypeDOH = "application/dns-message" - -// pathDOH is the URL path that should be used. -const pathDOH = "/dns-query" - -// postRequestToMsg extracts the dns message from the request body. -func postRequestToMsg(req *http.Request) (*dns.Msg, error) { - defer req.Body.Close() - - buf, err := ioutil.ReadAll(req.Body) - if err != nil { - return nil, err - } - m := new(dns.Msg) - err = m.Unpack(buf) - return m, err -} - -// getRequestToMsg extract the dns message from the GET request. -func getRequestToMsg(req *http.Request) (*dns.Msg, error) { - values := req.URL.Query() - b64, ok := values["dns"] - if !ok { - return nil, fmt.Errorf("no 'dns' query parameter found") - } - if len(b64) != 1 { - return nil, fmt.Errorf("multiple 'dns' query values found") - } - return base64ToMsg(b64[0]) -} - -func base64ToMsg(b64 string) (*dns.Msg, error) { - buf, err := b64Enc.DecodeString(b64) - if err != nil { - return nil, err - } - - m := new(dns.Msg) - err = m.Unpack(buf) - - return m, err -} - -var b64Enc = base64.RawURLEncoding - // DoHWriter is a nonwriter.Writer that adds more specific LocalAddr and RemoteAddr methods. type DoHWriter struct { nonwriter.Writer diff --git a/core/dnsserver/https_test.go b/core/dnsserver/https_test.go deleted file mode 100644 index a0ddc4b25..000000000 --- a/core/dnsserver/https_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package dnsserver - -import ( - "bytes" - "encoding/base64" - "net/http" - "testing" - - "github.com/miekg/dns" -) - -func TestPostRequest(t *testing.T) { - const ex = "example.org." - - m := new(dns.Msg) - m.SetQuestion(ex, dns.TypeDNSKEY) - - out, _ := m.Pack() - req, err := http.NewRequest(http.MethodPost, "https://"+ex+pathDOH+"?bla=foo:443", bytes.NewReader(out)) - if err != nil { - t.Errorf("Failure to make request: %s", err) - } - req.Header.Set("content-type", mimeTypeDOH) - req.Header.Set("accept", mimeTypeDOH) - - m, err = postRequestToMsg(req) - if err != nil { - t.Fatalf("Failure to get message from request: %s", err) - } - - if x := m.Question[0].Name; x != ex { - t.Errorf("Qname expected %s, got %s", ex, x) - } - if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { - t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) - } -} - -func TestGetRequest(t *testing.T) { - const ex = "example.org." - - m := new(dns.Msg) - m.SetQuestion(ex, dns.TypeDNSKEY) - - out, _ := m.Pack() - b64 := base64.RawURLEncoding.EncodeToString(out) - - req, err := http.NewRequest(http.MethodGet, "https://"+ex+pathDOH+"?dns="+b64, nil) - if err != nil { - t.Errorf("Failure to make request: %s", err) - } - req.Header.Set("content-type", mimeTypeDOH) - req.Header.Set("accept", mimeTypeDOH) - - m, err = getRequestToMsg(req) - if err != nil { - t.Fatalf("Failure to get message from request: %s", err) - } - - if x := m.Question[0].Name; x != ex { - t.Errorf("Qname expected %s, got %s", ex, x) - } - if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { - t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) - } -} diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index cf5d08a45..9b1eaaa7e 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -10,9 +10,8 @@ import ( "time" "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/plugin/pkg/doh" "github.com/coredns/coredns/plugin/pkg/response" - - "github.com/miekg/dns" ) // ServerHTTPS represents an instance of a DNS-over-HTTPS server. @@ -99,24 +98,12 @@ func (s *ServerHTTPS) Stop() error { // chain, converts it back and write it to the client. func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { - msg := new(dns.Msg) - var err error - - if r.URL.Path != pathDOH { + if r.URL.Path != doh.Path { http.Error(w, "", http.StatusNotFound) return } - switch r.Method { - case http.MethodPost: - msg, err = postRequestToMsg(r) - case http.MethodGet: - msg, err = getRequestToMsg(r) - default: - http.Error(w, "", http.StatusMethodNotAllowed) - return - } - + msg, err := doh.RequestToMsg(r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -136,7 +123,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { mt, _ := response.Typify(dw.Msg, time.Now().UTC()) age := dnsutil.MinimalTTL(dw.Msg, mt) - w.Header().Set("Content-Type", mimeTypeDOH) + w.Header().Set("Content-Type", doh.MimeType) w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%f", age.Seconds())) w.Header().Set("Content-Length", strconv.Itoa(len(buf))) w.WriteHeader(http.StatusOK) diff --git a/plugin/pkg/doh/doh.go b/plugin/pkg/doh/doh.go new file mode 100644 index 000000000..e0a398e9c --- /dev/null +++ b/plugin/pkg/doh/doh.go @@ -0,0 +1,119 @@ +package doh + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "net/http" + + "github.com/miekg/dns" +) + +// MimeType is the DoH mimetype that should be used. +const MimeType = "application/dns-message" + +// Path is the URL path that should be used. +const Path = "/dns-query" + +// NewRequest returns a new DoH request given a method, URL (without any paths, so exclude /dns-query) and dns.Msg. +func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) { + buf, err := m.Pack() + if err != nil { + return nil, err + } + + switch method { + case http.MethodGet: + b64 := base64.RawURLEncoding.EncodeToString(buf) + + req, err := http.NewRequest(http.MethodGet, "https://"+url+Path+"?dns="+b64, nil) + if err != nil { + return req, err + } + + req.Header.Set("content-type", MimeType) + req.Header.Set("accept", MimeType) + return req, nil + + case http.MethodPost: + req, err := http.NewRequest(http.MethodPost, "https://"+url+Path+"?bla=foo:443", bytes.NewReader(buf)) + if err != nil { + return req, err + } + + req.Header.Set("content-type", MimeType) + req.Header.Set("accept", MimeType) + return req, nil + + default: + return nil, fmt.Errorf("method not allowed: %s", method) + } + +} + +// ResponseToMsg converts a http.Repsonse to a dns message. +func ResponseToMsg(resp *http.Response) (*dns.Msg, error) { + defer resp.Body.Close() + + return toMsg(resp.Body) +} + +// RequestToMsg converts a http.Request to a dns message. +func RequestToMsg(req *http.Request) (*dns.Msg, error) { + switch req.Method { + case http.MethodGet: + return requestToMsgGet(req) + + case http.MethodPost: + return requestToMsgPost(req) + + default: + return nil, fmt.Errorf("method not allowed: %s", req.Method) + } + +} + +// requestToMsgPost extracts the dns message from the request body. +func requestToMsgPost(req *http.Request) (*dns.Msg, error) { + defer req.Body.Close() + return toMsg(req.Body) +} + +// requestToMsgGet extract the dns message from the GET request. +func requestToMsgGet(req *http.Request) (*dns.Msg, error) { + values := req.URL.Query() + b64, ok := values["dns"] + if !ok { + return nil, fmt.Errorf("no 'dns' query parameter found") + } + if len(b64) != 1 { + return nil, fmt.Errorf("multiple 'dns' query values found") + } + return base64ToMsg(b64[0]) +} + +func toMsg(r io.ReadCloser) (*dns.Msg, error) { + buf, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + m := new(dns.Msg) + err = m.Unpack(buf) + return m, err +} + +func base64ToMsg(b64 string) (*dns.Msg, error) { + buf, err := b64Enc.DecodeString(b64) + if err != nil { + return nil, err + } + + m := new(dns.Msg) + err = m.Unpack(buf) + + return m, err +} + +var b64Enc = base64.RawURLEncoding diff --git a/plugin/pkg/doh/doh_test.go b/plugin/pkg/doh/doh_test.go new file mode 100644 index 000000000..449166151 --- /dev/null +++ b/plugin/pkg/doh/doh_test.go @@ -0,0 +1,52 @@ +package doh + +import ( + "net/http" + "testing" + + "github.com/miekg/dns" +) + +func TestPostRequest(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeDNSKEY) + + req, err := NewRequest(http.MethodPost, "https://example.org:443", m) + if err != nil { + t.Errorf("Failure to make request: %s", err) + } + + m, err = RequestToMsg(req) + if err != nil { + t.Fatalf("Failure to get message from request: %s", err) + } + + if x := m.Question[0].Name; x != "example.org." { + t.Errorf("Qname expected %s, got %s", "example.org.", x) + } + if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { + t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) + } +} + +func TestGetRequest(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeDNSKEY) + + req, err := NewRequest(http.MethodGet, "https://example.org:443", m) + if err != nil { + t.Errorf("Failure to make request: %s", err) + } + + m, err = RequestToMsg(req) + if err != nil { + t.Fatalf("Failure to get message from request: %s", err) + } + + if x := m.Question[0].Name; x != "example.org." { + t.Errorf("Qname expected %s, got %s", "example.org.", x) + } + if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { + t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) + } +}