diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index c9f0da0cd..cf5d08a45 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -7,6 +7,10 @@ import ( "net" "net/http" "strconv" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/plugin/pkg/response" "github.com/miekg/dns" ) @@ -129,8 +133,11 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { buf, _ := dw.Msg.Pack() + mt, _ := response.Typify(dw.Msg, time.Now().UTC()) + age := dnsutil.MinimalTTL(dw.Msg, mt) + w.Header().Set("Content-Type", mimeTypeDOH) - w.Header().Set("Cache-Control", "max-age=128") // TODO(issues/1823): implement proper fix. + w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%f", age.Seconds())) w.Header().Set("Content-Length", strconv.Itoa(len(buf))) w.WriteHeader(http.StatusOK) diff --git a/plugin/cache/cache.go b/plugin/cache/cache.go index ed39fee86..c46267658 100644 --- a/plugin/cache/cache.go +++ b/plugin/cache/cache.go @@ -9,6 +9,7 @@ import ( "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/pkg/cache" + "github.com/coredns/coredns/plugin/pkg/dnsutil" "github.com/coredns/coredns/plugin/pkg/response" "github.com/coredns/coredns/request" @@ -158,7 +159,7 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { duration = w.nttl } - msgTTL := minMsgTTL(res, mt) + msgTTL := dnsutil.MinimalTTL(res, mt) if msgTTL < duration { duration = msgTTL } @@ -226,9 +227,8 @@ func (w *ResponseWriter) Write(buf []byte) (int, error) { } const ( - maxTTL = 1 * time.Hour - maxNTTL = 30 * time.Minute - failSafeTTL = 5 * time.Second + maxTTL = dnsutil.MaximumDefaulTTL + maxNTTL = dnsutil.MaximumDefaulTTL / 2 defaultCap = 10000 // default capacity of the cache. diff --git a/plugin/cache/item.go b/plugin/cache/item.go index 5761cdf87..f4858c3b1 100644 --- a/plugin/cache/item.go +++ b/plugin/cache/item.go @@ -4,7 +4,6 @@ import ( "time" "github.com/coredns/coredns/plugin/cache/freq" - "github.com/coredns/coredns/plugin/pkg/response" "github.com/miekg/dns" ) @@ -87,58 +86,3 @@ func (i *item) ttl(now time.Time) int { ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds()) return ttl } - -func minMsgTTL(m *dns.Msg, mt response.Type) time.Duration { - if mt != response.NoError && mt != response.NameError && mt != response.NoData { - return 0 - } - - // No data to examine, return a short ttl as a fail safe. - if len(m.Answer)+len(m.Ns)+len(m.Extra) == 0 { - return failSafeTTL - } - - minTTL := maxTTL - for _, r := range m.Answer { - switch mt { - case response.NameError, response.NoData: - if r.Header().Rrtype == dns.TypeSOA { - minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second - } - case response.NoError, response.Delegation: - if r.Header().Ttl < uint32(minTTL.Seconds()) { - minTTL = time.Duration(r.Header().Ttl) * time.Second - } - } - } - for _, r := range m.Ns { - switch mt { - case response.NameError, response.NoData: - if r.Header().Rrtype == dns.TypeSOA { - minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second - } - case response.NoError, response.Delegation: - if r.Header().Ttl < uint32(minTTL.Seconds()) { - minTTL = time.Duration(r.Header().Ttl) * time.Second - } - } - } - - for _, r := range m.Extra { - if r.Header().Rrtype == dns.TypeOPT { - // OPT records use TTL field for extended rcode and flags - continue - } - switch mt { - case response.NameError, response.NoData: - if r.Header().Rrtype == dns.TypeSOA { - minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second - } - case response.NoError, response.Delegation: - if r.Header().Ttl < uint32(minTTL.Seconds()) { - minTTL = time.Duration(r.Header().Ttl) * time.Second - } - } - } - return minTTL -} diff --git a/plugin/pkg/dnsutil/ttl.go b/plugin/pkg/dnsutil/ttl.go new file mode 100644 index 000000000..e969fa8a6 --- /dev/null +++ b/plugin/pkg/dnsutil/ttl.go @@ -0,0 +1,72 @@ +package dnsutil + +import ( + "time" + + "github.com/coredns/coredns/plugin/pkg/response" + + "github.com/miekg/dns" +) + +// MinimalTTL scans the message returns the lowest TTL found taking into the response.Type of the message. +func MinimalTTL(m *dns.Msg, mt response.Type) time.Duration { + if mt != response.NoError && mt != response.NameError && mt != response.NoData { + return MinimalDefaultTTL + } + + // No data to examine, return a short ttl as a fail safe. + if len(m.Answer)+len(m.Ns)+len(m.Extra) == 0 { + return MinimalDefaultTTL + } + + minTTL := MaximumDefaulTTL + for _, r := range m.Answer { + switch mt { + case response.NameError, response.NoData: + if r.Header().Rrtype == dns.TypeSOA { + minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second + } + case response.NoError, response.Delegation: + if r.Header().Ttl < uint32(minTTL.Seconds()) { + minTTL = time.Duration(r.Header().Ttl) * time.Second + } + } + } + for _, r := range m.Ns { + switch mt { + case response.NameError, response.NoData: + if r.Header().Rrtype == dns.TypeSOA { + minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second + } + case response.NoError, response.Delegation: + if r.Header().Ttl < uint32(minTTL.Seconds()) { + minTTL = time.Duration(r.Header().Ttl) * time.Second + } + } + } + + for _, r := range m.Extra { + if r.Header().Rrtype == dns.TypeOPT { + // OPT records use TTL field for extended rcode and flags + continue + } + switch mt { + case response.NameError, response.NoData: + if r.Header().Rrtype == dns.TypeSOA { + minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second + } + case response.NoError, response.Delegation: + if r.Header().Ttl < uint32(minTTL.Seconds()) { + minTTL = time.Duration(r.Header().Ttl) * time.Second + } + } + } + return minTTL +} + +const ( + // MinimalDefaultTTL is the absolute lowest TTL we use in CoreDNS. + MinimalDefaultTTL = 5 * time.Second + // MaximumDefaulTTL is the maximum TTL was use on RRsets in CoreDNS. + MaximumDefaulTTL = 1 * time.Hour +) diff --git a/plugin/cache/minttl_test.go b/plugin/pkg/dnsutil/ttl_test.go similarity index 84% rename from plugin/cache/minttl_test.go rename to plugin/pkg/dnsutil/ttl_test.go index 376c638a1..ee11d06f3 100644 --- a/plugin/cache/minttl_test.go +++ b/plugin/pkg/dnsutil/ttl_test.go @@ -1,4 +1,4 @@ -package cache +package dnsutil import ( "testing" @@ -12,7 +12,7 @@ import ( // See https://github.com/kubernetes/dns/issues/121, add some specific tests for those use cases. -func TestMinMsgTTL(t *testing.T) { +func TestMinimalTTL(t *testing.T) { m := new(dns.Msg) m.SetQuestion("z.alm.im.", dns.TypeA) m.Ns = []dns.RR{ @@ -25,7 +25,7 @@ func TestMinMsgTTL(t *testing.T) { if mt != response.NoData { t.Fatalf("Expected type to be response.NoData, got %s", mt) } - dur := minMsgTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) + dur := MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) if dur != time.Duration(3600*time.Second) { t.Fatalf("Expected minttl duration to be %d, got %d", 3600, dur) } @@ -35,13 +35,13 @@ func TestMinMsgTTL(t *testing.T) { if mt != response.NameError { t.Fatalf("Expected type to be response.NameError, got %s", mt) } - dur = minMsgTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) + dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) if dur != time.Duration(3600*time.Second) { t.Fatalf("Expected minttl duration to be %d, got %d", 3600, dur) } } -func BenchmarkMinMsgTTL(b *testing.B) { +func BenchmarkMinimalTTL(b *testing.B) { m := new(dns.Msg) m.SetQuestion("example.org.", dns.TypeA) m.Ns = []dns.RR{ @@ -64,9 +64,9 @@ func BenchmarkMinMsgTTL(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - dur := minMsgTTL(m, mt) + dur := MinimalTTL(m, mt) if dur != 1000*time.Second { - b.Fatalf("Wrong minMsgTTL %d, expected %d", dur, 1000*time.Second) + b.Fatalf("Wrong MinimalTTL %d, expected %d", dur, 1000*time.Second) } } }