make CoreDNS DoH Server (#1619)
* WIP: make CoreDNS DoH Server * It works * Fix tests * Review from Tom - on diff. PR * correct mime type * Cleanups and use the pkg/nonwriter * rename and updates * implement get * implement GET * Code review comments * correct context * tweaks * code review
This commit is contained in:
parent
67c9075331
commit
18b92e1117
7 changed files with 309 additions and 7 deletions
|
@ -36,6 +36,8 @@ func Transport(s string) string {
|
|||
return TransportDNS
|
||||
case strings.HasPrefix(s, TransportGRPC+"://"):
|
||||
return TransportGRPC
|
||||
case strings.HasPrefix(s, TransportHTTPS+"://"):
|
||||
return TransportHTTPS
|
||||
}
|
||||
return TransportDNS
|
||||
}
|
||||
|
@ -58,6 +60,9 @@ func normalizeZone(str string) (zoneAddr, error) {
|
|||
case strings.HasPrefix(str, TransportGRPC+"://"):
|
||||
trans = TransportGRPC
|
||||
str = str[len(TransportGRPC+"://"):]
|
||||
case strings.HasPrefix(str, TransportHTTPS+"://"):
|
||||
trans = TransportHTTPS
|
||||
str = str[len(TransportHTTPS+"://"):]
|
||||
}
|
||||
|
||||
host, port, ipnet, err := plugin.SplitHostPort(str)
|
||||
|
@ -75,6 +80,9 @@ func normalizeZone(str string) (zoneAddr, error) {
|
|||
if trans == TransportGRPC {
|
||||
port = GRPCPort
|
||||
}
|
||||
if trans == TransportHTTPS {
|
||||
port = HTTPSPort
|
||||
}
|
||||
}
|
||||
|
||||
return zoneAddr{Zone: dns.Fqdn(host), Port: port, Transport: trans, IPNet: ipnet}, nil
|
||||
|
@ -97,9 +105,10 @@ func SplitProtocolHostPort(address string) (protocol string, ip string, port str
|
|||
|
||||
// Supported transports.
|
||||
const (
|
||||
TransportDNS = "dns"
|
||||
TransportTLS = "tls"
|
||||
TransportGRPC = "grpc"
|
||||
TransportDNS = "dns"
|
||||
TransportTLS = "tls"
|
||||
TransportGRPC = "grpc"
|
||||
TransportHTTPS = "https"
|
||||
)
|
||||
|
||||
type zoneOverlap struct {
|
||||
|
|
56
core/dnsserver/https.go
Normal file
56
core/dnsserver/https.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// mimeTypeDOH is the DoH mimetype that should be used.
|
||||
const mimeTypeDOH = "application/dns-message"
|
||||
|
||||
// pathDOH is the URL path that should be used.
|
||||
const pathDOH = "/dns-query"
|
||||
|
||||
// postRequestToMsg extracts the dns message from the request body.
|
||||
func postRequestToMsg(req *http.Request) (*dns.Msg, error) {
|
||||
defer req.Body.Close()
|
||||
|
||||
buf, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
err = m.Unpack(buf)
|
||||
return m, err
|
||||
}
|
||||
|
||||
// getRequestToMsg extract the dns message from the GET request.
|
||||
func getRequestToMsg(req *http.Request) (*dns.Msg, error) {
|
||||
values := req.URL.Query()
|
||||
b64, ok := values["dns"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no 'dns' query parameter found")
|
||||
}
|
||||
if len(b64) != 1 {
|
||||
return nil, fmt.Errorf("multiple 'dns' query values found")
|
||||
}
|
||||
return base64ToMsg(b64[0])
|
||||
}
|
||||
|
||||
func base64ToMsg(b64 string) (*dns.Msg, error) {
|
||||
buf, err := b64Enc.DecodeString(b64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
err = m.Unpack(buf)
|
||||
|
||||
return m, err
|
||||
}
|
||||
|
||||
var b64Enc = base64.RawURLEncoding
|
66
core/dnsserver/https_test.go
Normal file
66
core/dnsserver/https_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestPostRequest(t *testing.T) {
|
||||
const ex = "example.org."
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(ex, dns.TypeDNSKEY)
|
||||
|
||||
out, _ := m.Pack()
|
||||
req, err := http.NewRequest(http.MethodPost, "https://"+ex+pathDOH+"?bla=foo:443", bytes.NewReader(out))
|
||||
if err != nil {
|
||||
t.Errorf("Failure to make request: %s", err)
|
||||
}
|
||||
req.Header.Set("content-type", mimeTypeDOH)
|
||||
req.Header.Set("accept", mimeTypeDOH)
|
||||
|
||||
m, err = postRequestToMsg(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failure to get message from request: %s", err)
|
||||
}
|
||||
|
||||
if x := m.Question[0].Name; x != ex {
|
||||
t.Errorf("Qname expected %s, got %s", ex, x)
|
||||
}
|
||||
if x := m.Question[0].Qtype; x != dns.TypeDNSKEY {
|
||||
t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequest(t *testing.T) {
|
||||
const ex = "example.org."
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(ex, dns.TypeDNSKEY)
|
||||
|
||||
out, _ := m.Pack()
|
||||
b64 := base64.RawURLEncoding.EncodeToString(out)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://"+ex+pathDOH+"?dns="+b64, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Failure to make request: %s", err)
|
||||
}
|
||||
req.Header.Set("content-type", mimeTypeDOH)
|
||||
req.Header.Set("accept", mimeTypeDOH)
|
||||
|
||||
m, err = getRequestToMsg(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failure to get message from request: %s", err)
|
||||
}
|
||||
|
||||
if x := m.Question[0].Name; x != ex {
|
||||
t.Errorf("Qname expected %s, got %s", ex, x)
|
||||
}
|
||||
if x := m.Question[0].Qtype; x != dns.TypeDNSKEY {
|
||||
t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY)
|
||||
}
|
||||
}
|
|
@ -133,6 +133,12 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
|
|||
}
|
||||
servers = append(servers, s)
|
||||
|
||||
case TransportHTTPS:
|
||||
s, err := NewServerHTTPS(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
servers = append(servers, s)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -235,6 +241,8 @@ const (
|
|||
TLSPort = "853"
|
||||
// GRPCPort is the default port for DNS-over-gRPC.
|
||||
GRPCPort = "443"
|
||||
// HTTPSPort is the default port for DNS-over-HTTPS.
|
||||
HTTPSPort = "443"
|
||||
)
|
||||
|
||||
// These "soft defaults" are configurable by
|
||||
|
|
149
core/dnsserver/server-https.go
Normal file
149
core/dnsserver/server-https.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package dnsserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/coredns/coredns/plugin/pkg/nonwriter"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// ServerHTTPS represents an instance of a DNS-over-HTTPS server.
|
||||
type ServerHTTPS struct {
|
||||
*Server
|
||||
httpsServer *http.Server
|
||||
listenAddr net.Addr
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// NewServerHTTPS returns a new CoreDNS GRPC server and compiles all plugins in to it.
|
||||
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
|
||||
s, err := NewServer(addr, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The *tls* plugin must make sure that multiple conflicting
|
||||
// TLS configuration return an error: it can only be specified once.
|
||||
var tlsConfig *tls.Config
|
||||
for _, conf := range s.zones {
|
||||
// Should we error if some configs *don't* have TLS?
|
||||
tlsConfig = conf.TLSConfig
|
||||
}
|
||||
|
||||
sh := &ServerHTTPS{Server: s, tlsConfig: tlsConfig, httpsServer: new(http.Server)}
|
||||
sh.httpsServer.Handler = sh
|
||||
|
||||
return sh, nil
|
||||
}
|
||||
|
||||
// Serve implements caddy.TCPServer interface.
|
||||
func (s *ServerHTTPS) Serve(l net.Listener) error {
|
||||
s.m.Lock()
|
||||
s.listenAddr = l.Addr()
|
||||
s.m.Unlock()
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
l = tls.NewListener(l, s.tlsConfig)
|
||||
}
|
||||
return s.httpsServer.Serve(l)
|
||||
}
|
||||
|
||||
// ServePacket implements caddy.UDPServer interface.
|
||||
func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil }
|
||||
|
||||
// Listen implements caddy.TCPServer interface.
|
||||
func (s *ServerHTTPS) Listen() (net.Listener, error) {
|
||||
|
||||
l, err := net.Listen("tcp", s.Addr[len(TransportHTTPS+"://"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// ListenPacket implements caddy.UDPServer interface.
|
||||
func (s *ServerHTTPS) ListenPacket() (net.PacketConn, error) { return nil, nil }
|
||||
|
||||
// OnStartupComplete lists the sites served by this server
|
||||
// and any relevant information, assuming Quiet is false.
|
||||
func (s *ServerHTTPS) OnStartupComplete() {
|
||||
if Quiet {
|
||||
return
|
||||
}
|
||||
|
||||
out := startUpZones(TransportHTTPS+"://", s.Addr, s.zones)
|
||||
if out != "" {
|
||||
fmt.Print(out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Stop stops the server. It blocks until the server is totally stopped.
|
||||
func (s *ServerHTTPS) Stop() error {
|
||||
s.m.Lock()
|
||||
defer s.m.Unlock()
|
||||
if s.httpsServer != nil {
|
||||
s.httpsServer.Shutdown(context.Background())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the plugin
|
||||
// chain, converts it back and write it to the client.
|
||||
func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
msg := new(dns.Msg)
|
||||
var err error
|
||||
|
||||
if r.URL.Path != pathDOH {
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
msg, err = postRequestToMsg(r)
|
||||
case http.MethodGet:
|
||||
msg, err = getRequestToMsg(r)
|
||||
default:
|
||||
http.Error(w, "", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a non-writer with the correct addresses in it.
|
||||
dw := &nonwriter.Writer{Laddr: s.listenAddr}
|
||||
h, p, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
po, _ := strconv.Atoi(p)
|
||||
ip := net.ParseIP(h)
|
||||
dw.Raddr = &net.TCPAddr{IP: ip, Port: po}
|
||||
|
||||
// We just call the normal chain handler - all error handling is done there.
|
||||
// We should expect a packet to be returned that we can send to the client.
|
||||
s.ServeDNS(context.Background(), dw, msg)
|
||||
|
||||
buf, _ := dw.Msg.Pack()
|
||||
|
||||
w.Header().Set("Content-Type", mimeTypeDOH)
|
||||
w.Header().Set("Cache-Control", "max-age=128") // TODO(issues/1823): implement proper fix.
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
w.Write(buf)
|
||||
}
|
||||
|
||||
// Shutdown stops the server (non gracefully).
|
||||
func (s *ServerHTTPS) Shutdown() error {
|
||||
if s.httpsServer != nil {
|
||||
s.httpsServer.Shutdown(context.Background())
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -71,6 +71,8 @@ func (h Host) Normalize() string {
|
|||
s = s[len(TransportDNS+"://"):]
|
||||
case strings.HasPrefix(s, TransportGRPC+"://"):
|
||||
s = s[len(TransportGRPC+"://"):]
|
||||
case strings.HasPrefix(s, TransportHTTPS+"://"):
|
||||
s = s[len(TransportHTTPS+"://"):]
|
||||
}
|
||||
|
||||
// The error can be ignore here, because this function is called after the corefile
|
||||
|
@ -138,7 +140,8 @@ func SplitHostPort(s string) (host, port string, ipnet *net.IPNet, err error) {
|
|||
|
||||
// Duplicated from core/dnsserver/address.go !
|
||||
const (
|
||||
TransportDNS = "dns"
|
||||
TransportTLS = "tls"
|
||||
TransportGRPC = "grpc"
|
||||
TransportDNS = "dns"
|
||||
TransportTLS = "tls"
|
||||
TransportGRPC = "grpc"
|
||||
TransportHTTPS = "https"
|
||||
)
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
package nonwriter
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -9,6 +11,11 @@ import (
|
|||
type Writer struct {
|
||||
dns.ResponseWriter
|
||||
Msg *dns.Msg
|
||||
|
||||
// Raddr is the remote's address. This can be optionally set.
|
||||
Raddr net.Addr
|
||||
// Laddr is our address. This can be optionally set.
|
||||
Laddr net.Addr
|
||||
}
|
||||
|
||||
// New makes and returns a new NonWriter.
|
||||
|
@ -20,4 +27,8 @@ func (w *Writer) WriteMsg(res *dns.Msg) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (w *Writer) Write(buf []byte) (int, error) { return len(buf), nil }
|
||||
// RemoteAddr returns the remote address.
|
||||
func (w *Writer) RemoteAddr() net.Addr { return w.Raddr }
|
||||
|
||||
// LocalAddr returns the local address.
|
||||
func (w *Writer) LocalAddr() net.Addr { return w.Laddr }
|
||||
|
|
Loading…
Add table
Reference in a new issue