diff --git a/core/dnsserver/config.go b/core/dnsserver/config.go index 6a720ef13..4ff2ecda1 100644 --- a/core/dnsserver/config.go +++ b/core/dnsserver/config.go @@ -3,6 +3,7 @@ package dnsserver import ( "crypto/tls" "fmt" + "net/http" "github.com/coredns/caddy" "github.com/coredns/coredns/plugin" @@ -31,6 +32,11 @@ type Config struct { // DNS-over-TLS or DNS-over-gRPC. Transport string + // If this function is not nil it will be used to inspect and validate + // HTTP requests. Although this isn't referenced in-tree, external plugins + // may depend on it. + HTTPRequestValidateFunc func(*http.Request) bool + // If this function is not nil it will be used to further filter access // to this handler. The primary use is to limit access to a reverse zone // on a non-octet boundary, i.e. /17 diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index 27757861c..057dac49c 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -20,12 +20,13 @@ import ( // ServerHTTPS represents an instance of a DNS-over-HTTPS server. type ServerHTTPS struct { *Server - httpsServer *http.Server - listenAddr net.Addr - tlsConfig *tls.Config + httpsServer *http.Server + listenAddr net.Addr + tlsConfig *tls.Config + validRequest func(*http.Request) bool } -// NewServerHTTPS returns a new CoreDNS GRPC server and compiles all plugins in to it. +// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it. func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { s, err := NewServer(addr, group) if err != nil { @@ -45,12 +46,23 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { // or the upgrade won't happen. tlsConfig.NextProtos = []string{"h2", "http/1.1"} + // Use a custom request validation func or use the standard DoH path check. + var validator func(*http.Request) bool + for _, conf := range s.zones { + validator = conf.HTTPRequestValidateFunc + } + if validator == nil { + validator = func(r *http.Request) bool { return r.URL.Path == doh.Path } + } + srv := &http.Server{ ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, IdleTimeout: 120 * time.Second, } - sh := &ServerHTTPS{Server: s, tlsConfig: tlsConfig, httpsServer: srv} + sh := &ServerHTTPS{ + Server: s, tlsConfig: tlsConfig, httpsServer: srv, validRequest: validator, + } sh.httpsServer.Handler = sh return sh, nil @@ -114,7 +126,7 @@ func (s *ServerHTTPS) Stop() error { // chain, converts it back and write it to the client. func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != doh.Path { + if !s.validRequest(r) { http.Error(w, "", http.StatusNotFound) return } diff --git a/core/dnsserver/server_https_test.go b/core/dnsserver/server_https_test.go new file mode 100644 index 000000000..6663c1075 --- /dev/null +++ b/core/dnsserver/server_https_test.go @@ -0,0 +1,66 @@ +package dnsserver + +import ( + "bytes" + "crypto/tls" + "net/http" + "net/http/httptest" + "regexp" + "testing" + + "github.com/miekg/dns" +) + +var ( + validPath = regexp.MustCompile("^/(dns-query|(?P[0-9a-f]+))$") + validator = func(r *http.Request) bool { return validPath.MatchString(r.URL.Path) } +) + +func testServerHTTPS(t *testing.T, path string, validator func(*http.Request) bool) *http.Response { + c := Config{ + Zone: "example.com.", + Transport: "https", + TLSConfig: &tls.Config{}, + ListenHosts: []string{"127.0.0.1"}, + Port: "443", + HTTPRequestValidateFunc: validator, + } + s, err := NewServerHTTPS("127.0.0.1:443", []*Config{&c}) + if err != nil { + t.Log(err) + t.Fatal("could not create HTTPS server") + } + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeDNSKEY) + buf, err := m.Pack() + if err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(buf)) + w := httptest.NewRecorder() + s.ServeHTTP(w, r) + + return w.Result() +} + +func TestCustomHTTPRequestValidator(t *testing.T) { + testCases := map[string]struct { + path string + expected int + validator func(*http.Request) bool + }{ + "default": {"/dns-query", http.StatusOK, nil}, + "custom validator": {"/b10cada", http.StatusOK, validator}, + "no validator set": {"/adb10c", http.StatusNotFound, nil}, + "invalid path with validator": {"/helloworld", http.StatusNotFound, validator}, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + res := testServerHTTPS(t, tc.path, tc.validator) + if res.StatusCode != tc.expected { + t.Error("unexpected HTTP code", res.StatusCode) + } + }) + } +}