package sync import "sync" type locker struct { mtx sync.RWMutex userCount int // not protected by mtx, must used outer mutex to update concurrently } type KeyLocker[K comparable] struct { lockers map[K]*locker lockersMtx sync.Mutex } func NewKeyLocker[K comparable]() *KeyLocker[K] { return &KeyLocker[K]{ lockers: make(map[K]*locker), } } func (l *KeyLocker[K]) Lock(key K) { l.lock(key, false) } func (l *KeyLocker[K]) RLock(key K) { l.lock(key, true) } func (l *KeyLocker[K]) lock(key K, read bool) { l.lockersMtx.Lock() if locker, found := l.lockers[key]; found { locker.userCount++ l.lockersMtx.Unlock() if read { locker.mtx.RLock() } else { locker.mtx.Lock() } return } locker := &locker{ userCount: 1, } if read { locker.mtx.RLock() } else { locker.mtx.Lock() } l.lockers[key] = locker l.lockersMtx.Unlock() } func (l *KeyLocker[K]) Unlock(key K) { l.unlock(key, false) } func (l *KeyLocker[K]) RUnlock(key K) { l.unlock(key, true) } func (l *KeyLocker[K]) unlock(key K, read bool) { l.lockersMtx.Lock() defer l.lockersMtx.Unlock() locker, found := l.lockers[key] if !found { return } if locker.userCount == 1 { delete(l.lockers, key) } locker.userCount-- if read { locker.mtx.RUnlock() } else { locker.mtx.Unlock() } }