diff --git a/limiting/limiter.go b/limiting/limiter.go index ad1edd6..cc4a0ea 100644 --- a/limiting/limiter.go +++ b/limiting/limiter.go @@ -1,41 +1,28 @@ package limiting import ( - "context" + "fmt" ) -type semaphore struct { - ch chan struct{} +type ReleaseFunc func() + +type Limiter interface { + // TryAcquire attempts to reserve a slot without blocking. + // + // Returns a release function and true if successful, otherwise false. + // The release function must be called exactly once. + // + // If the key was not defined in the limiter, no limit is applied. + Acquire(key string) (ReleaseFunc, bool) } -func newSemaphore(size uint64) *semaphore { - return &semaphore{make(chan struct{}, size)} +type semaphore interface { + acquire() bool + release() } -func (s *semaphore) acquire(ctx context.Context) error { - select { - case <-ctx.Done(): - return ctx.Err() - case s.ch <- struct{}{}: - return nil - } -} - -func (s *semaphore) tryAcquire() bool { - select { - case s.ch <- struct{}{}: - return true - default: - return false - } -} - -func (s *semaphore) release() { - <-s.ch -} - -type Limiter struct { - m map[string]*semaphore // for read-only access +type semaphoreLimiter[T semaphore] struct { + m map[string]T } // KeyLimit defines a concurrency limit for a set of keys. @@ -46,53 +33,54 @@ type Limiter struct { // Sets must not overlap. type KeyLimit struct { Keys []string - Limit uint64 + Limit int64 } -type ReleaseFunc func() +var NewAtomicLimiter = func(limits []KeyLimit) (Limiter, error) { + return newSemaphoreLimiter(limits, newAtomicSemaphore) +} -func New(limits []KeyLimit) *Limiter { - lr := Limiter{m: make(map[string]*semaphore)} - for _, l := range limits { - sem := newSemaphore(l.Limit) - for _, key := range l.Keys { - lr.m[key] = sem +var NewBurstAtomicLimiter = func(limits []KeyLimit) (Limiter, error) { + return newSemaphoreLimiter(limits, newBurstAtomicSemaphore) +} + +var NewChannelLimiter = func(limits []KeyLimit) (Limiter, error) { + return newSemaphoreLimiter(limits, newChannelSemaphore) +} + +func newSemaphoreLimiter[T semaphore](limits []KeyLimit, newSemaphore func(size int64) T) (*semaphoreLimiter[T], error) { + lr := semaphoreLimiter[T]{make(map[string]T)} + for _, limit := range limits { + if err := lr.addLimit(&limit, newSemaphore); err != nil { + return nil, err } } - return &lr + return &lr, nil } -// Acquire reserves a slot for the given key, blocking if necessary. -// -// If the context is canceled before reservation, returns an error; -// otherwise, returns a release function that must be called exactly once. -// -// If the key was not defined in the limiter, no limit is applied. -func (lr *Limiter) Acquire(ctx context.Context, key string) (ReleaseFunc, error) { - sem, ok := lr.m[key] - if !ok { - return func() {}, nil +func (lr *semaphoreLimiter[T]) addLimit(limit *KeyLimit, newSemaphore func(size int64) T) error { + if limit.Limit < 0 { + return fmt.Errorf("invalid limit %d", limit.Limit) } - if err := sem.acquire(ctx); err != nil { - return nil, err + sem := newSemaphore(limit.Limit) + for _, key := range limit.Keys { + if _, exists := lr.m[key]; exists { + return fmt.Errorf("duplicate key %q", key) + } + lr.m[key] = sem } - return func() { sem.release() }, nil + + return nil } -// TryAcquire attempts to reserve a slot without blocking. -// -// Returns a release function and true if successful, otherwise false. -// The release function must be called exactly once. -// -// If the key was not defined in the limiter, no limit is applied. -func (lr *Limiter) TryAcquire(key string) (ReleaseFunc, bool) { +func (lr *semaphoreLimiter[T]) Acquire(key string) (ReleaseFunc, bool) { sem, ok := lr.m[key] if !ok { return func() {}, true } - if ok := sem.tryAcquire(); ok { + if ok := sem.acquire(); ok { return func() { sem.release() }, true } return nil, false diff --git a/limiting/limiter_test.go b/limiting/limiter_test.go index 72af2ba..cf6088b 100644 --- a/limiting/limiter_test.go +++ b/limiting/limiter_test.go @@ -1,7 +1,6 @@ package limiting_test import ( - "context" "sync" "sync/atomic" "testing" @@ -12,19 +11,41 @@ import ( ) const ( - operationDuration = 100 * time.Millisecond + operationDuration = 10 * time.Millisecond operationCount = 64 ) -type testKeyLimit struct { +type testCase struct { keys []string - limit int + limit int64 withoutLimit bool - failCount atomic.Int32 + failCount atomic.Int64 } func TestLimiter(t *testing.T) { - testLimits := []*testKeyLimit{ + t.Run("atomic limiter", func(t *testing.T) { + testLimiter(t, limiting.NewAtomicLimiter) + }) + + t.Run("burst atomic limiter", func(t *testing.T) { + testLimiter(t, limiting.NewBurstAtomicLimiter) + }) + + t.Run("channel limiter", func(t *testing.T) { + testLimiter(t, limiting.NewChannelLimiter) + }) +} + +func testLimiter(t *testing.T, getLimiter func([]limiting.KeyLimit) (limiting.Limiter, error)) { + t.Run("duplicate key", func(t *testing.T) { + _, err := getLimiter([]limiting.KeyLimit{ + {[]string{"A", "B"}, 10}, + {[]string{"B", "C"}, 10}, + }) + require.Error(t, err) + }) + + testCases := []*testCase{ {keys: []string{"A"}, limit: operationCount / 4}, {keys: []string{"B"}, limit: operationCount / 2}, {keys: []string{"C", "D"}, limit: operationCount / 4}, @@ -32,34 +53,25 @@ func TestLimiter(t *testing.T) { {keys: []string{"F"}, withoutLimit: true}, } - t.Run("non-blocking mode", func(t *testing.T) { - testLimiter(t, testLimits, false) - }) + lr, err := getLimiter(getLimits(testCases)) + require.NoError(t, err) - resetFailCounts(testLimits) - - t.Run("blocking mode", func(t *testing.T) { - testLimiter(t, testLimits, true) - }) -} - -func testLimiter(t *testing.T, testCases []*testKeyLimit, blocking bool) { - lr := limiting.New(getLimits(testCases)) - tasks := createTestTasks(testCases, lr, blocking) + tasks := createTestTasks(testCases, lr) t.Run("first run", func(t *testing.T) { executeTasks(tasks...) - verifyResults(t, testCases, blocking) + verifyResults(t, testCases) }) + resetFailCounts(testCases) + t.Run("repeated run", func(t *testing.T) { - resetFailCounts(testCases) executeTasks(tasks...) - verifyResults(t, testCases, blocking) + verifyResults(t, testCases) }) } -func getLimits(testCases []*testKeyLimit) []limiting.KeyLimit { +func getLimits(testCases []*testCase) []limiting.KeyLimit { var limits []limiting.KeyLimit for _, tc := range testCases { if tc.withoutLimit { @@ -67,40 +79,31 @@ func getLimits(testCases []*testKeyLimit) []limiting.KeyLimit { } limits = append(limits, limiting.KeyLimit{ Keys: tc.keys, - Limit: uint64(tc.limit), + Limit: int64(tc.limit), }) } return limits } -func createTestTasks(testCases []*testKeyLimit, lr *limiting.Limiter, blocking bool) []func() { +func createTestTasks(testCases []*testCase, lr limiting.Limiter) []func() { var tasks []func() for _, tc := range testCases { for _, key := range tc.keys { tasks = append(tasks, func() { - executeTaskN(operationCount, func() { acquireAndExecute(tc, lr, key, blocking) }) + executeTaskN(operationCount, func() { acquireAndExecute(tc, lr, key) }) }) } } return tasks } -func acquireAndExecute(tc *testKeyLimit, lr *limiting.Limiter, key string, blocking bool) { - if blocking { - release, err := lr.Acquire(context.Background(), key) - if err != nil { - tc.failCount.Add(1) - return - } - defer release() - } else { - release, ok := lr.TryAcquire(key) - if !ok { - tc.failCount.Add(1) - return - } - defer release() +func acquireAndExecute(tc *testCase, lr limiting.Limiter, key string) { + release, ok := lr.Acquire(key) + if !ok { + tc.failCount.Add(1) + return } + defer release() time.Sleep(operationDuration) } @@ -125,21 +128,19 @@ func executeTaskN(N int, task func()) { executeTasks(tasks...) } -func verifyResults(t *testing.T, testCases []*testKeyLimit, blocking bool) { +func verifyResults(t *testing.T, testCases []*testCase) { for _, tc := range testCases { - var expectedFailCount int - if blocking || tc.withoutLimit { - expectedFailCount = 0 - } else { - expectedFailCount = max(operationCount*len(tc.keys)-tc.limit, 0) + var expectedFailCount int64 + if !tc.withoutLimit { + numKeys := int64(len(tc.keys)) + expectedFailCount = max(operationCount*numKeys-tc.limit, 0) } - actualFailCount := int(tc.failCount.Load()) - require.Equal(t, expectedFailCount, actualFailCount) + require.Equal(t, expectedFailCount, tc.failCount.Load()) } } -func resetFailCounts(testLimits []*testKeyLimit) { - for _, tc := range testLimits { +func resetFailCounts(testCases []*testCase) { + for _, tc := range testCases { tc.failCount.Store(0) } } diff --git a/limiting/semaphore.go b/limiting/semaphore.go new file mode 100644 index 0000000..91a32f0 --- /dev/null +++ b/limiting/semaphore.go @@ -0,0 +1,74 @@ +package limiting + +import ( + "sync/atomic" +) + +type atomicSemaphore struct { + countDown atomic.Int64 +} + +func newAtomicSemaphore(size int64) *atomicSemaphore { + sem := new(atomicSemaphore) + sem.countDown.Store(size) + return sem +} + +func (s *atomicSemaphore) acquire() bool { + for { + v := s.countDown.Load() + if v == 0 { + return false + } + if s.countDown.CompareAndSwap(v, v-1) { + return true + } + } +} + +func (s *atomicSemaphore) release() { + s.countDown.Add(1) +} + +type burstAtomicSemaphore struct { + count atomic.Int64 + limit int64 +} + +func newBurstAtomicSemaphore(size int64) *burstAtomicSemaphore { + return &burstAtomicSemaphore{limit: size} +} + +func (s *burstAtomicSemaphore) acquire() bool { + v := s.count.Add(1) + if v > s.limit { + s.count.Add(-1) + return false + } + return true +} + +func (s *burstAtomicSemaphore) release() { + s.count.Add(-1) +} + +type channelSemaphore struct { + ch chan struct{} +} + +func newChannelSemaphore(size int64) *channelSemaphore { + return &channelSemaphore{make(chan struct{}, size)} +} + +func (s *channelSemaphore) acquire() bool { + select { + case s.ch <- struct{}{}: + return true + default: + return false + } +} + +func (s *channelSemaphore) release() { + <-s.ch +}