diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index 30775c598..434efa296 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -113,7 +113,6 @@ func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { cacheSize.WithLabelValues(Denial).Set(float64(w.ncache.Len())) } - setMsgTTL(res, uint32(duration.Seconds())) if w.prefetch { return nil } diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index adac7d67b..f364e69f1 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "golang.org/x/net/context" + "github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware/pkg/cache" "github.com/coredns/coredns/middleware/pkg/response" @@ -205,3 +207,45 @@ func TestCache(t *testing.T) { } } } + +func BenchmarkCacheResponse(b *testing.B) { + c := &Cache{Zones: []string{"."}, pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxTTL} + c.pcache = cache.New(c.pcap) + c.ncache = cache.New(c.ncap) + c.prefetch = 1 + c.duration = 1 * time.Second + c.Next = BackendHandler() + + ctx := context.TODO() + + reqs := make([]*dns.Msg, 5) + for i, q := range []string{"example1", "example2", "a", "b", "ddd"} { + reqs[i] = new(dns.Msg) + reqs[i].SetQuestion(q+".example.org.", dns.TypeA) + } + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + req := reqs[i] + c.ServeDNS(ctx, &test.ResponseWriter{}, req) + i++ + i = i % 5 + } + }) +} + +func BackendHandler() middleware.Handler { + return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response = true + m.RecursionAvailable = true + + owner := m.Question[0].Name + m.Answer = []dns.RR{test.A(owner + " 303 IN A 127.0.0.53")} + + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} diff --git a/middleware/cache/handler.go b/middleware/cache/handler.go index 520b23767..ce3df2f75 100644 --- a/middleware/cache/handler.go +++ b/middleware/cache/handler.go @@ -29,6 +29,7 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) i, ttl := c.get(now, qname, qtype, do) if i != nil && ttl > 0 { resp := i.toMsg(r) + state.SizeAndDo(resp) resp, _ = state.Scrub(resp) w.WriteMsg(resp) diff --git a/middleware/cache/item.go b/middleware/cache/item.go index 5084bcf1c..02571ac5c 100644 --- a/middleware/cache/item.go +++ b/middleware/cache/item.go @@ -63,12 +63,29 @@ func (i *item) toMsg(m *dns.Msg) *dns.Msg { m1.Rcode = i.Rcode m1.Compress = true - m1.Answer = i.Answer - m1.Ns = i.Ns - m1.Extra = i.Extra + m1.Answer = make([]dns.RR, len(i.Answer)) + m1.Ns = make([]dns.RR, len(i.Ns)) + m1.Extra = make([]dns.RR, len(i.Extra)) - ttl := int(i.origTTL) - int(time.Now().UTC().Sub(i.stored).Seconds()) - setMsgTTL(m1, uint32(ttl)) + ttl := uint32(i.ttl(time.Now())) + if ttl < minTTL { + ttl = minTTL + } + + for j, r := range i.Answer { + m1.Answer[j] = dns.Copy(r) + m1.Answer[j].Header().Ttl = ttl + } + for j, r := range i.Ns { + m1.Ns[j] = dns.Copy(r) + m1.Ns[j].Header().Ttl = ttl + } + for j, r := range i.Extra { + m1.Extra[j] = dns.Copy(r) + if m1.Extra[j].Header().Rrtype != dns.TypeOPT { + m1.Extra[j].Header().Ttl = ttl + } + } return m1 } @@ -77,27 +94,6 @@ func (i *item) ttl(now time.Time) int { return ttl } -// setMsgTTL sets the ttl on all RRs in all sections. If ttl is smaller than minTTL -// that value is used. -func setMsgTTL(m *dns.Msg, ttl uint32) { - if ttl < minTTL { - ttl = minTTL - } - - for _, r := range m.Answer { - r.Header().Ttl = ttl - } - for _, r := range m.Ns { - r.Header().Ttl = ttl - } - for _, r := range m.Extra { - if r.Header().Rrtype == dns.TypeOPT { - continue - } - r.Header().Ttl = ttl - } -} - func minMsgTTL(m *dns.Msg, mt response.Type) time.Duration { if mt != response.NoError && mt != response.NameError && mt != response.NoData { return 0