pool: Make RW thread safe so can read and write at the same time

This commit is contained in:
Nick Craig-Wood 2024-03-13 11:59:17 +00:00
parent e686e34f89
commit cb2d2d72a0
2 changed files with 170 additions and 10 deletions

View file

@ -3,6 +3,7 @@ package pool
import ( import (
"errors" "errors"
"io" "io"
"sync"
) )
// RWAccount is a function which will be called after every read // RWAccount is a function which will be called after every read
@ -12,15 +13,25 @@ import (
type RWAccount func(n int) error type RWAccount func(n int) error
// RW contains the state for the read/writer // RW contains the state for the read/writer
//
// It can be used as a FIFO to read data from a source and write it out again.
type RW struct { type RW struct {
// Written once variables in initialization
pool *Pool // pool to get pages from pool *Pool // pool to get pages from
account RWAccount // account for a read
accountOn int // only account on or after this read
// Shared variables between Read and Write
// Write updates these but Read reads from them
// They must all stay in sync together
mu sync.Mutex // protect the shared variables
pages [][]byte // backing store pages [][]byte // backing store
size int // size written size int // size written
out int // offset we are reading from
lastOffset int // size in last page lastOffset int // size in last page
account RWAccount // account for a read
// Read side Variables
out int // offset we are reading from
reads int // count how many times the data has been read reads int // count how many times the data has been read
accountOn int // only account on or after this read
} }
var ( var (
@ -47,6 +58,8 @@ func NewRW(pool *Pool) *RW {
// called after every read from the RW. // called after every read from the RW.
// //
// It may return an error which will be passed back to the user. // It may return an error which will be passed back to the user.
//
// Not thread safe - call in initialization only.
func (rw *RW) SetAccounting(account RWAccount) *RW { func (rw *RW) SetAccounting(account RWAccount) *RW {
rw.account = account rw.account = account
return rw return rw
@ -73,6 +86,8 @@ type DelayAccountinger interface {
// e.g. when calculating hashes. // e.g. when calculating hashes.
// //
// Set this to 0 to account everything. // Set this to 0 to account everything.
//
// Not thread safe - call in initialization only.
func (rw *RW) DelayAccounting(i int) { func (rw *RW) DelayAccounting(i int) {
rw.accountOn = i rw.accountOn = i
rw.reads = 0 rw.reads = 0
@ -82,6 +97,8 @@ func (rw *RW) DelayAccounting(i int) {
// //
// Ensure there are pages before calling this. // Ensure there are pages before calling this.
func (rw *RW) readPage(i int) (page []byte) { func (rw *RW) readPage(i int) (page []byte) {
rw.mu.Lock()
defer rw.mu.Unlock()
// Count a read of the data if we read the first page // Count a read of the data if we read the first page
if i == 0 { if i == 0 {
rw.reads++ rw.reads++
@ -111,6 +128,13 @@ func (rw *RW) accountRead(n int) error {
return nil return nil
} }
// Returns true if we have read to EOF
func (rw *RW) eof() bool {
rw.mu.Lock()
defer rw.mu.Unlock()
return rw.out >= rw.size
}
// Read reads up to len(p) bytes into p. It returns the number of // Read reads up to len(p) bytes into p. It returns the number of
// bytes read (0 <= n <= len(p)) and any error encountered. If some // bytes read (0 <= n <= len(p)) and any error encountered. If some
// data is available but not len(p) bytes, Read returns what is // data is available but not len(p) bytes, Read returns what is
@ -121,7 +145,7 @@ func (rw *RW) Read(p []byte) (n int, err error) {
page []byte page []byte
) )
for len(p) > 0 { for len(p) > 0 {
if rw.out >= rw.size { if rw.eof() {
return n, io.EOF return n, io.EOF
} }
page = rw.readPage(rw.out) page = rw.readPage(rw.out)
@ -148,7 +172,7 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
nn int nn int
page []byte page []byte
) )
for rw.out < rw.size { for !rw.eof() {
page = rw.readPage(rw.out) page = rw.readPage(rw.out)
nn, err = w.Write(page) nn, err = w.Write(page)
n += int64(nn) n += int64(nn)
@ -166,6 +190,8 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
// Get the page we are writing to // Get the page we are writing to
func (rw *RW) writePage() (page []byte) { func (rw *RW) writePage() (page []byte) {
rw.mu.Lock()
defer rw.mu.Unlock()
if len(rw.pages) > 0 && rw.lastOffset < rw.pool.bufferSize { if len(rw.pages) > 0 && rw.lastOffset < rw.pool.bufferSize {
return rw.pages[len(rw.pages)-1][rw.lastOffset:] return rw.pages[len(rw.pages)-1][rw.lastOffset:]
} }
@ -187,8 +213,10 @@ func (rw *RW) Write(p []byte) (n int, err error) {
nn = copy(page, p) nn = copy(page, p)
p = p[nn:] p = p[nn:]
n += nn n += nn
rw.mu.Lock()
rw.size += nn rw.size += nn
rw.lastOffset += nn rw.lastOffset += nn
rw.mu.Unlock()
} }
return n, nil return n, nil
} }
@ -208,8 +236,10 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) {
page = rw.writePage() page = rw.writePage()
nn, err = r.Read(page) nn, err = r.Read(page)
n += int64(nn) n += int64(nn)
rw.mu.Lock()
rw.size += nn rw.size += nn
rw.lastOffset += nn rw.lastOffset += nn
rw.mu.Unlock()
} }
if err == io.EOF { if err == io.EOF {
err = nil err = nil
@ -229,7 +259,9 @@ func (rw *RW) ReadFrom(r io.Reader) (n int64, err error) {
// beyond the end of the written data is an error. // beyond the end of the written data is an error.
func (rw *RW) Seek(offset int64, whence int) (int64, error) { func (rw *RW) Seek(offset int64, whence int) (int64, error) {
var abs int64 var abs int64
rw.mu.Lock()
size := int64(rw.size) size := int64(rw.size)
rw.mu.Unlock()
switch whence { switch whence {
case io.SeekStart: case io.SeekStart:
abs = offset abs = offset
@ -252,6 +284,8 @@ func (rw *RW) Seek(offset int64, whence int) (int64, error) {
// Close the buffer returning memory to the pool // Close the buffer returning memory to the pool
func (rw *RW) Close() error { func (rw *RW) Close() error {
rw.mu.Lock()
defer rw.mu.Unlock()
for _, page := range rw.pages { for _, page := range rw.pages {
rw.pool.Put(page) rw.pool.Put(page)
} }
@ -261,6 +295,8 @@ func (rw *RW) Close() error {
// Size returns the number of bytes in the buffer // Size returns the number of bytes in the buffer
func (rw *RW) Size() int64 { func (rw *RW) Size() int64 {
rw.mu.Lock()
defer rw.mu.Unlock()
return int64(rw.size) return int64(rw.size)
} }

View file

@ -4,10 +4,12 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"sync"
"testing" "testing"
"time" "time"
"github.com/rclone/rclone/lib/random" "github.com/rclone/rclone/lib/random"
"github.com/rclone/rclone/lib/readers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -489,3 +491,125 @@ func TestRWBoundaryConditions(t *testing.T) {
}) })
} }
} }
// The RW should be thread safe for reading and writing concurrently
func TestRWConcurrency(t *testing.T) {
const bufSize = 1024
// Write data of size using Write
write := func(rw *RW, size int64) {
in := readers.NewPatternReader(size)
buf := make([]byte, bufSize)
nn := int64(0)
for {
nr, inErr := in.Read(buf)
if inErr != nil && inErr != io.EOF {
require.NoError(t, inErr)
}
nw, rwErr := rw.Write(buf[:nr])
require.NoError(t, rwErr)
assert.Equal(t, nr, nw)
nn += int64(nw)
if inErr == io.EOF {
break
}
}
assert.Equal(t, size, nn)
}
// Write the data using ReadFrom
readFrom := func(rw *RW, size int64) {
in := readers.NewPatternReader(size)
nn, err := rw.ReadFrom(in)
assert.NoError(t, err)
assert.Equal(t, size, nn)
}
// Read the data back from inP and check it is OK
check := func(in io.Reader, size int64) {
ck := readers.NewPatternReader(size)
ckBuf := make([]byte, bufSize)
rwBuf := make([]byte, bufSize)
nn := int64(0)
for {
nck, ckErr := ck.Read(ckBuf)
if ckErr != io.EOF {
require.NoError(t, ckErr)
}
var nin int
var inErr error
for {
var nnin int
nnin, inErr = in.Read(rwBuf[nin:])
if inErr != io.EOF {
require.NoError(t, inErr)
}
nin += nnin
nn += int64(nnin)
if nin >= len(rwBuf) || nn >= size || inErr != io.EOF {
break
}
}
require.Equal(t, ckBuf[:nck], rwBuf[:nin])
if ckErr == io.EOF && inErr == io.EOF {
break
}
}
assert.Equal(t, size, nn)
}
// Read the data back and check it is OK
read := func(rw *RW, size int64) {
check(rw, size)
}
// Read the data back and check it is OK in using WriteTo
writeTo := func(rw *RW, size int64) {
in, out := io.Pipe()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
check(in, size)
}()
var n int64
for n < size {
nn, err := rw.WriteTo(out)
assert.NoError(t, err)
n += nn
}
assert.Equal(t, size, n)
require.NoError(t, out.Close())
wg.Wait()
}
type test struct {
name string
fn func(*RW, int64)
}
const size = blockSize*255 + 255
// Read and Write the data with a range of block sizes and functions
for _, write := range []test{{"Write", write}, {"ReadFrom", readFrom}} {
t.Run(write.name, func(t *testing.T) {
for _, read := range []test{{"Read", read}, {"WriteTo", writeTo}} {
t.Run(read.name, func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(2)
rw := NewRW(rwPool)
go func() {
defer wg.Done()
read.fn(rw, size)
}()
go func() {
defer wg.Done()
write.fn(rw, size)
}()
wg.Wait()
})
}
})
}
}