forked from TrueCloudLab/rclone
pool: Make RW thread safe so can read and write at the same time
This commit is contained in:
parent
e686e34f89
commit
cb2d2d72a0
2 changed files with 170 additions and 10 deletions
|
@ -3,6 +3,7 @@ package pool
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RWAccount is a function which will be called after every read
|
// RWAccount is a function which will be called after every read
|
||||||
|
@ -12,15 +13,25 @@ import (
|
||||||
type RWAccount func(n int) error
|
type RWAccount func(n int) error
|
||||||
|
|
||||||
// RW contains the state for the read/writer
|
// RW contains the state for the read/writer
|
||||||
|
//
|
||||||
|
// It can be used as a FIFO to read data from a source and write it out again.
|
||||||
type RW struct {
|
type RW struct {
|
||||||
pool *Pool // pool to get pages from
|
// Written once variables in initialization
|
||||||
pages [][]byte // backing store
|
pool *Pool // pool to get pages from
|
||||||
size int // size written
|
account RWAccount // account for a read
|
||||||
out int // offset we are reading from
|
accountOn int // only account on or after this read
|
||||||
lastOffset int // size in last page
|
|
||||||
account RWAccount // account for a read
|
// Shared variables between Read and Write
|
||||||
reads int // count how many times the data has been read
|
// Write updates these but Read reads from them
|
||||||
accountOn int // only account on or after this read
|
// They must all stay in sync together
|
||||||
|
mu sync.Mutex // protect the shared variables
|
||||||
|
pages [][]byte // backing store
|
||||||
|
size int // size written
|
||||||
|
lastOffset int // size in last page
|
||||||
|
|
||||||
|
// Read side Variables
|
||||||
|
out int // offset we are reading from
|
||||||
|
reads int // count how many times the data has been read
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -47,6 +58,8 @@ func NewRW(pool *Pool) *RW {
|
||||||
// called after every read from the RW.
|
// called after every read from the RW.
|
||||||
//
|
//
|
||||||
// It may return an error which will be passed back to the user.
|
// It may return an error which will be passed back to the user.
|
||||||
|
//
|
||||||
|
// Not thread safe - call in initialization only.
|
||||||
func (rw *RW) SetAccounting(account RWAccount) *RW {
|
func (rw *RW) SetAccounting(account RWAccount) *RW {
|
||||||
rw.account = account
|
rw.account = account
|
||||||
return rw
|
return rw
|
||||||
|
@ -73,6 +86,8 @@ type DelayAccountinger interface {
|
||||||
// e.g. when calculating hashes.
|
// e.g. when calculating hashes.
|
||||||
//
|
//
|
||||||
// Set this to 0 to account everything.
|
// Set this to 0 to account everything.
|
||||||
|
//
|
||||||
|
// Not thread safe - call in initialization only.
|
||||||
func (rw *RW) DelayAccounting(i int) {
|
func (rw *RW) DelayAccounting(i int) {
|
||||||
rw.accountOn = i
|
rw.accountOn = i
|
||||||
rw.reads = 0
|
rw.reads = 0
|
||||||
|
@ -82,6 +97,8 @@ func (rw *RW) DelayAccounting(i int) {
|
||||||
//
|
//
|
||||||
// Ensure there are pages before calling this.
|
// Ensure there are pages before calling this.
|
||||||
func (rw *RW) readPage(i int) (page []byte) {
|
func (rw *RW) readPage(i int) (page []byte) {
|
||||||
|
rw.mu.Lock()
|
||||||
|
defer rw.mu.Unlock()
|
||||||
// Count a read of the data if we read the first page
|
// Count a read of the data if we read the first page
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
rw.reads++
|
rw.reads++
|
||||||
|
@ -111,6 +128,13 @@ func (rw *RW) accountRead(n int) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns true if we have read to EOF
|
||||||
|
func (rw *RW) eof() bool {
|
||||||
|
rw.mu.Lock()
|
||||||
|
defer rw.mu.Unlock()
|
||||||
|
return rw.out >= rw.size
|
||||||
|
}
|
||||||
|
|
||||||
// Read reads up to len(p) bytes into p. It returns the number of
|
// Read reads up to len(p) bytes into p. It returns the number of
|
||||||
// bytes read (0 <= n <= len(p)) and any error encountered. If some
|
// bytes read (0 <= n <= len(p)) and any error encountered. If some
|
||||||
// data is available but not len(p) bytes, Read returns what is
|
// data is available but not len(p) bytes, Read returns what is
|
||||||
|
@ -121,7 +145,7 @@ func (rw *RW) Read(p []byte) (n int, err error) {
|
||||||
page []byte
|
page []byte
|
||||||
)
|
)
|
||||||
for len(p) > 0 {
|
for len(p) > 0 {
|
||||||
if rw.out >= rw.size {
|
if rw.eof() {
|
||||||
return n, io.EOF
|
return n, io.EOF
|
||||||
}
|
}
|
||||||
page = rw.readPage(rw.out)
|
page = rw.readPage(rw.out)
|
||||||
|
@ -148,7 +172,7 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
nn int
|
nn int
|
||||||
page []byte
|
page []byte
|
||||||
)
|
)
|
||||||
for rw.out < rw.size {
|
for !rw.eof() {
|
||||||
page = rw.readPage(rw.out)
|
page = rw.readPage(rw.out)
|
||||||
nn, err = w.Write(page)
|
nn, err = w.Write(page)
|
||||||
n += int64(nn)
|
n += int64(nn)
|
||||||
|
@ -166,6 +190,8 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
|
|
||||||
// Get the page we are writing to
|
// Get the page we are writing to
|
||||||
func (rw *RW) writePage() (page []byte) {
|
func (rw *RW) writePage() (page []byte) {
|
||||||
|
rw.mu.Lock()
|
||||||
|
defer rw.mu.Unlock()
|
||||||
if len(rw.pages) > 0 && rw.lastOffset < rw.pool.bufferSize {
|
if len(rw.pages) > 0 && rw.lastOffset < rw.pool.bufferSize {
|
||||||
return rw.pages[len(rw.pages)-1][rw.lastOffset:]
|
return rw.pages[len(rw.pages)-1][rw.lastOffset:]
|
||||||
}
|
}
|
||||||
|
@ -187,8 +213,10 @@ func (rw *RW) Write(p []byte) (n int, err error) {
|
||||||
nn = copy(page, p)
|
nn = copy(page, p)
|
||||||
p = p[nn:]
|
p = p[nn:]
|
||||||
n += nn
|
n += nn
|
||||||
|
rw.mu.Lock()
|
||||||
rw.size += nn
|
rw.size += nn
|
||||||
rw.lastOffset += nn
|
rw.lastOffset += nn
|
||||||
|
rw.mu.Unlock()
|
||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
@ -208,8 +236,10 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
page = rw.writePage()
|
page = rw.writePage()
|
||||||
nn, err = r.Read(page)
|
nn, err = r.Read(page)
|
||||||
n += int64(nn)
|
n += int64(nn)
|
||||||
|
rw.mu.Lock()
|
||||||
rw.size += nn
|
rw.size += nn
|
||||||
rw.lastOffset += nn
|
rw.lastOffset += nn
|
||||||
|
rw.mu.Unlock()
|
||||||
}
|
}
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
err = nil
|
err = nil
|
||||||
|
@ -229,7 +259,9 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
// beyond the end of the written data is an error.
|
// beyond the end of the written data is an error.
|
||||||
func (rw *RW) Seek(offset int64, whence int) (int64, error) {
|
func (rw *RW) Seek(offset int64, whence int) (int64, error) {
|
||||||
var abs int64
|
var abs int64
|
||||||
|
rw.mu.Lock()
|
||||||
size := int64(rw.size)
|
size := int64(rw.size)
|
||||||
|
rw.mu.Unlock()
|
||||||
switch whence {
|
switch whence {
|
||||||
case io.SeekStart:
|
case io.SeekStart:
|
||||||
abs = offset
|
abs = offset
|
||||||
|
@ -252,6 +284,8 @@ func (rw *RW) Seek(offset int64, whence int) (int64, error) {
|
||||||
|
|
||||||
// Close the buffer returning memory to the pool
|
// Close the buffer returning memory to the pool
|
||||||
func (rw *RW) Close() error {
|
func (rw *RW) Close() error {
|
||||||
|
rw.mu.Lock()
|
||||||
|
defer rw.mu.Unlock()
|
||||||
for _, page := range rw.pages {
|
for _, page := range rw.pages {
|
||||||
rw.pool.Put(page)
|
rw.pool.Put(page)
|
||||||
}
|
}
|
||||||
|
@ -261,6 +295,8 @@ func (rw *RW) Close() error {
|
||||||
|
|
||||||
// Size returns the number of bytes in the buffer
|
// Size returns the number of bytes in the buffer
|
||||||
func (rw *RW) Size() int64 {
|
func (rw *RW) Size() int64 {
|
||||||
|
rw.mu.Lock()
|
||||||
|
defer rw.mu.Unlock()
|
||||||
return int64(rw.size)
|
return int64(rw.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,12 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rclone/rclone/lib/random"
|
"github.com/rclone/rclone/lib/random"
|
||||||
|
"github.com/rclone/rclone/lib/readers"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -489,3 +491,125 @@ func TestRWBoundaryConditions(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The RW should be thread safe for reading and writing concurrently
|
||||||
|
func TestRWConcurrency(t *testing.T) {
|
||||||
|
const bufSize = 1024
|
||||||
|
|
||||||
|
// Write data of size using Write
|
||||||
|
write := func(rw *RW, size int64) {
|
||||||
|
in := readers.NewPatternReader(size)
|
||||||
|
buf := make([]byte, bufSize)
|
||||||
|
nn := int64(0)
|
||||||
|
for {
|
||||||
|
nr, inErr := in.Read(buf)
|
||||||
|
if inErr != nil && inErr != io.EOF {
|
||||||
|
require.NoError(t, inErr)
|
||||||
|
}
|
||||||
|
nw, rwErr := rw.Write(buf[:nr])
|
||||||
|
require.NoError(t, rwErr)
|
||||||
|
assert.Equal(t, nr, nw)
|
||||||
|
nn += int64(nw)
|
||||||
|
if inErr == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, size, nn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the data using ReadFrom
|
||||||
|
readFrom := func(rw *RW, size int64) {
|
||||||
|
in := readers.NewPatternReader(size)
|
||||||
|
nn, err := rw.ReadFrom(in)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, size, nn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the data back from inP and check it is OK
|
||||||
|
check := func(in io.Reader, size int64) {
|
||||||
|
ck := readers.NewPatternReader(size)
|
||||||
|
ckBuf := make([]byte, bufSize)
|
||||||
|
rwBuf := make([]byte, bufSize)
|
||||||
|
nn := int64(0)
|
||||||
|
for {
|
||||||
|
nck, ckErr := ck.Read(ckBuf)
|
||||||
|
if ckErr != io.EOF {
|
||||||
|
require.NoError(t, ckErr)
|
||||||
|
}
|
||||||
|
var nin int
|
||||||
|
var inErr error
|
||||||
|
for {
|
||||||
|
var nnin int
|
||||||
|
nnin, inErr = in.Read(rwBuf[nin:])
|
||||||
|
if inErr != io.EOF {
|
||||||
|
require.NoError(t, inErr)
|
||||||
|
}
|
||||||
|
nin += nnin
|
||||||
|
nn += int64(nnin)
|
||||||
|
if nin >= len(rwBuf) || nn >= size || inErr != io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.Equal(t, ckBuf[:nck], rwBuf[:nin])
|
||||||
|
if ckErr == io.EOF && inErr == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, size, nn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the data back and check it is OK
|
||||||
|
read := func(rw *RW, size int64) {
|
||||||
|
check(rw, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the data back and check it is OK in using WriteTo
|
||||||
|
writeTo := func(rw *RW, size int64) {
|
||||||
|
in, out := io.Pipe()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
check(in, size)
|
||||||
|
}()
|
||||||
|
var n int64
|
||||||
|
for n < size {
|
||||||
|
nn, err := rw.WriteTo(out)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
n += nn
|
||||||
|
}
|
||||||
|
assert.Equal(t, size, n)
|
||||||
|
require.NoError(t, out.Close())
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
type test struct {
|
||||||
|
name string
|
||||||
|
fn func(*RW, int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
const size = blockSize*255 + 255
|
||||||
|
|
||||||
|
// Read and Write the data with a range of block sizes and functions
|
||||||
|
for _, write := range []test{{"Write", write}, {"ReadFrom", readFrom}} {
|
||||||
|
t.Run(write.name, func(t *testing.T) {
|
||||||
|
for _, read := range []test{{"Read", read}, {"WriteTo", writeTo}} {
|
||||||
|
t.Run(read.name, func(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
rw := NewRW(rwPool)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
read.fn(rw, size)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
write.fn(rw, size)
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue