diff --git a/core/dnsserver/address.go b/core/dnsserver/address.go index 39d656eff..8f544e97a 100644 --- a/core/dnsserver/address.go +++ b/core/dnsserver/address.go @@ -36,6 +36,8 @@ func Transport(s string) string { return TransportDNS case strings.HasPrefix(s, TransportGRPC+"://"): return TransportGRPC + case strings.HasPrefix(s, TransportHTTPS+"://"): + return TransportHTTPS } return TransportDNS } @@ -58,6 +60,9 @@ func normalizeZone(str string) (zoneAddr, error) { case strings.HasPrefix(str, TransportGRPC+"://"): trans = TransportGRPC str = str[len(TransportGRPC+"://"):] + case strings.HasPrefix(str, TransportHTTPS+"://"): + trans = TransportHTTPS + str = str[len(TransportHTTPS+"://"):] } host, port, ipnet, err := plugin.SplitHostPort(str) @@ -75,6 +80,9 @@ func normalizeZone(str string) (zoneAddr, error) { if trans == TransportGRPC { port = GRPCPort } + if trans == TransportHTTPS { + port = HTTPSPort + } } return zoneAddr{Zone: dns.Fqdn(host), Port: port, Transport: trans, IPNet: ipnet}, nil @@ -97,9 +105,10 @@ func SplitProtocolHostPort(address string) (protocol string, ip string, port str // Supported transports. const ( - TransportDNS = "dns" - TransportTLS = "tls" - TransportGRPC = "grpc" + TransportDNS = "dns" + TransportTLS = "tls" + TransportGRPC = "grpc" + TransportHTTPS = "https" ) type zoneOverlap struct { diff --git a/core/dnsserver/https.go b/core/dnsserver/https.go new file mode 100644 index 000000000..028b74709 --- /dev/null +++ b/core/dnsserver/https.go @@ -0,0 +1,56 @@ +package dnsserver + +import ( + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + + "github.com/miekg/dns" +) + +// mimeTypeDOH is the DoH mimetype that should be used. +const mimeTypeDOH = "application/dns-message" + +// pathDOH is the URL path that should be used. +const pathDOH = "/dns-query" + +// postRequestToMsg extracts the dns message from the request body. +func postRequestToMsg(req *http.Request) (*dns.Msg, error) { + defer req.Body.Close() + + buf, err := ioutil.ReadAll(req.Body) + if err != nil { + return nil, err + } + m := new(dns.Msg) + err = m.Unpack(buf) + return m, err +} + +// getRequestToMsg extract the dns message from the GET request. +func getRequestToMsg(req *http.Request) (*dns.Msg, error) { + values := req.URL.Query() + b64, ok := values["dns"] + if !ok { + return nil, fmt.Errorf("no 'dns' query parameter found") + } + if len(b64) != 1 { + return nil, fmt.Errorf("multiple 'dns' query values found") + } + return base64ToMsg(b64[0]) +} + +func base64ToMsg(b64 string) (*dns.Msg, error) { + buf, err := b64Enc.DecodeString(b64) + if err != nil { + return nil, err + } + + m := new(dns.Msg) + err = m.Unpack(buf) + + return m, err +} + +var b64Enc = base64.RawURLEncoding diff --git a/core/dnsserver/https_test.go b/core/dnsserver/https_test.go new file mode 100644 index 000000000..a0ddc4b25 --- /dev/null +++ b/core/dnsserver/https_test.go @@ -0,0 +1,66 @@ +package dnsserver + +import ( + "bytes" + "encoding/base64" + "net/http" + "testing" + + "github.com/miekg/dns" +) + +func TestPostRequest(t *testing.T) { + const ex = "example.org." + + m := new(dns.Msg) + m.SetQuestion(ex, dns.TypeDNSKEY) + + out, _ := m.Pack() + req, err := http.NewRequest(http.MethodPost, "https://"+ex+pathDOH+"?bla=foo:443", bytes.NewReader(out)) + if err != nil { + t.Errorf("Failure to make request: %s", err) + } + req.Header.Set("content-type", mimeTypeDOH) + req.Header.Set("accept", mimeTypeDOH) + + m, err = postRequestToMsg(req) + if err != nil { + t.Fatalf("Failure to get message from request: %s", err) + } + + if x := m.Question[0].Name; x != ex { + t.Errorf("Qname expected %s, got %s", ex, x) + } + if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { + t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) + } +} + +func TestGetRequest(t *testing.T) { + const ex = "example.org." + + m := new(dns.Msg) + m.SetQuestion(ex, dns.TypeDNSKEY) + + out, _ := m.Pack() + b64 := base64.RawURLEncoding.EncodeToString(out) + + req, err := http.NewRequest(http.MethodGet, "https://"+ex+pathDOH+"?dns="+b64, nil) + if err != nil { + t.Errorf("Failure to make request: %s", err) + } + req.Header.Set("content-type", mimeTypeDOH) + req.Header.Set("accept", mimeTypeDOH) + + m, err = getRequestToMsg(req) + if err != nil { + t.Fatalf("Failure to get message from request: %s", err) + } + + if x := m.Question[0].Name; x != ex { + t.Errorf("Qname expected %s, got %s", ex, x) + } + if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { + t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) + } +} diff --git a/core/dnsserver/register.go b/core/dnsserver/register.go index 779dc6b5f..ced2519af 100644 --- a/core/dnsserver/register.go +++ b/core/dnsserver/register.go @@ -133,6 +133,12 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) { } servers = append(servers, s) + case TransportHTTPS: + s, err := NewServerHTTPS(addr, group) + if err != nil { + return nil, err + } + servers = append(servers, s) } } @@ -235,6 +241,8 @@ const ( TLSPort = "853" // GRPCPort is the default port for DNS-over-gRPC. GRPCPort = "443" + // HTTPSPort is the default port for DNS-over-HTTPS. + HTTPSPort = "443" ) // These "soft defaults" are configurable by diff --git a/core/dnsserver/server-https.go b/core/dnsserver/server-https.go new file mode 100644 index 000000000..f460f0ff4 --- /dev/null +++ b/core/dnsserver/server-https.go @@ -0,0 +1,149 @@ +package dnsserver + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "strconv" + + "github.com/coredns/coredns/plugin/pkg/nonwriter" + "github.com/miekg/dns" +) + +// ServerHTTPS represents an instance of a DNS-over-HTTPS server. +type ServerHTTPS struct { + *Server + httpsServer *http.Server + listenAddr net.Addr + tlsConfig *tls.Config +} + +// NewServerHTTPS returns a new CoreDNS GRPC server and compiles all plugins in to it. +func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { + s, err := NewServer(addr, group) + if err != nil { + return nil, err + } + // The *tls* plugin must make sure that multiple conflicting + // TLS configuration return an error: it can only be specified once. + var tlsConfig *tls.Config + for _, conf := range s.zones { + // Should we error if some configs *don't* have TLS? + tlsConfig = conf.TLSConfig + } + + sh := &ServerHTTPS{Server: s, tlsConfig: tlsConfig, httpsServer: new(http.Server)} + sh.httpsServer.Handler = sh + + return sh, nil +} + +// Serve implements caddy.TCPServer interface. +func (s *ServerHTTPS) Serve(l net.Listener) error { + s.m.Lock() + s.listenAddr = l.Addr() + s.m.Unlock() + + if s.tlsConfig != nil { + l = tls.NewListener(l, s.tlsConfig) + } + return s.httpsServer.Serve(l) +} + +// ServePacket implements caddy.UDPServer interface. +func (s *ServerHTTPS) ServePacket(p net.PacketConn) error { return nil } + +// Listen implements caddy.TCPServer interface. +func (s *ServerHTTPS) Listen() (net.Listener, error) { + + l, err := net.Listen("tcp", s.Addr[len(TransportHTTPS+"://"):]) + if err != nil { + return nil, err + } + return l, nil +} + +// ListenPacket implements caddy.UDPServer interface. +func (s *ServerHTTPS) ListenPacket() (net.PacketConn, error) { return nil, nil } + +// OnStartupComplete lists the sites served by this server +// and any relevant information, assuming Quiet is false. +func (s *ServerHTTPS) OnStartupComplete() { + if Quiet { + return + } + + out := startUpZones(TransportHTTPS+"://", s.Addr, s.zones) + if out != "" { + fmt.Print(out) + } + return +} + +// Stop stops the server. It blocks until the server is totally stopped. +func (s *ServerHTTPS) Stop() error { + s.m.Lock() + defer s.m.Unlock() + if s.httpsServer != nil { + s.httpsServer.Shutdown(context.Background()) + } + return nil +} + +// ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the plugin +// chain, converts it back and write it to the client. +func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { + + msg := new(dns.Msg) + var err error + + if r.URL.Path != pathDOH { + http.Error(w, "", http.StatusNotFound) + return + } + + switch r.Method { + case http.MethodPost: + msg, err = postRequestToMsg(r) + case http.MethodGet: + msg, err = getRequestToMsg(r) + default: + http.Error(w, "", http.StatusMethodNotAllowed) + return + } + + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Create a non-writer with the correct addresses in it. + dw := &nonwriter.Writer{Laddr: s.listenAddr} + h, p, _ := net.SplitHostPort(r.RemoteAddr) + po, _ := strconv.Atoi(p) + ip := net.ParseIP(h) + dw.Raddr = &net.TCPAddr{IP: ip, Port: po} + + // We just call the normal chain handler - all error handling is done there. + // We should expect a packet to be returned that we can send to the client. + s.ServeDNS(context.Background(), dw, msg) + + buf, _ := dw.Msg.Pack() + + w.Header().Set("Content-Type", mimeTypeDOH) + w.Header().Set("Cache-Control", "max-age=128") // TODO(issues/1823): implement proper fix. + w.Header().Set("Content-Length", strconv.Itoa(len(buf))) + w.WriteHeader(http.StatusOK) + + w.Write(buf) +} + +// Shutdown stops the server (non gracefully). +func (s *ServerHTTPS) Shutdown() error { + if s.httpsServer != nil { + s.httpsServer.Shutdown(context.Background()) + } + return nil +} diff --git a/plugin/normalize.go b/plugin/normalize.go index ef6f2eaa0..fbbe5c826 100644 --- a/plugin/normalize.go +++ b/plugin/normalize.go @@ -71,6 +71,8 @@ func (h Host) Normalize() string { s = s[len(TransportDNS+"://"):] case strings.HasPrefix(s, TransportGRPC+"://"): s = s[len(TransportGRPC+"://"):] + case strings.HasPrefix(s, TransportHTTPS+"://"): + s = s[len(TransportHTTPS+"://"):] } // The error can be ignore here, because this function is called after the corefile @@ -138,7 +140,8 @@ func SplitHostPort(s string) (host, port string, ipnet *net.IPNet, err error) { // Duplicated from core/dnsserver/address.go ! const ( - TransportDNS = "dns" - TransportTLS = "tls" - TransportGRPC = "grpc" + TransportDNS = "dns" + TransportTLS = "tls" + TransportGRPC = "grpc" + TransportHTTPS = "https" ) diff --git a/plugin/pkg/nonwriter/nonwriter.go b/plugin/pkg/nonwriter/nonwriter.go index 7819a320f..b157e4242 100644 --- a/plugin/pkg/nonwriter/nonwriter.go +++ b/plugin/pkg/nonwriter/nonwriter.go @@ -2,6 +2,8 @@ package nonwriter import ( + "net" + "github.com/miekg/dns" ) @@ -9,6 +11,11 @@ import ( type Writer struct { dns.ResponseWriter Msg *dns.Msg + + // Raddr is the remote's address. This can be optionally set. + Raddr net.Addr + // Laddr is our address. This can be optionally set. + Laddr net.Addr } // New makes and returns a new NonWriter. @@ -20,4 +27,8 @@ func (w *Writer) WriteMsg(res *dns.Msg) error { return nil } -func (w *Writer) Write(buf []byte) (int, error) { return len(buf), nil } +// RemoteAddr returns the remote address. +func (w *Writer) RemoteAddr() net.Addr { return w.Raddr } + +// LocalAddr returns the local address. +func (w *Writer) LocalAddr() net.Addr { return w.Laddr }