lib/cache: fix locking so we don't try to create the item many times

Before this fix, if several Get requests were submitted very quickly,
this could run the item create function multiple times due to the
unlock of the mutex in the creation code.

This fixes the problem by having a mutex in each cache entry which is
held when the item is being created.
This commit is contained in:
Nick Craig-Wood 2020-12-27 17:28:35 +00:00
parent b2b5b7598c
commit bfcf6baf93
2 changed files with 47 additions and 18 deletions

35
lib/cache/cache.go vendored
View file

@ -29,8 +29,10 @@ func New() *Cache {
// cacheEntry is stored in the cache // cacheEntry is stored in the cache
type cacheEntry struct { type cacheEntry struct {
createMu sync.Mutex // held while creating the item
value interface{} // cached item value interface{} // cached item
err error // creation error err error // creation error
ok bool // true if entry is valid
key string // key key string // key
lastUsed time.Time // time used for expiry lastUsed time.Time // time used for expiry
pinCount int // non zero if the entry should not be removed pinCount int // non zero if the entry should not be removed
@ -55,23 +57,27 @@ func (c *Cache) used(entry *cacheEntry) {
// afresh with the create function. // afresh with the create function.
func (c *Cache) Get(key string, create CreateFunc) (value interface{}, err error) { func (c *Cache) Get(key string, create CreateFunc) (value interface{}, err error) {
c.mu.Lock() c.mu.Lock()
entry, ok := c.cache[key] entry, found := c.cache[key]
if !ok { if !found {
c.mu.Unlock() // Unlock in case Get is called recursively
value, ok, err = create(key)
if err != nil && !ok {
return value, err
}
entry = &cacheEntry{ entry = &cacheEntry{
value: value, key: key,
key: key,
err: err,
} }
c.mu.Lock()
c.cache[key] = entry c.cache[key] = entry
} }
defer c.mu.Unlock() c.mu.Unlock()
c.used(entry) // Only one racing Get will have found=false here
entry.createMu.Lock()
if !found {
entry.value, entry.ok, entry.err = create(key)
}
entry.createMu.Unlock()
c.mu.Lock()
if !found && !entry.ok {
delete(c.cache, key)
} else {
c.used(entry)
}
c.mu.Unlock()
return entry.value, entry.err return entry.value, entry.err
} }
@ -102,6 +108,7 @@ func (c *Cache) Put(key string, value interface{}) {
entry := &cacheEntry{ entry := &cacheEntry{
value: value, value: value,
key: key, key: key,
ok: true,
} }
c.used(entry) c.used(entry)
c.cache[key] = entry c.cache[key] = entry
@ -112,7 +119,7 @@ func (c *Cache) GetMaybe(key string) (value interface{}, found bool) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
entry, found := c.cache[key] entry, found := c.cache[key]
if !found { if !found || !entry.ok {
return nil, found return nil, found
} }
c.used(entry) c.used(entry)

View file

@ -3,6 +3,8 @@ package cache
import ( import (
"errors" "errors"
"fmt" "fmt"
"sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -11,7 +13,7 @@ import (
) )
var ( var (
called = 0 called = int32(0)
errSentinel = errors.New("an error") errSentinel = errors.New("an error")
errCached = errors.New("a cached error") errCached = errors.New("a cached error")
) )
@ -19,17 +21,19 @@ var (
func setup(t *testing.T) (*Cache, CreateFunc) { func setup(t *testing.T) (*Cache, CreateFunc) {
called = 0 called = 0
create := func(path string) (interface{}, bool, error) { create := func(path string) (interface{}, bool, error) {
assert.Equal(t, 0, called) newCalled := atomic.AddInt32(&called, 1)
called++ assert.Equal(t, int32(1), newCalled)
switch path { switch path {
case "/": case "/":
time.Sleep(100 * time.Millisecond)
return "/", true, nil return "/", true, nil
case "/file.txt": case "/file.txt":
return "/file.txt", true, errCached return "/file.txt", true, errCached
case "/error": case "/error":
return nil, false, errSentinel return nil, false, errSentinel
} }
panic(fmt.Sprintf("Unknown path %q", path)) assert.Fail(t, fmt.Sprintf("Unknown path %q", path))
return nil, false, nil
} }
c := New() c := New()
return c, create return c, create
@ -51,6 +55,24 @@ func TestGet(t *testing.T) {
assert.Equal(t, f, f2) assert.Equal(t, f, f2)
} }
func TestGetConcurrent(t *testing.T) {
c, create := setup(t)
assert.Equal(t, 0, len(c.cache))
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := c.Get("/", create)
require.NoError(t, err)
}()
}
wg.Wait()
assert.Equal(t, 1, len(c.cache))
}
func TestGetFile(t *testing.T) { func TestGetFile(t *testing.T) {
c, create := setup(t) c, create := setup(t)