diff --git a/plugin/cache/cache_test.go b/plugin/cache/cache_test.go index 9417a59f6..d839ea1a3 100644 --- a/plugin/cache/cache_test.go +++ b/plugin/cache/cache_test.go @@ -191,7 +191,7 @@ func TestCache(t *testing.T) { c, crr := newTestCache(maxTTL) - for _, tc := range cacheTestCases { + for n, tc := range cacheTestCases { m := tc.in.Msg() m = cacheMsg(m, tc) @@ -204,11 +204,15 @@ func TestCache(t *testing.T) { crr.set(m, k, mt, c.pttl) } - i, _ := c.get(time.Now().UTC(), state, "dns://:53") + i := c.getIgnoreTTL(time.Now().UTC(), state, "dns://:53") ok := i != nil - if ok != tc.shouldCache { - t.Errorf("Cached message that should not have been cached: %s", state.Name()) + if !tc.shouldCache && ok { + t.Errorf("Test %d: Cached message that should not have been cached: %s", n, state.Name()) + continue + } + if tc.shouldCache && !ok { + t.Errorf("Test %d: Did not cache message that should have been cached: %s", n, state.Name()) continue } diff --git a/plugin/cache/handler.go b/plugin/cache/handler.go index b7adc3a9e..2b4c89350 100644 --- a/plugin/cache/handler.go +++ b/plugin/cache/handler.go @@ -89,38 +89,23 @@ func (c *Cache) shouldPrefetch(i *item, now time.Time) bool { // Name implements the Handler interface. func (c *Cache) Name() string { return "cache" } -func (c *Cache) get(now time.Time, state request.Request, server string) (*item, bool) { - k := hash(state.Name(), state.QType()) - cacheRequests.WithLabelValues(server, c.zonesMetricLabel).Inc() - - if i, ok := c.ncache.Get(k); ok && i.(*item).ttl(now) > 0 { - cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel).Inc() - return i.(*item), true - } - - if i, ok := c.pcache.Get(k); ok && i.(*item).ttl(now) > 0 { - cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel).Inc() - return i.(*item), true - } - cacheMisses.WithLabelValues(server, c.zonesMetricLabel).Inc() - return nil, false -} - // getIgnoreTTL unconditionally returns an item if it exists in the cache. func (c *Cache) getIgnoreTTL(now time.Time, state request.Request, server string) *item { k := hash(state.Name(), state.QType()) cacheRequests.WithLabelValues(server, c.zonesMetricLabel).Inc() if i, ok := c.ncache.Get(k); ok { - ttl := i.(*item).ttl(now) - if ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds())) { + itm := i.(*item) + ttl := itm.ttl(now) + if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel).Inc() return i.(*item) } } if i, ok := c.pcache.Get(k); ok { - ttl := i.(*item).ttl(now) - if ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds())) { + itm := i.(*item) + ttl := itm.ttl(now) + if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel).Inc() return i.(*item) } diff --git a/plugin/cache/item.go b/plugin/cache/item.go index 3b47a3b6b..56d188b36 100644 --- a/plugin/cache/item.go +++ b/plugin/cache/item.go @@ -1,14 +1,18 @@ package cache import ( + "strings" "time" "github.com/coredns/coredns/plugin/cache/freq" + "github.com/coredns/coredns/request" "github.com/miekg/dns" ) type item struct { + Name string + QType uint16 Rcode int AuthenticatedData bool RecursionAvailable bool @@ -24,6 +28,10 @@ type item struct { func newItem(m *dns.Msg, now time.Time, d time.Duration) *item { i := new(item) + if len(m.Question) != 0 { + i.Name = m.Question[0].Name + i.QType = m.Question[0].Qtype + } i.Rcode = m.Rcode i.AuthenticatedData = m.AuthenticatedData i.RecursionAvailable = m.RecursionAvailable @@ -87,3 +95,10 @@ func (i *item) ttl(now time.Time) int { ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds()) return ttl } + +func (i *item) matches(state request.Request) bool { + if state.QType() == i.QType && strings.EqualFold(state.QName(), i.Name) { + return true + } + return false +}