package limiting_test import ( "sync" "sync/atomic" "testing" "time" "git.frostfs.info/TrueCloudLab/frostfs-qos/limiting" "github.com/stretchr/testify/require" ) const ( operationDuration = 10 * time.Millisecond operationCount = 64 ) type testCase struct { keys []string limit int64 withoutLimit bool failCount atomic.Int64 } func TestLimiter(t *testing.T) { testLimiter(t, func(kl []limiting.KeyLimit) (limiting.Limiter, error) { return limiting.NewSemaphoreLimiter(kl) }) } func testLimiter(t *testing.T, getLimiter func([]limiting.KeyLimit) (limiting.Limiter, error)) { t.Run("duplicate key", func(t *testing.T) { _, err := getLimiter([]limiting.KeyLimit{ {[]string{"A", "B"}, 10}, {[]string{"B", "C"}, 10}, }) require.Error(t, err) }) testCases := []*testCase{ {keys: []string{"A"}, limit: operationCount / 4}, {keys: []string{"B"}, limit: operationCount / 2}, {keys: []string{"C", "D"}, limit: operationCount / 4}, {keys: []string{"E"}, limit: 2 * operationCount}, {keys: []string{"F"}, withoutLimit: true}, } lr, err := getLimiter(getLimits(testCases)) require.NoError(t, err) tasks := createTestTasks(testCases, lr) t.Run("first run", func(t *testing.T) { executeTasks(tasks...) verifyResults(t, testCases) }) resetFailCounts(testCases) t.Run("repeated run", func(t *testing.T) { executeTasks(tasks...) verifyResults(t, testCases) }) } func getLimits(testCases []*testCase) []limiting.KeyLimit { var limits []limiting.KeyLimit for _, tc := range testCases { if tc.withoutLimit { continue } limits = append(limits, limiting.KeyLimit{ Keys: tc.keys, Limit: int64(tc.limit), }) } return limits } func createTestTasks(testCases []*testCase, lr limiting.Limiter) []func() { var tasks []func() for _, tc := range testCases { for _, key := range tc.keys { tasks = append(tasks, func() { executeTaskN(operationCount, func() { acquireAndExecute(tc, lr, key) }) }) } } return tasks } func acquireAndExecute(tc *testCase, lr limiting.Limiter, key string) { release, ok := lr.Acquire(key) if !ok { tc.failCount.Add(1) return } defer release() time.Sleep(operationDuration) } func executeTasks(tasks ...func()) { var g sync.WaitGroup g.Add(len(tasks)) for _, task := range tasks { go func() { defer g.Done() task() }() } g.Wait() } func executeTaskN(N int, task func()) { tasks := make([]func(), N) for i := range N { tasks[i] = task } executeTasks(tasks...) } func verifyResults(t *testing.T, testCases []*testCase) { for _, tc := range testCases { var expectedFailCount int64 if !tc.withoutLimit { numKeys := int64(len(tc.keys)) expectedFailCount = max(operationCount*numKeys-tc.limit, 0) } require.Equal(t, expectedFailCount, tc.failCount.Load()) } } func resetFailCounts(testCases []*testCase) { for _, tc := range testCases { tc.failCount.Store(0) } }