diff --git a/limiting/limiter.go b/limiting/limiter.go index ad1edd6..e02c444 100644 --- a/limiting/limiter.go +++ b/limiting/limiter.go @@ -2,6 +2,7 @@ package limiting import ( "context" + "fmt" ) type semaphore struct { @@ -51,15 +52,24 @@ type KeyLimit struct { type ReleaseFunc func() -func New(limits []KeyLimit) *Limiter { +func New(limits []KeyLimit) (*Limiter, error) { lr := Limiter{m: make(map[string]*semaphore)} for _, l := range limits { - sem := newSemaphore(l.Limit) - for _, key := range l.Keys { - lr.m[key] = sem + if err := addLimit(&lr, l.Keys, newSemaphore(l.Limit)); err != nil { + return nil, err } } - return &lr + return &lr, nil +} + +func addLimit(lr *Limiter, keys []string, sem *semaphore) error { + for _, key := range keys { + if _, exists := lr.m[key]; exists { + return fmt.Errorf("duplicate key %q", key) + } + lr.m[key] = sem + } + return nil } // Acquire reserves a slot for the given key, blocking if necessary. diff --git a/limiting/limiter_test.go b/limiting/limiter_test.go index 72af2ba..22440f8 100644 --- a/limiting/limiter_test.go +++ b/limiting/limiter_test.go @@ -24,6 +24,14 @@ type testKeyLimit struct { } func TestLimiter(t *testing.T) { + t.Run("duplicate key", func(t *testing.T) { + _, err := limiting.New([]limiting.KeyLimit{ + {[]string{"A", "B"}, 10}, + {[]string{"B", "C"}, 10}, + }) + require.Error(t, err) + }) + testLimits := []*testKeyLimit{ {keys: []string{"A"}, limit: operationCount / 4}, {keys: []string{"B"}, limit: operationCount / 2}, @@ -44,7 +52,9 @@ func TestLimiter(t *testing.T) { } func testLimiter(t *testing.T, testCases []*testKeyLimit, blocking bool) { - lr := limiting.New(getLimits(testCases)) + lr, err := limiting.New(getLimits(testCases)) + require.NoError(t, err) + tasks := createTestTasks(testCases, lr, blocking) t.Run("first run", func(t *testing.T) {