[#4] limiting: Add check for duplicated keys

Signed-off-by: Aleksey Savchuk <a.savchuk@yadro.com>
This commit is contained in:
Aleksey Savchuk 2025-02-10 13:00:29 +03:00
parent 3e7881e3fe
commit 311ce63094
Signed by: a-savchuk
GPG key ID: 70C0A7FF6F9C4639
3 changed files with 173 additions and 110 deletions

View file

@ -1,41 +1,28 @@
package limiting package limiting
import ( import (
"context" "fmt"
) )
type semaphore struct { type ReleaseFunc func()
ch chan struct{}
type Limiter interface {
// TryAcquire attempts to reserve a slot without blocking.
//
// Returns a release function and true if successful, otherwise false.
// The release function must be called exactly once.
//
// If the key was not defined in the limiter, no limit is applied.
Acquire(key string) (ReleaseFunc, bool)
} }
func newSemaphore(size uint64) *semaphore { type semaphore interface {
return &semaphore{make(chan struct{}, size)} acquire() bool
release()
} }
func (s *semaphore) acquire(ctx context.Context) error { type semaphoreLimiter[T semaphore] struct {
select { m map[string]T
case <-ctx.Done():
return ctx.Err()
case s.ch <- struct{}{}:
return nil
}
}
func (s *semaphore) tryAcquire() bool {
select {
case s.ch <- struct{}{}:
return true
default:
return false
}
}
func (s *semaphore) release() {
<-s.ch
}
type Limiter struct {
m map[string]*semaphore // for read-only access
} }
// KeyLimit defines a concurrency limit for a set of keys. // KeyLimit defines a concurrency limit for a set of keys.
@ -46,53 +33,54 @@ type Limiter struct {
// Sets must not overlap. // Sets must not overlap.
type KeyLimit struct { type KeyLimit struct {
Keys []string Keys []string
Limit uint64 Limit int64
} }
type ReleaseFunc func() var NewAtomicLimiter = func(limits []KeyLimit) (Limiter, error) {
return newSemaphoreLimiter(limits, newAtomicSemaphore)
func New(limits []KeyLimit) *Limiter {
lr := Limiter{m: make(map[string]*semaphore)}
for _, l := range limits {
sem := newSemaphore(l.Limit)
for _, key := range l.Keys {
lr.m[key] = sem
}
}
return &lr
} }
// Acquire reserves a slot for the given key, blocking if necessary. var NewBurstAtomicLimiter = func(limits []KeyLimit) (Limiter, error) {
// return newSemaphoreLimiter(limits, newBurstAtomicSemaphore)
// If the context is canceled before reservation, returns an error;
// otherwise, returns a release function that must be called exactly once.
//
// If the key was not defined in the limiter, no limit is applied.
func (lr *Limiter) Acquire(ctx context.Context, key string) (ReleaseFunc, error) {
sem, ok := lr.m[key]
if !ok {
return func() {}, nil
} }
if err := sem.acquire(ctx); err != nil { var NewChannelLimiter = func(limits []KeyLimit) (Limiter, error) {
return newSemaphoreLimiter(limits, newChannelSemaphore)
}
func newSemaphoreLimiter[T semaphore](limits []KeyLimit, newSemaphore func(size int64) T) (*semaphoreLimiter[T], error) {
lr := semaphoreLimiter[T]{make(map[string]T)}
for _, limit := range limits {
if err := lr.addLimit(&limit, newSemaphore); err != nil {
return nil, err return nil, err
} }
return func() { sem.release() }, nil }
return &lr, nil
} }
// TryAcquire attempts to reserve a slot without blocking. func (lr *semaphoreLimiter[T]) addLimit(limit *KeyLimit, newSemaphore func(size int64) T) error {
// if limit.Limit < 0 {
// Returns a release function and true if successful, otherwise false. return fmt.Errorf("invalid limit %d", limit.Limit)
// The release function must be called exactly once. }
//
// If the key was not defined in the limiter, no limit is applied. sem := newSemaphore(limit.Limit)
func (lr *Limiter) TryAcquire(key string) (ReleaseFunc, bool) { for _, key := range limit.Keys {
if _, exists := lr.m[key]; exists {
return fmt.Errorf("duplicate key %q", key)
}
lr.m[key] = sem
}
return nil
}
func (lr *semaphoreLimiter[T]) Acquire(key string) (ReleaseFunc, bool) {
sem, ok := lr.m[key] sem, ok := lr.m[key]
if !ok { if !ok {
return func() {}, true return func() {}, true
} }
if ok := sem.tryAcquire(); ok { if ok := sem.acquire(); ok {
return func() { sem.release() }, true return func() { sem.release() }, true
} }
return nil, false return nil, false

View file

@ -1,7 +1,6 @@
package limiting_test package limiting_test
import ( import (
"context"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@ -12,19 +11,41 @@ import (
) )
const ( const (
operationDuration = 100 * time.Millisecond operationDuration = 10 * time.Millisecond
operationCount = 64 operationCount = 64
) )
type testKeyLimit struct { type testCase struct {
keys []string keys []string
limit int limit int64
withoutLimit bool withoutLimit bool
failCount atomic.Int32 failCount atomic.Int64
} }
func TestLimiter(t *testing.T) { func TestLimiter(t *testing.T) {
testLimits := []*testKeyLimit{ t.Run("atomic limiter", func(t *testing.T) {
testLimiter(t, limiting.NewAtomicLimiter)
})
t.Run("burst atomic limiter", func(t *testing.T) {
testLimiter(t, limiting.NewBurstAtomicLimiter)
})
t.Run("channel limiter", func(t *testing.T) {
testLimiter(t, limiting.NewChannelLimiter)
})
}
func testLimiter(t *testing.T, getLimiter func([]limiting.KeyLimit) (limiting.Limiter, error)) {
t.Run("duplicate key", func(t *testing.T) {
_, err := getLimiter([]limiting.KeyLimit{
{[]string{"A", "B"}, 10},
{[]string{"B", "C"}, 10},
})
require.Error(t, err)
})
testCases := []*testCase{
{keys: []string{"A"}, limit: operationCount / 4}, {keys: []string{"A"}, limit: operationCount / 4},
{keys: []string{"B"}, limit: operationCount / 2}, {keys: []string{"B"}, limit: operationCount / 2},
{keys: []string{"C", "D"}, limit: operationCount / 4}, {keys: []string{"C", "D"}, limit: operationCount / 4},
@ -32,34 +53,25 @@ func TestLimiter(t *testing.T) {
{keys: []string{"F"}, withoutLimit: true}, {keys: []string{"F"}, withoutLimit: true},
} }
t.Run("non-blocking mode", func(t *testing.T) { lr, err := getLimiter(getLimits(testCases))
testLimiter(t, testLimits, false) require.NoError(t, err)
})
resetFailCounts(testLimits) tasks := createTestTasks(testCases, lr)
t.Run("blocking mode", func(t *testing.T) {
testLimiter(t, testLimits, true)
})
}
func testLimiter(t *testing.T, testCases []*testKeyLimit, blocking bool) {
lr := limiting.New(getLimits(testCases))
tasks := createTestTasks(testCases, lr, blocking)
t.Run("first run", func(t *testing.T) { t.Run("first run", func(t *testing.T) {
executeTasks(tasks...) executeTasks(tasks...)
verifyResults(t, testCases, blocking) verifyResults(t, testCases)
}) })
t.Run("repeated run", func(t *testing.T) {
resetFailCounts(testCases) resetFailCounts(testCases)
t.Run("repeated run", func(t *testing.T) {
executeTasks(tasks...) executeTasks(tasks...)
verifyResults(t, testCases, blocking) verifyResults(t, testCases)
}) })
} }
func getLimits(testCases []*testKeyLimit) []limiting.KeyLimit { func getLimits(testCases []*testCase) []limiting.KeyLimit {
var limits []limiting.KeyLimit var limits []limiting.KeyLimit
for _, tc := range testCases { for _, tc := range testCases {
if tc.withoutLimit { if tc.withoutLimit {
@ -67,40 +79,31 @@ func getLimits(testCases []*testKeyLimit) []limiting.KeyLimit {
} }
limits = append(limits, limiting.KeyLimit{ limits = append(limits, limiting.KeyLimit{
Keys: tc.keys, Keys: tc.keys,
Limit: uint64(tc.limit), Limit: int64(tc.limit),
}) })
} }
return limits return limits
} }
func createTestTasks(testCases []*testKeyLimit, lr *limiting.Limiter, blocking bool) []func() { func createTestTasks(testCases []*testCase, lr limiting.Limiter) []func() {
var tasks []func() var tasks []func()
for _, tc := range testCases { for _, tc := range testCases {
for _, key := range tc.keys { for _, key := range tc.keys {
tasks = append(tasks, func() { tasks = append(tasks, func() {
executeTaskN(operationCount, func() { acquireAndExecute(tc, lr, key, blocking) }) executeTaskN(operationCount, func() { acquireAndExecute(tc, lr, key) })
}) })
} }
} }
return tasks return tasks
} }
func acquireAndExecute(tc *testKeyLimit, lr *limiting.Limiter, key string, blocking bool) { func acquireAndExecute(tc *testCase, lr limiting.Limiter, key string) {
if blocking { release, ok := lr.Acquire(key)
release, err := lr.Acquire(context.Background(), key)
if err != nil {
tc.failCount.Add(1)
return
}
defer release()
} else {
release, ok := lr.TryAcquire(key)
if !ok { if !ok {
tc.failCount.Add(1) tc.failCount.Add(1)
return return
} }
defer release() defer release()
}
time.Sleep(operationDuration) time.Sleep(operationDuration)
} }
@ -125,21 +128,19 @@ func executeTaskN(N int, task func()) {
executeTasks(tasks...) executeTasks(tasks...)
} }
func verifyResults(t *testing.T, testCases []*testKeyLimit, blocking bool) { func verifyResults(t *testing.T, testCases []*testCase) {
for _, tc := range testCases { for _, tc := range testCases {
var expectedFailCount int var expectedFailCount int64
if blocking || tc.withoutLimit { if !tc.withoutLimit {
expectedFailCount = 0 numKeys := int64(len(tc.keys))
} else { expectedFailCount = max(operationCount*numKeys-tc.limit, 0)
expectedFailCount = max(operationCount*len(tc.keys)-tc.limit, 0)
} }
actualFailCount := int(tc.failCount.Load()) require.Equal(t, expectedFailCount, tc.failCount.Load())
require.Equal(t, expectedFailCount, actualFailCount)
} }
} }
func resetFailCounts(testLimits []*testKeyLimit) { func resetFailCounts(testCases []*testCase) {
for _, tc := range testLimits { for _, tc := range testCases {
tc.failCount.Store(0) tc.failCount.Store(0)
} }
} }

74
limiting/semaphore.go Normal file
View file

@ -0,0 +1,74 @@
package limiting
import (
"sync/atomic"
)
type atomicSemaphore struct {
countDown atomic.Int64
}
func newAtomicSemaphore(size int64) *atomicSemaphore {
sem := new(atomicSemaphore)
sem.countDown.Store(size)
return sem
}
func (s *atomicSemaphore) acquire() bool {
for {
v := s.countDown.Load()
if v == 0 {
return false
}
if s.countDown.CompareAndSwap(v, v-1) {
return true
}
}
}
func (s *atomicSemaphore) release() {
s.countDown.Add(1)
}
type burstAtomicSemaphore struct {
count atomic.Int64
limit int64
}
func newBurstAtomicSemaphore(size int64) *burstAtomicSemaphore {
return &burstAtomicSemaphore{limit: size}
}
func (s *burstAtomicSemaphore) acquire() bool {
v := s.count.Add(1)
if v > s.limit {
s.count.Add(-1)
return false
}
return true
}
func (s *burstAtomicSemaphore) release() {
s.count.Add(-1)
}
type channelSemaphore struct {
ch chan struct{}
}
func newChannelSemaphore(size int64) *channelSemaphore {
return &channelSemaphore{make(chan struct{}, size)}
}
func (s *channelSemaphore) acquire() bool {
select {
case s.ch <- struct{}{}:
return true
default:
return false
}
}
func (s *channelSemaphore) release() {
<-s.ch
}