lib/cache: fix locking so we don't try to create the item many times

Before this fix, if several Get requests were submitted very quickly,
this could run the item create function multiple times due to the
unlock of the mutex in the creation code.

This fixes the problem by having a mutex in each cache entry which is
held when the item is being created.
This commit is contained in:
Nick Craig-Wood 2020-12-27 17:28:35 +00:00
parent b2b5b7598c
commit bfcf6baf93
2 changed files with 47 additions and 18 deletions

35
lib/cache/cache.go vendored
View file

@ -29,8 +29,10 @@ func New() *Cache {
// cacheEntry is stored in the cache
type cacheEntry struct {
createMu sync.Mutex // held while creating the item
value interface{} // cached item
err error // creation error
ok bool // true if entry is valid
key string // key
lastUsed time.Time // time used for expiry
pinCount int // non zero if the entry should not be removed
@ -55,23 +57,27 @@ func (c *Cache) used(entry *cacheEntry) {
// afresh with the create function.
func (c *Cache) Get(key string, create CreateFunc) (value interface{}, err error) {
c.mu.Lock()
entry, ok := c.cache[key]
if !ok {
c.mu.Unlock() // Unlock in case Get is called recursively
value, ok, err = create(key)
if err != nil && !ok {
return value, err
}
entry, found := c.cache[key]
if !found {
entry = &cacheEntry{
value: value,
key: key,
err: err,
key: key,
}
c.mu.Lock()
c.cache[key] = entry
}
defer c.mu.Unlock()
c.used(entry)
c.mu.Unlock()
// Only one racing Get will have found=false here
entry.createMu.Lock()
if !found {
entry.value, entry.ok, entry.err = create(key)
}
entry.createMu.Unlock()
c.mu.Lock()
if !found && !entry.ok {
delete(c.cache, key)
} else {
c.used(entry)
}
c.mu.Unlock()
return entry.value, entry.err
}
@ -102,6 +108,7 @@ func (c *Cache) Put(key string, value interface{}) {
entry := &cacheEntry{
value: value,
key: key,
ok: true,
}
c.used(entry)
c.cache[key] = entry
@ -112,7 +119,7 @@ func (c *Cache) GetMaybe(key string) (value interface{}, found bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, found := c.cache[key]
if !found {
if !found || !entry.ok {
return nil, found
}
c.used(entry)

View file

@ -3,6 +3,8 @@ package cache
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@ -11,7 +13,7 @@ import (
)
var (
called = 0
called = int32(0)
errSentinel = errors.New("an error")
errCached = errors.New("a cached error")
)
@ -19,17 +21,19 @@ var (
func setup(t *testing.T) (*Cache, CreateFunc) {
called = 0
create := func(path string) (interface{}, bool, error) {
assert.Equal(t, 0, called)
called++
newCalled := atomic.AddInt32(&called, 1)
assert.Equal(t, int32(1), newCalled)
switch path {
case "/":
time.Sleep(100 * time.Millisecond)
return "/", true, nil
case "/file.txt":
return "/file.txt", true, errCached
case "/error":
return nil, false, errSentinel
}
panic(fmt.Sprintf("Unknown path %q", path))
assert.Fail(t, fmt.Sprintf("Unknown path %q", path))
return nil, false, nil
}
c := New()
return c, create
@ -51,6 +55,24 @@ func TestGet(t *testing.T) {
assert.Equal(t, f, f2)
}
func TestGetConcurrent(t *testing.T) {
c, create := setup(t)
assert.Equal(t, 0, len(c.cache))
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := c.Get("/", create)
require.NoError(t, err)
}()
}
wg.Wait()
assert.Equal(t, 1, len(c.cache))
}
func TestGetFile(t *testing.T) {
c, create := setup(t)