Fix writecache counters #595

Merged
fyrchik merged 5 commits from dstepanov-yadro/frostfs-node:fix/writecache_bbolt_db_counter into master 2024-09-04 19:51:02 +00:00
14 changed files with 335 additions and 174 deletions

View file

@ -8,6 +8,7 @@ import (
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/netmap"
cntClient "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/morph/client/container"
putsvc "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/put"
utilSync "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/sync"
apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status"
cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
netmapSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
@ -24,59 +25,6 @@ type valueWithTime[V any] struct {
e error
}
type locker struct {
mtx sync.Mutex
waiters int // not protected by mtx, must used outer mutex to update concurrently
}
type keyLocker[K comparable] struct {
lockers map[K]*locker
lockersMtx sync.Mutex
}
func newKeyLocker[K comparable]() *keyLocker[K] {
return &keyLocker[K]{
lockers: make(map[K]*locker),
}
}
func (l *keyLocker[K]) LockKey(key K) {
l.lockersMtx.Lock()
if locker, found := l.lockers[key]; found {
locker.waiters++
l.lockersMtx.Unlock()
locker.mtx.Lock()
return
}
locker := &locker{
waiters: 1,
}
locker.mtx.Lock()
l.lockers[key] = locker
l.lockersMtx.Unlock()
}
func (l *keyLocker[K]) UnlockKey(key K) {
l.lockersMtx.Lock()
defer l.lockersMtx.Unlock()
locker, found := l.lockers[key]
if !found {
return
}
if locker.waiters == 1 {
delete(l.lockers, key)
}
locker.waiters--
locker.mtx.Unlock()
}
// entity that provides TTL cache interface.
type ttlNetCache[K comparable, V any] struct {
ttl time.Duration
@ -87,7 +35,7 @@ type ttlNetCache[K comparable, V any] struct {
netRdr netValueReader[K, V]
keyLocker *keyLocker[K]
keyLocker *utilSync.KeyLocker[K]
}
// complicates netValueReader with TTL caching mechanism.
@ -100,7 +48,7 @@ func newNetworkTTLCache[K comparable, V any](sz int, ttl time.Duration, netRdr n
sz: sz,
cache: cache,
netRdr: netRdr,
keyLocker: newKeyLocker[K](),
keyLocker: utilSync.NewKeyLocker[K](),
}
}
@ -115,8 +63,8 @@ func (c *ttlNetCache[K, V]) get(key K) (V, error) {
return val.v, val.e
}
c.keyLocker.LockKey(key)
defer c.keyLocker.UnlockKey(key)
c.keyLocker.Lock(key)
defer c.keyLocker.Unlock(key)
val, ok = c.cache.Peek(key)
if ok && time.Since(val.t) < c.ttl {
@ -135,8 +83,8 @@ func (c *ttlNetCache[K, V]) get(key K) (V, error) {
}
func (c *ttlNetCache[K, V]) set(k K, v V, e error) {
c.keyLocker.LockKey(k)
defer c.keyLocker.UnlockKey(k)
c.keyLocker.Lock(k)
defer c.keyLocker.Unlock(k)
c.cache.Add(k, &valueWithTime[V]{
v: v,
@ -146,8 +94,8 @@ func (c *ttlNetCache[K, V]) set(k K, v V, e error) {
}
func (c *ttlNetCache[K, V]) remove(key K) {
c.keyLocker.LockKey(key)
defer c.keyLocker.UnlockKey(key)
c.keyLocker.Lock(key)
defer c.keyLocker.Unlock(key)
c.cache.Remove(key)
}

View file

@ -513,4 +513,5 @@ const (
FrostFSNodeCantUnmarshalObjectFromDB = "can't unmarshal an object from the DB" // Error in ../node/cmd/frostfs-node/morph.go
RuntimeSoftMemoryLimitUpdated = "soft runtime memory limit value updated"
RuntimeSoftMemoryDefinedWithGOMEMLIMIT = "soft runtime memory defined with GOMEMLIMIT environment variable, config value skipped"
FailedToCountWritecacheItems = "failed to count writecache items"
)

View file

@ -13,7 +13,10 @@ func (t *FSTree) Open(ro bool) error {
// Init implements common.Storage.
func (t *FSTree) Init() error {
return util.MkdirAllX(t.RootPath, t.Permissions)
if err := util.MkdirAllX(t.RootPath, t.Permissions); err != nil {
return err
}
return t.initFileCounter()
}
// Close implements common.Storage.

View file

@ -0,0 +1,32 @@
package fstree
import (
"math"
"sync/atomic"
)
// FileCounter used to count files in FSTree. The implementation must be thread-safe.
type FileCounter interface {
Set(v uint64)
Inc()
Dec()
}
type noopCounter struct{}
func (c *noopCounter) Set(uint64) {}
func (c *noopCounter) Inc() {}
func (c *noopCounter) Dec() {}
type SimpleCounter struct {
v atomic.Uint64
}
func NewSimpleCounter() *SimpleCounter {
return &SimpleCounter{}
}
func (c *SimpleCounter) Set(v uint64) { c.v.Store(v) }
func (c *SimpleCounter) Inc() { c.v.Add(1) }
func (c *SimpleCounter) Dec() { c.v.Add(math.MaxUint64) }
func (c *SimpleCounter) Value() uint64 { return c.v.Load() }

View file

@ -10,6 +10,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"syscall"
"time"
@ -26,6 +27,16 @@ import (
"go.opentelemetry.io/otel/trace"
)
type keyLock interface {
Lock(string)
Unlock(string)
}
type noopKeyLock struct{}
func (l *noopKeyLock) Lock(string) {}
func (l *noopKeyLock) Unlock(string) {}
// FSTree represents an object storage as a filesystem tree.
type FSTree struct {
Info
@ -37,6 +48,12 @@ type FSTree struct {
noSync bool
readOnly bool
metrics Metrics
fileGuard keyLock
fileCounter FileCounter
fileCounterEnabled bool
suffix atomic.Uint64
}
// Info groups the information about file storage.
@ -63,10 +80,12 @@ func New(opts ...Option) *FSTree {
Permissions: 0700,
RootPath: "./",
},
Config: nil,
Depth: 4,
DirNameLen: DirNameLen,
metrics: &noopMetrics{},
Config: nil,
Depth: 4,
DirNameLen: DirNameLen,
metrics: &noopMetrics{},
fileGuard: &noopKeyLock{},
fileCounter: &noopCounter{},
}
for i := range opts {
opts[i](f)
@ -244,7 +263,17 @@ func (t *FSTree) Delete(ctx context.Context, prm common.DeletePrm) (common.Delet
p := t.treePath(prm.Address)
err = os.Remove(p)
if t.fileCounterEnabled {
t.fileGuard.Lock(p)
err = os.Remove(p)
t.fileGuard.Unlock(p)
if err == nil {
t.fileCounter.Dec()
}
} else {
err = os.Remove(p)
}
if err != nil && os.IsNotExist(err) {
err = logicerr.Wrap(new(apistatus.ObjectNotFound))
}
@ -317,45 +346,19 @@ func (t *FSTree) Put(ctx context.Context, prm common.PutPrm) (common.PutRes, err
prm.RawData = t.Compress(prm.RawData)
}
// Here is a situation:
// Feb 09 13:10:37 buky neofs-node[32445]: 2023-02-09T13:10:37.161Z info log/log.go:13 local object storage operation {"shard_id": "SkT8BfjouW6t93oLuzQ79s", "address": "7NxFz4SruSi8TqXacr2Ae22nekMhgYk1sfkddJo9PpWk/5enyUJGCyU1sfrURDnHEjZFdbGqANVhayYGfdSqtA6wA", "op": "PUT", "type": "fstree", "storage_id": ""}
// Feb 09 13:10:37 buky neofs-node[32445]: 2023-02-09T13:10:37.183Z info log/log.go:13 local object storage operation {"shard_id": "SkT8BfjouW6t93oLuzQ79s", "address": "7NxFz4SruSi8TqXacr2Ae22nekMhgYk1sfkddJo9PpWk/5enyUJGCyU1sfrURDnHEjZFdbGqANVhayYGfdSqtA6wA", "op": "metabase PUT"}
// Feb 09 13:10:37 buky neofs-node[32445]: 2023-02-09T13:10:37.862Z debug policer/check.go:231 shortage of object copies detected {"component": "Object Policer", "object": "7NxFz4SruSi8TqXacr2Ae22nekMhgYk1sfkddJo9PpWk/5enyUJGCyU1sfrURDnHEjZFdbGqANVhayYGfdSqtA6wA", "shortage": 1}
// Feb 09 13:10:37 buky neofs-node[32445]: 2023-02-09T13:10:37.862Z debug shard/get.go:124 object is missing in write-cache {"shard_id": "SkT8BfjouW6t93oLuzQ79s", "addr": "7NxFz4SruSi8TqXacr2Ae22nekMhgYk1sfkddJo9PpWk/5enyUJGCyU1sfrURDnHEjZFdbGqANVhayYGfdSqtA6wA", "skip_meta": false}
//
// 1. We put an object on node 1.
// 2. Relentless policer sees that it has only 1 copy and tries to PUT it to node 2.
// 3. PUT operation started by client at (1) also puts an object here.
// 4. Now we have concurrent writes and one of `Rename` calls will return `no such file` error.
// Even more than that, concurrent writes can corrupt data.
//
// So here is a solution:
// 1. Write a file to 'name + 1'.
// 2. If it exists, retry with temporary name being 'name + 2'.
// 3. Set some reasonable number of attempts.
//
// It is a bit kludgey, but I am unusually proud about having found this out after
// hours of research on linux kernel, dirsync mount option and ext4 FS, turned out
// to be so hecking simple.
// In a very rare situation we can have multiple partially written copies on disk,
// this will be fixed in another issue (we should remove garbage on start).
size = len(prm.RawData)
const retryCount = 5
for i := 0; i < retryCount; i++ {
tmpPath := p + "#" + strconv.FormatUint(uint64(i), 10)
err = t.writeAndRename(tmpPath, p, prm.RawData)
if err != syscall.EEXIST || i == retryCount-1 {
return common.PutRes{StorageID: []byte{}}, err
}
}
err = fmt.Errorf("couldn't read file after %d retries", retryCount)
// unreachable, but precaution never hurts, especially 1 day before release.
tmpPath := p + "#" + strconv.FormatUint(t.suffix.Add(1), 10)
err = t.writeAndRename(tmpPath, p, prm.RawData)
return common.PutRes{StorageID: []byte{}}, err
}
// writeAndRename opens tmpPath exclusively, writes data to it and renames it to p.
func (t *FSTree) writeAndRename(tmpPath, p string, data []byte) error {
if t.fileCounterEnabled {
t.fileGuard.Lock(p)
defer t.fileGuard.Unlock(p)
}
err := t.writeFile(tmpPath, data)
if err != nil {
var pe *fs.PathError
@ -364,10 +367,21 @@ func (t *FSTree) writeAndRename(tmpPath, p string, data []byte) error {
case syscall.ENOSPC:
err = common.ErrNoSpace
_ = os.RemoveAll(tmpPath)
case syscall.EEXIST:
return syscall.EEXIST
}
}
return err
}
if t.fileCounterEnabled {
t.fileCounter.Inc()
var targetFileExists bool
if _, e := os.Stat(p); e == nil {
targetFileExists = true
}
err = os.Rename(tmpPath, p)
if err == nil && targetFileExists {
t.fileCounter.Dec()
}
} else {
err = os.Rename(tmpPath, p)
}
@ -396,27 +410,6 @@ func (t *FSTree) writeFile(p string, data []byte) error {
return err
}
// PutStream puts executes handler on a file opened for write.
func (t *FSTree) PutStream(addr oid.Address, handler func(*os.File) error) error {
if t.readOnly {
return common.ErrReadOnly
}
p := t.treePath(addr)
if err := util.MkdirAllX(filepath.Dir(p), t.Permissions); err != nil {
return err
}
f, err := os.OpenFile(p, t.writeFlags(), t.Permissions)
if err != nil {
return err
}
defer f.Close()
return handler(f)
}
// Get returns an object from the storage by address.
func (t *FSTree) Get(ctx context.Context, prm common.GetPrm) (common.GetRes, error) {
var (
@ -450,6 +443,9 @@ func (t *FSTree) Get(ctx context.Context, prm common.GetPrm) (common.GetRes, err
data, err = os.ReadFile(p)
if err != nil {
if os.IsNotExist(err) {
return common.GetRes{}, logicerr.Wrap(new(apistatus.ObjectNotFound))
}
return common.GetRes{}, err
}
}
@ -509,11 +505,23 @@ func (t *FSTree) GetRange(ctx context.Context, prm common.GetRangePrm) (common.G
}, nil
}
// NumberOfObjects walks the file tree rooted at FSTree's root
// and returns number of stored objects.
func (t *FSTree) NumberOfObjects() (uint64, error) {
var counter uint64
// initFileCounter walks the file tree rooted at FSTree's root,
// counts total items count, inits counter and returns number of stored objects.
func (t *FSTree) initFileCounter() error {
if !t.fileCounterEnabled {
return nil
}
counter, err := t.countFiles()
if err != nil {
return err
}
t.fileCounter.Set(counter)
return nil
}
func (t *FSTree) countFiles() (uint64, error) {
var counter uint64
// it is simpler to just consider every file
// that is not directory as an object
err := filepath.WalkDir(t.RootPath,

View file

@ -1,10 +1,16 @@
package fstree
import (
"context"
"errors"
"testing"
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/common"
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/util/logicerr"
objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
func TestAddressToString(t *testing.T) {
@ -28,3 +34,68 @@ func Benchmark_addressFromString(b *testing.B) {
}
}
}
func TestObjectCounter(t *testing.T) {
t.Parallel()
counter := NewSimpleCounter()
fst := New(
WithPath(t.TempDir()),
WithDepth(2),
WithDirNameLen(2),
WithFileCounter(counter))
require.NoError(t, fst.Open(false))
require.NoError(t, fst.Init())
counterValue := counter.Value()
require.Equal(t, uint64(0), counterValue)
defer func() {
require.NoError(t, fst.Close())
}()
addr := oidtest.Address()
obj := objectSDK.New()
obj.SetID(addr.Object())
obj.SetContainerID(addr.Container())
obj.SetPayload([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0})
var putPrm common.PutPrm
putPrm.Address = addr
putPrm.RawData, _ = obj.Marshal()
var getPrm common.GetPrm
getPrm.Address = putPrm.Address
var delPrm common.DeletePrm
delPrm.Address = addr
eg, egCtx := errgroup.WithContext(context.Background())
eg.Go(func() error {
for j := 0; j < 1_000; j++ {
_, err := fst.Put(egCtx, putPrm)
if err != nil {
return err
}
}
return nil
})
eg.Go(func() error {
var le logicerr.Logical
for j := 0; j < 1_000; j++ {
_, err := fst.Delete(egCtx, delPrm)
if err != nil && !errors.As(err, &le) {
return err
}
}
return nil
})
require.NoError(t, eg.Wait())
counterValue = counter.Value()
realCount, err := fst.countFiles()
require.NoError(t, err)
require.Equal(t, realCount, counterValue)
}

View file

@ -2,6 +2,8 @@ package fstree
import (
"io/fs"
utilSync "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/sync"
)
type Option func(*FSTree)
@ -41,3 +43,11 @@ func WithMetrics(m Metrics) Option {
f.metrics = m
}
}
func WithFileCounter(c FileCounter) Option {
return func(f *FSTree) {
f.fileCounterEnabled = true
f.fileCounter = c
f.fileGuard = utilSync.NewKeyLocker[string]()
}
}

View file

@ -53,7 +53,7 @@ var storages = []storage{
},
},
{
desc: "fstree",
desc: "fstree_without_object_counter",
create: func(dir string) common.Storage {
return fstree.New(
fstree.WithPath(dir),
@ -62,6 +62,17 @@ var storages = []storage{
)
},
},
{
desc: "fstree_with_object_counter",
create: func(dir string) common.Storage {
return fstree.New(
fstree.WithPath(dir),
fstree.WithDepth(2),
fstree.WithDirNameLen(2),
fstree.WithFileCounter(fstree.NewSimpleCounter()),
)
},
},
{
desc: "blobovniczatree",
create: func(dir string) common.Storage {

View file

@ -2,6 +2,7 @@ package writecachebbolt
import (
"context"
"math"
"time"
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/common"
@ -49,9 +50,12 @@ func (c *cache) Delete(ctx context.Context, addr oid.Address) error {
if dataSize > 0 {
storageType = writecache.StorageTypeDB
var recordDeleted bool
err := c.db.Update(func(tx *bbolt.Tx) error {
b := tx.Bucket(defaultBucket)
err := b.Delete([]byte(saddr))
key := []byte(saddr)
recordDeleted = b.Get(key) != nil
err := b.Delete(key)
return err
})
if err != nil {
@ -62,8 +66,11 @@ func (c *cache) Delete(ctx context.Context, addr oid.Address) error {
storagelog.StorageTypeField(wcStorageType),
storagelog.OpField("db DELETE"),
)
if recordDeleted {
c.objCounters.cDB.Add(math.MaxUint64)
c.estimateCacheSize()
}
deleted = true
c.objCounters.DecDB()
return nil
}
@ -75,9 +82,8 @@ func (c *cache) Delete(ctx context.Context, addr oid.Address) error {
storagelog.StorageTypeField(wcStorageType),
storagelog.OpField("fstree DELETE"),
)
c.objCounters.DecFS()
deleted = true
c.estimateCacheSize()
}
return metaerr.Wrap(err)
}

View file

@ -85,9 +85,15 @@ func (c *cache) putSmall(obj objectInfo) error {
return ErrOutOfSpace
}
var newRecord bool
err := c.db.Batch(func(tx *bbolt.Tx) error {
b := tx.Bucket(defaultBucket)
return b.Put([]byte(obj.addr), obj.data)
key := []byte(obj.addr)
newRecord = b.Get(key) == nil
if newRecord {
return b.Put(key, obj.data)
}
return nil
})
if err == nil {
storagelog.Write(c.log,
@ -95,7 +101,10 @@ func (c *cache) putSmall(obj objectInfo) error {
storagelog.StorageTypeField(wcStorageType),
storagelog.OpField("db PUT"),
)
c.objCounters.IncDB()
if newRecord {
c.objCounters.cDB.Add(1)
c.estimateCacheSize()
}
}
return err
}
@ -117,12 +126,12 @@ func (c *cache) putBig(ctx context.Context, addr string, prm common.PutPrm) erro
c.compressFlags[addr] = struct{}{}
c.mtx.Unlock()
}
c.objCounters.IncFS()
storagelog.Write(c.log,
storagelog.AddressField(addr),
storagelog.StorageTypeField(wcStorageType),
storagelog.OpField("fstree PUT"),
)
c.estimateCacheSize()
return nil
}

View file

@ -5,14 +5,21 @@ import (
"math"
"sync/atomic"
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/local_object_storage/blobstor/fstree"
"go.etcd.io/bbolt"
)
func (c *cache) estimateCacheSize() uint64 {
db := c.objCounters.DB() * c.smallObjectSize
fstree := c.objCounters.FS() * c.maxObjectSize
c.metrics.SetEstimateSize(db, fstree)
return db + fstree
dbCount := c.objCounters.DB()
fsCount := c.objCounters.FS()
if fsCount > 0 {
fsCount-- //db file
}
dbSize := dbCount * c.smallObjectSize
fsSize := fsCount * c.maxObjectSize
c.metrics.SetEstimateSize(dbSize, fsSize)
c.metrics.SetActualCounters(dbCount, fsCount)
return dbSize + fsSize
}
func (c *cache) incSizeDB(sz uint64) uint64 {
@ -23,34 +30,35 @@ func (c *cache) incSizeFS(sz uint64) uint64 {
return sz + c.maxObjectSize
}
var _ fstree.FileCounter = &counters{}
type counters struct {
cDB, cFS atomic.Uint64
}
func (x *counters) IncDB() {
x.cDB.Add(1)
}
func (x *counters) DecDB() {
x.cDB.Add(math.MaxUint64)
}
func (x *counters) DB() uint64 {
return x.cDB.Load()
}
func (x *counters) IncFS() {
func (x *counters) FS() uint64 {
return x.cFS.Load()
}
// Set implements fstree.ObjectCounter.
func (x *counters) Set(v uint64) {
x.cFS.Store(v)
}
// Inc implements fstree.ObjectCounter.
func (x *counters) Inc() {
x.cFS.Add(1)
}
func (x *counters) DecFS() {
// Dec implements fstree.ObjectCounter.
func (x *counters) Dec() {
x.cFS.Add(math.MaxUint64)
}
func (x *counters) FS() uint64 {
return x.cFS.Load()
}
func (c *cache) initCounters() error {
var inDB uint64
err := c.db.View(func(tx *bbolt.Tx) error {
@ -63,18 +71,6 @@ func (c *cache) initCounters() error {
if err != nil {
return fmt.Errorf("could not read write-cache DB counter: %w", err)
}
inFS, err := c.fsTree.NumberOfObjects()
if err != nil {
return fmt.Errorf("could not read write-cache FS counter: %w", err)
}
if inFS > 0 {
inFS-- //small.bolt DB file
}
c.objCounters.cDB.Store(inDB)
c.objCounters.cFS.Store(inFS)
c.metrics.SetActualCounters(inDB, inFS)
return nil
}

View file

@ -3,6 +3,7 @@ package writecachebbolt
import (
"context"
"fmt"
"math"
"os"
"git.frostfs.info/TrueCloudLab/frostfs-node/internal/logs"
@ -54,29 +55,39 @@ func (c *cache) openStore(readOnly bool) error {
fstree.WithPerm(os.ModePerm),
fstree.WithDepth(1),
fstree.WithDirNameLen(1),
fstree.WithNoSync(c.noSync))
fstree.WithNoSync(c.noSync),
fstree.WithFileCounter(&c.objCounters),
)
if err := c.fsTree.Open(readOnly); err != nil {
return fmt.Errorf("could not open FSTree: %w", err)
}
if err := c.fsTree.Init(); err != nil {
return fmt.Errorf("could not init FSTree: %w", err)
}
return nil
}
func (c *cache) deleteFromDB(key string) {
var recordDeleted bool
err := c.db.Batch(func(tx *bbolt.Tx) error {
b := tx.Bucket(defaultBucket)
return b.Delete([]byte(key))
key := []byte(key)
recordDeleted = !recordDeleted && b.Get(key) != nil
return b.Delete(key)
})
if err == nil {
c.objCounters.DecDB()
c.metrics.Evict(writecache.StorageTypeDB)
storagelog.Write(c.log,
storagelog.AddressField(key),
storagelog.StorageTypeField(wcStorageType),
storagelog.OpField("db DELETE"),
)
c.estimateCacheSize()
if recordDeleted {
c.objCounters.cDB.Add(math.MaxUint64)
c.estimateCacheSize()
}
} else {
c.log.Error(logs.WritecacheCantRemoveObjectsFromTheDatabase, zap.Error(err))
}
@ -111,7 +122,6 @@ func (c *cache) deleteFromDisk(ctx context.Context, keys []string) []string {
storagelog.OpField("fstree DELETE"),
)
c.metrics.Evict(writecache.StorageTypeFSTree)
c.objCounters.DecFS()
c.estimateCacheSize()
}
}

View file

@ -0,0 +1,56 @@
package sync
import "sync"
type locker struct {
mtx sync.Mutex
waiters int // not protected by mtx, must used outer mutex to update concurrently
}
type KeyLocker[K comparable] struct {
lockers map[K]*locker
lockersMtx sync.Mutex
}
func NewKeyLocker[K comparable]() *KeyLocker[K] {
return &KeyLocker[K]{
lockers: make(map[K]*locker),
}
}
func (l *KeyLocker[K]) Lock(key K) {
l.lockersMtx.Lock()
if locker, found := l.lockers[key]; found {
locker.waiters++
l.lockersMtx.Unlock()
locker.mtx.Lock()
return
}
locker := &locker{
waiters: 1,
}
locker.mtx.Lock()
l.lockers[key] = locker
l.lockersMtx.Unlock()
}
func (l *KeyLocker[K]) Unlock(key K) {
l.lockersMtx.Lock()
defer l.lockersMtx.Unlock()
locker, found := l.lockers[key]
if !found {
return
}
if locker.waiters == 1 {
delete(l.lockers, key)
}
locker.waiters--
locker.mtx.Unlock()
}

View file

@ -1,4 +1,4 @@
package main
package sync
import (
"context"
@ -12,11 +12,11 @@ import (
func TestKeyLocker(t *testing.T) {
taken := false
eg, _ := errgroup.WithContext(context.Background())
keyLocker := newKeyLocker[int]()
keyLocker := NewKeyLocker[int]()
for i := 0; i < 100; i++ {
eg.Go(func() error {
keyLocker.LockKey(0)
defer keyLocker.UnlockKey(0)
keyLocker.Lock(0)
defer keyLocker.Unlock(0)
require.False(t, taken)
taken = true