fs/buffer: Fix panic on concurrent Read/Close - fixes #1213

This commit is contained in:
Nick Craig-Wood 2017-03-06 19:22:17 +00:00
parent 8dc7bf883d
commit b2a4ea9304
2 changed files with 82 additions and 3 deletions

View file

@ -16,6 +16,8 @@ var asyncBufferPool = sync.Pool{
New: func() interface{} { return newBuffer() }, New: func() interface{} { return newBuffer() },
} }
var errorStreamAbandoned = errors.New("stream abandoned")
// asyncReader will do async read-ahead from the input reader // asyncReader will do async read-ahead from the input reader
// and make the data available as an io.Reader. // and make the data available as an io.Reader.
// This should be fully transparent, except that once an error // This should be fully transparent, except that once an error
@ -31,6 +33,7 @@ type asyncReader struct {
exited chan struct{} // Channel is closed been the async reader shuts down exited chan struct{} // Channel is closed been the async reader shuts down
size int // size of buffer to use size int // size of buffer to use
closed bool // whether we have closed the underlying stream closed bool // whether we have closed the underlying stream
mu sync.Mutex // lock for Read/WriteTo/Abandon/Close
} }
// newAsyncReader returns a reader that will asynchronously read from // newAsyncReader returns a reader that will asynchronously read from
@ -39,7 +42,7 @@ type asyncReader struct {
// function has returned. // function has returned.
// The input can be read from the returned reader. // The input can be read from the returned reader.
// When done use Close to release the buffers and close the supplied input. // When done use Close to release the buffers and close the supplied input.
func newAsyncReader(rd io.ReadCloser, buffers int) (io.ReadCloser, error) { func newAsyncReader(rd io.ReadCloser, buffers int) (*asyncReader, error) {
if buffers <= 0 { if buffers <= 0 {
return nil, errors.New("number of buffers too small") return nil, errors.New("number of buffers too small")
} }
@ -113,6 +116,10 @@ func (a *asyncReader) fill() (err error) {
} }
b, ok := <-a.ready b, ok := <-a.ready
if !ok { if !ok {
// Return an error to show fill failed
if a.err == nil {
return errorStreamAbandoned
}
return a.err return a.err
} }
a.cur = b a.cur = b
@ -122,6 +129,9 @@ func (a *asyncReader) fill() (err error) {
// Read will return the next available data. // Read will return the next available data.
func (a *asyncReader) Read(p []byte) (n int, err error) { func (a *asyncReader) Read(p []byte) (n int, err error) {
a.mu.Lock()
defer a.mu.Unlock()
// Swap buffer and maybe return error // Swap buffer and maybe return error
err = a.fill() err = a.fill()
if err != nil { if err != nil {
@ -144,6 +154,9 @@ func (a *asyncReader) Read(p []byte) (n int, err error) {
// The return value n is the number of bytes written. // The return value n is the number of bytes written.
// Any error encountered during the write is also returned. // Any error encountered during the write is also returned.
func (a *asyncReader) WriteTo(w io.Writer) (n int64, err error) { func (a *asyncReader) WriteTo(w io.Writer) (n int64, err error) {
a.mu.Lock()
defer a.mu.Unlock()
n = 0 n = 0
for { for {
err = a.fill() err = a.fill()
@ -175,6 +188,9 @@ func (a *asyncReader) Abandon() {
// Close and wait for go routine // Close and wait for go routine
close(a.exit) close(a.exit)
<-a.exited <-a.exited
// take the lock to wait for Read/WriteTo to complete
a.mu.Lock()
defer a.mu.Unlock()
// Return any outstanding buffers to the Pool // Return any outstanding buffers to the Pool
if a.cur != nil { if a.cur != nil {
a.putBuffer(a.cur) a.putBuffer(a.cur)

View file

@ -6,8 +6,10 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"strings" "strings"
"sync"
"testing" "testing"
"testing/iotest" "testing/iotest"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -196,8 +198,7 @@ func TestAsyncReaderWriteTo(t *testing.T) {
buf := bufio.NewReaderSize(read, bufsize) buf := bufio.NewReaderSize(read, bufsize)
ar, _ := newAsyncReader(ioutil.NopCloser(buf), l) ar, _ := newAsyncReader(ioutil.NopCloser(buf), l)
dst := &bytes.Buffer{} dst := &bytes.Buffer{}
wt := ar.(io.WriterTo) _, err := ar.WriteTo(dst)
_, err := wt.WriteTo(dst)
if err != nil && err != io.EOF && err != iotest.ErrTimeout { if err != nil && err != io.EOF && err != iotest.ErrTimeout {
t.Fatal("Copy:", err) t.Fatal("Copy:", err)
} }
@ -215,3 +216,65 @@ func TestAsyncReaderWriteTo(t *testing.T) {
} }
} }
} }
// Read an infinite number of zeros
type zeroReader struct {
closed bool
}
func (z *zeroReader) Read(p []byte) (n int, err error) {
if z.closed {
return 0, io.EOF
}
for i := range p {
p[i] = 0
}
return len(p), nil
}
func (z *zeroReader) Close() error {
if z.closed {
panic("double close on zeroReader")
}
z.closed = true
return nil
}
// Test closing and abandoning
func testAsyncReaderClose(t *testing.T, writeto bool) {
zr := &zeroReader{}
a, err := newAsyncReader(zr, 16)
require.NoError(t, err)
var copyN int64
var copyErr error
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if true {
// exercise the WriteTo path
copyN, copyErr = a.WriteTo(ioutil.Discard)
} else {
// exercise the Read path
buf := make([]byte, 64*1024)
for {
var n int
n, copyErr = a.Read(buf)
copyN += int64(n)
if copyErr != nil {
break
}
}
}
}()
// Do some copying
time.Sleep(100 * time.Millisecond)
// Abandon the copy
a.Abandon()
wg.Wait()
assert.Equal(t, errorStreamAbandoned, copyErr)
// t.Logf("Copied %d bytes, err %v", copyN, copyErr)
assert.True(t, copyN > 0)
}
func TestAsyncReaderCloseRead(t *testing.T) { testAsyncReaderClose(t, false) }
func TestAsyncReaderCloseWriteTo(t *testing.T) { testAsyncReaderClose(t, true) }