diff --git a/pkg/core/quota/limiter.go b/pkg/core/quota/limiter.go new file mode 100644 index 000000000..a4da2d62f --- /dev/null +++ b/pkg/core/quota/limiter.go @@ -0,0 +1,292 @@ +package quota + +import ( + "container/heap" + "context" + "fmt" + "math/rand/v2" + "sync" + "sync/atomic" +) + +type Quota struct { + Read int64 + Write int64 +} + +type Priority struct { + Class string + Value byte +} + +type Release func() + +type limit struct { + read, write int64 + maxRead, maxWrite int64 +} + +type queueItem struct { + priority byte + ts uint64 + index int +} + +type queue struct { + l *limit + w int32 + items []*queueItem +} + +type QuotaLimiter struct { + l *limit + classes map[string]*queue + queues []*queue + nextQueue *queue + cond *sync.Cond + ts *atomic.Uint64 +} + +func (ql *QuotaLimiter) Acquire(ctx context.Context, p Priority, q Quota) (Release, error) { + queue, ok := ql.classes[p.Class] + if !ok { + return nil, fmt.Errorf("unknown class '%s'", p.Class) + } + + if ql.l.maxRead > 0 && q.Read > ql.l.maxRead { + return nil, fmt.Errorf("read quota %d exceeds total limit %d", q.Read, ql.l.maxRead) + } + + if ql.l.maxWrite > 0 && q.Write > ql.l.maxWrite { + return nil, fmt.Errorf("read quota %d exceeds total limit %d", q.Write, ql.l.maxWrite) + } + + if queue.l.maxRead > 0 && q.Read > queue.l.maxRead { + return nil, fmt.Errorf("read quota %d exceeds queue limit %d", q.Read, queue.l.maxRead) + } + + if queue.l.maxWrite > 0 && q.Write > queue.l.maxWrite { + return nil, fmt.Errorf("read quota %d exceeds queue limit %d", q.Write, queue.l.maxWrite) + } + + ts := ql.ts.Add(1) + + ql.cond.L.Lock() + defer ql.cond.L.Unlock() + + stop := context.AfterFunc(ctx, func() { + ql.cond.Broadcast() + }) + defer stop() + + allow := ql.nextQueue == nil && // no scheduled queue + hasQuota(q, queue.l) && // queue limit + hasQuota(q, ql.l) // global lomit + + if allow { + applyQuota(q, queue.l) + applyQuota(q, ql.l) + return func() { ql.release(p, q) }, nil + } + + qi := &queueItem{ + priority: p.Value, + ts: ts, + } + + queue.push(qi) + if queue.count() == 1 { + ql.resetNextQueue() + } + + var hasGlobalQuota, hasQueueQuota, isNextItem bool + for !allow { + ql.cond.Wait() + + if err := ctx.Err(); err != nil { + queue.drop(qi) + if queue.count() == 0 { + ql.resetNextQueue() + } + return nil, ctx.Err() + } + + hasGlobalQuota = hasQuota(q, ql.l) + hasQueueQuota = hasQuota(q, queue.l) + isNextItem = ql.nextQueue == queue && queue.top() == qi + + if hasGlobalQuota && !hasQueueQuota && isNextItem { + ql.changeNextQueue() + } + allow = hasGlobalQuota && hasQueueQuota && isNextItem + } + + applyQuota(q, queue.l) + applyQuota(q, ql.l) + queue.pop() + ql.resetNextQueue() + return func() { ql.release(p, q) }, nil +} + +func (ql *QuotaLimiter) release(p Priority, q Quota) { + queue, ok := ql.classes[p.Class] + if !ok { + panic("unknown class " + p.Class) + } + + ql.cond.L.Lock() + defer ql.cond.L.Unlock() + + releaseQuota(q, queue.l) + releaseQuota(q, ql.l) + + ql.cond.Broadcast() +} + +func (ql *QuotaLimiter) resetNextQueue() { + var nonEmptyQueues []*queue + var totalWeight int64 + for _, q := range ql.queues { + if q.count() > 0 { + nonEmptyQueues = append(nonEmptyQueues, q) + totalWeight += int64(q.weight()) + } + } + if len(nonEmptyQueues) == 0 { + ql.nextQueue = nil + return + } + ql.selectNextQueue(nonEmptyQueues, totalWeight) +} + +func (ql *QuotaLimiter) changeNextQueue() { + var nonEmptyQueues []*queue + var totalWeight int64 + for _, q := range ql.queues { + if q.count() > 0 && q != ql.nextQueue { + nonEmptyQueues = append(nonEmptyQueues, q) + totalWeight += int64(q.weight()) + } + } + if len(nonEmptyQueues) == 0 { + return + } + ql.selectNextQueue(nonEmptyQueues, totalWeight) +} + +func (ql *QuotaLimiter) selectNextQueue(nonEmptyQueues []*queue, totalWeight int64) { + if totalWeight == 0 { + ql.nextQueue = nonEmptyQueues[rand.IntN(len(nonEmptyQueues))] + return + } + weight := rand.Int64N(totalWeight) + var low, up int64 + for _, q := range nonEmptyQueues { + low = up + up += int64(q.weight()) + if weight >= low && weight < up { + ql.nextQueue = q + return + } + } + panic("undefined next queue") +} + +func hasQuota(q Quota, l *limit) bool { + if q.Read > 0 && l.maxRead > 0 && q.Read+l.read > l.maxRead { + return false + } + if q.Write > 0 && l.maxWrite > 0 && q.Write+l.write > l.write { + return false + } + return true +} + +func applyQuota(q Quota, l *limit) { + if q.Read > 0 && l.maxRead > 0 { + l.read += q.Read + } + if q.Write > 0 && l.maxWrite > 0 { + l.write += q.Write + } +} + +func releaseQuota(q Quota, l *limit) { + if q.Read > 0 && l.maxRead > 0 { + l.read -= q.Read + if l.read < 0 { + panic("invalid read limit after release") + } + } + if q.Write > 0 && l.maxWrite > 0 { + l.write -= q.Write + if l.write < 0 { + panic("invalid write limit after release") + } + } +} + +func (q *queue) push(qi *queueItem) { + heap.Push(q, qi) +} + +func (q *queue) pop() { + heap.Pop(q) +} + +func (q *queue) drop(qi *queueItem) { + heap.Remove(q, qi.index) +} + +func (q *queue) top() *queueItem { + if len(q.items) > 0 { + return q.items[0] + } + + return nil +} + +func (q *queue) count() int { + return len(q.items) +} + +func (q *queue) weight() int32 { + return q.w +} + +// Len implements heap.Interface. +func (q *queue) Len() int { + return q.count() +} + +// Less implements heap.Interface. +func (q *queue) Less(i int, j int) bool { + if q.items[i].priority == q.items[j].priority { + return q.items[i].ts < q.items[j].ts + } + return q.items[i].priority > q.items[j].priority +} + +// Pop implements heap.Interface. +func (q *queue) Pop() any { + n := len(q.items) + item := q.items[n-1] + q.items[n-1] = nil + q.items = q.items[0 : n-1] + item.index = -1 + return item +} + +// Push implements heap.Interface. +func (q *queue) Push(x any) { + it := x.(*queueItem) + it.index = q.Len() + q.items = append(q.items, it) +} + +// Swap implements heap.Interface. +func (q *queue) Swap(i int, j int) { + q.items[i], q.items[j] = q.items[j], q.items[i] + q.items[i].index = i + q.items[j].index = j +} diff --git a/pkg/core/quota/limiter_test.go b/pkg/core/quota/limiter_test.go new file mode 100644 index 000000000..a3ae36609 --- /dev/null +++ b/pkg/core/quota/limiter_test.go @@ -0,0 +1,100 @@ +package quota + +import ( + "math/rand/v2" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestQueue(t *testing.T) { + t.Run("different priority", func(t *testing.T) { + q := &queue{} + const count = 12345 + for i := range count { + priority := i % 256 + q.push(&queueItem{ + priority: byte(priority), + }) + } + testQueueInvariant(t, q) + }) + t.Run("same priority, different ts, inc", func(t *testing.T) { + q := &queue{} + var ts uint64 + const count = 10000 + for range count { + q.push(&queueItem{ + priority: 100, + ts: ts, + }) + ts++ + } + testQueueInvariant(t, q) + }) + t.Run("same priority, different ts, dec", func(t *testing.T) { + q := &queue{} + + const count = 10000 + ts := uint64(count) + for range count { + q.push(&queueItem{ + priority: 100, + ts: ts, + }) + ts-- + } + testQueueInvariant(t, q) + }) + t.Run("drop, inc", func(t *testing.T) { + q := &queue{} + var ts uint64 + const count = 12345 + for i := range count { + q.push(&queueItem{ + priority: byte(i % 256), + ts: ts, + }) + ts++ + } + for q.Len() > 0 { + idx := rand.IntN(q.count()) + it := q.items[idx] + q.drop(it) + testQueueInvariant(t, q) + } + }) + t.Run("drop, dec", func(t *testing.T) { + q := &queue{} + const count = 12345 + ts := uint64(count) + for i := range count { + q.push(&queueItem{ + priority: byte(i % 256), + ts: ts, + }) + ts-- + } + for q.Len() > 0 { + idx := rand.IntN(q.count()) + it := q.items[idx] + q.drop(it) + testQueueInvariant(t, q) + } + }) +} + +func testQueueInvariant(t *testing.T, q *queue) { + var previous *queueItem + for q.count() > 0 { + current := q.top() + if previous != nil { + require.True(t, previous.priority > current.priority || + (previous.priority == current.priority && + (previous.ts == current.ts || previous.ts < current.ts))) + } + previous = current + q.pop() + } + require.Equal(t, 0, q.count()) +}