Make worker pools input/output chans symmetric

Input and output channel are now both of type `chan Job`, this makes it
possible to chain multiple worker pools together.
This commit is contained in:
Alexander Neumann 2016-02-05 22:22:24 +01:00
parent e5ee4eba53
commit ee422110c8
2 changed files with 26 additions and 27 deletions

View file

@ -1,14 +1,14 @@
package worker
import "sync"
import (
"fmt"
"sync"
)
// Job is one unit of work.
type Job interface{}
// Result is something the worker function returned, including the original job
// and an (eventual) error.
type Result struct {
Job Job
// Job is one unit of work. It is given to a Func, and the returned result and
// error are stored in Result and Error.
type Job struct {
Data interface{}
Result interface{}
Error error
}
@ -22,12 +22,12 @@ type Pool struct {
done chan struct{}
wg *sync.WaitGroup
jobCh <-chan Job
resCh chan<- Result
resCh chan<- Job
}
// New returns a new worker pool with n goroutines, each running the function
// f. The workers are started immediately.
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Result) *Pool {
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool {
p := &Pool{
f: f,
done: make(chan struct{}),
@ -52,10 +52,9 @@ func (p *Pool) runWorker(numWorker int) {
// enable the input channel when starting up a new goroutine
inCh = p.jobCh
// but do not enable the output channel until we have a result
outCh chan<- Result
outCh chan<- Job
job Job
res Result
ok bool
)
@ -66,16 +65,15 @@ func (p *Pool) runWorker(numWorker int) {
case job, ok = <-inCh:
if !ok {
fmt.Printf("in channel closed, worker exiting\n")
return
}
r, err := p.f(job, p.done)
res = Result{Job: job, Result: r, Error: err}
job.Result, job.Error = p.f(job, p.done)
inCh = nil
outCh = p.resCh
case outCh <- res:
case outCh <- job:
outCh = nil
inCh = p.jobCh
}

View file

@ -13,16 +13,16 @@ const concurrency = 10
var errTooLarge = errors.New("too large")
func square(job worker.Job, done <-chan struct{}) (interface{}, error) {
n := job.(int)
n := job.Data.(int)
if n > 2000 {
return nil, errTooLarge
}
return n * n, nil
}
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Result, *worker.Pool) {
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) {
inCh := make(chan worker.Job, bufsize)
outCh := make(chan worker.Result, bufsize)
outCh := make(chan worker.Job, bufsize)
return inCh, outCh, worker.New(n, f, inCh, outCh)
}
@ -31,7 +31,7 @@ func TestPool(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square)
for i := 0; i < 150; i++ {
inCh <- i
inCh <- worker.Job{Data: i}
}
close(inCh)
@ -39,10 +39,11 @@ func TestPool(t *testing.T) {
for res := range outCh {
if res.Error != nil {
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
continue
}
n := res.Job.(int)
n := res.Data.(int)
m := res.Result.(int)
if m != n*n {
@ -55,14 +56,14 @@ func TestPoolErrors(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square)
for i := 0; i < 150; i++ {
inCh <- i + 1900
inCh <- worker.Job{Data: i + 1900}
}
close(inCh)
p.Wait()
for res := range outCh {
n := res.Job.(int)
n := res.Data.(int)
if n > 2000 {
if res.Error == nil {
@ -77,7 +78,7 @@ func TestPoolErrors(t *testing.T) {
continue
} else {
if res.Error != nil {
t.Errorf("unexpected error for job %v received: %v", res.Job, res.Error)
t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
continue
}
}
@ -92,7 +93,7 @@ func TestPoolErrors(t *testing.T) {
var errCancelled = errors.New("cancelled")
func wait(job worker.Job, done <-chan struct{}) (interface{}, error) {
d := job.(time.Duration)
d := job.Data.(time.Duration)
select {
case <-time.After(d):
return time.Now(), nil
@ -105,7 +106,7 @@ func TestPoolCancel(t *testing.T) {
jobCh, resCh, p := newBufferedPool(20, concurrency, wait)
for i := 0; i < 20; i++ {
jobCh <- 10 * time.Millisecond
jobCh <- worker.Job{Data: 10 * time.Millisecond}
}
time.Sleep(20 * time.Millisecond)