[#231] node: Fix race condition in TTL cache

Use key locker to lock by key.

Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
This commit is contained in:
Dmitrii Stepanov 2023-04-21 13:49:05 +03:00
parent ddbc9e255f
commit 04be9415d9
2 changed files with 112 additions and 54 deletions

View file

@ -24,25 +24,72 @@ type valueWithTime[V any] struct {
e error e error
} }
// valueInProgress is a struct that contains type locker struct {
// values that are being fetched/updated. mtx *sync.Mutex
type valueInProgress[V any] struct { waiters int // not protected by mtx, must used outer mutex to update concurrently
m sync.RWMutex }
v V
e error type keyLocker[K comparable] struct {
lockers map[K]*locker
lockersMtx *sync.Mutex
}
func newKeyLocker[K comparable]() *keyLocker[K] {
return &keyLocker[K]{
lockers: make(map[K]*locker),
lockersMtx: &sync.Mutex{},
}
}
func (l *keyLocker[K]) LockKey(key K) {
l.lockersMtx.Lock()
if locker, found := l.lockers[key]; found {
locker.waiters++
l.lockersMtx.Unlock()
locker.mtx.Lock()
return
}
locker := &locker{
mtx: &sync.Mutex{},
waiters: 1,
}
locker.mtx.Lock()
l.lockers[key] = locker
l.lockersMtx.Unlock()
}
func (l *keyLocker[K]) UnlockKey(key K) {
l.lockersMtx.Lock()
defer l.lockersMtx.Unlock()
locker, found := l.lockers[key]
if !found {
return
}
if locker.waiters == 1 {
delete(l.lockers, key)
}
locker.waiters--
locker.mtx.Unlock()
} }
// entity that provides TTL cache interface. // entity that provides TTL cache interface.
type ttlNetCache[K comparable, V any] struct { type ttlNetCache[K comparable, V any] struct {
m sync.RWMutex // protects progMap ttl time.Duration
progMap map[K]*valueInProgress[V] // contains fetch-in-progress keys
ttl time.Duration
sz int sz int
cache *lru.Cache[K, *valueWithTime[V]] cache *lru.Cache[K, *valueWithTime[V]]
netRdr netValueReader[K, V] netRdr netValueReader[K, V]
keyLocker *keyLocker[K]
} }
// complicates netValueReader with TTL caching mechanism. // complicates netValueReader with TTL caching mechanism.
@ -51,72 +98,48 @@ func newNetworkTTLCache[K comparable, V any](sz int, ttl time.Duration, netRdr n
fatalOnErr(err) fatalOnErr(err)
return &ttlNetCache[K, V]{ return &ttlNetCache[K, V]{
ttl: ttl, ttl: ttl,
sz: sz, sz: sz,
cache: cache, cache: cache,
netRdr: netRdr, netRdr: netRdr,
progMap: make(map[K]*valueInProgress[V]), keyLocker: newKeyLocker[K](),
} }
} }
func waitForUpdate[V any](vip *valueInProgress[V]) (V, error) {
vip.m.RLock()
defer vip.m.RUnlock()
return vip.v, vip.e
}
// reads value by the key. // reads value by the key.
// //
// updates the value from the network on cache miss or by TTL. // updates the value from the network on cache miss or by TTL.
// //
// returned value should not be modified. // returned value should not be modified.
func (c *ttlNetCache[K, V]) get(key K) (V, error) { func (c *ttlNetCache[K, V]) get(key K) (V, error) {
valWithTime, ok := c.cache.Peek(key) val, ok := c.cache.Peek(key)
if ok { if ok && time.Since(val.t) < c.ttl {
if time.Since(valWithTime.t) < c.ttl { return val.v, val.e
return valWithTime.v, valWithTime.e
}
c.cache.Remove(key)
} }
c.m.RLock() c.keyLocker.LockKey(key)
valInProg, ok := c.progMap[key] defer c.keyLocker.UnlockKey(key)
c.m.RUnlock()
if ok { val, ok = c.cache.Peek(key)
return waitForUpdate(valInProg) if ok && time.Since(val.t) < c.ttl {
return val.v, val.e
} }
c.m.Lock()
valInProg, ok = c.progMap[key]
if ok {
c.m.Unlock()
return waitForUpdate(valInProg)
}
valInProg = &valueInProgress[V]{}
valInProg.m.Lock()
c.progMap[key] = valInProg
c.m.Unlock()
v, err := c.netRdr(key) v, err := c.netRdr(key)
c.set(key, v, err)
valInProg.v = v c.cache.Add(key, &valueWithTime[V]{
valInProg.e = err v: v,
valInProg.m.Unlock() t: time.Now(),
e: err,
c.m.Lock() })
delete(c.progMap, key)
c.m.Unlock()
return v, err return v, err
} }
func (c *ttlNetCache[K, V]) set(k K, v V, e error) { func (c *ttlNetCache[K, V]) set(k K, v V, e error) {
c.keyLocker.LockKey(k)
defer c.keyLocker.UnlockKey(k)
c.cache.Add(k, &valueWithTime[V]{ c.cache.Add(k, &valueWithTime[V]{
v: v, v: v,
t: time.Now(), t: time.Now(),
@ -125,6 +148,9 @@ func (c *ttlNetCache[K, V]) set(k K, v V, e error) {
} }
func (c *ttlNetCache[K, V]) remove(key K) { func (c *ttlNetCache[K, V]) remove(key K) {
c.keyLocker.LockKey(key)
defer c.keyLocker.UnlockKey(key)
c.cache.Remove(key) c.cache.Remove(key)
} }

View file

@ -0,0 +1,32 @@
package main
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
func TestKeyLocker(t *testing.T) {
taken := false
eg, _ := errgroup.WithContext(context.Background())
keyLocker := newKeyLocker[int]()
for i := 0; i < 100; i++ {
eg.Go(func() error {
keyLocker.LockKey(0)
defer keyLocker.UnlockKey(0)
require.False(t, taken)
taken = true
require.True(t, taken)
time.Sleep(10 * time.Millisecond)
taken = false
require.False(t, taken)
return nil
})
}
require.NoError(t, eg.Wait())
}