forked from TrueCloudLab/rclone
pool: Make RW thread safe so can read and write at the same time
This commit is contained in:
parent
e686e34f89
commit
cb2d2d72a0
2 changed files with 170 additions and 10 deletions
|
@ -3,6 +3,7 @@ package pool
|
|||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RWAccount is a function which will be called after every read
|
||||
|
@ -12,15 +13,25 @@ import (
|
|||
type RWAccount func(n int) error
|
||||
|
||||
// RW contains the state for the read/writer
|
||||
//
|
||||
// It can be used as a FIFO to read data from a source and write it out again.
|
||||
type RW struct {
|
||||
pool *Pool // pool to get pages from
|
||||
pages [][]byte // backing store
|
||||
size int // size written
|
||||
out int // offset we are reading from
|
||||
lastOffset int // size in last page
|
||||
account RWAccount // account for a read
|
||||
reads int // count how many times the data has been read
|
||||
accountOn int // only account on or after this read
|
||||
// Written once variables in initialization
|
||||
pool *Pool // pool to get pages from
|
||||
account RWAccount // account for a read
|
||||
accountOn int // only account on or after this read
|
||||
|
||||
// Shared variables between Read and Write
|
||||
// Write updates these but Read reads from them
|
||||
// They must all stay in sync together
|
||||
mu sync.Mutex // protect the shared variables
|
||||
pages [][]byte // backing store
|
||||
size int // size written
|
||||
lastOffset int // size in last page
|
||||
|
||||
// Read side Variables
|
||||
out int // offset we are reading from
|
||||
reads int // count how many times the data has been read
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -47,6 +58,8 @@ func NewRW(pool *Pool) *RW {
|
|||
// called after every read from the RW.
|
||||
//
|
||||
// It may return an error which will be passed back to the user.
|
||||
//
|
||||
// Not thread safe - call in initialization only.
|
||||
func (rw *RW) SetAccounting(account RWAccount) *RW {
|
||||
rw.account = account
|
||||
return rw
|
||||
|
@ -73,6 +86,8 @@ type DelayAccountinger interface {
|
|||
// e.g. when calculating hashes.
|
||||
//
|
||||
// Set this to 0 to account everything.
|
||||
//
|
||||
// Not thread safe - call in initialization only.
|
||||
func (rw *RW) DelayAccounting(i int) {
|
||||
rw.accountOn = i
|
||||
rw.reads = 0
|
||||
|
@ -82,6 +97,8 @@ func (rw *RW) DelayAccounting(i int) {
|
|||
//
|
||||
// Ensure there are pages before calling this.
|
||||
func (rw *RW) readPage(i int) (page []byte) {
|
||||
rw.mu.Lock()
|
||||
defer rw.mu.Unlock()
|
||||
// Count a read of the data if we read the first page
|
||||
if i == 0 {
|
||||
rw.reads++
|
||||
|
@ -111,6 +128,13 @@ func (rw *RW) accountRead(n int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Returns true if we have read to EOF
|
||||
func (rw *RW) eof() bool {
|
||||
rw.mu.Lock()
|
||||
defer rw.mu.Unlock()
|
||||
return rw.out >= rw.size
|
||||
}
|
||||
|
||||
// Read reads up to len(p) bytes into p. It returns the number of
|
||||
// bytes read (0 <= n <= len(p)) and any error encountered. If some
|
||||
// data is available but not len(p) bytes, Read returns what is
|
||||
|
@ -121,7 +145,7 @@ func (rw *RW) Read(p []byte) (n int, err error) {
|
|||
page []byte
|
||||
)
|
||||
for len(p) > 0 {
|
||||
if rw.out >= rw.size {
|
||||
if rw.eof() {
|
||||
return n, io.EOF
|
||||
}
|
||||
page = rw.readPage(rw.out)
|
||||
|
@ -148,7 +172,7 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
|
|||
nn int
|
||||
page []byte
|
||||
)
|
||||
for rw.out < rw.size {
|
||||
for !rw.eof() {
|
||||
page = rw.readPage(rw.out)
|
||||
nn, err = w.Write(page)
|
||||
n += int64(nn)
|
||||
|
@ -166,6 +190,8 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
|
|||
|
||||
// Get the page we are writing to
|
||||
func (rw *RW) writePage() (page []byte) {
|
||||
rw.mu.Lock()
|
||||
defer rw.mu.Unlock()
|
||||
if len(rw.pages) > 0 && rw.lastOffset < rw.pool.bufferSize {
|
||||
return rw.pages[len(rw.pages)-1][rw.lastOffset:]
|
||||
}
|
||||
|
@ -187,8 +213,10 @@ func (rw *RW) Write(p []byte) (n int, err error) {
|
|||
nn = copy(page, p)
|
||||
p = p[nn:]
|
||||
n += nn
|
||||
rw.mu.Lock()
|
||||
rw.size += nn
|
||||
rw.lastOffset += nn
|
||||
rw.mu.Unlock()
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
@ -208,8 +236,10 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
page = rw.writePage()
|
||||
nn, err = r.Read(page)
|
||||
n += int64(nn)
|
||||
rw.mu.Lock()
|
||||
rw.size += nn
|
||||
rw.lastOffset += nn
|
||||
rw.mu.Unlock()
|
||||
}
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
|
@ -229,7 +259,9 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) {
|
|||
// beyond the end of the written data is an error.
|
||||
func (rw *RW) Seek(offset int64, whence int) (int64, error) {
|
||||
var abs int64
|
||||
rw.mu.Lock()
|
||||
size := int64(rw.size)
|
||||
rw.mu.Unlock()
|
||||
switch whence {
|
||||
case io.SeekStart:
|
||||
abs = offset
|
||||
|
@ -252,6 +284,8 @@ func (rw *RW) Seek(offset int64, whence int) (int64, error) {
|
|||
|
||||
// Close the buffer returning memory to the pool
|
||||
func (rw *RW) Close() error {
|
||||
rw.mu.Lock()
|
||||
defer rw.mu.Unlock()
|
||||
for _, page := range rw.pages {
|
||||
rw.pool.Put(page)
|
||||
}
|
||||
|
@ -261,6 +295,8 @@ func (rw *RW) Close() error {
|
|||
|
||||
// Size returns the number of bytes in the buffer
|
||||
func (rw *RW) Size() int64 {
|
||||
rw.mu.Lock()
|
||||
defer rw.mu.Unlock()
|
||||
return int64(rw.size)
|
||||
}
|
||||
|
||||
|
|
|
@ -4,10 +4,12 @@ import (
|
|||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rclone/rclone/lib/random"
|
||||
"github.com/rclone/rclone/lib/readers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -489,3 +491,125 @@ func TestRWBoundaryConditions(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// The RW should be thread safe for reading and writing concurrently
|
||||
func TestRWConcurrency(t *testing.T) {
|
||||
const bufSize = 1024
|
||||
|
||||
// Write data of size using Write
|
||||
write := func(rw *RW, size int64) {
|
||||
in := readers.NewPatternReader(size)
|
||||
buf := make([]byte, bufSize)
|
||||
nn := int64(0)
|
||||
for {
|
||||
nr, inErr := in.Read(buf)
|
||||
if inErr != nil && inErr != io.EOF {
|
||||
require.NoError(t, inErr)
|
||||
}
|
||||
nw, rwErr := rw.Write(buf[:nr])
|
||||
require.NoError(t, rwErr)
|
||||
assert.Equal(t, nr, nw)
|
||||
nn += int64(nw)
|
||||
if inErr == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, size, nn)
|
||||
}
|
||||
|
||||
// Write the data using ReadFrom
|
||||
readFrom := func(rw *RW, size int64) {
|
||||
in := readers.NewPatternReader(size)
|
||||
nn, err := rw.ReadFrom(in)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, size, nn)
|
||||
}
|
||||
|
||||
// Read the data back from inP and check it is OK
|
||||
check := func(in io.Reader, size int64) {
|
||||
ck := readers.NewPatternReader(size)
|
||||
ckBuf := make([]byte, bufSize)
|
||||
rwBuf := make([]byte, bufSize)
|
||||
nn := int64(0)
|
||||
for {
|
||||
nck, ckErr := ck.Read(ckBuf)
|
||||
if ckErr != io.EOF {
|
||||
require.NoError(t, ckErr)
|
||||
}
|
||||
var nin int
|
||||
var inErr error
|
||||
for {
|
||||
var nnin int
|
||||
nnin, inErr = in.Read(rwBuf[nin:])
|
||||
if inErr != io.EOF {
|
||||
require.NoError(t, inErr)
|
||||
}
|
||||
nin += nnin
|
||||
nn += int64(nnin)
|
||||
if nin >= len(rwBuf) || nn >= size || inErr != io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Equal(t, ckBuf[:nck], rwBuf[:nin])
|
||||
if ckErr == io.EOF && inErr == io.EOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, size, nn)
|
||||
}
|
||||
|
||||
// Read the data back and check it is OK
|
||||
read := func(rw *RW, size int64) {
|
||||
check(rw, size)
|
||||
}
|
||||
|
||||
// Read the data back and check it is OK in using WriteTo
|
||||
writeTo := func(rw *RW, size int64) {
|
||||
in, out := io.Pipe()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
check(in, size)
|
||||
}()
|
||||
var n int64
|
||||
for n < size {
|
||||
nn, err := rw.WriteTo(out)
|
||||
assert.NoError(t, err)
|
||||
n += nn
|
||||
}
|
||||
assert.Equal(t, size, n)
|
||||
require.NoError(t, out.Close())
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
fn func(*RW, int64)
|
||||
}
|
||||
|
||||
const size = blockSize*255 + 255
|
||||
|
||||
// Read and Write the data with a range of block sizes and functions
|
||||
for _, write := range []test{{"Write", write}, {"ReadFrom", readFrom}} {
|
||||
t.Run(write.name, func(t *testing.T) {
|
||||
for _, read := range []test{{"Read", read}, {"WriteTo", writeTo}} {
|
||||
t.Run(read.name, func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
rw := NewRW(rwPool)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
read.fn(rw, size)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
write.fn(rw, size)
|
||||
}()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue