From 0fe5e34fb03f35e59bc7da19a5d4e27b84d0140c Mon Sep 17 00:00:00 2001 From: Dmitrii Stepanov Date: Tue, 25 Apr 2023 16:48:46 +0300 Subject: [PATCH] [#231] node: Fix race condition in TTL cache Use key locker to lock by key. Signed-off-by: Dmitrii Stepanov --- cmd/frostfs-node/cache.go | 92 ++++++++++++++++++++++++++++++---- cmd/frostfs-node/cache_test.go | 32 ++++++++++++ 2 files changed, 114 insertions(+), 10 deletions(-) create mode 100644 cmd/frostfs-node/cache_test.go diff --git a/cmd/frostfs-node/cache.go b/cmd/frostfs-node/cache.go index 3d4fc7375..dfbaf3525 100644 --- a/cmd/frostfs-node/cache.go +++ b/cmd/frostfs-node/cache.go @@ -24,6 +24,61 @@ type valueWithTime[V any] struct { e error } +type locker struct { + mtx *sync.Mutex + waiters int // not protected by mtx, must used outer mutex to update concurrently +} + +type keyLocker[K comparable] struct { + lockers map[K]*locker + lockersMtx *sync.Mutex +} + +func newKeyLocker[K comparable]() *keyLocker[K] { + return &keyLocker[K]{ + lockers: make(map[K]*locker), + lockersMtx: &sync.Mutex{}, + } +} + +func (l *keyLocker[K]) LockKey(key K) { + l.lockersMtx.Lock() + + if locker, found := l.lockers[key]; found { + locker.waiters++ + l.lockersMtx.Unlock() + + locker.mtx.Lock() + return + } + + locker := &locker{ + mtx: &sync.Mutex{}, + waiters: 1, + } + locker.mtx.Lock() + + l.lockers[key] = locker + l.lockersMtx.Unlock() +} + +func (l *keyLocker[K]) UnlockKey(key K) { + l.lockersMtx.Lock() + defer l.lockersMtx.Unlock() + + locker, found := l.lockers[key] + if !found { + return + } + + if locker.waiters == 1 { + delete(l.lockers, key) + } + locker.waiters-- + + locker.mtx.Unlock() +} + // entity that provides TTL cache interface. type ttlNetCache[K comparable, V any] struct { ttl time.Duration @@ -33,6 +88,8 @@ type ttlNetCache[K comparable, V any] struct { cache *lru.Cache[K, *valueWithTime[V]] netRdr netValueReader[K, V] + + keyLocker *keyLocker[K] } // complicates netValueReader with TTL caching mechanism. @@ -41,10 +98,11 @@ func newNetworkTTLCache[K comparable, V any](sz int, ttl time.Duration, netRdr n fatalOnErr(err) return &ttlNetCache[K, V]{ - ttl: ttl, - sz: sz, - cache: cache, - netRdr: netRdr, + ttl: ttl, + sz: sz, + cache: cache, + netRdr: netRdr, + keyLocker: newKeyLocker[K](), } } @@ -55,22 +113,33 @@ func newNetworkTTLCache[K comparable, V any](sz int, ttl time.Duration, netRdr n // returned value should not be modified. func (c *ttlNetCache[K, V]) get(key K) (V, error) { val, ok := c.cache.Peek(key) - if ok { - if time.Since(val.t) < c.ttl { - return val.v, val.e - } + if ok && time.Since(val.t) < c.ttl { + return val.v, val.e + } - c.cache.Remove(key) + c.keyLocker.LockKey(key) + defer c.keyLocker.UnlockKey(key) + + val, ok = c.cache.Peek(key) + if ok && time.Since(val.t) < c.ttl { + return val.v, val.e } v, err := c.netRdr(key) - c.set(key, v, err) + c.cache.Add(key, &valueWithTime[V]{ + v: v, + t: time.Now(), + e: err, + }) return v, err } func (c *ttlNetCache[K, V]) set(k K, v V, e error) { + c.keyLocker.LockKey(k) + defer c.keyLocker.UnlockKey(k) + c.cache.Add(k, &valueWithTime[V]{ v: v, t: time.Now(), @@ -79,6 +148,9 @@ func (c *ttlNetCache[K, V]) set(k K, v V, e error) { } func (c *ttlNetCache[K, V]) remove(key K) { + c.keyLocker.LockKey(key) + defer c.keyLocker.UnlockKey(key) + c.cache.Remove(key) } diff --git a/cmd/frostfs-node/cache_test.go b/cmd/frostfs-node/cache_test.go new file mode 100644 index 000000000..a3e1c4ea6 --- /dev/null +++ b/cmd/frostfs-node/cache_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestKeyLocker(t *testing.T) { + taken := false + eg, _ := errgroup.WithContext(context.Background()) + keyLocker := newKeyLocker[int]() + for i := 0; i < 100; i++ { + eg.Go(func() error { + keyLocker.LockKey(0) + defer keyLocker.UnlockKey(0) + + require.False(t, taken) + taken = true + require.True(t, taken) + time.Sleep(10 * time.Millisecond) + taken = false + require.False(t, taken) + + return nil + }) + } + require.NoError(t, eg.Wait()) +}