diff --git a/plugin/pkg/cache/cache.go b/plugin/pkg/cache/cache.go index 8a5ad783e..3a2c8ff7f 100644 --- a/plugin/pkg/cache/cache.go +++ b/plugin/pkg/cache/cache.go @@ -76,12 +76,15 @@ func newShard(size int) *shard { return &shard{items: make(map[uint64]interface{ // Add adds element indexed by key into the cache. Any existing element is overwritten func (s *shard) Add(key uint64, el interface{}) { - l := s.Len() - if l+1 > s.size { - s.Evict() - } - s.Lock() + if len(s.items) >= s.size { + if _, ok := s.items[key]; !ok { + for k := range s.items { + delete(s.items, k) + break + } + } + } s.items[key] = el s.Unlock() } @@ -95,24 +98,12 @@ func (s *shard) Remove(key uint64) { // Evict removes a random element from the cache. func (s *shard) Evict() { - hasKey := false - var key uint64 - - s.RLock() + s.Lock() for k := range s.items { - key = k - hasKey = true + delete(s.items, k) break } - s.RUnlock() - - if !hasKey { - // empty cache - return - } - - // If this item is gone between the RUnlock and Lock race we don't care. - s.Remove(key) + s.Unlock() } // Get looks up the element indexed under key. diff --git a/plugin/pkg/cache/cache_test.go b/plugin/pkg/cache/cache_test.go index 0c56bb9b3..2714967a6 100644 --- a/plugin/pkg/cache/cache_test.go +++ b/plugin/pkg/cache/cache_test.go @@ -3,12 +3,23 @@ package cache import "testing" func TestCacheAddAndGet(t *testing.T) { - c := New(4) + const N = shardSize * 4 + c := New(N) c.Add(1, 1) if _, found := c.Get(1); !found { t.Fatal("Failed to find inserted record") } + + for i := 0; i < N; i++ { + c.Add(uint64(i), 1) + } + for i := 0; i < N; i++ { + c.Add(uint64(i), 1) + if c.Len() != N { + t.Fatal("A item was unnecessarily evicted from the cache") + } + } } func TestCacheLen(t *testing.T) { @@ -30,6 +41,18 @@ func TestCacheLen(t *testing.T) { } } +func TestCacheSharding(t *testing.T) { + c := New(shardSize) + for i := 0; i < shardSize*2; i++ { + c.Add(uint64(i), 1) + } + for i, s := range c.shards { + if s.Len() == 0 { + t.Errorf("Failed to populate shard: %d", i) + } + } +} + func BenchmarkCache(b *testing.B) { b.ReportAllocs() diff --git a/plugin/pkg/cache/shard_test.go b/plugin/pkg/cache/shard_test.go index 26675cee1..a3831305d 100644 --- a/plugin/pkg/cache/shard_test.go +++ b/plugin/pkg/cache/shard_test.go @@ -1,14 +1,40 @@ package cache -import "testing" +import ( + "sync" + "testing" +) func TestShardAddAndGet(t *testing.T) { - s := newShard(4) + s := newShard(1) s.Add(1, 1) if _, found := s.Get(1); !found { t.Fatal("Failed to find inserted record") } + + s.Add(2, 1) + if _, found := s.Get(1); found { + t.Fatal("Failed to evict record") + } + if _, found := s.Get(2); !found { + t.Fatal("Failed to find inserted record") + } +} + +func TestAddEvict(t *testing.T) { + const size = 1024 + s := newShard(size) + + for i := uint64(0); i < size; i++ { + s.Add(i, 1) + } + for i := uint64(0); i < size; i++ { + s.Add(i, 1) + if s.Len() != size { + t.Fatal("A item was unnecessarily evicted from the cache") + } + } } func TestShardLen(t *testing.T) { @@ -57,4 +83,57 @@ func TestShardLenEvict(t *testing.T) { if l := s.Len(); l != 4 { t.Fatalf("Shard size should %d, got %d", 4, l) } + + // Make sure we don't accidentally evict an element when + // we the key is already stored. + for i := 0; i < 4; i++ { + s.Add(5, 1) + if l := s.Len(); l != 4 { + t.Fatalf("Shard size should %d, got %d", 4, l) + } + } +} + +func TestShardEvictParallel(t *testing.T) { + s := newShard(shardSize) + for i := uint64(0); i < shardSize; i++ { + s.Add(i, struct{}{}) + } + start := make(chan struct{}) + var wg sync.WaitGroup + for i := 0; i < shardSize; i++ { + wg.Add(1) + go func() { + <-start + s.Evict() + wg.Done() + }() + } + close(start) // start evicting in parallel + wg.Wait() + if s.Len() != 0 { + t.Fatalf("Failed to evict all keys in parallel: %d", s.Len()) + } +} + +func BenchmarkShard(b *testing.B) { + s := newShard(shardSize) + b.ResetTimer() + for i := 0; i < b.N; i++ { + k := uint64(i) % shardSize * 2 + s.Add(k, 1) + s.Get(k) + } +} + +func BenchmarkShardParallel(b *testing.B) { + s := newShard(shardSize) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for i := uint64(0); pb.Next(); i++ { + k := i % shardSize * 2 + s.Add(k, 1) + s.Get(k) + } + }) }