forked from TrueCloudLab/frostfs-node
[#231] node: Fix race condition in TTL cache
Use key locker to lock by key. Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
This commit is contained in:
parent
ddbc9e255f
commit
04be9415d9
2 changed files with 112 additions and 54 deletions
|
@ -24,18 +24,63 @@ type valueWithTime[V any] struct {
|
||||||
e error
|
e error
|
||||||
}
|
}
|
||||||
|
|
||||||
// valueInProgress is a struct that contains
|
type locker struct {
|
||||||
// values that are being fetched/updated.
|
mtx *sync.Mutex
|
||||||
type valueInProgress[V any] struct {
|
waiters int // not protected by mtx, must used outer mutex to update concurrently
|
||||||
m sync.RWMutex
|
}
|
||||||
v V
|
|
||||||
e error
|
type keyLocker[K comparable] struct {
|
||||||
|
lockers map[K]*locker
|
||||||
|
lockersMtx *sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newKeyLocker[K comparable]() *keyLocker[K] {
|
||||||
|
return &keyLocker[K]{
|
||||||
|
lockers: make(map[K]*locker),
|
||||||
|
lockersMtx: &sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *keyLocker[K]) LockKey(key K) {
|
||||||
|
l.lockersMtx.Lock()
|
||||||
|
|
||||||
|
if locker, found := l.lockers[key]; found {
|
||||||
|
locker.waiters++
|
||||||
|
l.lockersMtx.Unlock()
|
||||||
|
|
||||||
|
locker.mtx.Lock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
locker := &locker{
|
||||||
|
mtx: &sync.Mutex{},
|
||||||
|
waiters: 1,
|
||||||
|
}
|
||||||
|
locker.mtx.Lock()
|
||||||
|
|
||||||
|
l.lockers[key] = locker
|
||||||
|
l.lockersMtx.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *keyLocker[K]) UnlockKey(key K) {
|
||||||
|
l.lockersMtx.Lock()
|
||||||
|
defer l.lockersMtx.Unlock()
|
||||||
|
|
||||||
|
locker, found := l.lockers[key]
|
||||||
|
if !found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if locker.waiters == 1 {
|
||||||
|
delete(l.lockers, key)
|
||||||
|
}
|
||||||
|
locker.waiters--
|
||||||
|
|
||||||
|
locker.mtx.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// entity that provides TTL cache interface.
|
// entity that provides TTL cache interface.
|
||||||
type ttlNetCache[K comparable, V any] struct {
|
type ttlNetCache[K comparable, V any] struct {
|
||||||
m sync.RWMutex // protects progMap
|
|
||||||
progMap map[K]*valueInProgress[V] // contains fetch-in-progress keys
|
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
|
|
||||||
sz int
|
sz int
|
||||||
|
@ -43,6 +88,8 @@ type ttlNetCache[K comparable, V any] struct {
|
||||||
cache *lru.Cache[K, *valueWithTime[V]]
|
cache *lru.Cache[K, *valueWithTime[V]]
|
||||||
|
|
||||||
netRdr netValueReader[K, V]
|
netRdr netValueReader[K, V]
|
||||||
|
|
||||||
|
keyLocker *keyLocker[K]
|
||||||
}
|
}
|
||||||
|
|
||||||
// complicates netValueReader with TTL caching mechanism.
|
// complicates netValueReader with TTL caching mechanism.
|
||||||
|
@ -55,68 +102,44 @@ func newNetworkTTLCache[K comparable, V any](sz int, ttl time.Duration, netRdr n
|
||||||
sz: sz,
|
sz: sz,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
netRdr: netRdr,
|
netRdr: netRdr,
|
||||||
progMap: make(map[K]*valueInProgress[V]),
|
keyLocker: newKeyLocker[K](),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForUpdate[V any](vip *valueInProgress[V]) (V, error) {
|
|
||||||
vip.m.RLock()
|
|
||||||
defer vip.m.RUnlock()
|
|
||||||
|
|
||||||
return vip.v, vip.e
|
|
||||||
}
|
|
||||||
|
|
||||||
// reads value by the key.
|
// reads value by the key.
|
||||||
//
|
//
|
||||||
// updates the value from the network on cache miss or by TTL.
|
// updates the value from the network on cache miss or by TTL.
|
||||||
//
|
//
|
||||||
// returned value should not be modified.
|
// returned value should not be modified.
|
||||||
func (c *ttlNetCache[K, V]) get(key K) (V, error) {
|
func (c *ttlNetCache[K, V]) get(key K) (V, error) {
|
||||||
valWithTime, ok := c.cache.Peek(key)
|
val, ok := c.cache.Peek(key)
|
||||||
if ok {
|
if ok && time.Since(val.t) < c.ttl {
|
||||||
if time.Since(valWithTime.t) < c.ttl {
|
return val.v, val.e
|
||||||
return valWithTime.v, valWithTime.e
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.cache.Remove(key)
|
c.keyLocker.LockKey(key)
|
||||||
|
defer c.keyLocker.UnlockKey(key)
|
||||||
|
|
||||||
|
val, ok = c.cache.Peek(key)
|
||||||
|
if ok && time.Since(val.t) < c.ttl {
|
||||||
|
return val.v, val.e
|
||||||
}
|
}
|
||||||
|
|
||||||
c.m.RLock()
|
|
||||||
valInProg, ok := c.progMap[key]
|
|
||||||
c.m.RUnlock()
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
return waitForUpdate(valInProg)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.m.Lock()
|
|
||||||
valInProg, ok = c.progMap[key]
|
|
||||||
if ok {
|
|
||||||
c.m.Unlock()
|
|
||||||
return waitForUpdate(valInProg)
|
|
||||||
}
|
|
||||||
|
|
||||||
valInProg = &valueInProgress[V]{}
|
|
||||||
valInProg.m.Lock()
|
|
||||||
c.progMap[key] = valInProg
|
|
||||||
|
|
||||||
c.m.Unlock()
|
|
||||||
|
|
||||||
v, err := c.netRdr(key)
|
v, err := c.netRdr(key)
|
||||||
c.set(key, v, err)
|
|
||||||
|
|
||||||
valInProg.v = v
|
c.cache.Add(key, &valueWithTime[V]{
|
||||||
valInProg.e = err
|
v: v,
|
||||||
valInProg.m.Unlock()
|
t: time.Now(),
|
||||||
|
e: err,
|
||||||
c.m.Lock()
|
})
|
||||||
delete(c.progMap, key)
|
|
||||||
c.m.Unlock()
|
|
||||||
|
|
||||||
return v, err
|
return v, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ttlNetCache[K, V]) set(k K, v V, e error) {
|
func (c *ttlNetCache[K, V]) set(k K, v V, e error) {
|
||||||
|
c.keyLocker.LockKey(k)
|
||||||
|
defer c.keyLocker.UnlockKey(k)
|
||||||
|
|
||||||
c.cache.Add(k, &valueWithTime[V]{
|
c.cache.Add(k, &valueWithTime[V]{
|
||||||
v: v,
|
v: v,
|
||||||
t: time.Now(),
|
t: time.Now(),
|
||||||
|
@ -125,6 +148,9 @@ func (c *ttlNetCache[K, V]) set(k K, v V, e error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ttlNetCache[K, V]) remove(key K) {
|
func (c *ttlNetCache[K, V]) remove(key K) {
|
||||||
|
c.keyLocker.LockKey(key)
|
||||||
|
defer c.keyLocker.UnlockKey(key)
|
||||||
|
|
||||||
c.cache.Remove(key)
|
c.cache.Remove(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
32
cmd/frostfs-node/cache_test.go
Normal file
32
cmd/frostfs-node/cache_test.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeyLocker(t *testing.T) {
|
||||||
|
taken := false
|
||||||
|
eg, _ := errgroup.WithContext(context.Background())
|
||||||
|
keyLocker := newKeyLocker[int]()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
eg.Go(func() error {
|
||||||
|
keyLocker.LockKey(0)
|
||||||
|
defer keyLocker.UnlockKey(0)
|
||||||
|
|
||||||
|
require.False(t, taken)
|
||||||
|
taken = true
|
||||||
|
require.True(t, taken)
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
taken = false
|
||||||
|
require.False(t, taken)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
require.NoError(t, eg.Wait())
|
||||||
|
}
|
Loading…
Reference in a new issue