From bfcf6baf93a283be6d56b73f58e3ac67ebbd7908 Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Sun, 27 Dec 2020 17:28:35 +0000 Subject: [PATCH] lib/cache: fix locking so we don't try to create the item many times Before this fix, if several Get requests were submitted very quickly, this could run the item create function multiple times due to the unlock of the mutex in the creation code. This fixes the problem by having a mutex in each cache entry which is held when the item is being created. --- lib/cache/cache.go | 35 +++++++++++++++++++++-------------- lib/cache/cache_test.go | 30 ++++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 268ce4e88..4a4b3dfc3 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -29,8 +29,10 @@ func New() *Cache { // cacheEntry is stored in the cache type cacheEntry struct { + createMu sync.Mutex // held while creating the item value interface{} // cached item err error // creation error + ok bool // true if entry is valid key string // key lastUsed time.Time // time used for expiry pinCount int // non zero if the entry should not be removed @@ -55,23 +57,27 @@ func (c *Cache) used(entry *cacheEntry) { // afresh with the create function. func (c *Cache) Get(key string, create CreateFunc) (value interface{}, err error) { c.mu.Lock() - entry, ok := c.cache[key] - if !ok { - c.mu.Unlock() // Unlock in case Get is called recursively - value, ok, err = create(key) - if err != nil && !ok { - return value, err - } + entry, found := c.cache[key] + if !found { entry = &cacheEntry{ - value: value, - key: key, - err: err, + key: key, } - c.mu.Lock() c.cache[key] = entry } - defer c.mu.Unlock() - c.used(entry) + c.mu.Unlock() + // Only one racing Get will have found=false here + entry.createMu.Lock() + if !found { + entry.value, entry.ok, entry.err = create(key) + } + entry.createMu.Unlock() + c.mu.Lock() + if !found && !entry.ok { + delete(c.cache, key) + } else { + c.used(entry) + } + c.mu.Unlock() return entry.value, entry.err } @@ -102,6 +108,7 @@ func (c *Cache) Put(key string, value interface{}) { entry := &cacheEntry{ value: value, key: key, + ok: true, } c.used(entry) c.cache[key] = entry @@ -112,7 +119,7 @@ func (c *Cache) GetMaybe(key string) (value interface{}, found bool) { c.mu.Lock() defer c.mu.Unlock() entry, found := c.cache[key] - if !found { + if !found || !entry.ok { return nil, found } c.used(entry) diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index d2811e66d..bf9a1b1e7 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -3,6 +3,8 @@ package cache import ( "errors" "fmt" + "sync" + "sync/atomic" "testing" "time" @@ -11,7 +13,7 @@ import ( ) var ( - called = 0 + called = int32(0) errSentinel = errors.New("an error") errCached = errors.New("a cached error") ) @@ -19,17 +21,19 @@ var ( func setup(t *testing.T) (*Cache, CreateFunc) { called = 0 create := func(path string) (interface{}, bool, error) { - assert.Equal(t, 0, called) - called++ + newCalled := atomic.AddInt32(&called, 1) + assert.Equal(t, int32(1), newCalled) switch path { case "/": + time.Sleep(100 * time.Millisecond) return "/", true, nil case "/file.txt": return "/file.txt", true, errCached case "/error": return nil, false, errSentinel } - panic(fmt.Sprintf("Unknown path %q", path)) + assert.Fail(t, fmt.Sprintf("Unknown path %q", path)) + return nil, false, nil } c := New() return c, create @@ -51,6 +55,24 @@ func TestGet(t *testing.T) { assert.Equal(t, f, f2) } +func TestGetConcurrent(t *testing.T) { + c, create := setup(t) + assert.Equal(t, 0, len(c.cache)) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := c.Get("/", create) + require.NoError(t, err) + }() + } + wg.Wait() + + assert.Equal(t, 1, len(c.cache)) +} + func TestGetFile(t *testing.T) { c, create := setup(t)