Fix max-age in http server (#1890)

* Fix max-age in http server

Move the minMsgTTL to dnsutil and rename it MinimalTTL, move some
constants there as well.
Use these new function in server_https to correctly set the max-age
HTTP header.

Fixes: #1823

* Linter
This commit is contained in:
Miek Gieben 2018-06-27 21:12:27 +01:00 committed by GitHub
parent 99287d091c
commit dae506b563
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 68 deletions

View file

@ -7,6 +7,10 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"time"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -129,8 +133,11 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
buf, _ := dw.Msg.Pack() buf, _ := dw.Msg.Pack()
mt, _ := response.Typify(dw.Msg, time.Now().UTC())
age := dnsutil.MinimalTTL(dw.Msg, mt)
w.Header().Set("Content-Type", mimeTypeDOH) w.Header().Set("Content-Type", mimeTypeDOH)
w.Header().Set("Cache-Control", "max-age=128") // TODO(issues/1823): implement proper fix. w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%f", age.Seconds()))
w.Header().Set("Content-Length", strconv.Itoa(len(buf))) w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)

View file

@ -9,6 +9,7 @@ import (
"github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/pkg/cache" "github.com/coredns/coredns/plugin/pkg/cache"
"github.com/coredns/coredns/plugin/pkg/dnsutil"
"github.com/coredns/coredns/plugin/pkg/response" "github.com/coredns/coredns/plugin/pkg/response"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -158,7 +159,7 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
duration = w.nttl duration = w.nttl
} }
msgTTL := minMsgTTL(res, mt) msgTTL := dnsutil.MinimalTTL(res, mt)
if msgTTL < duration { if msgTTL < duration {
duration = msgTTL duration = msgTTL
} }
@ -226,9 +227,8 @@ func (w *ResponseWriter) Write(buf []byte) (int, error) {
} }
const ( const (
maxTTL = 1 * time.Hour maxTTL = dnsutil.MaximumDefaulTTL
maxNTTL = 30 * time.Minute maxNTTL = dnsutil.MaximumDefaulTTL / 2
failSafeTTL = 5 * time.Second
defaultCap = 10000 // default capacity of the cache. defaultCap = 10000 // default capacity of the cache.

56
plugin/cache/item.go vendored
View file

@ -4,7 +4,6 @@ import (
"time" "time"
"github.com/coredns/coredns/plugin/cache/freq" "github.com/coredns/coredns/plugin/cache/freq"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -87,58 +86,3 @@ func (i *item) ttl(now time.Time) int {
ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds()) ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
return ttl return ttl
} }
func minMsgTTL(m *dns.Msg, mt response.Type) time.Duration {
if mt != response.NoError && mt != response.NameError && mt != response.NoData {
return 0
}
// No data to examine, return a short ttl as a fail safe.
if len(m.Answer)+len(m.Ns)+len(m.Extra) == 0 {
return failSafeTTL
}
minTTL := maxTTL
for _, r := range m.Answer {
switch mt {
case response.NameError, response.NoData:
if r.Header().Rrtype == dns.TypeSOA {
minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second
}
case response.NoError, response.Delegation:
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
}
for _, r := range m.Ns {
switch mt {
case response.NameError, response.NoData:
if r.Header().Rrtype == dns.TypeSOA {
minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second
}
case response.NoError, response.Delegation:
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
}
for _, r := range m.Extra {
if r.Header().Rrtype == dns.TypeOPT {
// OPT records use TTL field for extended rcode and flags
continue
}
switch mt {
case response.NameError, response.NoData:
if r.Header().Rrtype == dns.TypeSOA {
minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second
}
case response.NoError, response.Delegation:
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
}
return minTTL
}

72
plugin/pkg/dnsutil/ttl.go Normal file
View file

@ -0,0 +1,72 @@
package dnsutil
import (
"time"
"github.com/coredns/coredns/plugin/pkg/response"
"github.com/miekg/dns"
)
// MinimalTTL scans the message returns the lowest TTL found taking into the response.Type of the message.
func MinimalTTL(m *dns.Msg, mt response.Type) time.Duration {
if mt != response.NoError && mt != response.NameError && mt != response.NoData {
return MinimalDefaultTTL
}
// No data to examine, return a short ttl as a fail safe.
if len(m.Answer)+len(m.Ns)+len(m.Extra) == 0 {
return MinimalDefaultTTL
}
minTTL := MaximumDefaulTTL
for _, r := range m.Answer {
switch mt {
case response.NameError, response.NoData:
if r.Header().Rrtype == dns.TypeSOA {
minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second
}
case response.NoError, response.Delegation:
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
}
for _, r := range m.Ns {
switch mt {
case response.NameError, response.NoData:
if r.Header().Rrtype == dns.TypeSOA {
minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second
}
case response.NoError, response.Delegation:
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
}
for _, r := range m.Extra {
if r.Header().Rrtype == dns.TypeOPT {
// OPT records use TTL field for extended rcode and flags
continue
}
switch mt {
case response.NameError, response.NoData:
if r.Header().Rrtype == dns.TypeSOA {
minTTL = time.Duration(r.(*dns.SOA).Minttl) * time.Second
}
case response.NoError, response.Delegation:
if r.Header().Ttl < uint32(minTTL.Seconds()) {
minTTL = time.Duration(r.Header().Ttl) * time.Second
}
}
}
return minTTL
}
const (
// MinimalDefaultTTL is the absolute lowest TTL we use in CoreDNS.
MinimalDefaultTTL = 5 * time.Second
// MaximumDefaulTTL is the maximum TTL was use on RRsets in CoreDNS.
MaximumDefaulTTL = 1 * time.Hour
)

View file

@ -1,4 +1,4 @@
package cache package dnsutil
import ( import (
"testing" "testing"
@ -12,7 +12,7 @@ import (
// See https://github.com/kubernetes/dns/issues/121, add some specific tests for those use cases. // See https://github.com/kubernetes/dns/issues/121, add some specific tests for those use cases.
func TestMinMsgTTL(t *testing.T) { func TestMinimalTTL(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("z.alm.im.", dns.TypeA) m.SetQuestion("z.alm.im.", dns.TypeA)
m.Ns = []dns.RR{ m.Ns = []dns.RR{
@ -25,7 +25,7 @@ func TestMinMsgTTL(t *testing.T) {
if mt != response.NoData { if mt != response.NoData {
t.Fatalf("Expected type to be response.NoData, got %s", mt) t.Fatalf("Expected type to be response.NoData, got %s", mt)
} }
dur := minMsgTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) dur := MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA)
if dur != time.Duration(3600*time.Second) { if dur != time.Duration(3600*time.Second) {
t.Fatalf("Expected minttl duration to be %d, got %d", 3600, dur) t.Fatalf("Expected minttl duration to be %d, got %d", 3600, dur)
} }
@ -35,13 +35,13 @@ func TestMinMsgTTL(t *testing.T) {
if mt != response.NameError { if mt != response.NameError {
t.Fatalf("Expected type to be response.NameError, got %s", mt) t.Fatalf("Expected type to be response.NameError, got %s", mt)
} }
dur = minMsgTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA)
if dur != time.Duration(3600*time.Second) { if dur != time.Duration(3600*time.Second) {
t.Fatalf("Expected minttl duration to be %d, got %d", 3600, dur) t.Fatalf("Expected minttl duration to be %d, got %d", 3600, dur)
} }
} }
func BenchmarkMinMsgTTL(b *testing.B) { func BenchmarkMinimalTTL(b *testing.B) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("example.org.", dns.TypeA) m.SetQuestion("example.org.", dns.TypeA)
m.Ns = []dns.RR{ m.Ns = []dns.RR{
@ -64,9 +64,9 @@ func BenchmarkMinMsgTTL(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
dur := minMsgTTL(m, mt) dur := MinimalTTL(m, mt)
if dur != 1000*time.Second { if dur != 1000*time.Second {
b.Fatalf("Wrong minMsgTTL %d, expected %d", dur, 1000*time.Second) b.Fatalf("Wrong MinimalTTL %d, expected %d", dur, 1000*time.Second)
} }
} }
} }