diff --git a/middleware/cache/README.md b/middleware/cache/README.md index 145ce2fdf..6477fe891 100644 --- a/middleware/cache/README.md +++ b/middleware/cache/README.md @@ -10,13 +10,12 @@ cache [TTL] [ZONES...] * **TTL** max TTL in seconds. If not specified, the maximum TTL will be used which is 3600 for noerror responses and 1800 for denial of existence ones. - A set TTL of 300 *cache 300* would cache the record up to 300 seconds. - Smaller record provided TTLs will take precedence. + Setting a TTL of 300 *cache 300* would cache the record up to 300 seconds. * **ZONES** zones it should cache for. If empty, the zones from the configuration block are used. Each element in the cache is cached according to its TTL (with **TTL** as the max). For the negative cache, the SOA's MinTTL value is used. A cache can contain up to 10,000 items by -default. A TTL of zero is not allowed. No cache invalidation triggered by other middlewares is available. Therefore even reloaded items might still be cached for the duration of the TTL. +default. A TTL of zero is not allowed. If you want more control: @@ -24,16 +23,21 @@ If you want more control: cache [TTL] [ZONES...] { success CAPACITY [TTL] denial CAPACITY [TTL] + prefetch AMOUNT [[DURATION] [PERCENTAGE%]] } ~~~ * **TTL** and **ZONES** as above. * `success`, override the settings for caching successful responses, **CAPACITY** indicates the maximum - number of packets we cache before we start evicting (LRU). **TTL** overrides the cache maximum TTL. + number of packets we cache before we start evicting (*randomly*). **TTL** overrides the cache maximum TTL. * `denial`, override the settings for caching denial of existence responses, **CAPACITY** indicates the maximum number of packets we cache before we start evicting (LRU). **TTL** overrides the cache maximum TTL. - -There is a third category (`error`) but those responses are never cached. + There is a third category (`error`) but those responses are never cached. +* `prefetch`, will prefetch popular items when they are about to be expunged from the cache. + Popular means **AMOUNT** queries have been seen no gaps of **DURATION** or more between them. + **DURATION** defaults to 1m. Prefetching will happen when the TTL drops below **PERCENTAGE**, + which defaults to `10%`. Values should be in the range `[10%, 90%]`. Note the percent sign is + mandatory. **PERCENTAGE** is treated as an `int`. The minimum TTL allowed on resource records is 5 seconds. diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index e2a669723..30775c598 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -2,15 +2,15 @@ package cache import ( + "encoding/binary" + "hash/fnv" "log" - "strconv" - "strings" "time" "github.com/coredns/coredns/middleware" + "github.com/coredns/coredns/middleware/pkg/cache" "github.com/coredns/coredns/middleware/pkg/response" - "github.com/hashicorp/golang-lru" "github.com/miekg/dns" ) @@ -20,48 +20,73 @@ type Cache struct { Next middleware.Handler Zones []string - ncache *lru.Cache + ncache *cache.Cache ncap int nttl time.Duration - pcache *lru.Cache + pcache *cache.Cache pcap int pttl time.Duration + + // Prefetch. + prefetch int + duration time.Duration + percentage int } -// Return key under which we store the item. The empty string is returned -// when we don't want to cache the message. Currently we do not cache Truncated, errors -// zone transfers or dynamic update messages. -func key(m *dns.Msg, t response.Type, do bool) string { +// Return key under which we store the item, -1 will be returned if we don't store the +// message. +// Currently we do not cache Truncated, errors zone transfers or dynamic update messages. +func key(m *dns.Msg, t response.Type, do bool) int { // We don't store truncated responses. if m.Truncated { - return "" + return -1 } // Nor errors or Meta or Update if t == response.OtherError || t == response.Meta || t == response.Update { - return "" + return -1 } - qtype := m.Question[0].Qtype - qname := strings.ToLower(m.Question[0].Name) - return rawKey(qname, qtype, do) + return int(hash(m.Question[0].Name, m.Question[0].Qtype, do)) } -func rawKey(qname string, qtype uint16, do bool) string { +var one = []byte("1") +var zero = []byte("0") + +func hash(qname string, qtype uint16, do bool) uint32 { + h := fnv.New32() + if do { - return "1" + qname + "." + strconv.Itoa(int(qtype)) + h.Write(one) + } else { + h.Write(zero) } - return "0" + qname + "." + strconv.Itoa(int(qtype)) + + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, qtype) + h.Write(b) + + for i := range qname { + c := qname[i] + if c >= 'A' && c <= 'Z' { + c += 'a' - 'A' + } + h.Write([]byte{c}) + } + + return h.Sum32() } // ResponseWriter is a response writer that caches the reply message. type ResponseWriter struct { dns.ResponseWriter *Cache + + prefetch bool // When true write nothing back to the client. } // WriteMsg implements the dns.ResponseWriter interface. -func (c *ResponseWriter) WriteMsg(res *dns.Msg) error { +func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { do := false mt, opt := response.Typify(res, time.Now().UTC()) if opt != nil { @@ -71,9 +96,9 @@ func (c *ResponseWriter) WriteMsg(res *dns.Msg) error { // key returns empty string for anything we don't want to cache. key := key(res, mt, do) - duration := c.pttl + duration := w.pttl if mt == response.NameError || mt == response.NoData { - duration = c.nttl + duration = w.nttl } msgTTL := minMsgTTL(res, mt) @@ -81,20 +106,23 @@ func (c *ResponseWriter) WriteMsg(res *dns.Msg) error { duration = msgTTL } - if key != "" { - c.set(res, key, mt, duration) + if key != -1 { + w.set(res, key, mt, duration) - cacheSize.WithLabelValues(Success).Set(float64(c.pcache.Len())) - cacheSize.WithLabelValues(Denial).Set(float64(c.ncache.Len())) + cacheSize.WithLabelValues(Success).Set(float64(w.pcache.Len())) + cacheSize.WithLabelValues(Denial).Set(float64(w.ncache.Len())) } setMsgTTL(res, uint32(duration.Seconds())) + if w.prefetch { + return nil + } - return c.ResponseWriter.WriteMsg(res) + return w.ResponseWriter.WriteMsg(res) } -func (c *ResponseWriter) set(m *dns.Msg, key string, mt response.Type, duration time.Duration) { - if key == "" { +func (w *ResponseWriter) set(m *dns.Msg, key int, mt response.Type, duration time.Duration) { + if key == -1 { log.Printf("[ERROR] Caching called with empty cache key") return } @@ -102,11 +130,11 @@ func (c *ResponseWriter) set(m *dns.Msg, key string, mt response.Type, duration switch mt { case response.NoError, response.Delegation: i := newItem(m, duration) - c.pcache.Add(key, i) + w.pcache.Add(uint32(key), i) case response.NameError, response.NoData: i := newItem(m, duration) - c.ncache.Add(key, i) + w.ncache.Add(uint32(key), i) case response.OtherError: // don't cache these @@ -116,9 +144,12 @@ func (c *ResponseWriter) set(m *dns.Msg, key string, mt response.Type, duration } // Write implements the dns.ResponseWriter interface. -func (c *ResponseWriter) Write(buf []byte) (int, error) { +func (w *ResponseWriter) Write(buf []byte) (int, error) { log.Printf("[WARNING] Caching called with Write: not caching reply") - n, err := c.ResponseWriter.Write(buf) + if w.prefetch { + return 0, nil + } + n, err := w.ResponseWriter.Write(buf) return n, err } diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 18aa05fe5..adac7d67b 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -7,10 +7,10 @@ import ( "time" "github.com/coredns/coredns/middleware" + "github.com/coredns/coredns/middleware/pkg/cache" "github.com/coredns/coredns/middleware/pkg/response" "github.com/coredns/coredns/middleware/test" - lru "github.com/hashicorp/golang-lru" "github.com/miekg/dns" ) @@ -148,10 +148,10 @@ func cacheMsg(m *dns.Msg, tc cacheTestCase) *dns.Msg { func newTestCache(ttl time.Duration) (*Cache, *ResponseWriter) { c := &Cache{Zones: []string{"."}, pcap: defaultCap, ncap: defaultCap, pttl: ttl, nttl: ttl} - c.pcache, _ = lru.New(c.pcap) - c.ncache, _ = lru.New(c.ncap) + c.pcache = cache.New(c.pcap) + c.ncache = cache.New(c.ncap) - crr := &ResponseWriter{nil, c} + crr := &ResponseWriter{ResponseWriter: nil, Cache: c} return c, crr } @@ -176,7 +176,8 @@ func TestCache(t *testing.T) { name := middleware.Name(m.Question[0].Name).Normalize() qtype := m.Question[0].Qtype - i, ok, _ := c.get(name, qtype, do) + i, _ := c.get(time.Now().UTC(), name, qtype, do) + ok := i != nil if ok != tc.shouldCache { t.Errorf("cached message that should not have been cached: %s", name) diff --git a/middleware/cache/freq/freq.go b/middleware/cache/freq/freq.go new file mode 100644 index 000000000..528eedc1c --- /dev/null +++ b/middleware/cache/freq/freq.go @@ -0,0 +1,54 @@ +// Package freq keeps track of last X seen events. The events themselves are not stored +// here. So the Freq type should be added next to the thing it is tracking. +package freq + +import ( + "sync" + "time" +) + +type Freq struct { + // Last time we saw a query for this element. + last time.Time + // Number of this in the last time slice. + hits int + + sync.RWMutex +} + +// New returns a new initialized Freq. +func New(t time.Time) *Freq { + return &Freq{last: t, hits: 0} +} + +// Updates updates the number of hits. Last time seen will be set to now. +// If the last time we've seen this entity is within now - d, we increment hits, otherwise +// we reset hits to 1. It returns the number of hits. +func (f *Freq) Update(d time.Duration, now time.Time) int { + earliest := now.Add(-1 * d) + f.Lock() + defer f.Unlock() + if f.last.Before(earliest) { + f.last = now + f.hits = 1 + return f.hits + } + f.last = now + f.hits++ + return f.hits +} + +// Hits returns the number of hits that we have seen, according to the updates we have done to f. +func (f *Freq) Hits() int { + f.RLock() + defer f.RUnlock() + return f.hits +} + +// Reset resets f to time t and hits to hits. +func (f *Freq) Reset(t time.Time, hits int) { + f.Lock() + defer f.Unlock() + f.last = t + f.hits = hits +} diff --git a/middleware/cache/freq/freq_test.go b/middleware/cache/freq/freq_test.go new file mode 100644 index 000000000..740194c86 --- /dev/null +++ b/middleware/cache/freq/freq_test.go @@ -0,0 +1,36 @@ +package freq + +import ( + "testing" + "time" +) + +func TestFreqUpdate(t *testing.T) { + now := time.Now().UTC() + f := New(now) + window := 1 * time.Minute + + f.Update(window, time.Now().UTC()) + f.Update(window, time.Now().UTC()) + f.Update(window, time.Now().UTC()) + hitsCheck(t, f, 3) + + f.Reset(now, 0) + history := time.Now().UTC().Add(-3 * time.Minute) + f.Update(window, history) + hitsCheck(t, f, 1) +} + +func TestReset(t *testing.T) { + f := New(time.Now().UTC()) + f.Update(1*time.Minute, time.Now().UTC()) + hitsCheck(t, f, 1) + f.Reset(time.Now().UTC(), 0) + hitsCheck(t, f, 0) +} + +func hitsCheck(t *testing.T, f *Freq, expected int) { + if x := f.Hits(); x != expected { + t.Fatalf("Expected hits to be %d, got %d", expected, x) + } +} diff --git a/middleware/cache/handler.go b/middleware/cache/handler.go index 195322e31..520b23767 100644 --- a/middleware/cache/handler.go +++ b/middleware/cache/handler.go @@ -24,36 +24,58 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) do := state.Do() // TODO(): might need more from OPT record? Like the actual bufsize? - if i, ok, expired := c.get(qname, qtype, do); ok && !expired { + now := time.Now().UTC() + + i, ttl := c.get(now, qname, qtype, do) + if i != nil && ttl > 0 { resp := i.toMsg(r) state.SizeAndDo(resp) resp, _ = state.Scrub(resp) w.WriteMsg(resp) + i.Freq.Update(c.duration, now) + + pct := 100 + if i.origTTL != 0 { // you'll never know + pct = int(float64(ttl) / float64(i.origTTL) * 100) + } + + if c.prefetch > 0 && i.Freq.Hits() > c.prefetch && pct < c.percentage { + // When prefetching we loose the item i, and with it the frequency + // that we've gathered sofar. See we copy the frequence info back + // into the new item that was stored in the cache. + prr := &ResponseWriter{ResponseWriter: w, Cache: c, prefetch: true} + middleware.NextOrFailure(c.Name(), c.Next, ctx, prr, r) + + if i1, _ := c.get(now, qname, qtype, do); i1 != nil { + i1.Freq.Reset(now, i.Freq.Hits()) + } + } + return dns.RcodeSuccess, nil } - crr := &ResponseWriter{w, c} + crr := &ResponseWriter{ResponseWriter: w, Cache: c} return middleware.NextOrFailure(c.Name(), c.Next, ctx, crr, r) } // Name implements the Handler interface. func (c *Cache) Name() string { return "cache" } -func (c *Cache) get(qname string, qtype uint16, do bool) (*item, bool, bool) { - k := rawKey(qname, qtype, do) +func (c *Cache) get(now time.Time, qname string, qtype uint16, do bool) (*item, int) { + k := hash(qname, qtype, do) if i, ok := c.ncache.Get(k); ok { cacheHits.WithLabelValues(Denial).Inc() - return i.(*item), ok, i.(*item).expired(time.Now()) + return i.(*item), i.(*item).ttl(now) } if i, ok := c.pcache.Get(k); ok { cacheHits.WithLabelValues(Success).Inc() - return i.(*item), ok, i.(*item).expired(time.Now()) + return i.(*item), i.(*item).ttl(now) } cacheMisses.Inc() - return nil, false, false + return nil, 0 } var ( diff --git a/middleware/cache/item.go b/middleware/cache/item.go index 6a75afdf9..5084bcf1c 100644 --- a/middleware/cache/item.go +++ b/middleware/cache/item.go @@ -3,6 +3,7 @@ package cache import ( "time" + "github.com/coredns/coredns/middleware/cache/freq" "github.com/coredns/coredns/middleware/pkg/response" "github.com/miekg/dns" ) @@ -18,6 +19,8 @@ type item struct { origTTL uint32 stored time.Time + + *freq.Freq } func newItem(m *dns.Msg, d time.Duration) *item { @@ -43,10 +46,12 @@ func newItem(m *dns.Msg, d time.Duration) *item { i.origTTL = uint32(d.Seconds()) i.stored = time.Now().UTC() + i.Freq = new(freq.Freq) + return i } -// toMsg turns i into a message, it tailers the reply to m. +// toMsg turns i into a message, it tailors the reply to m. // The Authoritative bit is always set to 0, because the answer is from the cache. func (i *item) toMsg(m *dns.Msg) *dns.Msg { m1 := new(dns.Msg) @@ -67,9 +72,9 @@ func (i *item) toMsg(m *dns.Msg) *dns.Msg { return m1 } -func (i *item) expired(now time.Time) bool { +func (i *item) ttl(now time.Time) int { ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds()) - return ttl < 0 + return ttl } // setMsgTTL sets the ttl on all RRs in all sections. If ttl is smaller than minTTL diff --git a/middleware/cache/item_test.go b/middleware/cache/item_test.go deleted file mode 100644 index b338d02bd..000000000 --- a/middleware/cache/item_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package cache - -import ( - "testing" - - "github.com/miekg/dns" -) - -func TestKey(t *testing.T) { - if x := rawKey("miek.nl.", dns.TypeMX, false); x != "0miek.nl..15" { - t.Errorf("failed to create correct key, got %s", x) - } - if x := rawKey("miek.nl.", dns.TypeMX, true); x != "1miek.nl..15" { - t.Errorf("failed to create correct key, got %s", x) - } - // rawKey does not lowercase. - if x := rawKey("miEK.nL.", dns.TypeMX, true); x != "1miEK.nL..15" { - t.Errorf("failed to create correct key, got %s", x) - } -} diff --git a/middleware/cache/prefech_test.go b/middleware/cache/prefech_test.go new file mode 100644 index 000000000..69ad5f92a --- /dev/null +++ b/middleware/cache/prefech_test.go @@ -0,0 +1,54 @@ +package cache + +import ( + "fmt" + "testing" + "time" + + "github.com/coredns/coredns/middleware" + "github.com/coredns/coredns/middleware/pkg/cache" + "github.com/coredns/coredns/middleware/pkg/dnsrecorder" + + "github.com/coredns/coredns/middleware/test" + "github.com/miekg/dns" + "golang.org/x/net/context" +) + +var p = false + +func TestPrefetch(t *testing.T) { + c := &Cache{Zones: []string{"."}, pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxTTL} + c.pcache = cache.New(c.pcap) + c.ncache = cache.New(c.ncap) + c.prefetch = 1 + c.duration = 1 * time.Second + c.Next = PrefetchHandler(t, dns.RcodeSuccess, nil) + + ctx := context.TODO() + + req := new(dns.Msg) + req.SetQuestion("lowttl.example.org.", dns.TypeA) + + rec := dnsrecorder.New(&test.ResponseWriter{}) + + c.ServeDNS(ctx, rec, req) + p = true // prefetch should be true for the 2nd fetch + c.ServeDNS(ctx, rec, req) +} + +func PrefetchHandler(t *testing.T, rcode int, err error) middleware.Handler { + return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("lowttl.example.org.", dns.TypeA) + m.Response = true + m.RecursionAvailable = true + m.Answer = append(m.Answer, test.A("lowttl.example.org. 80 IN A 127.0.0.53")) + if p != w.(*ResponseWriter).prefetch { + err = fmt.Errorf("cache prefetch not equal to p: got %t, want %t", p, w.(*ResponseWriter).prefetch) + t.Fatal(err) + } + + w.WriteMsg(m) + return rcode, err + }) +} diff --git a/middleware/cache/setup.go b/middleware/cache/setup.go index eb835ba4e..65cfb70d1 100644 --- a/middleware/cache/setup.go +++ b/middleware/cache/setup.go @@ -7,8 +7,8 @@ import ( "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/middleware" + "github.com/coredns/coredns/middleware/pkg/cache" - "github.com/hashicorp/golang-lru" "github.com/mholt/caddy" ) @@ -38,7 +38,7 @@ func setup(c *caddy.Controller) error { func cacheParse(c *caddy.Controller) (*Cache, error) { - ca := &Cache{pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxNTTL} + ca := &Cache{pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxNTTL, prefetch: 0, duration: 1 * time.Minute} for c.Next() { // cache [ttl] [zones..] @@ -109,6 +109,46 @@ func cacheParse(c *caddy.Controller) (*Cache, error) { } ca.nttl = time.Duration(nttl) * time.Second } + case "prefetch": + args := c.RemainingArgs() + if len(args) == 0 || len(args) > 3 { + return nil, c.ArgErr() + } + amount, err := strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + if amount < 0 { + return nil, fmt.Errorf("prefetch amount should be positive: %d", amount) + } + ca.prefetch = amount + + ca.duration = 1 * time.Minute + ca.percentage = 10 + if len(args) > 1 { + dur, err := time.ParseDuration(args[1]) + if err != nil { + return nil, err + } + ca.duration = dur + } + if len(args) > 2 { + pct := args[2] + if x := pct[len(pct)-1]; x != '%' { + return nil, fmt.Errorf("last character of percentage should be `%%`, but is: %q", x) + } + pct = pct[:len(pct)-1] + + num, err := strconv.Atoi(pct) + if err != nil { + return nil, err + } + if num < 10 || num > 90 { + return nil, fmt.Errorf("percentage should fall in range [10, 90]: %d", num) + } + ca.percentage = num + } + default: return nil, c.ArgErr() } @@ -118,17 +158,10 @@ func cacheParse(c *caddy.Controller) (*Cache, error) { origins[i] = middleware.Host(origins[i]).Normalize() } - var err error ca.Zones = origins - ca.pcache, err = lru.New(ca.pcap) - if err != nil { - return nil, err - } - ca.ncache, err = lru.New(ca.ncap) - if err != nil { - return nil, err - } + ca.pcache = cache.New(ca.pcap) + ca.ncache = cache.New(ca.ncap) return ca, nil } diff --git a/middleware/cache/setup_test.go b/middleware/cache/setup_test.go index f46a93b76..afc2ecc13 100644 --- a/middleware/cache/setup_test.go +++ b/middleware/cache/setup_test.go @@ -9,46 +9,57 @@ import ( func TestSetup(t *testing.T) { tests := []struct { - input string - shouldErr bool - expectedNcap int - expectedPcap int - expectedNttl time.Duration - expectedPttl time.Duration + input string + shouldErr bool + expectedNcap int + expectedPcap int + expectedNttl time.Duration + expectedPttl time.Duration + expectedPrefetch int }{ - {`cache`, false, defaultCap, defaultCap, maxNTTL, maxTTL}, - {`cache {}`, false, defaultCap, defaultCap, maxNTTL, maxTTL}, + {`cache`, false, defaultCap, defaultCap, maxNTTL, maxTTL, 0}, + {`cache {}`, false, defaultCap, defaultCap, maxNTTL, maxTTL, 0}, {`cache example.nl { success 10 - }`, false, defaultCap, 10, maxNTTL, maxTTL}, + }`, false, defaultCap, 10, maxNTTL, maxTTL, 0}, {`cache example.nl { success 10 denial 10 15 - }`, false, 10, 10, 15 * time.Second, maxTTL}, + }`, false, 10, 10, 15 * time.Second, maxTTL, 0}, {`cache 25 example.nl { success 10 denial 10 15 - }`, false, 10, 10, 15 * time.Second, 25 * time.Second}, - {`cache aaa example.nl`, false, defaultCap, defaultCap, maxNTTL, maxTTL}, + }`, false, 10, 10, 15 * time.Second, 25 * time.Second, 0}, + {`cache aaa example.nl`, false, defaultCap, defaultCap, maxNTTL, maxTTL, 0}, + {`cache { + prefetch 10 + }`, false, defaultCap, defaultCap, maxNTTL, maxTTL, 10}, // fails {`cache example.nl { success denial 10 15 - }`, true, defaultCap, defaultCap, maxTTL, maxTTL}, + }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0}, {`cache example.nl { success 15 denial aaa - }`, true, defaultCap, defaultCap, maxTTL, maxTTL}, + }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0}, {`cache example.nl { positive 15 negative aaa - }`, true, defaultCap, defaultCap, maxTTL, maxTTL}, - {`cache 0 example.nl`, true, defaultCap, defaultCap, maxTTL, maxTTL}, - {`cache -1 example.nl`, true, defaultCap, defaultCap, maxTTL, maxTTL}, + }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0}, + {`cache 0 example.nl`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0}, + {`cache -1 example.nl`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0}, {`cache 1 example.nl { positive 0 - }`, true, defaultCap, defaultCap, maxTTL, maxTTL}, + }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0}, + {`cache 1 example.nl { + positive 0 + prefetch -1 + }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0}, + {`cache 1 example.nl { + prefetch 0 blurp + }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0}, } for i, test := range tests { c := caddy.NewTestController("dns", test.input) @@ -76,5 +87,8 @@ func TestSetup(t *testing.T) { if ca.pttl != test.expectedPttl { t.Errorf("Test %v: Expected pttl %v but found: %v", i, test.expectedPttl, ca.pttl) } + if ca.prefetch != test.expectedPrefetch { + t.Errorf("Test %v: Expected prefetch %v but found: %v", i, test.expectedPrefetch, ca.prefetch) + } } } diff --git a/middleware/dnssec/cache.go b/middleware/dnssec/cache.go index 2153c84cb..ea95b73b4 100644 --- a/middleware/dnssec/cache.go +++ b/middleware/dnssec/cache.go @@ -2,14 +2,13 @@ package dnssec import ( "hash/fnv" - "strconv" "github.com/miekg/dns" ) -// Key serializes the RRset and return a signature cache key. -func key(rrs []dns.RR) string { - h := fnv.New64() +// hash serializes the RRset and return a signature cache key. +func hash(rrs []dns.RR) uint32 { + h := fnv.New32() buf := make([]byte, 256) for _, r := range rrs { off, err := dns.PackRR(r, buf, 0, nil, false) @@ -18,6 +17,6 @@ func key(rrs []dns.RR) string { } } - i := h.Sum64() - return strconv.FormatUint(i, 10) + i := h.Sum32() + return i } diff --git a/middleware/dnssec/cache_test.go b/middleware/dnssec/cache_test.go index c88e310e3..b9434fcbe 100644 --- a/middleware/dnssec/cache_test.go +++ b/middleware/dnssec/cache_test.go @@ -4,10 +4,9 @@ import ( "testing" "time" + "github.com/coredns/coredns/middleware/pkg/cache" "github.com/coredns/coredns/middleware/test" "github.com/coredns/coredns/request" - - "github.com/hashicorp/golang-lru" ) func TestCacheSet(t *testing.T) { @@ -21,11 +20,11 @@ func TestCacheSet(t *testing.T) { t.Fatalf("failed to parse key: %v\n", err) } - cache, _ := lru.New(defaultCap) + c := cache.New(defaultCap) m := testMsg() state := request.Request{Req: m} - k := key(m.Answer) // calculate *before* we add the sig - d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil, cache) + k := hash(m.Answer) // calculate *before* we add the sig + d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil, c) m = d.Sign(state, "miek.nl.", time.Now().UTC()) _, ok := d.get(k) diff --git a/middleware/dnssec/dnssec.go b/middleware/dnssec/dnssec.go index 7ad5f2bf7..4e1e70217 100644 --- a/middleware/dnssec/dnssec.go +++ b/middleware/dnssec/dnssec.go @@ -6,11 +6,11 @@ import ( "time" "github.com/coredns/coredns/middleware" + "github.com/coredns/coredns/middleware/pkg/cache" "github.com/coredns/coredns/middleware/pkg/response" "github.com/coredns/coredns/middleware/pkg/singleflight" "github.com/coredns/coredns/request" - "github.com/hashicorp/golang-lru" "github.com/miekg/dns" ) @@ -21,15 +21,15 @@ type Dnssec struct { zones []string keys []*DNSKEY inflight *singleflight.Group - cache *lru.Cache + cache *cache.Cache } // New returns a new Dnssec. -func New(zones []string, keys []*DNSKEY, next middleware.Handler, cache *lru.Cache) Dnssec { +func New(zones []string, keys []*DNSKEY, next middleware.Handler, c *cache.Cache) Dnssec { return Dnssec{Next: next, zones: zones, keys: keys, - cache: cache, + cache: c, inflight: new(singleflight.Group), } } @@ -90,7 +90,7 @@ func (d Dnssec) Sign(state request.Request, zone string, now time.Time) *dns.Msg } func (d Dnssec) sign(rrs []dns.RR, signerName string, ttl, incep, expir uint32) ([]dns.RR, error) { - k := key(rrs) + k := hash(rrs) sgs, ok := d.get(k) if ok { return sgs, nil @@ -110,11 +110,11 @@ func (d Dnssec) sign(rrs []dns.RR, signerName string, ttl, incep, expir uint32) return sigs.([]dns.RR), err } -func (d Dnssec) set(key string, sigs []dns.RR) { +func (d Dnssec) set(key uint32, sigs []dns.RR) { d.cache.Add(key, sigs) } -func (d Dnssec) get(key string) ([]dns.RR, bool) { +func (d Dnssec) get(key uint32) ([]dns.RR, bool) { if s, ok := d.cache.Get(key); ok { cacheHits.Inc() return s.([]dns.RR), true diff --git a/middleware/dnssec/dnssec_test.go b/middleware/dnssec/dnssec_test.go index 1c9c9a545..3549a7c8f 100644 --- a/middleware/dnssec/dnssec_test.go +++ b/middleware/dnssec/dnssec_test.go @@ -4,10 +4,10 @@ import ( "testing" "time" + "github.com/coredns/coredns/middleware/pkg/cache" "github.com/coredns/coredns/middleware/test" "github.com/coredns/coredns/request" - "github.com/hashicorp/golang-lru" "github.com/miekg/dns" ) @@ -69,8 +69,8 @@ func TestSigningDifferentZone(t *testing.T) { m := testMsgEx() state := request.Request{Req: m} - cache, _ := lru.New(defaultCap) - d := New([]string{"example.org."}, []*DNSKEY{key}, nil, cache) + c := cache.New(defaultCap) + d := New([]string{"example.org."}, []*DNSKEY{key}, nil, c) m = d.Sign(state, "example.org.", time.Now().UTC()) if !section(m.Answer, 1) { t.Errorf("answer section should have 1 sig") @@ -183,8 +183,8 @@ func testMsgDname() *dns.Msg { func newDnssec(t *testing.T, zones []string) (Dnssec, func(), func()) { k, rm1, rm2 := newKey(t) - cache, _ := lru.New(defaultCap) - d := New(zones, []*DNSKEY{k}, nil, cache) + c := cache.New(defaultCap) + d := New(zones, []*DNSKEY{k}, nil, c) return d, rm1, rm2 } diff --git a/middleware/dnssec/handler_test.go b/middleware/dnssec/handler_test.go index 37a92935a..60847d17b 100644 --- a/middleware/dnssec/handler_test.go +++ b/middleware/dnssec/handler_test.go @@ -6,10 +6,10 @@ import ( "testing" "github.com/coredns/coredns/middleware/file" + "github.com/coredns/coredns/middleware/pkg/cache" "github.com/coredns/coredns/middleware/pkg/dnsrecorder" "github.com/coredns/coredns/middleware/test" - "github.com/hashicorp/golang-lru" "github.com/miekg/dns" "golang.org/x/net/context" ) @@ -89,8 +89,8 @@ func TestLookupZone(t *testing.T) { dnskey, rm1, rm2 := newKey(t) defer rm1() defer rm2() - cache, _ := lru.New(defaultCap) - dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, fm, cache) + c := cache.New(defaultCap) + dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, fm, c) ctx := context.TODO() for _, tc := range dnsTestCases { @@ -128,8 +128,8 @@ func TestLookupDNSKEY(t *testing.T) { dnskey, rm1, rm2 := newKey(t) defer rm1() defer rm2() - cache, _ := lru.New(defaultCap) - dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, test.ErrorHandler(), cache) + c := cache.New(defaultCap) + dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, test.ErrorHandler(), c) ctx := context.TODO() for _, tc := range dnssecTestCases { diff --git a/middleware/dnssec/setup.go b/middleware/dnssec/setup.go index 1b1fb6393..6935c4aaa 100644 --- a/middleware/dnssec/setup.go +++ b/middleware/dnssec/setup.go @@ -6,8 +6,8 @@ import ( "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/middleware" + "github.com/coredns/coredns/middleware/pkg/cache" - "github.com/hashicorp/golang-lru" "github.com/mholt/caddy" ) @@ -24,12 +24,9 @@ func setup(c *caddy.Controller) error { return middleware.Error("dnssec", err) } - cache, err := lru.New(capacity) - if err != nil { - return err - } + ca := cache.New(capacity) dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler { - return New(zones, keys, next, cache) + return New(zones, keys, next, ca) }) // Export the capacity for the metrics. This only happens once, because this is a re-load change only. diff --git a/middleware/etcd/etcd.go b/middleware/etcd/etcd.go index f701785f3..b438b794d 100644 --- a/middleware/etcd/etcd.go +++ b/middleware/etcd/etcd.go @@ -9,6 +9,7 @@ import ( "github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware/etcd/msg" + "github.com/coredns/coredns/middleware/pkg/cache" "github.com/coredns/coredns/middleware/pkg/singleflight" "github.com/coredns/coredns/middleware/proxy" "github.com/coredns/coredns/request" @@ -90,7 +91,10 @@ func (e *Etcd) Records(name string, exact bool) ([]msg.Service, error) { // get is a wrapper for client.Get that uses SingleInflight to suppress multiple outstanding queries. func (e *Etcd) get(path string, recursive bool) (*etcdc.Response, error) { - resp, err := e.Inflight.Do(path, func() (interface{}, error) { + + hash := cache.Hash([]byte(path)) + + resp, err := e.Inflight.Do(hash, func() (interface{}, error) { ctx, cancel := context.WithTimeout(e.Ctx, etcdTimeout) defer cancel() r, e := e.Client.Get(ctx, path, &etcdc.GetOptions{Sort: false, Recursive: recursive}) diff --git a/middleware/pkg/cache/cache.go b/middleware/pkg/cache/cache.go new file mode 100644 index 000000000..56cae2180 --- /dev/null +++ b/middleware/pkg/cache/cache.go @@ -0,0 +1,129 @@ +// Package cache implements a cache. The cache hold 256 shards, each shard +// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it +// just randomly evicts elements when it gets full. +package cache + +import ( + "hash/fnv" + "sync" +) + +// Hash returns the FNV hash of what. +func Hash(what []byte) uint32 { + h := fnv.New32() + h.Write(what) + return h.Sum32() +} + +// Cache is cache. +type Cache struct { + shards [shardSize]*shard +} + +// shard is a cache with random eviction. +type shard struct { + items map[uint32]interface{} + size int + + sync.RWMutex +} + +// New returns a new cache. +func New(size int) *Cache { + ssize := size / shardSize + if ssize < 512 { + ssize = 512 + } + + c := &Cache{} + + // Initialize all the shards + for i := 0; i < shardSize; i++ { + c.shards[i] = newShard(ssize) + } + return c +} + +// Add adds a new element to the cache. If the element already exists it is overwritten. +func (c *Cache) Add(key uint32, el interface{}) { + shard := key & (shardSize - 1) + c.shards[shard].Add(key, el) +} + +// Get looks up element index under key. +func (c *Cache) Get(key uint32) (interface{}, bool) { + shard := key & (shardSize - 1) + return c.shards[shard].Get(key) +} + +// Remove removes the element indexed with key. +func (c *Cache) Remove(key uint32) { + shard := key & (shardSize - 1) + c.shards[shard].Remove(key) +} + +// Len returns the number of elements in the cache. +func (c *Cache) Len() int { + l := 0 + for _, s := range c.shards { + l += s.Len() + } + return l +} + +// newShard returns a new shard with size. +func newShard(size int) *shard { return &shard{items: make(map[uint32]interface{}), size: size} } + +// Add adds element indexed by key into the cache. Any existing element is overwritten +func (s *shard) Add(key uint32, el interface{}) { + l := s.Len() + if l+1 > s.size { + s.Evict() + } + + s.Lock() + s.items[key] = el + s.Unlock() +} + +// Remove removes the element indexed by key from the cache. +func (s *shard) Remove(key uint32) { + s.Lock() + delete(s.items, key) + s.Unlock() +} + +// Evict removes a random element from the cache. +func (s *shard) Evict() { + s.Lock() + defer s.Unlock() + + key := -1 + for k := range s.items { + key = int(k) + break + } + if key == -1 { + // empty cache + return + } + delete(s.items, uint32(key)) +} + +// Get looks up the element indexed under key. +func (s *shard) Get(key uint32) (interface{}, bool) { + s.RLock() + el, found := s.items[key] + s.RUnlock() + return el, found +} + +// Len returns the current length of the cache. +func (s *shard) Len() int { + s.RLock() + l := len(s.items) + s.RUnlock() + return l +} + +const shardSize = 256 diff --git a/middleware/pkg/cache/cache_test.go b/middleware/pkg/cache/cache_test.go new file mode 100644 index 000000000..2c92bf438 --- /dev/null +++ b/middleware/pkg/cache/cache_test.go @@ -0,0 +1,31 @@ +package cache + +import "testing" + +func TestCacheAddAndGet(t *testing.T) { + c := New(4) + c.Add(1, 1) + + if _, found := c.Get(1); !found { + t.Fatal("Failed to find inserted record") + } +} + +func TestCacheLen(t *testing.T) { + c := New(4) + + c.Add(1, 1) + if l := c.Len(); l != 1 { + t.Fatalf("Cache size should %d, got %d", 1, l) + } + + c.Add(1, 1) + if l := c.Len(); l != 1 { + t.Fatalf("Cache size should %d, got %d", 1, l) + } + + c.Add(2, 2) + if l := c.Len(); l != 2 { + t.Fatalf("Cache size should %d, got %d", 2, l) + } +} diff --git a/middleware/pkg/cache/shard_test.go b/middleware/pkg/cache/shard_test.go new file mode 100644 index 000000000..26675cee1 --- /dev/null +++ b/middleware/pkg/cache/shard_test.go @@ -0,0 +1,60 @@ +package cache + +import "testing" + +func TestShardAddAndGet(t *testing.T) { + s := newShard(4) + s.Add(1, 1) + + if _, found := s.Get(1); !found { + t.Fatal("Failed to find inserted record") + } +} + +func TestShardLen(t *testing.T) { + s := newShard(4) + + s.Add(1, 1) + if l := s.Len(); l != 1 { + t.Fatalf("Shard size should %d, got %d", 1, l) + } + + s.Add(1, 1) + if l := s.Len(); l != 1 { + t.Fatalf("Shard size should %d, got %d", 1, l) + } + + s.Add(2, 2) + if l := s.Len(); l != 2 { + t.Fatalf("Shard size should %d, got %d", 2, l) + } +} + +func TestShardEvict(t *testing.T) { + s := newShard(1) + s.Add(1, 1) + s.Add(2, 2) + // 1 should be gone + + if _, found := s.Get(1); found { + t.Fatal("Found item that should have been evicted") + } +} + +func TestShardLenEvict(t *testing.T) { + s := newShard(4) + s.Add(1, 1) + s.Add(2, 1) + s.Add(3, 1) + s.Add(4, 1) + + if l := s.Len(); l != 4 { + t.Fatalf("Shard size should %d, got %d", 4, l) + } + + // This should evict one element + s.Add(5, 1) + if l := s.Len(); l != 4 { + t.Fatalf("Shard size should %d, got %d", 4, l) + } +} diff --git a/middleware/pkg/singleflight/singleflight.go b/middleware/pkg/singleflight/singleflight.go index ff2c2ee4f..365e3ef58 100644 --- a/middleware/pkg/singleflight/singleflight.go +++ b/middleware/pkg/singleflight/singleflight.go @@ -31,17 +31,17 @@ type call struct { // units of work can be executed with duplicate suppression. type Group struct { mu sync.Mutex // protects m - m map[string]*call // lazily initialized + m map[uint32]*call // lazily initialized } // Do executes and returns the results of the given function, making // sure that only one execution is in-flight for a given key at a // time. If a duplicate comes in, the duplicate caller waits for the // original to complete and receives the same results. -func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) { +func (g *Group) Do(key uint32, fn func() (interface{}, error)) (interface{}, error) { g.mu.Lock() if g.m == nil { - g.m = make(map[string]*call) + g.m = make(map[uint32]*call) } if c, ok := g.m[key]; ok { g.mu.Unlock() diff --git a/middleware/pkg/singleflight/singleflight_test.go b/middleware/pkg/singleflight/singleflight_test.go index 47b4d3dc0..d1d406e0b 100644 --- a/middleware/pkg/singleflight/singleflight_test.go +++ b/middleware/pkg/singleflight/singleflight_test.go @@ -27,7 +27,7 @@ import ( func TestDo(t *testing.T) { var g Group - v, err := g.Do("key", func() (interface{}, error) { + v, err := g.Do(1, func() (interface{}, error) { return "bar", nil }) if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { @@ -41,7 +41,7 @@ func TestDo(t *testing.T) { func TestDoErr(t *testing.T) { var g Group someErr := errors.New("Some error") - v, err := g.Do("key", func() (interface{}, error) { + v, err := g.Do(1, func() (interface{}, error) { return nil, someErr }) if err != someErr { @@ -66,7 +66,7 @@ func TestDoDupSuppress(t *testing.T) { for i := 0; i < n; i++ { wg.Add(1) go func() { - v, err := g.Do("key", fn) + v, err := g.Do(1, fn) if err != nil { t.Errorf("Do error: %v", err) }