New cache implementation and prefetch handing in mw/cache (#731)

* cache: add sharded cache implementation

Add Cache impl and a few tests. This cache is 256-way sharded, mainly
so each shard has it's own lock. The main cache structure is a readonly
jump plane into the right shard.

This should remove the single lock contention on the main lock and
provide more concurrent throughput - Obviously this hasn't been tested
or measured.

The key into the cache was made a uint32 (hash.fnv) and the hashing op
is not using strings.ToLower anymore remove any GC in that code path.

* here too

* Minimum shard size

* typos

* blurp

* small cleanups no defer

* typo

* Add freq based on Johns idea

* cherry-pick conflict resolv

* typo

* update from early code review from john

* add prefetch to the cache

* mw/cache: add prefetch

* remove println

* remove comment

* Fix tests

* Test prefetch in setup

* Add start of cache

* try add diff cache options

* Add hacky testcase

* not needed

* allow the use of a percentage for prefetch

If the TTL falls below xx% do a prefetch, if the record was popular.
Some other fixes and correctly prefetch only popular records.
This commit is contained in:
Miek Gieben 2017-06-13 12:39:10 -07:00 committed by GitHub
parent b1efd3736e
commit e9eda7e7c8
23 changed files with 595 additions and 142 deletions

View file

@ -10,13 +10,12 @@ cache [TTL] [ZONES...]
* **TTL** max TTL in seconds. If not specified, the maximum TTL will be used which is 3600 for * **TTL** max TTL in seconds. If not specified, the maximum TTL will be used which is 3600 for
noerror responses and 1800 for denial of existence ones. noerror responses and 1800 for denial of existence ones.
A set TTL of 300 *cache 300* would cache the record up to 300 seconds. Setting a TTL of 300 *cache 300* would cache the record up to 300 seconds.
Smaller record provided TTLs will take precedence.
* **ZONES** zones it should cache for. If empty, the zones from the configuration block are used. * **ZONES** zones it should cache for. If empty, the zones from the configuration block are used.
Each element in the cache is cached according to its TTL (with **TTL** as the max). Each element in the cache is cached according to its TTL (with **TTL** as the max).
For the negative cache, the SOA's MinTTL value is used. A cache can contain up to 10,000 items by For the negative cache, the SOA's MinTTL value is used. A cache can contain up to 10,000 items by
default. A TTL of zero is not allowed. No cache invalidation triggered by other middlewares is available. Therefore even reloaded items might still be cached for the duration of the TTL. default. A TTL of zero is not allowed.
If you want more control: If you want more control:
@ -24,16 +23,21 @@ If you want more control:
cache [TTL] [ZONES...] { cache [TTL] [ZONES...] {
success CAPACITY [TTL] success CAPACITY [TTL]
denial CAPACITY [TTL] denial CAPACITY [TTL]
prefetch AMOUNT [[DURATION] [PERCENTAGE%]]
} }
~~~ ~~~
* **TTL** and **ZONES** as above. * **TTL** and **ZONES** as above.
* `success`, override the settings for caching successful responses, **CAPACITY** indicates the maximum * `success`, override the settings for caching successful responses, **CAPACITY** indicates the maximum
number of packets we cache before we start evicting (LRU). **TTL** overrides the cache maximum TTL. number of packets we cache before we start evicting (*randomly*). **TTL** overrides the cache maximum TTL.
* `denial`, override the settings for caching denial of existence responses, **CAPACITY** indicates the maximum * `denial`, override the settings for caching denial of existence responses, **CAPACITY** indicates the maximum
number of packets we cache before we start evicting (LRU). **TTL** overrides the cache maximum TTL. number of packets we cache before we start evicting (LRU). **TTL** overrides the cache maximum TTL.
There is a third category (`error`) but those responses are never cached.
There is a third category (`error`) but those responses are never cached. * `prefetch`, will prefetch popular items when they are about to be expunged from the cache.
Popular means **AMOUNT** queries have been seen no gaps of **DURATION** or more between them.
**DURATION** defaults to 1m. Prefetching will happen when the TTL drops below **PERCENTAGE**,
which defaults to `10%`. Values should be in the range `[10%, 90%]`. Note the percent sign is
mandatory. **PERCENTAGE** is treated as an `int`.
The minimum TTL allowed on resource records is 5 seconds. The minimum TTL allowed on resource records is 5 seconds.

View file

@ -2,15 +2,15 @@
package cache package cache
import ( import (
"encoding/binary"
"hash/fnv"
"log" "log"
"strconv"
"strings"
"time" "time"
"github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/coredns/coredns/middleware/pkg/response" "github.com/coredns/coredns/middleware/pkg/response"
"github.com/hashicorp/golang-lru"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -20,48 +20,73 @@ type Cache struct {
Next middleware.Handler Next middleware.Handler
Zones []string Zones []string
ncache *lru.Cache ncache *cache.Cache
ncap int ncap int
nttl time.Duration nttl time.Duration
pcache *lru.Cache pcache *cache.Cache
pcap int pcap int
pttl time.Duration pttl time.Duration
// Prefetch.
prefetch int
duration time.Duration
percentage int
} }
// Return key under which we store the item. The empty string is returned // Return key under which we store the item, -1 will be returned if we don't store the
// when we don't want to cache the message. Currently we do not cache Truncated, errors // message.
// zone transfers or dynamic update messages. // Currently we do not cache Truncated, errors zone transfers or dynamic update messages.
func key(m *dns.Msg, t response.Type, do bool) string { func key(m *dns.Msg, t response.Type, do bool) int {
// We don't store truncated responses. // We don't store truncated responses.
if m.Truncated { if m.Truncated {
return "" return -1
} }
// Nor errors or Meta or Update // Nor errors or Meta or Update
if t == response.OtherError || t == response.Meta || t == response.Update { if t == response.OtherError || t == response.Meta || t == response.Update {
return "" return -1
} }
qtype := m.Question[0].Qtype return int(hash(m.Question[0].Name, m.Question[0].Qtype, do))
qname := strings.ToLower(m.Question[0].Name)
return rawKey(qname, qtype, do)
} }
func rawKey(qname string, qtype uint16, do bool) string { var one = []byte("1")
var zero = []byte("0")
func hash(qname string, qtype uint16, do bool) uint32 {
h := fnv.New32()
if do { if do {
return "1" + qname + "." + strconv.Itoa(int(qtype)) h.Write(one)
} else {
h.Write(zero)
} }
return "0" + qname + "." + strconv.Itoa(int(qtype))
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, qtype)
h.Write(b)
for i := range qname {
c := qname[i]
if c >= 'A' && c <= 'Z' {
c += 'a' - 'A'
}
h.Write([]byte{c})
}
return h.Sum32()
} }
// ResponseWriter is a response writer that caches the reply message. // ResponseWriter is a response writer that caches the reply message.
type ResponseWriter struct { type ResponseWriter struct {
dns.ResponseWriter dns.ResponseWriter
*Cache *Cache
prefetch bool // When true write nothing back to the client.
} }
// WriteMsg implements the dns.ResponseWriter interface. // WriteMsg implements the dns.ResponseWriter interface.
func (c *ResponseWriter) WriteMsg(res *dns.Msg) error { func (w *ResponseWriter) WriteMsg(res *dns.Msg) error {
do := false do := false
mt, opt := response.Typify(res, time.Now().UTC()) mt, opt := response.Typify(res, time.Now().UTC())
if opt != nil { if opt != nil {
@ -71,9 +96,9 @@ func (c *ResponseWriter) WriteMsg(res *dns.Msg) error {
// key returns empty string for anything we don't want to cache. // key returns empty string for anything we don't want to cache.
key := key(res, mt, do) key := key(res, mt, do)
duration := c.pttl duration := w.pttl
if mt == response.NameError || mt == response.NoData { if mt == response.NameError || mt == response.NoData {
duration = c.nttl duration = w.nttl
} }
msgTTL := minMsgTTL(res, mt) msgTTL := minMsgTTL(res, mt)
@ -81,20 +106,23 @@ func (c *ResponseWriter) WriteMsg(res *dns.Msg) error {
duration = msgTTL duration = msgTTL
} }
if key != "" { if key != -1 {
c.set(res, key, mt, duration) w.set(res, key, mt, duration)
cacheSize.WithLabelValues(Success).Set(float64(c.pcache.Len())) cacheSize.WithLabelValues(Success).Set(float64(w.pcache.Len()))
cacheSize.WithLabelValues(Denial).Set(float64(c.ncache.Len())) cacheSize.WithLabelValues(Denial).Set(float64(w.ncache.Len()))
} }
setMsgTTL(res, uint32(duration.Seconds())) setMsgTTL(res, uint32(duration.Seconds()))
if w.prefetch {
return nil
}
return c.ResponseWriter.WriteMsg(res) return w.ResponseWriter.WriteMsg(res)
} }
func (c *ResponseWriter) set(m *dns.Msg, key string, mt response.Type, duration time.Duration) { func (w *ResponseWriter) set(m *dns.Msg, key int, mt response.Type, duration time.Duration) {
if key == "" { if key == -1 {
log.Printf("[ERROR] Caching called with empty cache key") log.Printf("[ERROR] Caching called with empty cache key")
return return
} }
@ -102,11 +130,11 @@ func (c *ResponseWriter) set(m *dns.Msg, key string, mt response.Type, duration
switch mt { switch mt {
case response.NoError, response.Delegation: case response.NoError, response.Delegation:
i := newItem(m, duration) i := newItem(m, duration)
c.pcache.Add(key, i) w.pcache.Add(uint32(key), i)
case response.NameError, response.NoData: case response.NameError, response.NoData:
i := newItem(m, duration) i := newItem(m, duration)
c.ncache.Add(key, i) w.ncache.Add(uint32(key), i)
case response.OtherError: case response.OtherError:
// don't cache these // don't cache these
@ -116,9 +144,12 @@ func (c *ResponseWriter) set(m *dns.Msg, key string, mt response.Type, duration
} }
// Write implements the dns.ResponseWriter interface. // Write implements the dns.ResponseWriter interface.
func (c *ResponseWriter) Write(buf []byte) (int, error) { func (w *ResponseWriter) Write(buf []byte) (int, error) {
log.Printf("[WARNING] Caching called with Write: not caching reply") log.Printf("[WARNING] Caching called with Write: not caching reply")
n, err := c.ResponseWriter.Write(buf) if w.prefetch {
return 0, nil
}
n, err := w.ResponseWriter.Write(buf)
return n, err return n, err
} }

View file

@ -7,10 +7,10 @@ import (
"time" "time"
"github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/coredns/coredns/middleware/pkg/response" "github.com/coredns/coredns/middleware/pkg/response"
"github.com/coredns/coredns/middleware/test" "github.com/coredns/coredns/middleware/test"
lru "github.com/hashicorp/golang-lru"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -148,10 +148,10 @@ func cacheMsg(m *dns.Msg, tc cacheTestCase) *dns.Msg {
func newTestCache(ttl time.Duration) (*Cache, *ResponseWriter) { func newTestCache(ttl time.Duration) (*Cache, *ResponseWriter) {
c := &Cache{Zones: []string{"."}, pcap: defaultCap, ncap: defaultCap, pttl: ttl, nttl: ttl} c := &Cache{Zones: []string{"."}, pcap: defaultCap, ncap: defaultCap, pttl: ttl, nttl: ttl}
c.pcache, _ = lru.New(c.pcap) c.pcache = cache.New(c.pcap)
c.ncache, _ = lru.New(c.ncap) c.ncache = cache.New(c.ncap)
crr := &ResponseWriter{nil, c} crr := &ResponseWriter{ResponseWriter: nil, Cache: c}
return c, crr return c, crr
} }
@ -176,7 +176,8 @@ func TestCache(t *testing.T) {
name := middleware.Name(m.Question[0].Name).Normalize() name := middleware.Name(m.Question[0].Name).Normalize()
qtype := m.Question[0].Qtype qtype := m.Question[0].Qtype
i, ok, _ := c.get(name, qtype, do) i, _ := c.get(time.Now().UTC(), name, qtype, do)
ok := i != nil
if ok != tc.shouldCache { if ok != tc.shouldCache {
t.Errorf("cached message that should not have been cached: %s", name) t.Errorf("cached message that should not have been cached: %s", name)

54
middleware/cache/freq/freq.go vendored Normal file
View file

@ -0,0 +1,54 @@
// Package freq keeps track of last X seen events. The events themselves are not stored
// here. So the Freq type should be added next to the thing it is tracking.
package freq
import (
"sync"
"time"
)
type Freq struct {
// Last time we saw a query for this element.
last time.Time
// Number of this in the last time slice.
hits int
sync.RWMutex
}
// New returns a new initialized Freq.
func New(t time.Time) *Freq {
return &Freq{last: t, hits: 0}
}
// Updates updates the number of hits. Last time seen will be set to now.
// If the last time we've seen this entity is within now - d, we increment hits, otherwise
// we reset hits to 1. It returns the number of hits.
func (f *Freq) Update(d time.Duration, now time.Time) int {
earliest := now.Add(-1 * d)
f.Lock()
defer f.Unlock()
if f.last.Before(earliest) {
f.last = now
f.hits = 1
return f.hits
}
f.last = now
f.hits++
return f.hits
}
// Hits returns the number of hits that we have seen, according to the updates we have done to f.
func (f *Freq) Hits() int {
f.RLock()
defer f.RUnlock()
return f.hits
}
// Reset resets f to time t and hits to hits.
func (f *Freq) Reset(t time.Time, hits int) {
f.Lock()
defer f.Unlock()
f.last = t
f.hits = hits
}

36
middleware/cache/freq/freq_test.go vendored Normal file
View file

@ -0,0 +1,36 @@
package freq
import (
"testing"
"time"
)
func TestFreqUpdate(t *testing.T) {
now := time.Now().UTC()
f := New(now)
window := 1 * time.Minute
f.Update(window, time.Now().UTC())
f.Update(window, time.Now().UTC())
f.Update(window, time.Now().UTC())
hitsCheck(t, f, 3)
f.Reset(now, 0)
history := time.Now().UTC().Add(-3 * time.Minute)
f.Update(window, history)
hitsCheck(t, f, 1)
}
func TestReset(t *testing.T) {
f := New(time.Now().UTC())
f.Update(1*time.Minute, time.Now().UTC())
hitsCheck(t, f, 1)
f.Reset(time.Now().UTC(), 0)
hitsCheck(t, f, 0)
}
func hitsCheck(t *testing.T, f *Freq, expected int) {
if x := f.Hits(); x != expected {
t.Fatalf("Expected hits to be %d, got %d", expected, x)
}
}

View file

@ -24,36 +24,58 @@ func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
do := state.Do() // TODO(): might need more from OPT record? Like the actual bufsize? do := state.Do() // TODO(): might need more from OPT record? Like the actual bufsize?
if i, ok, expired := c.get(qname, qtype, do); ok && !expired { now := time.Now().UTC()
i, ttl := c.get(now, qname, qtype, do)
if i != nil && ttl > 0 {
resp := i.toMsg(r) resp := i.toMsg(r)
state.SizeAndDo(resp) state.SizeAndDo(resp)
resp, _ = state.Scrub(resp) resp, _ = state.Scrub(resp)
w.WriteMsg(resp) w.WriteMsg(resp)
i.Freq.Update(c.duration, now)
pct := 100
if i.origTTL != 0 { // you'll never know
pct = int(float64(ttl) / float64(i.origTTL) * 100)
}
if c.prefetch > 0 && i.Freq.Hits() > c.prefetch && pct < c.percentage {
// When prefetching we loose the item i, and with it the frequency
// that we've gathered sofar. See we copy the frequence info back
// into the new item that was stored in the cache.
prr := &ResponseWriter{ResponseWriter: w, Cache: c, prefetch: true}
middleware.NextOrFailure(c.Name(), c.Next, ctx, prr, r)
if i1, _ := c.get(now, qname, qtype, do); i1 != nil {
i1.Freq.Reset(now, i.Freq.Hits())
}
}
return dns.RcodeSuccess, nil return dns.RcodeSuccess, nil
} }
crr := &ResponseWriter{w, c} crr := &ResponseWriter{ResponseWriter: w, Cache: c}
return middleware.NextOrFailure(c.Name(), c.Next, ctx, crr, r) return middleware.NextOrFailure(c.Name(), c.Next, ctx, crr, r)
} }
// Name implements the Handler interface. // Name implements the Handler interface.
func (c *Cache) Name() string { return "cache" } func (c *Cache) Name() string { return "cache" }
func (c *Cache) get(qname string, qtype uint16, do bool) (*item, bool, bool) { func (c *Cache) get(now time.Time, qname string, qtype uint16, do bool) (*item, int) {
k := rawKey(qname, qtype, do) k := hash(qname, qtype, do)
if i, ok := c.ncache.Get(k); ok { if i, ok := c.ncache.Get(k); ok {
cacheHits.WithLabelValues(Denial).Inc() cacheHits.WithLabelValues(Denial).Inc()
return i.(*item), ok, i.(*item).expired(time.Now()) return i.(*item), i.(*item).ttl(now)
} }
if i, ok := c.pcache.Get(k); ok { if i, ok := c.pcache.Get(k); ok {
cacheHits.WithLabelValues(Success).Inc() cacheHits.WithLabelValues(Success).Inc()
return i.(*item), ok, i.(*item).expired(time.Now()) return i.(*item), i.(*item).ttl(now)
} }
cacheMisses.Inc() cacheMisses.Inc()
return nil, false, false return nil, 0
} }
var ( var (

View file

@ -3,6 +3,7 @@ package cache
import ( import (
"time" "time"
"github.com/coredns/coredns/middleware/cache/freq"
"github.com/coredns/coredns/middleware/pkg/response" "github.com/coredns/coredns/middleware/pkg/response"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -18,6 +19,8 @@ type item struct {
origTTL uint32 origTTL uint32
stored time.Time stored time.Time
*freq.Freq
} }
func newItem(m *dns.Msg, d time.Duration) *item { func newItem(m *dns.Msg, d time.Duration) *item {
@ -43,10 +46,12 @@ func newItem(m *dns.Msg, d time.Duration) *item {
i.origTTL = uint32(d.Seconds()) i.origTTL = uint32(d.Seconds())
i.stored = time.Now().UTC() i.stored = time.Now().UTC()
i.Freq = new(freq.Freq)
return i return i
} }
// toMsg turns i into a message, it tailers the reply to m. // toMsg turns i into a message, it tailors the reply to m.
// The Authoritative bit is always set to 0, because the answer is from the cache. // The Authoritative bit is always set to 0, because the answer is from the cache.
func (i *item) toMsg(m *dns.Msg) *dns.Msg { func (i *item) toMsg(m *dns.Msg) *dns.Msg {
m1 := new(dns.Msg) m1 := new(dns.Msg)
@ -67,9 +72,9 @@ func (i *item) toMsg(m *dns.Msg) *dns.Msg {
return m1 return m1
} }
func (i *item) expired(now time.Time) bool { func (i *item) ttl(now time.Time) int {
ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds()) ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds())
return ttl < 0 return ttl
} }
// setMsgTTL sets the ttl on all RRs in all sections. If ttl is smaller than minTTL // setMsgTTL sets the ttl on all RRs in all sections. If ttl is smaller than minTTL

View file

@ -1,20 +0,0 @@
package cache
import (
"testing"
"github.com/miekg/dns"
)
func TestKey(t *testing.T) {
if x := rawKey("miek.nl.", dns.TypeMX, false); x != "0miek.nl..15" {
t.Errorf("failed to create correct key, got %s", x)
}
if x := rawKey("miek.nl.", dns.TypeMX, true); x != "1miek.nl..15" {
t.Errorf("failed to create correct key, got %s", x)
}
// rawKey does not lowercase.
if x := rawKey("miEK.nL.", dns.TypeMX, true); x != "1miEK.nL..15" {
t.Errorf("failed to create correct key, got %s", x)
}
}

54
middleware/cache/prefech_test.go vendored Normal file
View file

@ -0,0 +1,54 @@
package cache
import (
"fmt"
"testing"
"time"
"github.com/coredns/coredns/middleware"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/coredns/coredns/middleware/pkg/dnsrecorder"
"github.com/coredns/coredns/middleware/test"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
var p = false
func TestPrefetch(t *testing.T) {
c := &Cache{Zones: []string{"."}, pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxTTL}
c.pcache = cache.New(c.pcap)
c.ncache = cache.New(c.ncap)
c.prefetch = 1
c.duration = 1 * time.Second
c.Next = PrefetchHandler(t, dns.RcodeSuccess, nil)
ctx := context.TODO()
req := new(dns.Msg)
req.SetQuestion("lowttl.example.org.", dns.TypeA)
rec := dnsrecorder.New(&test.ResponseWriter{})
c.ServeDNS(ctx, rec, req)
p = true // prefetch should be true for the 2nd fetch
c.ServeDNS(ctx, rec, req)
}
func PrefetchHandler(t *testing.T, rcode int, err error) middleware.Handler {
return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
m := new(dns.Msg)
m.SetQuestion("lowttl.example.org.", dns.TypeA)
m.Response = true
m.RecursionAvailable = true
m.Answer = append(m.Answer, test.A("lowttl.example.org. 80 IN A 127.0.0.53"))
if p != w.(*ResponseWriter).prefetch {
err = fmt.Errorf("cache prefetch not equal to p: got %t, want %t", p, w.(*ResponseWriter).prefetch)
t.Fatal(err)
}
w.WriteMsg(m)
return rcode, err
})
}

View file

@ -7,8 +7,8 @@ import (
"github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/hashicorp/golang-lru"
"github.com/mholt/caddy" "github.com/mholt/caddy"
) )
@ -38,7 +38,7 @@ func setup(c *caddy.Controller) error {
func cacheParse(c *caddy.Controller) (*Cache, error) { func cacheParse(c *caddy.Controller) (*Cache, error) {
ca := &Cache{pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxNTTL} ca := &Cache{pcap: defaultCap, ncap: defaultCap, pttl: maxTTL, nttl: maxNTTL, prefetch: 0, duration: 1 * time.Minute}
for c.Next() { for c.Next() {
// cache [ttl] [zones..] // cache [ttl] [zones..]
@ -109,6 +109,46 @@ func cacheParse(c *caddy.Controller) (*Cache, error) {
} }
ca.nttl = time.Duration(nttl) * time.Second ca.nttl = time.Duration(nttl) * time.Second
} }
case "prefetch":
args := c.RemainingArgs()
if len(args) == 0 || len(args) > 3 {
return nil, c.ArgErr()
}
amount, err := strconv.Atoi(args[0])
if err != nil {
return nil, err
}
if amount < 0 {
return nil, fmt.Errorf("prefetch amount should be positive: %d", amount)
}
ca.prefetch = amount
ca.duration = 1 * time.Minute
ca.percentage = 10
if len(args) > 1 {
dur, err := time.ParseDuration(args[1])
if err != nil {
return nil, err
}
ca.duration = dur
}
if len(args) > 2 {
pct := args[2]
if x := pct[len(pct)-1]; x != '%' {
return nil, fmt.Errorf("last character of percentage should be `%%`, but is: %q", x)
}
pct = pct[:len(pct)-1]
num, err := strconv.Atoi(pct)
if err != nil {
return nil, err
}
if num < 10 || num > 90 {
return nil, fmt.Errorf("percentage should fall in range [10, 90]: %d", num)
}
ca.percentage = num
}
default: default:
return nil, c.ArgErr() return nil, c.ArgErr()
} }
@ -118,17 +158,10 @@ func cacheParse(c *caddy.Controller) (*Cache, error) {
origins[i] = middleware.Host(origins[i]).Normalize() origins[i] = middleware.Host(origins[i]).Normalize()
} }
var err error
ca.Zones = origins ca.Zones = origins
ca.pcache, err = lru.New(ca.pcap) ca.pcache = cache.New(ca.pcap)
if err != nil { ca.ncache = cache.New(ca.ncap)
return nil, err
}
ca.ncache, err = lru.New(ca.ncap)
if err != nil {
return nil, err
}
return ca, nil return ca, nil
} }

View file

@ -9,46 +9,57 @@ import (
func TestSetup(t *testing.T) { func TestSetup(t *testing.T) {
tests := []struct { tests := []struct {
input string input string
shouldErr bool shouldErr bool
expectedNcap int expectedNcap int
expectedPcap int expectedPcap int
expectedNttl time.Duration expectedNttl time.Duration
expectedPttl time.Duration expectedPttl time.Duration
expectedPrefetch int
}{ }{
{`cache`, false, defaultCap, defaultCap, maxNTTL, maxTTL}, {`cache`, false, defaultCap, defaultCap, maxNTTL, maxTTL, 0},
{`cache {}`, false, defaultCap, defaultCap, maxNTTL, maxTTL}, {`cache {}`, false, defaultCap, defaultCap, maxNTTL, maxTTL, 0},
{`cache example.nl { {`cache example.nl {
success 10 success 10
}`, false, defaultCap, 10, maxNTTL, maxTTL}, }`, false, defaultCap, 10, maxNTTL, maxTTL, 0},
{`cache example.nl { {`cache example.nl {
success 10 success 10
denial 10 15 denial 10 15
}`, false, 10, 10, 15 * time.Second, maxTTL}, }`, false, 10, 10, 15 * time.Second, maxTTL, 0},
{`cache 25 example.nl { {`cache 25 example.nl {
success 10 success 10
denial 10 15 denial 10 15
}`, false, 10, 10, 15 * time.Second, 25 * time.Second}, }`, false, 10, 10, 15 * time.Second, 25 * time.Second, 0},
{`cache aaa example.nl`, false, defaultCap, defaultCap, maxNTTL, maxTTL}, {`cache aaa example.nl`, false, defaultCap, defaultCap, maxNTTL, maxTTL, 0},
{`cache {
prefetch 10
}`, false, defaultCap, defaultCap, maxNTTL, maxTTL, 10},
// fails // fails
{`cache example.nl { {`cache example.nl {
success success
denial 10 15 denial 10 15
}`, true, defaultCap, defaultCap, maxTTL, maxTTL}, }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0},
{`cache example.nl { {`cache example.nl {
success 15 success 15
denial aaa denial aaa
}`, true, defaultCap, defaultCap, maxTTL, maxTTL}, }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0},
{`cache example.nl { {`cache example.nl {
positive 15 positive 15
negative aaa negative aaa
}`, true, defaultCap, defaultCap, maxTTL, maxTTL}, }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0},
{`cache 0 example.nl`, true, defaultCap, defaultCap, maxTTL, maxTTL}, {`cache 0 example.nl`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0},
{`cache -1 example.nl`, true, defaultCap, defaultCap, maxTTL, maxTTL}, {`cache -1 example.nl`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0},
{`cache 1 example.nl { {`cache 1 example.nl {
positive 0 positive 0
}`, true, defaultCap, defaultCap, maxTTL, maxTTL}, }`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0},
{`cache 1 example.nl {
positive 0
prefetch -1
}`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0},
{`cache 1 example.nl {
prefetch 0 blurp
}`, true, defaultCap, defaultCap, maxTTL, maxTTL, 0},
} }
for i, test := range tests { for i, test := range tests {
c := caddy.NewTestController("dns", test.input) c := caddy.NewTestController("dns", test.input)
@ -76,5 +87,8 @@ func TestSetup(t *testing.T) {
if ca.pttl != test.expectedPttl { if ca.pttl != test.expectedPttl {
t.Errorf("Test %v: Expected pttl %v but found: %v", i, test.expectedPttl, ca.pttl) t.Errorf("Test %v: Expected pttl %v but found: %v", i, test.expectedPttl, ca.pttl)
} }
if ca.prefetch != test.expectedPrefetch {
t.Errorf("Test %v: Expected prefetch %v but found: %v", i, test.expectedPrefetch, ca.prefetch)
}
} }
} }

View file

@ -2,14 +2,13 @@ package dnssec
import ( import (
"hash/fnv" "hash/fnv"
"strconv"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// Key serializes the RRset and return a signature cache key. // hash serializes the RRset and return a signature cache key.
func key(rrs []dns.RR) string { func hash(rrs []dns.RR) uint32 {
h := fnv.New64() h := fnv.New32()
buf := make([]byte, 256) buf := make([]byte, 256)
for _, r := range rrs { for _, r := range rrs {
off, err := dns.PackRR(r, buf, 0, nil, false) off, err := dns.PackRR(r, buf, 0, nil, false)
@ -18,6 +17,6 @@ func key(rrs []dns.RR) string {
} }
} }
i := h.Sum64() i := h.Sum32()
return strconv.FormatUint(i, 10) return i
} }

View file

@ -4,10 +4,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/coredns/coredns/middleware/test" "github.com/coredns/coredns/middleware/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/hashicorp/golang-lru"
) )
func TestCacheSet(t *testing.T) { func TestCacheSet(t *testing.T) {
@ -21,11 +20,11 @@ func TestCacheSet(t *testing.T) {
t.Fatalf("failed to parse key: %v\n", err) t.Fatalf("failed to parse key: %v\n", err)
} }
cache, _ := lru.New(defaultCap) c := cache.New(defaultCap)
m := testMsg() m := testMsg()
state := request.Request{Req: m} state := request.Request{Req: m}
k := key(m.Answer) // calculate *before* we add the sig k := hash(m.Answer) // calculate *before* we add the sig
d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil, cache) d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, nil, c)
m = d.Sign(state, "miek.nl.", time.Now().UTC()) m = d.Sign(state, "miek.nl.", time.Now().UTC())
_, ok := d.get(k) _, ok := d.get(k)

View file

@ -6,11 +6,11 @@ import (
"time" "time"
"github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/coredns/coredns/middleware/pkg/response" "github.com/coredns/coredns/middleware/pkg/response"
"github.com/coredns/coredns/middleware/pkg/singleflight" "github.com/coredns/coredns/middleware/pkg/singleflight"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/hashicorp/golang-lru"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -21,15 +21,15 @@ type Dnssec struct {
zones []string zones []string
keys []*DNSKEY keys []*DNSKEY
inflight *singleflight.Group inflight *singleflight.Group
cache *lru.Cache cache *cache.Cache
} }
// New returns a new Dnssec. // New returns a new Dnssec.
func New(zones []string, keys []*DNSKEY, next middleware.Handler, cache *lru.Cache) Dnssec { func New(zones []string, keys []*DNSKEY, next middleware.Handler, c *cache.Cache) Dnssec {
return Dnssec{Next: next, return Dnssec{Next: next,
zones: zones, zones: zones,
keys: keys, keys: keys,
cache: cache, cache: c,
inflight: new(singleflight.Group), inflight: new(singleflight.Group),
} }
} }
@ -90,7 +90,7 @@ func (d Dnssec) Sign(state request.Request, zone string, now time.Time) *dns.Msg
} }
func (d Dnssec) sign(rrs []dns.RR, signerName string, ttl, incep, expir uint32) ([]dns.RR, error) { func (d Dnssec) sign(rrs []dns.RR, signerName string, ttl, incep, expir uint32) ([]dns.RR, error) {
k := key(rrs) k := hash(rrs)
sgs, ok := d.get(k) sgs, ok := d.get(k)
if ok { if ok {
return sgs, nil return sgs, nil
@ -110,11 +110,11 @@ func (d Dnssec) sign(rrs []dns.RR, signerName string, ttl, incep, expir uint32)
return sigs.([]dns.RR), err return sigs.([]dns.RR), err
} }
func (d Dnssec) set(key string, sigs []dns.RR) { func (d Dnssec) set(key uint32, sigs []dns.RR) {
d.cache.Add(key, sigs) d.cache.Add(key, sigs)
} }
func (d Dnssec) get(key string) ([]dns.RR, bool) { func (d Dnssec) get(key uint32) ([]dns.RR, bool) {
if s, ok := d.cache.Get(key); ok { if s, ok := d.cache.Get(key); ok {
cacheHits.Inc() cacheHits.Inc()
return s.([]dns.RR), true return s.([]dns.RR), true

View file

@ -4,10 +4,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/coredns/coredns/middleware/test" "github.com/coredns/coredns/middleware/test"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/hashicorp/golang-lru"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -69,8 +69,8 @@ func TestSigningDifferentZone(t *testing.T) {
m := testMsgEx() m := testMsgEx()
state := request.Request{Req: m} state := request.Request{Req: m}
cache, _ := lru.New(defaultCap) c := cache.New(defaultCap)
d := New([]string{"example.org."}, []*DNSKEY{key}, nil, cache) d := New([]string{"example.org."}, []*DNSKEY{key}, nil, c)
m = d.Sign(state, "example.org.", time.Now().UTC()) m = d.Sign(state, "example.org.", time.Now().UTC())
if !section(m.Answer, 1) { if !section(m.Answer, 1) {
t.Errorf("answer section should have 1 sig") t.Errorf("answer section should have 1 sig")
@ -183,8 +183,8 @@ func testMsgDname() *dns.Msg {
func newDnssec(t *testing.T, zones []string) (Dnssec, func(), func()) { func newDnssec(t *testing.T, zones []string) (Dnssec, func(), func()) {
k, rm1, rm2 := newKey(t) k, rm1, rm2 := newKey(t)
cache, _ := lru.New(defaultCap) c := cache.New(defaultCap)
d := New(zones, []*DNSKEY{k}, nil, cache) d := New(zones, []*DNSKEY{k}, nil, c)
return d, rm1, rm2 return d, rm1, rm2
} }

View file

@ -6,10 +6,10 @@ import (
"testing" "testing"
"github.com/coredns/coredns/middleware/file" "github.com/coredns/coredns/middleware/file"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/coredns/coredns/middleware/pkg/dnsrecorder" "github.com/coredns/coredns/middleware/pkg/dnsrecorder"
"github.com/coredns/coredns/middleware/test" "github.com/coredns/coredns/middleware/test"
"github.com/hashicorp/golang-lru"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
@ -89,8 +89,8 @@ func TestLookupZone(t *testing.T) {
dnskey, rm1, rm2 := newKey(t) dnskey, rm1, rm2 := newKey(t)
defer rm1() defer rm1()
defer rm2() defer rm2()
cache, _ := lru.New(defaultCap) c := cache.New(defaultCap)
dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, fm, cache) dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, fm, c)
ctx := context.TODO() ctx := context.TODO()
for _, tc := range dnsTestCases { for _, tc := range dnsTestCases {
@ -128,8 +128,8 @@ func TestLookupDNSKEY(t *testing.T) {
dnskey, rm1, rm2 := newKey(t) dnskey, rm1, rm2 := newKey(t)
defer rm1() defer rm1()
defer rm2() defer rm2()
cache, _ := lru.New(defaultCap) c := cache.New(defaultCap)
dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, test.ErrorHandler(), cache) dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, test.ErrorHandler(), c)
ctx := context.TODO() ctx := context.TODO()
for _, tc := range dnssecTestCases { for _, tc := range dnssecTestCases {

View file

@ -6,8 +6,8 @@ import (
"github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/hashicorp/golang-lru"
"github.com/mholt/caddy" "github.com/mholt/caddy"
) )
@ -24,12 +24,9 @@ func setup(c *caddy.Controller) error {
return middleware.Error("dnssec", err) return middleware.Error("dnssec", err)
} }
cache, err := lru.New(capacity) ca := cache.New(capacity)
if err != nil {
return err
}
dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler { dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler {
return New(zones, keys, next, cache) return New(zones, keys, next, ca)
}) })
// Export the capacity for the metrics. This only happens once, because this is a re-load change only. // Export the capacity for the metrics. This only happens once, because this is a re-load change only.

View file

@ -9,6 +9,7 @@ import (
"github.com/coredns/coredns/middleware" "github.com/coredns/coredns/middleware"
"github.com/coredns/coredns/middleware/etcd/msg" "github.com/coredns/coredns/middleware/etcd/msg"
"github.com/coredns/coredns/middleware/pkg/cache"
"github.com/coredns/coredns/middleware/pkg/singleflight" "github.com/coredns/coredns/middleware/pkg/singleflight"
"github.com/coredns/coredns/middleware/proxy" "github.com/coredns/coredns/middleware/proxy"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
@ -90,7 +91,10 @@ func (e *Etcd) Records(name string, exact bool) ([]msg.Service, error) {
// get is a wrapper for client.Get that uses SingleInflight to suppress multiple outstanding queries. // get is a wrapper for client.Get that uses SingleInflight to suppress multiple outstanding queries.
func (e *Etcd) get(path string, recursive bool) (*etcdc.Response, error) { func (e *Etcd) get(path string, recursive bool) (*etcdc.Response, error) {
resp, err := e.Inflight.Do(path, func() (interface{}, error) {
hash := cache.Hash([]byte(path))
resp, err := e.Inflight.Do(hash, func() (interface{}, error) {
ctx, cancel := context.WithTimeout(e.Ctx, etcdTimeout) ctx, cancel := context.WithTimeout(e.Ctx, etcdTimeout)
defer cancel() defer cancel()
r, e := e.Client.Get(ctx, path, &etcdc.GetOptions{Sort: false, Recursive: recursive}) r, e := e.Client.Get(ctx, path, &etcdc.GetOptions{Sort: false, Recursive: recursive})

129
middleware/pkg/cache/cache.go vendored Normal file
View file

@ -0,0 +1,129 @@
// Package cache implements a cache. The cache hold 256 shards, each shard
// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it
// just randomly evicts elements when it gets full.
package cache
import (
"hash/fnv"
"sync"
)
// Hash returns the FNV hash of what.
func Hash(what []byte) uint32 {
h := fnv.New32()
h.Write(what)
return h.Sum32()
}
// Cache is cache.
type Cache struct {
shards [shardSize]*shard
}
// shard is a cache with random eviction.
type shard struct {
items map[uint32]interface{}
size int
sync.RWMutex
}
// New returns a new cache.
func New(size int) *Cache {
ssize := size / shardSize
if ssize < 512 {
ssize = 512
}
c := &Cache{}
// Initialize all the shards
for i := 0; i < shardSize; i++ {
c.shards[i] = newShard(ssize)
}
return c
}
// Add adds a new element to the cache. If the element already exists it is overwritten.
func (c *Cache) Add(key uint32, el interface{}) {
shard := key & (shardSize - 1)
c.shards[shard].Add(key, el)
}
// Get looks up element index under key.
func (c *Cache) Get(key uint32) (interface{}, bool) {
shard := key & (shardSize - 1)
return c.shards[shard].Get(key)
}
// Remove removes the element indexed with key.
func (c *Cache) Remove(key uint32) {
shard := key & (shardSize - 1)
c.shards[shard].Remove(key)
}
// Len returns the number of elements in the cache.
func (c *Cache) Len() int {
l := 0
for _, s := range c.shards {
l += s.Len()
}
return l
}
// newShard returns a new shard with size.
func newShard(size int) *shard { return &shard{items: make(map[uint32]interface{}), size: size} }
// Add adds element indexed by key into the cache. Any existing element is overwritten
func (s *shard) Add(key uint32, el interface{}) {
l := s.Len()
if l+1 > s.size {
s.Evict()
}
s.Lock()
s.items[key] = el
s.Unlock()
}
// Remove removes the element indexed by key from the cache.
func (s *shard) Remove(key uint32) {
s.Lock()
delete(s.items, key)
s.Unlock()
}
// Evict removes a random element from the cache.
func (s *shard) Evict() {
s.Lock()
defer s.Unlock()
key := -1
for k := range s.items {
key = int(k)
break
}
if key == -1 {
// empty cache
return
}
delete(s.items, uint32(key))
}
// Get looks up the element indexed under key.
func (s *shard) Get(key uint32) (interface{}, bool) {
s.RLock()
el, found := s.items[key]
s.RUnlock()
return el, found
}
// Len returns the current length of the cache.
func (s *shard) Len() int {
s.RLock()
l := len(s.items)
s.RUnlock()
return l
}
const shardSize = 256

31
middleware/pkg/cache/cache_test.go vendored Normal file
View file

@ -0,0 +1,31 @@
package cache
import "testing"
func TestCacheAddAndGet(t *testing.T) {
c := New(4)
c.Add(1, 1)
if _, found := c.Get(1); !found {
t.Fatal("Failed to find inserted record")
}
}
func TestCacheLen(t *testing.T) {
c := New(4)
c.Add(1, 1)
if l := c.Len(); l != 1 {
t.Fatalf("Cache size should %d, got %d", 1, l)
}
c.Add(1, 1)
if l := c.Len(); l != 1 {
t.Fatalf("Cache size should %d, got %d", 1, l)
}
c.Add(2, 2)
if l := c.Len(); l != 2 {
t.Fatalf("Cache size should %d, got %d", 2, l)
}
}

60
middleware/pkg/cache/shard_test.go vendored Normal file
View file

@ -0,0 +1,60 @@
package cache
import "testing"
func TestShardAddAndGet(t *testing.T) {
s := newShard(4)
s.Add(1, 1)
if _, found := s.Get(1); !found {
t.Fatal("Failed to find inserted record")
}
}
func TestShardLen(t *testing.T) {
s := newShard(4)
s.Add(1, 1)
if l := s.Len(); l != 1 {
t.Fatalf("Shard size should %d, got %d", 1, l)
}
s.Add(1, 1)
if l := s.Len(); l != 1 {
t.Fatalf("Shard size should %d, got %d", 1, l)
}
s.Add(2, 2)
if l := s.Len(); l != 2 {
t.Fatalf("Shard size should %d, got %d", 2, l)
}
}
func TestShardEvict(t *testing.T) {
s := newShard(1)
s.Add(1, 1)
s.Add(2, 2)
// 1 should be gone
if _, found := s.Get(1); found {
t.Fatal("Found item that should have been evicted")
}
}
func TestShardLenEvict(t *testing.T) {
s := newShard(4)
s.Add(1, 1)
s.Add(2, 1)
s.Add(3, 1)
s.Add(4, 1)
if l := s.Len(); l != 4 {
t.Fatalf("Shard size should %d, got %d", 4, l)
}
// This should evict one element
s.Add(5, 1)
if l := s.Len(); l != 4 {
t.Fatalf("Shard size should %d, got %d", 4, l)
}
}

View file

@ -31,17 +31,17 @@ type call struct {
// units of work can be executed with duplicate suppression. // units of work can be executed with duplicate suppression.
type Group struct { type Group struct {
mu sync.Mutex // protects m mu sync.Mutex // protects m
m map[string]*call // lazily initialized m map[uint32]*call // lazily initialized
} }
// Do executes and returns the results of the given function, making // Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a // sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the // time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results. // original to complete and receives the same results.
func (g *Group) Do(key string, fn func() (interface{}, error)) (interface{}, error) { func (g *Group) Do(key uint32, fn func() (interface{}, error)) (interface{}, error) {
g.mu.Lock() g.mu.Lock()
if g.m == nil { if g.m == nil {
g.m = make(map[string]*call) g.m = make(map[uint32]*call)
} }
if c, ok := g.m[key]; ok { if c, ok := g.m[key]; ok {
g.mu.Unlock() g.mu.Unlock()

View file

@ -27,7 +27,7 @@ import (
func TestDo(t *testing.T) { func TestDo(t *testing.T) {
var g Group var g Group
v, err := g.Do("key", func() (interface{}, error) { v, err := g.Do(1, func() (interface{}, error) {
return "bar", nil return "bar", nil
}) })
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
@ -41,7 +41,7 @@ func TestDo(t *testing.T) {
func TestDoErr(t *testing.T) { func TestDoErr(t *testing.T) {
var g Group var g Group
someErr := errors.New("Some error") someErr := errors.New("Some error")
v, err := g.Do("key", func() (interface{}, error) { v, err := g.Do(1, func() (interface{}, error) {
return nil, someErr return nil, someErr
}) })
if err != someErr { if err != someErr {
@ -66,7 +66,7 @@ func TestDoDupSuppress(t *testing.T) {
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
v, err := g.Do("key", fn) v, err := g.Do(1, fn)
if err != nil { if err != nil {
t.Errorf("Do error: %v", err) t.Errorf("Do error: %v", err)
} }