109 lines
2.8 KiB
Go
109 lines
2.8 KiB
Go
// Copyright (C) 2020 Storj Labs, Inc.
|
|
// See LICENSE for copying information.
|
|
|
|
package sync2
|
|
|
|
import (
|
|
"context"
|
|
"math"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/zeebo/errs"
|
|
)
|
|
|
|
// SuccessThreshold tracks task formed by a known amount of concurrent tasks.
|
|
// It notifies the caller when reached a specific successful threshold without
|
|
// interrupting the remaining tasks.
|
|
type SuccessThreshold struct {
|
|
noCopy noCopy // nolint: structcheck
|
|
|
|
toSucceed int64
|
|
pending int64
|
|
|
|
successes int64
|
|
failures int64
|
|
|
|
done chan struct{}
|
|
once sync.Once
|
|
}
|
|
|
|
// NewSuccessThreshold creates a SuccessThreshold with the tasks number and
|
|
// successThreshold.
|
|
//
|
|
// It returns an error if tasks is less or equal than 1 or successThreshold
|
|
// is less or equal than 0 or greater or equal than 1.
|
|
func NewSuccessThreshold(tasks int, successThreshold float64) (*SuccessThreshold, error) {
|
|
switch {
|
|
case tasks <= 1:
|
|
return nil, errs.New(
|
|
"invalid number of tasks. It must be greater than 1, got %d", tasks,
|
|
)
|
|
case successThreshold <= 0 || successThreshold > 1:
|
|
return nil, errs.New(
|
|
"invalid successThreshold. It must be greater than 0 and less or equal to 1, got %f", successThreshold,
|
|
)
|
|
}
|
|
|
|
tasksToSuccess := int64(math.Ceil(float64(tasks) * successThreshold))
|
|
|
|
// just in case of floating point issues
|
|
if tasksToSuccess > int64(tasks) {
|
|
tasksToSuccess = int64(tasks)
|
|
}
|
|
|
|
return &SuccessThreshold{
|
|
toSucceed: tasksToSuccess,
|
|
pending: int64(tasks),
|
|
done: make(chan struct{}),
|
|
}, nil
|
|
}
|
|
|
|
// Success tells the SuccessThreshold that one tasks was successful.
|
|
func (threshold *SuccessThreshold) Success() {
|
|
atomic.AddInt64(&threshold.successes, 1)
|
|
|
|
if atomic.AddInt64(&threshold.toSucceed, -1) <= 0 {
|
|
threshold.markAsDone()
|
|
}
|
|
|
|
if atomic.AddInt64(&threshold.pending, -1) <= 0 {
|
|
threshold.markAsDone()
|
|
}
|
|
}
|
|
|
|
// Failure tells the SuccessThreshold that one task was a failure.
|
|
func (threshold *SuccessThreshold) Failure() {
|
|
atomic.AddInt64(&threshold.failures, 1)
|
|
|
|
if atomic.AddInt64(&threshold.pending, -1) <= 0 {
|
|
threshold.markAsDone()
|
|
}
|
|
}
|
|
|
|
// Wait blocks the caller until the successThreshold is reached or all the tasks
|
|
// have finished.
|
|
func (threshold *SuccessThreshold) Wait(ctx context.Context) {
|
|
select {
|
|
case <-ctx.Done():
|
|
case <-threshold.done:
|
|
}
|
|
}
|
|
|
|
// markAsDone finalizes threshold closing the completed channel just once.
|
|
// It's safe to be called multiple times.
|
|
func (threshold *SuccessThreshold) markAsDone() {
|
|
threshold.once.Do(func() {
|
|
close(threshold.done)
|
|
})
|
|
}
|
|
|
|
// SuccessCount returns the number of successes so far.
|
|
func (threshold *SuccessThreshold) SuccessCount() int {
|
|
return int(atomic.LoadInt64(&threshold.successes))
|
|
}
|
|
|
|
// FailureCount returns the number of failures so far.
|
|
func (threshold *SuccessThreshold) FailureCount() int {
|
|
return int(atomic.LoadInt64(&threshold.failures))
|
|
}
|