diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 268ce4e88..4a4b3dfc3 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -29,8 +29,10 @@ func New() *Cache { // cacheEntry is stored in the cache type cacheEntry struct { + createMu sync.Mutex // held while creating the item value interface{} // cached item err error // creation error + ok bool // true if entry is valid key string // key lastUsed time.Time // time used for expiry pinCount int // non zero if the entry should not be removed @@ -55,23 +57,27 @@ func (c *Cache) used(entry *cacheEntry) { // afresh with the create function. func (c *Cache) Get(key string, create CreateFunc) (value interface{}, err error) { c.mu.Lock() - entry, ok := c.cache[key] - if !ok { - c.mu.Unlock() // Unlock in case Get is called recursively - value, ok, err = create(key) - if err != nil && !ok { - return value, err - } + entry, found := c.cache[key] + if !found { entry = &cacheEntry{ - value: value, - key: key, - err: err, + key: key, } - c.mu.Lock() c.cache[key] = entry } - defer c.mu.Unlock() - c.used(entry) + c.mu.Unlock() + // Only one racing Get will have found=false here + entry.createMu.Lock() + if !found { + entry.value, entry.ok, entry.err = create(key) + } + entry.createMu.Unlock() + c.mu.Lock() + if !found && !entry.ok { + delete(c.cache, key) + } else { + c.used(entry) + } + c.mu.Unlock() return entry.value, entry.err } @@ -102,6 +108,7 @@ func (c *Cache) Put(key string, value interface{}) { entry := &cacheEntry{ value: value, key: key, + ok: true, } c.used(entry) c.cache[key] = entry @@ -112,7 +119,7 @@ func (c *Cache) GetMaybe(key string) (value interface{}, found bool) { c.mu.Lock() defer c.mu.Unlock() entry, found := c.cache[key] - if !found { + if !found || !entry.ok { return nil, found } c.used(entry) diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index d2811e66d..bf9a1b1e7 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -3,6 +3,8 @@ package cache import ( "errors" "fmt" + "sync" + "sync/atomic" "testing" "time" @@ -11,7 +13,7 @@ import ( ) var ( - called = 0 + called = int32(0) errSentinel = errors.New("an error") errCached = errors.New("a cached error") ) @@ -19,17 +21,19 @@ var ( func setup(t *testing.T) (*Cache, CreateFunc) { called = 0 create := func(path string) (interface{}, bool, error) { - assert.Equal(t, 0, called) - called++ + newCalled := atomic.AddInt32(&called, 1) + assert.Equal(t, int32(1), newCalled) switch path { case "/": + time.Sleep(100 * time.Millisecond) return "/", true, nil case "/file.txt": return "/file.txt", true, errCached case "/error": return nil, false, errSentinel } - panic(fmt.Sprintf("Unknown path %q", path)) + assert.Fail(t, fmt.Sprintf("Unknown path %q", path)) + return nil, false, nil } c := New() return c, create @@ -51,6 +55,24 @@ func TestGet(t *testing.T) { assert.Equal(t, f, f2) } +func TestGetConcurrent(t *testing.T) { + c, create := setup(t) + assert.Equal(t, 0, len(c.cache)) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := c.Get("/", create) + require.NoError(t, err) + }() + } + wg.Wait() + + assert.Equal(t, 1, len(c.cache)) +} + func TestGetFile(t *testing.T) { c, create := setup(t)