package sync

import "sync"

type locker struct {
	mtx     sync.Mutex
	waiters int // not protected by mtx, must used outer mutex to update concurrently
}

type KeyLocker[K comparable] struct {
	lockers    map[K]*locker
	lockersMtx sync.Mutex
}

func NewKeyLocker[K comparable]() *KeyLocker[K] {
	return &KeyLocker[K]{
		lockers: make(map[K]*locker),
	}
}

func (l *KeyLocker[K]) Lock(key K) {
	l.lockersMtx.Lock()

	if locker, found := l.lockers[key]; found {
		locker.waiters++
		l.lockersMtx.Unlock()

		locker.mtx.Lock()
		return
	}

	locker := &locker{
		waiters: 1,
	}
	locker.mtx.Lock()

	l.lockers[key] = locker
	l.lockersMtx.Unlock()
}

func (l *KeyLocker[K]) Unlock(key K) {
	l.lockersMtx.Lock()
	defer l.lockersMtx.Unlock()

	locker, found := l.lockers[key]
	if !found {
		return
	}

	if locker.waiters == 1 {
		delete(l.lockers, key)
	}
	locker.waiters--

	locker.mtx.Unlock()
}