diff --git a/limiting/limiter.go b/limiting/limiter.go new file mode 100644 index 0000000..afbd8c0 --- /dev/null +++ b/limiting/limiter.go @@ -0,0 +1,101 @@ +package limiting + +import ( + "context" +) + +type semaphore struct { + ch chan struct{} +} + +func newSemaphore(size uint64) *semaphore { + return &semaphore{make(chan struct{}, size)} +} + +func (s *semaphore) acquire(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case s.ch <- struct{}{}: + return nil + } +} + +func (s *semaphore) tryAcquire() bool { + select { + case s.ch <- struct{}{}: + return true + default: + return false + } +} + +func (s *semaphore) release() { + <-s.ch +} + +type Limiter struct { + m map[string]*semaphore +} + +// KeyLimit defines a concurrency limit for a set of keys. +// +// All keys of one set share the same limit. +// Keys of different sets have separate limits. +// +// Sets must not overlap. +type KeyLimit struct { + Keys []string + Limit uint64 +} + +type ReleaseFunc func() + +func New(limits []KeyLimit) *Limiter { + lr := Limiter{m: make(map[string]*semaphore)} + + for _, l := range limits { + sem := newSemaphore(l.Limit) + for _, item := range l.Keys { + lr.m[item] = sem + } + } + + return &lr +} + +// Acquire reserves a slot for the given key, blocking if necessary. +// +// If the context is canceled before reservation, returns an error; +// otherwise, returns a release function that must be called exactly once. +// +// If the key was not defined in the limiter, no limit is applied. +func (lr *Limiter) Acquire(ctx context.Context, key string) (ReleaseFunc, error) { + sem, ok := lr.m[key] + if !ok { + return func() {}, nil + } + + if err := sem.acquire(ctx); err != nil { + return nil, err + } + return func() { sem.release() }, nil +} + +// TryAcquire attempts to reserve a slot without blocking. +// +// Returns a release function and true if successful, otherwise false. +// The release function must be called exactly once. +// +// If the key was not defined in the limiter, no limit is applied. +func (lr *Limiter) TryAcquire(key string) (ReleaseFunc, bool) { + sem, ok := lr.m[key] + if !ok { + return func() {}, true + } + + if ok := sem.tryAcquire(); ok { + return func() { sem.release() }, true + } + return nil, false +} diff --git a/limiting/limiter_test.go b/limiting/limiter_test.go new file mode 100644 index 0000000..562c33c --- /dev/null +++ b/limiting/limiter_test.go @@ -0,0 +1,145 @@ +package limiting_test + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "git.frostfs.info/TrueCloudLab/frostfs-qos/limiting" + "github.com/stretchr/testify/require" +) + +const ( + operationDuration = 100 * time.Millisecond + operationCount = 64 +) + +type testKeyLimit struct { + keys []string + limit int + withoutLimit bool + failCount atomic.Int32 +} + +func TestLimiter(t *testing.T) { + testLimits := []*testKeyLimit{ + {keys: []string{"A"}, limit: operationCount / 4}, + {keys: []string{"B"}, limit: operationCount / 2}, + {keys: []string{"C", "D"}, limit: operationCount / 4}, + {keys: []string{"E"}, limit: 2 * operationCount}, + {keys: []string{"F"}, withoutLimit: true}, + } + + t.Run("non-blocking mode", func(t *testing.T) { + testLimiter(t, testLimits, false) + }) + + resetFailCounts(testLimits) + + t.Run("blocking mode", func(t *testing.T) { + testLimiter(t, testLimits, true) + }) +} + +func testLimiter(t *testing.T, testCases []*testKeyLimit, blocking bool) { + lr := limiting.New(getLimits(testCases)) + tasks := createTestTasks(testCases, lr, blocking) + + t.Run("first run", func(t *testing.T) { + executeTasks(tasks...) + verifyResults(t, testCases, blocking) + }) + + t.Run("repeated run", func(t *testing.T) { + resetFailCounts(testCases) + executeTasks(tasks...) + verifyResults(t, testCases, blocking) + }) +} + +func getLimits(testCases []*testKeyLimit) []limiting.KeyLimit { + var limits []limiting.KeyLimit + for _, tc := range testCases { + if tc.withoutLimit { + continue + } + limits = append(limits, limiting.KeyLimit{ + Keys: tc.keys, + Limit: uint64(tc.limit), + }) + } + return limits +} + +func createTestTasks(testCases []*testKeyLimit, lr *limiting.Limiter, blocking bool) []func() { + var tasks []func() + for _, tc := range testCases { + for _, key := range tc.keys { + tasks = append(tasks, func() { + executeTaskN(operationCount, func() { acquireAndExecute(tc, lr, key, blocking) }) + }) + } + } + return tasks +} + +func acquireAndExecute(tc *testKeyLimit, lr *limiting.Limiter, key string, blocking bool) { + if blocking { + release, err := lr.Acquire(context.Background(), key) + if err != nil { + tc.failCount.Add(1) + return + } + defer release() + } else { + release, ok := lr.TryAcquire(key) + if !ok { + tc.failCount.Add(1) + return + } + defer release() + } + time.Sleep(operationDuration) +} + +func executeTasks(tasks ...func()) { + var g sync.WaitGroup + + g.Add(len(tasks)) + for _, task := range tasks { + go func() { + defer g.Done() + task() + }() + } + g.Wait() +} + +func executeTaskN(N int, task func()) { + tasks := make([]func(), N) + for i := range N { + tasks[i] = task + } + executeTasks(tasks...) +} + +func verifyResults(t *testing.T, testCases []*testKeyLimit, blocking bool) { + for _, tl := range testCases { + var expectedFailCount int + if blocking || tl.withoutLimit { + expectedFailCount = 0 + } else { + expectedFailCount = max(operationCount*len(tl.keys)-tl.limit, 0) + } + actualFailCount := int(tl.failCount.Load()) + require.Equal(t, expectedFailCount, actualFailCount) + } +} + +func resetFailCounts(testLimits []*testKeyLimit) { + for _, tl := range testLimits { + tl.failCount.Store(0) + } +}