diff --git a/lib/pool/reader_writer.go b/lib/pool/reader_writer.go index e7ed535a2..18bd11e8e 100644 --- a/lib/pool/reader_writer.go +++ b/lib/pool/reader_writer.go @@ -19,6 +19,8 @@ type RW struct { out int // offset we are reading from lastOffset int // size in last page account RWAccount // account for a read + reads int // count how many times the data has been read + accountOn int // only account on or after this read } var ( @@ -50,10 +52,40 @@ func (rw *RW) SetAccounting(account RWAccount) *RW { return rw } +// DelayAccountinger enables an accounting delay +type DelayAccountinger interface { + // DelayAccounting makes sure the accounting function only + // gets called on the i-th or later read of the data from this + // point (counting from 1). + // + // This is useful so that we don't account initial reads of + // the data e.g. when calculating hashes. + // + // Set this to 0 to account everything. + DelayAccounting(i int) +} + +// DelayAccounting makes sure the accounting function only gets called +// on the i-th or later read of the data from this point (counting +// from 1). +// +// This is useful so that we don't account initial reads of the data +// e.g. when calculating hashes. +// +// Set this to 0 to account everything. +func (rw *RW) DelayAccounting(i int) { + rw.accountOn = i + rw.reads = 0 +} + // Returns the page and offset of i for reading. // // Ensure there are pages before calling this. func (rw *RW) readPage(i int) (page []byte) { + // Count a read of the data if we read the first page + if i == 0 { + rw.reads++ + } pageNumber := i / rw.pool.bufferSize offset := i % rw.pool.bufferSize page = rw.pages[pageNumber] @@ -69,7 +101,14 @@ func (rw *RW) accountRead(n int) error { if rw.account == nil { return nil } - return rw.account(n) + // Don't start accounting until we've reached this many reads + // + // rw.reads will be 1 the first time this is called + // rw.accountOn 2 means start accounting on the 2nd read through + if rw.reads >= rw.accountOn { + return rw.account(n) + } + return nil } // Read reads up to len(p) bytes into p. It returns the number of @@ -227,10 +266,11 @@ func (rw *RW) Size() int64 { // Check interfaces var ( - _ io.Reader = (*RW)(nil) - _ io.ReaderFrom = (*RW)(nil) - _ io.Writer = (*RW)(nil) - _ io.WriterTo = (*RW)(nil) - _ io.Seeker = (*RW)(nil) - _ io.Closer = (*RW)(nil) + _ io.Reader = (*RW)(nil) + _ io.ReaderFrom = (*RW)(nil) + _ io.Writer = (*RW)(nil) + _ io.WriterTo = (*RW)(nil) + _ io.Seeker = (*RW)(nil) + _ io.Closer = (*RW)(nil) + _ DelayAccountinger = (*RW)(nil) ) diff --git a/lib/pool/reader_writer_test.go b/lib/pool/reader_writer_test.go index cd7592e45..e9e02c22e 100644 --- a/lib/pool/reader_writer_test.go +++ b/lib/pool/reader_writer_test.go @@ -9,6 +9,7 @@ import ( "github.com/rclone/rclone/lib/random" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const blockSize = 4096 @@ -178,71 +179,164 @@ func TestRW(t *testing.T) { assert.Equal(t, testData[7:10], dst) }) - errBoom := errors.New("accounting error") + t.Run("Account", func(t *testing.T) { + errBoom := errors.New("accounting error") - t.Run("AccountRead", func(t *testing.T) { - // Test accounting errors - rw := newRW() - defer close(rw) + t.Run("Read", func(t *testing.T) { + rw := newRW() + defer close(rw) - var total int - rw.SetAccounting(func(n int) error { - total += n - return nil + var total int + rw.SetAccounting(func(n int) error { + total += n + return nil + }) + + dst = make([]byte, 3) + n, err = rw.Read(dst) + assert.Equal(t, 3, n) + assert.NoError(t, err) + assert.Equal(t, 3, total) }) - dst = make([]byte, 3) - n, err = rw.Read(dst) - assert.Equal(t, 3, n) - assert.NoError(t, err) - assert.Equal(t, 3, total) - }) + t.Run("WriteTo", func(t *testing.T) { + rw := newRW() + defer close(rw) + var b bytes.Buffer - t.Run("AccountWriteTo", func(t *testing.T) { - rw := newRW() - defer close(rw) - var b bytes.Buffer + var total int + rw.SetAccounting(func(n int) error { + total += n + return nil + }) - var total int - rw.SetAccounting(func(n int) error { - total += n - return nil + n, err := rw.WriteTo(&b) + assert.NoError(t, err) + assert.Equal(t, 10, total) + assert.Equal(t, int64(10), n) + assert.Equal(t, testData, b.Bytes()) }) - n, err := rw.WriteTo(&b) - assert.NoError(t, err) - assert.Equal(t, 10, total) - assert.Equal(t, int64(10), n) - assert.Equal(t, testData, b.Bytes()) - }) + t.Run("ReadDelay", func(t *testing.T) { + rw := newRW() + defer close(rw) - t.Run("AccountReadError", func(t *testing.T) { - // Test accounting errors - rw := newRW() - defer close(rw) + var total int + rw.SetAccounting(func(n int) error { + total += n + return nil + }) - rw.SetAccounting(func(n int) error { - return errBoom + rewind := func() { + _, err := rw.Seek(0, io.SeekStart) + require.NoError(t, err) + } + + rw.DelayAccounting(3) + + dst = make([]byte, 16) + + n, err = rw.Read(dst) + assert.Equal(t, 10, n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, total) + rewind() + + n, err = rw.Read(dst) + assert.Equal(t, 10, n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, total) + rewind() + + n, err = rw.Read(dst) + assert.Equal(t, 10, n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 10, total) + rewind() + + n, err = rw.Read(dst) + assert.Equal(t, 10, n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 20, total) + rewind() }) - dst = make([]byte, 3) - n, err = rw.Read(dst) - assert.Equal(t, 3, n) - assert.Equal(t, errBoom, err) - }) + t.Run("WriteToDelay", func(t *testing.T) { + rw := newRW() + defer close(rw) + var b bytes.Buffer - t.Run("AccountWriteToError", func(t *testing.T) { - rw := newRW() - defer close(rw) - rw.SetAccounting(func(n int) error { - return errBoom + var total int + rw.SetAccounting(func(n int) error { + total += n + return nil + }) + + rw.DelayAccounting(3) + + rewind := func() { + _, err := rw.Seek(0, io.SeekStart) + require.NoError(t, err) + b.Reset() + } + + n, err := rw.WriteTo(&b) + assert.NoError(t, err) + assert.Equal(t, 0, total) + assert.Equal(t, int64(10), n) + assert.Equal(t, testData, b.Bytes()) + rewind() + + n, err = rw.WriteTo(&b) + assert.NoError(t, err) + assert.Equal(t, 0, total) + assert.Equal(t, int64(10), n) + assert.Equal(t, testData, b.Bytes()) + rewind() + + n, err = rw.WriteTo(&b) + assert.NoError(t, err) + assert.Equal(t, 10, total) + assert.Equal(t, int64(10), n) + assert.Equal(t, testData, b.Bytes()) + rewind() + + n, err = rw.WriteTo(&b) + assert.NoError(t, err) + assert.Equal(t, 20, total) + assert.Equal(t, int64(10), n) + assert.Equal(t, testData, b.Bytes()) + rewind() }) - var b bytes.Buffer - n, err := rw.WriteTo(&b) - assert.Equal(t, errBoom, err) - assert.Equal(t, int64(10), n) - assert.Equal(t, testData, b.Bytes()) + t.Run("ReadError", func(t *testing.T) { + // Test accounting errors + rw := newRW() + defer close(rw) + + rw.SetAccounting(func(n int) error { + return errBoom + }) + + dst = make([]byte, 3) + n, err = rw.Read(dst) + assert.Equal(t, 3, n) + assert.Equal(t, errBoom, err) + }) + + t.Run("WriteToError", func(t *testing.T) { + rw := newRW() + defer close(rw) + rw.SetAccounting(func(n int) error { + return errBoom + }) + var b bytes.Buffer + + n, err := rw.WriteTo(&b) + assert.Equal(t, errBoom, err) + assert.Equal(t, int64(10), n) + assert.Equal(t, testData, b.Bytes()) + }) }) } @@ -363,26 +457,35 @@ func TestRWBoundaryConditions(t *testing.T) { assert.Equal(t, int64(len(data)), nn) } + type test struct { + name string + fn func(*RW, []byte, int) + } + // Read and Write the data with a range of block sizes and functions - for _, writeFn := range []func(*RW, []byte, int){write, readFrom} { - for _, readFn := range []func(*RW, []byte, int){read, writeTo} { - for _, size := range sizes { - data := buf[:size] - for _, chunkSize := range sizes { - //t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize) - rw := NewRW(rwPool) - assert.Equal(t, int64(0), rw.Size()) - accounted = 0 - rw.SetAccounting(account) - assert.Equal(t, 0, accounted) - writeFn(rw, data, chunkSize) - assert.Equal(t, int64(size), rw.Size()) - assert.Equal(t, 0, accounted) - readFn(rw, data, chunkSize) - assert.NoError(t, rw.Close()) - assert.Equal(t, size, accounted) - } + for _, write := range []test{{"Write", write}, {"ReadFrom", readFrom}} { + t.Run(write.name, func(t *testing.T) { + for _, read := range []test{{"Read", read}, {"WriteTo", writeTo}} { + t.Run(read.name, func(t *testing.T) { + for _, size := range sizes { + data := buf[:size] + for _, chunkSize := range sizes { + //t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize) + rw := NewRW(rwPool) + assert.Equal(t, int64(0), rw.Size()) + accounted = 0 + rw.SetAccounting(account) + assert.Equal(t, 0, accounted) + write.fn(rw, data, chunkSize) + assert.Equal(t, int64(size), rw.Size()) + assert.Equal(t, 0, accounted) + read.fn(rw, data, chunkSize) + assert.NoError(t, rw.Close()) + assert.Equal(t, size, accounted) + } + } + }) } - } + }) } }