diff --git a/pkg/core/quota/mclock.go b/pkg/core/quota/mclock.go index 9ec6d3f86..9c70cedc9 100644 --- a/pkg/core/quota/mclock.go +++ b/pkg/core/quota/mclock.go @@ -3,6 +3,7 @@ package quota import ( "container/heap" "context" + "errors" "math" "sync" "sync/atomic" @@ -14,6 +15,8 @@ const ( undefinedReservation float64 = -1.0 ) +var ErrSchedulerClosed = errors.New("mClock scheduler is closed") + type mQueueItem interface { ts() float64 setIndex(idx int) @@ -50,6 +53,7 @@ type TagInfo struct { type clock interface { now() float64 runAt(ts float64, f func()) + close() } type MClockQueue struct { @@ -67,6 +71,7 @@ type MClockQueue struct { limitQueue *mQueue sharesQueue *mQueue readyQueue *mQueue + closed bool } func NewMClockQueue(limit uint64, tagInfo map[string]TagInfo, idleTimeout float64) *MClockQueue { @@ -106,13 +111,31 @@ func NewMClockQueue(limit uint64, tagInfo map[string]TagInfo, idleTimeout float6 } func (q *MClockQueue) RequestArrival(ctx context.Context, tag string) (Release, error) { - req, release := q.pushRequest(tag) + req, release, err := q.pushRequest(tag) + if err != nil { + return nil, err + } select { case <-ctx.Done(): q.dropRequest(req) return nil, ctx.Err() case <-req.scheduled: return release, nil + case <-req.canceled: + return nil, ErrSchedulerClosed + } +} + +func (q *MClockQueue) Close() { + q.mtx.Lock() + defer q.mtx.Unlock() + + q.closed = true + q.clock.close() + for q.limitQueue.Len() > 0 { + item := heap.Pop(q.limitQueue).(*limitMQueueItem) + close(item.r.canceled) + q.removeFromQueues(item.r) } } @@ -132,10 +155,14 @@ func (q *MClockQueue) dropRequest(req *request) { q.removeFromQueues(req) } -func (q *MClockQueue) pushRequest(tag string) (*request, Release) { +func (q *MClockQueue) pushRequest(tag string) (*request, Release, error) { q.mtx.Lock() defer q.mtx.Unlock() + if q.closed { + return nil, nil, ErrSchedulerClosed + } + now := q.clock.now() tagInfo, ok := q.tagInfo[tag] if !ok { @@ -182,7 +209,7 @@ func (q *MClockQueue) pushRequest(tag string) (*request, Release) { heap.Push(q.limitQueue, &limitMQueueItem{r: r}) q.scheduleRequest(true) - return r, q.requestCompleted + return r, q.requestCompleted, nil } func (q *MClockQueue) adjustTags(now float64, idleTag string) { @@ -317,6 +344,10 @@ func (q *MClockQueue) requestCompleted() { q.mtx.Lock() defer q.mtx.Unlock() + if q.closed { + return + } + if q.inProgress == 0 { panic("invalid requests count") } diff --git a/pkg/core/quota/mclock_test.go b/pkg/core/quota/mclock_test.go index 81c3dcf50..0a95f2a95 100644 --- a/pkg/core/quota/mclock_test.go +++ b/pkg/core/quota/mclock_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) func TestMClockSharesScheduling(t *testing.T) { @@ -23,13 +24,15 @@ func TestMClockSharesScheduling(t *testing.T) { var requests []*request tag := "class1" for i := 0; i < reqCount/2; i++ { - req, release := q.pushRequest(tag) + req, release, err := q.pushRequest(tag) + require.NoError(t, err) requests = append(requests, req) releases = append(releases, release) } tag = "class2" for i := 0; i < reqCount/2; i++ { - req, release := q.pushRequest(tag) + req, release, err := q.pushRequest(tag) + require.NoError(t, err) requests = append(requests, req) releases = append(releases, release) } @@ -85,6 +88,8 @@ func (n *noopClock) runAt(ts float64, f func()) { return } +func (n *noopClock) close() {} + func TestMClockRequestCancel(t *testing.T) { q := NewMClockQueue(1, map[string]TagInfo{ "class1": {shares: 2}, @@ -124,13 +129,15 @@ func TestMClockLimitScheduling(t *testing.T) { var requests []*request tag := "class1" for i := 0; i < reqCount/2; i++ { - req, release := q.pushRequest(tag) + req, release, err := q.pushRequest(tag) + require.NoError(t, err) requests = append(requests, req) releases = append(releases, release) } tag = "class2" for i := 0; i < reqCount/2; i++ { - req, release := q.pushRequest(tag) + req, release, err := q.pushRequest(tag) + require.NoError(t, err) requests = append(requests, req) releases = append(releases, release) } @@ -206,13 +213,15 @@ func TestMClockReservationScheduling(t *testing.T) { var requests []*request tag := "class1" for i := 0; i < reqCount/2; i++ { - req, release := q.pushRequest(tag) + req, release, err := q.pushRequest(tag) + require.NoError(t, err) requests = append(requests, req) releases = append(releases, release) } tag = "class2" for i := 0; i < reqCount/2; i++ { - req, release := q.pushRequest(tag) + req, release, err := q.pushRequest(tag) + require.NoError(t, err) requests = append(requests, req) releases = append(releases, release) } @@ -268,7 +277,8 @@ func TestMClockIdleTag(t *testing.T) { tag := "class1" for i := 0; i < reqCount/2; i++ { cl.v += idleTimeout / 2 - req, _ := q.pushRequest(tag) + req, _, err := q.pushRequest(tag) + require.NoError(t, err) requests = append(requests, req) } @@ -277,7 +287,8 @@ func TestMClockIdleTag(t *testing.T) { cl.v += 2 * idleTimeout tag = "class2" - req, _ := q.pushRequest(tag) + req, _, err := q.pushRequest(tag) + require.NoError(t, err) requests = append(requests, req) // class2 must be defined as idle, so all shares tags must be adjusted. @@ -290,3 +301,48 @@ func TestMClockIdleTag(t *testing.T) { } } } + +func TestMClockClose(t *testing.T) { + q := NewMClockQueue(1, map[string]TagInfo{ + "class1": {shares: 1}, + }, 1000) + q.clock = &noopClock{} + + requestRunning := make(chan struct{}) + checkDone := make(chan struct{}) + eg, ctx := errgroup.WithContext(context.Background()) + tag := "class1" + eg.Go(func() error { + release, err := q.RequestArrival(ctx, tag) + if err != nil { + return err + } + defer release() + close(requestRunning) + <-checkDone + return nil + }) + <-requestRunning + + eg.Go(func() error { + release, err := q.RequestArrival(ctx, tag) + require.Nil(t, release) + require.ErrorIs(t, err, ErrSchedulerClosed) + return nil + }) + + // wait until second request will be blocked on wait + for q.limitQueue.Len() == 0 { + time.Sleep(1 * time.Second) + } + + q.Close() + + release, err := q.RequestArrival(context.Background(), tag) + require.Nil(t, release) + require.ErrorIs(t, err, ErrSchedulerClosed) + + close(checkDone) + + require.NoError(t, eg.Wait()) +}