forked from TrueCloudLab/restic
Make worker pools input/output chans symmetric
Input and output channel are now both of type `chan Job`, this makes it possible to chain multiple worker pools together.
This commit is contained in:
parent
e5ee4eba53
commit
ee422110c8
2 changed files with 26 additions and 27 deletions
|
@ -1,14 +1,14 @@
|
|||
package worker
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Job is one unit of work.
|
||||
type Job interface{}
|
||||
|
||||
// Result is something the worker function returned, including the original job
|
||||
// and an (eventual) error.
|
||||
type Result struct {
|
||||
Job Job
|
||||
// Job is one unit of work. It is given to a Func, and the returned result and
|
||||
// error are stored in Result and Error.
|
||||
type Job struct {
|
||||
Data interface{}
|
||||
Result interface{}
|
||||
Error error
|
||||
}
|
||||
|
@ -22,12 +22,12 @@ type Pool struct {
|
|||
done chan struct{}
|
||||
wg *sync.WaitGroup
|
||||
jobCh <-chan Job
|
||||
resCh chan<- Result
|
||||
resCh chan<- Job
|
||||
}
|
||||
|
||||
// New returns a new worker pool with n goroutines, each running the function
|
||||
// f. The workers are started immediately.
|
||||
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Result) *Pool {
|
||||
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool {
|
||||
p := &Pool{
|
||||
f: f,
|
||||
done: make(chan struct{}),
|
||||
|
@ -52,10 +52,9 @@ func (p *Pool) runWorker(numWorker int) {
|
|||
// enable the input channel when starting up a new goroutine
|
||||
inCh = p.jobCh
|
||||
// but do not enable the output channel until we have a result
|
||||
outCh chan<- Result
|
||||
outCh chan<- Job
|
||||
|
||||
job Job
|
||||
res Result
|
||||
ok bool
|
||||
)
|
||||
|
||||
|
@ -66,16 +65,15 @@ func (p *Pool) runWorker(numWorker int) {
|
|||
|
||||
case job, ok = <-inCh:
|
||||
if !ok {
|
||||
fmt.Printf("in channel closed, worker exiting\n")
|
||||
return
|
||||
}
|
||||
|
||||
r, err := p.f(job, p.done)
|
||||
res = Result{Job: job, Result: r, Error: err}
|
||||
|
||||
job.Result, job.Error = p.f(job, p.done)
|
||||
inCh = nil
|
||||
outCh = p.resCh
|
||||
|
||||
case outCh <- res:
|
||||
case outCh <- job:
|
||||
outCh = nil
|
||||
inCh = p.jobCh
|
||||
}
|
||||
|
|
|
@ -13,16 +13,16 @@ const concurrency = 10
|
|||
var errTooLarge = errors.New("too large")
|
||||
|
||||
func square(job worker.Job, done <-chan struct{}) (interface{}, error) {
|
||||
n := job.(int)
|
||||
n := job.Data.(int)
|
||||
if n > 2000 {
|
||||
return nil, errTooLarge
|
||||
}
|
||||
return n * n, nil
|
||||
}
|
||||
|
||||
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Result, *worker.Pool) {
|
||||
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) {
|
||||
inCh := make(chan worker.Job, bufsize)
|
||||
outCh := make(chan worker.Result, bufsize)
|
||||
outCh := make(chan worker.Job, bufsize)
|
||||
|
||||
return inCh, outCh, worker.New(n, f, inCh, outCh)
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func TestPool(t *testing.T) {
|
|||
inCh, outCh, p := newBufferedPool(200, concurrency, square)
|
||||
|
||||
for i := 0; i < 150; i++ {
|
||||
inCh <- i
|
||||
inCh <- worker.Job{Data: i}
|
||||
}
|
||||
|
||||
close(inCh)
|
||||
|
@ -39,10 +39,11 @@ func TestPool(t *testing.T) {
|
|||
|
||||
for res := range outCh {
|
||||
if res.Error != nil {
|
||||
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
|
||||
t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
|
||||
continue
|
||||
}
|
||||
|
||||
n := res.Job.(int)
|
||||
n := res.Data.(int)
|
||||
m := res.Result.(int)
|
||||
|
||||
if m != n*n {
|
||||
|
@ -55,14 +56,14 @@ func TestPoolErrors(t *testing.T) {
|
|||
inCh, outCh, p := newBufferedPool(200, concurrency, square)
|
||||
|
||||
for i := 0; i < 150; i++ {
|
||||
inCh <- i + 1900
|
||||
inCh <- worker.Job{Data: i + 1900}
|
||||
}
|
||||
|
||||
close(inCh)
|
||||
p.Wait()
|
||||
|
||||
for res := range outCh {
|
||||
n := res.Job.(int)
|
||||
n := res.Data.(int)
|
||||
|
||||
if n > 2000 {
|
||||
if res.Error == nil {
|
||||
|
@ -77,7 +78,7 @@ func TestPoolErrors(t *testing.T) {
|
|||
continue
|
||||
} else {
|
||||
if res.Error != nil {
|
||||
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
|
||||
t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
@ -92,7 +93,7 @@ func TestPoolErrors(t *testing.T) {
|
|||
var errCancelled = errors.New("cancelled")
|
||||
|
||||
func wait(job worker.Job, done <-chan struct{}) (interface{}, error) {
|
||||
d := job.(time.Duration)
|
||||
d := job.Data.(time.Duration)
|
||||
select {
|
||||
case <-time.After(d):
|
||||
return time.Now(), nil
|
||||
|
@ -105,7 +106,7 @@ func TestPoolCancel(t *testing.T) {
|
|||
jobCh, resCh, p := newBufferedPool(20, concurrency, wait)
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
jobCh <- 10 * time.Millisecond
|
||||
jobCh <- worker.Job{Data: 10 * time.Millisecond}
|
||||
}
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
|
Loading…
Reference in a new issue