diff --git a/cmd/frostfs-node/cache.go b/cmd/frostfs-node/cache.go index 2dc8b7f698..dfbaf3525c 100644 --- a/cmd/frostfs-node/cache.go +++ b/cmd/frostfs-node/cache.go @@ -24,25 +24,72 @@ type valueWithTime[V any] struct { e error } -// valueInProgress is a struct that contains -// values that are being fetched/updated. -type valueInProgress[V any] struct { - m sync.RWMutex - v V - e error +type locker struct { + mtx *sync.Mutex + waiters int // not protected by mtx, must used outer mutex to update concurrently +} + +type keyLocker[K comparable] struct { + lockers map[K]*locker + lockersMtx *sync.Mutex +} + +func newKeyLocker[K comparable]() *keyLocker[K] { + return &keyLocker[K]{ + lockers: make(map[K]*locker), + lockersMtx: &sync.Mutex{}, + } +} + +func (l *keyLocker[K]) LockKey(key K) { + l.lockersMtx.Lock() + + if locker, found := l.lockers[key]; found { + locker.waiters++ + l.lockersMtx.Unlock() + + locker.mtx.Lock() + return + } + + locker := &locker{ + mtx: &sync.Mutex{}, + waiters: 1, + } + locker.mtx.Lock() + + l.lockers[key] = locker + l.lockersMtx.Unlock() +} + +func (l *keyLocker[K]) UnlockKey(key K) { + l.lockersMtx.Lock() + defer l.lockersMtx.Unlock() + + locker, found := l.lockers[key] + if !found { + return + } + + if locker.waiters == 1 { + delete(l.lockers, key) + } + locker.waiters-- + + locker.mtx.Unlock() } // entity that provides TTL cache interface. type ttlNetCache[K comparable, V any] struct { - m sync.RWMutex // protects progMap - progMap map[K]*valueInProgress[V] // contains fetch-in-progress keys - ttl time.Duration + ttl time.Duration sz int cache *lru.Cache[K, *valueWithTime[V]] netRdr netValueReader[K, V] + + keyLocker *keyLocker[K] } // complicates netValueReader with TTL caching mechanism. @@ -51,72 +98,48 @@ func newNetworkTTLCache[K comparable, V any](sz int, ttl time.Duration, netRdr n fatalOnErr(err) return &ttlNetCache[K, V]{ - ttl: ttl, - sz: sz, - cache: cache, - netRdr: netRdr, - progMap: make(map[K]*valueInProgress[V]), + ttl: ttl, + sz: sz, + cache: cache, + netRdr: netRdr, + keyLocker: newKeyLocker[K](), } } -func waitForUpdate[V any](vip *valueInProgress[V]) (V, error) { - vip.m.RLock() - defer vip.m.RUnlock() - - return vip.v, vip.e -} - // reads value by the key. // // updates the value from the network on cache miss or by TTL. // // returned value should not be modified. func (c *ttlNetCache[K, V]) get(key K) (V, error) { - valWithTime, ok := c.cache.Peek(key) - if ok { - if time.Since(valWithTime.t) < c.ttl { - return valWithTime.v, valWithTime.e - } - - c.cache.Remove(key) + val, ok := c.cache.Peek(key) + if ok && time.Since(val.t) < c.ttl { + return val.v, val.e } - c.m.RLock() - valInProg, ok := c.progMap[key] - c.m.RUnlock() + c.keyLocker.LockKey(key) + defer c.keyLocker.UnlockKey(key) - if ok { - return waitForUpdate(valInProg) + val, ok = c.cache.Peek(key) + if ok && time.Since(val.t) < c.ttl { + return val.v, val.e } - c.m.Lock() - valInProg, ok = c.progMap[key] - if ok { - c.m.Unlock() - return waitForUpdate(valInProg) - } - - valInProg = &valueInProgress[V]{} - valInProg.m.Lock() - c.progMap[key] = valInProg - - c.m.Unlock() - v, err := c.netRdr(key) - c.set(key, v, err) - valInProg.v = v - valInProg.e = err - valInProg.m.Unlock() - - c.m.Lock() - delete(c.progMap, key) - c.m.Unlock() + c.cache.Add(key, &valueWithTime[V]{ + v: v, + t: time.Now(), + e: err, + }) return v, err } func (c *ttlNetCache[K, V]) set(k K, v V, e error) { + c.keyLocker.LockKey(k) + defer c.keyLocker.UnlockKey(k) + c.cache.Add(k, &valueWithTime[V]{ v: v, t: time.Now(), @@ -125,6 +148,9 @@ func (c *ttlNetCache[K, V]) set(k K, v V, e error) { } func (c *ttlNetCache[K, V]) remove(key K) { + c.keyLocker.LockKey(key) + defer c.keyLocker.UnlockKey(key) + c.cache.Remove(key) } diff --git a/cmd/frostfs-node/cache_test.go b/cmd/frostfs-node/cache_test.go new file mode 100644 index 0000000000..a3e1c4ea65 --- /dev/null +++ b/cmd/frostfs-node/cache_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestKeyLocker(t *testing.T) { + taken := false + eg, _ := errgroup.WithContext(context.Background()) + keyLocker := newKeyLocker[int]() + for i := 0; i < 100; i++ { + eg.Go(func() error { + keyLocker.LockKey(0) + defer keyLocker.UnlockKey(0) + + require.False(t, taken) + taken = true + require.True(t, taken) + time.Sleep(10 * time.Millisecond) + taken = false + require.False(t, taken) + + return nil + }) + } + require.NoError(t, eg.Wait()) +}