diff --git a/plugin/cache/handler.go b/plugin/cache/handler.go index 2d608e8d3..4dc29167a 100644 --- a/plugin/cache/handler.go +++ b/plugin/cache/handler.go @@ -29,28 +29,10 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) i, found := c.get(now, state, server) if i != nil && found { resp := i.toMsg(r, now) - w.WriteMsg(resp) - if c.prefetch > 0 { - ttl := i.ttl(now) - i.Freq.Update(c.duration, now) - - threshold := int(math.Ceil(float64(c.percentage) / 100 * float64(i.origTTL))) - if i.Freq.Hits() >= c.prefetch && ttl <= threshold { - cw := newPrefetchResponseWriter(server, state, c) - go func(w dns.ResponseWriter) { - cachePrefetches.WithLabelValues(server).Inc() - plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r) - - // When prefetching we loose the item i, and with it the frequency - // that we've gathered sofar. See we copy the frequencies info back - // into the new item that was stored in the cache. - if i1 := c.exists(state); i1 != nil { - i1.Freq.Reset(now, i.Freq.Hits()) - } - }(cw) - } + if c.shouldPrefetch(i, now) { + go c.doPrefetch(ctx, state, server, i, now) } return dns.RcodeSuccess, nil } @@ -59,6 +41,29 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) return plugin.NextOrFailure(c.Name(), c.Next, ctx, crr, r) } +func (c *Cache) doPrefetch(ctx context.Context, state request.Request, server string, i *item, now time.Time) { + cw := newPrefetchResponseWriter(server, state, c) + + cachePrefetches.WithLabelValues(server).Inc() + plugin.NextOrFailure(c.Name(), c.Next, ctx, cw, state.Req) + + // When prefetching we loose the item i, and with it the frequency + // that we've gathered sofar. See we copy the frequencies info back + // into the new item that was stored in the cache. + if i1 := c.exists(state); i1 != nil { + i1.Freq.Reset(now, i.Freq.Hits()) + } +} + +func (c *Cache) shouldPrefetch(i *item, now time.Time) bool { + if c.prefetch <= 0 { + return false + } + i.Freq.Update(c.duration, now) + threshold := int(math.Ceil(float64(c.percentage) / 100 * float64(i.origTTL))) + return i.Freq.Hits() >= c.prefetch && i.ttl(now) <= threshold +} + // Name implements the Handler interface. func (c *Cache) Name() string { return "cache" }