forked from TrueCloudLab/restic
Add worker pool
A worker pool is needed whenever something should be done concurrently. This small library makes it easy to create a worker pool by specifying channels, concurrency and a function that should be executed for each job and returns a result and an error.
This commit is contained in:
parent
1e0b7dbdd2
commit
e5ee4eba53
3 changed files with 231 additions and 0 deletions
2
src/restic/worker/doc.go
Normal file
2
src/restic/worker/doc.go
Normal file
|
@ -0,0 +1,2 @@
|
|||
// Package worker implements a worker pool.
|
||||
package worker
|
95
src/restic/worker/pool.go
Normal file
95
src/restic/worker/pool.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package worker
|
||||
|
||||
import "sync"
|
||||
|
||||
// Job is one unit of work.
|
||||
type Job interface{}
|
||||
|
||||
// Result is something the worker function returned, including the original job
|
||||
// and an (eventual) error.
|
||||
type Result struct {
|
||||
Job Job
|
||||
Result interface{}
|
||||
Error error
|
||||
}
|
||||
|
||||
// Func does the actual work within a Pool.
|
||||
type Func func(job Job, done <-chan struct{}) (result interface{}, err error)
|
||||
|
||||
// Pool implements a worker pool.
|
||||
type Pool struct {
|
||||
f Func
|
||||
done chan struct{}
|
||||
wg *sync.WaitGroup
|
||||
jobCh <-chan Job
|
||||
resCh chan<- Result
|
||||
}
|
||||
|
||||
// New returns a new worker pool with n goroutines, each running the function
|
||||
// f. The workers are started immediately.
|
||||
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Result) *Pool {
|
||||
p := &Pool{
|
||||
f: f,
|
||||
done: make(chan struct{}),
|
||||
wg: &sync.WaitGroup{},
|
||||
jobCh: jobChan,
|
||||
resCh: resultChan,
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
p.wg.Add(1)
|
||||
go p.runWorker(i)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// runWorker runs a worker function.
|
||||
func (p *Pool) runWorker(numWorker int) {
|
||||
defer p.wg.Done()
|
||||
|
||||
var (
|
||||
// enable the input channel when starting up a new goroutine
|
||||
inCh = p.jobCh
|
||||
// but do not enable the output channel until we have a result
|
||||
outCh chan<- Result
|
||||
|
||||
job Job
|
||||
res Result
|
||||
ok bool
|
||||
)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
|
||||
case job, ok = <-inCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
r, err := p.f(job, p.done)
|
||||
res = Result{Job: job, Result: r, Error: err}
|
||||
|
||||
inCh = nil
|
||||
outCh = p.resCh
|
||||
|
||||
case outCh <- res:
|
||||
outCh = nil
|
||||
inCh = p.jobCh
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel signals termination to all worker goroutines.
|
||||
func (p *Pool) Cancel() {
|
||||
close(p.done)
|
||||
}
|
||||
|
||||
// Wait waits for all worker goroutines to terminate, afterwards the output
|
||||
// channel is closed.
|
||||
func (p *Pool) Wait() {
|
||||
p.wg.Wait()
|
||||
close(p.resCh)
|
||||
}
|
134
src/restic/worker/pool_test.go
Normal file
134
src/restic/worker/pool_test.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package worker_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"restic/worker"
|
||||
)
|
||||
|
||||
const concurrency = 10
|
||||
|
||||
var errTooLarge = errors.New("too large")
|
||||
|
||||
func square(job worker.Job, done <-chan struct{}) (interface{}, error) {
|
||||
n := job.(int)
|
||||
if n > 2000 {
|
||||
return nil, errTooLarge
|
||||
}
|
||||
return n * n, nil
|
||||
}
|
||||
|
||||
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Result, *worker.Pool) {
|
||||
inCh := make(chan worker.Job, bufsize)
|
||||
outCh := make(chan worker.Result, bufsize)
|
||||
|
||||
return inCh, outCh, worker.New(n, f, inCh, outCh)
|
||||
}
|
||||
|
||||
func TestPool(t *testing.T) {
|
||||
inCh, outCh, p := newBufferedPool(200, concurrency, square)
|
||||
|
||||
for i := 0; i < 150; i++ {
|
||||
inCh <- i
|
||||
}
|
||||
|
||||
close(inCh)
|
||||
p.Wait()
|
||||
|
||||
for res := range outCh {
|
||||
if res.Error != nil {
|
||||
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
|
||||
}
|
||||
|
||||
n := res.Job.(int)
|
||||
m := res.Result.(int)
|
||||
|
||||
if m != n*n {
|
||||
t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolErrors(t *testing.T) {
|
||||
inCh, outCh, p := newBufferedPool(200, concurrency, square)
|
||||
|
||||
for i := 0; i < 150; i++ {
|
||||
inCh <- i + 1900
|
||||
}
|
||||
|
||||
close(inCh)
|
||||
p.Wait()
|
||||
|
||||
for res := range outCh {
|
||||
n := res.Job.(int)
|
||||
|
||||
if n > 2000 {
|
||||
if res.Error == nil {
|
||||
t.Errorf("expected error not found, result is %v", res)
|
||||
continue
|
||||
}
|
||||
|
||||
if res.Error != errTooLarge {
|
||||
t.Errorf("unexpected error found, result is %v", res)
|
||||
}
|
||||
|
||||
continue
|
||||
} else {
|
||||
if res.Error != nil {
|
||||
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
m := res.Result.(int)
|
||||
if m != n*n {
|
||||
t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var errCancelled = errors.New("cancelled")
|
||||
|
||||
func wait(job worker.Job, done <-chan struct{}) (interface{}, error) {
|
||||
d := job.(time.Duration)
|
||||
select {
|
||||
case <-time.After(d):
|
||||
return time.Now(), nil
|
||||
case <-done:
|
||||
return nil, errCancelled
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolCancel(t *testing.T) {
|
||||
jobCh, resCh, p := newBufferedPool(20, concurrency, wait)
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
jobCh <- 10 * time.Millisecond
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
p.Cancel()
|
||||
p.Wait()
|
||||
|
||||
foundResult := false
|
||||
foundCancelError := false
|
||||
for res := range resCh {
|
||||
if res.Error == nil {
|
||||
foundResult = true
|
||||
}
|
||||
|
||||
if res.Error == errCancelled {
|
||||
foundCancelError = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundResult {
|
||||
t.Error("did not find one expected result")
|
||||
}
|
||||
|
||||
if !foundCancelError {
|
||||
t.Error("did not find one expected cancel error")
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue