forked from TrueCloudLab/rclone
lib/pool: add DelayAccounting() to fix accounting when reading hashes
This commit is contained in:
parent
f4b1a51af6
commit
bc986b44b2
2 changed files with 219 additions and 76 deletions
|
@ -19,6 +19,8 @@ type RW struct {
|
|||
out int // offset we are reading from
|
||||
lastOffset int // size in last page
|
||||
account RWAccount // account for a read
|
||||
reads int // count how many times the data has been read
|
||||
accountOn int // only account on or after this read
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -50,10 +52,40 @@ func (rw *RW) SetAccounting(account RWAccount) *RW {
|
|||
return rw
|
||||
}
|
||||
|
||||
// DelayAccountinger enables an accounting delay
|
||||
type DelayAccountinger interface {
|
||||
// DelayAccounting makes sure the accounting function only
|
||||
// gets called on the i-th or later read of the data from this
|
||||
// point (counting from 1).
|
||||
//
|
||||
// This is useful so that we don't account initial reads of
|
||||
// the data e.g. when calculating hashes.
|
||||
//
|
||||
// Set this to 0 to account everything.
|
||||
DelayAccounting(i int)
|
||||
}
|
||||
|
||||
// DelayAccounting makes sure the accounting function only gets called
|
||||
// on the i-th or later read of the data from this point (counting
|
||||
// from 1).
|
||||
//
|
||||
// This is useful so that we don't account initial reads of the data
|
||||
// e.g. when calculating hashes.
|
||||
//
|
||||
// Set this to 0 to account everything.
|
||||
func (rw *RW) DelayAccounting(i int) {
|
||||
rw.accountOn = i
|
||||
rw.reads = 0
|
||||
}
|
||||
|
||||
// Returns the page and offset of i for reading.
|
||||
//
|
||||
// Ensure there are pages before calling this.
|
||||
func (rw *RW) readPage(i int) (page []byte) {
|
||||
// Count a read of the data if we read the first page
|
||||
if i == 0 {
|
||||
rw.reads++
|
||||
}
|
||||
pageNumber := i / rw.pool.bufferSize
|
||||
offset := i % rw.pool.bufferSize
|
||||
page = rw.pages[pageNumber]
|
||||
|
@ -69,7 +101,14 @@ func (rw *RW) accountRead(n int) error {
|
|||
if rw.account == nil {
|
||||
return nil
|
||||
}
|
||||
return rw.account(n)
|
||||
// Don't start accounting until we've reached this many reads
|
||||
//
|
||||
// rw.reads will be 1 the first time this is called
|
||||
// rw.accountOn 2 means start accounting on the 2nd read through
|
||||
if rw.reads >= rw.accountOn {
|
||||
return rw.account(n)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read reads up to len(p) bytes into p. It returns the number of
|
||||
|
@ -227,10 +266,11 @@ func (rw *RW) Size() int64 {
|
|||
|
||||
// Check interfaces
|
||||
var (
|
||||
_ io.Reader = (*RW)(nil)
|
||||
_ io.ReaderFrom = (*RW)(nil)
|
||||
_ io.Writer = (*RW)(nil)
|
||||
_ io.WriterTo = (*RW)(nil)
|
||||
_ io.Seeker = (*RW)(nil)
|
||||
_ io.Closer = (*RW)(nil)
|
||||
_ io.Reader = (*RW)(nil)
|
||||
_ io.ReaderFrom = (*RW)(nil)
|
||||
_ io.Writer = (*RW)(nil)
|
||||
_ io.WriterTo = (*RW)(nil)
|
||||
_ io.Seeker = (*RW)(nil)
|
||||
_ io.Closer = (*RW)(nil)
|
||||
_ DelayAccountinger = (*RW)(nil)
|
||||
)
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/rclone/rclone/lib/random"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const blockSize = 4096
|
||||
|
@ -178,71 +179,164 @@ func TestRW(t *testing.T) {
|
|||
assert.Equal(t, testData[7:10], dst)
|
||||
})
|
||||
|
||||
errBoom := errors.New("accounting error")
|
||||
t.Run("Account", func(t *testing.T) {
|
||||
errBoom := errors.New("accounting error")
|
||||
|
||||
t.Run("AccountRead", func(t *testing.T) {
|
||||
// Test accounting errors
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
t.Run("Read", func(t *testing.T) {
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
|
||||
var total int
|
||||
rw.SetAccounting(func(n int) error {
|
||||
total += n
|
||||
return nil
|
||||
var total int
|
||||
rw.SetAccounting(func(n int) error {
|
||||
total += n
|
||||
return nil
|
||||
})
|
||||
|
||||
dst = make([]byte, 3)
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 3, n)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, total)
|
||||
})
|
||||
|
||||
dst = make([]byte, 3)
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 3, n)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, total)
|
||||
})
|
||||
t.Run("WriteTo", func(t *testing.T) {
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
var b bytes.Buffer
|
||||
|
||||
t.Run("AccountWriteTo", func(t *testing.T) {
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
var b bytes.Buffer
|
||||
var total int
|
||||
rw.SetAccounting(func(n int) error {
|
||||
total += n
|
||||
return nil
|
||||
})
|
||||
|
||||
var total int
|
||||
rw.SetAccounting(func(n int) error {
|
||||
total += n
|
||||
return nil
|
||||
n, err := rw.WriteTo(&b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, total)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
})
|
||||
|
||||
n, err := rw.WriteTo(&b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, total)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
})
|
||||
t.Run("ReadDelay", func(t *testing.T) {
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
|
||||
t.Run("AccountReadError", func(t *testing.T) {
|
||||
// Test accounting errors
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
var total int
|
||||
rw.SetAccounting(func(n int) error {
|
||||
total += n
|
||||
return nil
|
||||
})
|
||||
|
||||
rw.SetAccounting(func(n int) error {
|
||||
return errBoom
|
||||
rewind := func() {
|
||||
_, err := rw.Seek(0, io.SeekStart)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
rw.DelayAccounting(3)
|
||||
|
||||
dst = make([]byte, 16)
|
||||
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 10, n)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, 0, total)
|
||||
rewind()
|
||||
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 10, n)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, 0, total)
|
||||
rewind()
|
||||
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 10, n)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, 10, total)
|
||||
rewind()
|
||||
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 10, n)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, 20, total)
|
||||
rewind()
|
||||
})
|
||||
|
||||
dst = make([]byte, 3)
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 3, n)
|
||||
assert.Equal(t, errBoom, err)
|
||||
})
|
||||
t.Run("WriteToDelay", func(t *testing.T) {
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
var b bytes.Buffer
|
||||
|
||||
t.Run("AccountWriteToError", func(t *testing.T) {
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
rw.SetAccounting(func(n int) error {
|
||||
return errBoom
|
||||
var total int
|
||||
rw.SetAccounting(func(n int) error {
|
||||
total += n
|
||||
return nil
|
||||
})
|
||||
|
||||
rw.DelayAccounting(3)
|
||||
|
||||
rewind := func() {
|
||||
_, err := rw.Seek(0, io.SeekStart)
|
||||
require.NoError(t, err)
|
||||
b.Reset()
|
||||
}
|
||||
|
||||
n, err := rw.WriteTo(&b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, total)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
rewind()
|
||||
|
||||
n, err = rw.WriteTo(&b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, total)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
rewind()
|
||||
|
||||
n, err = rw.WriteTo(&b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, total)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
rewind()
|
||||
|
||||
n, err = rw.WriteTo(&b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 20, total)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
rewind()
|
||||
})
|
||||
var b bytes.Buffer
|
||||
|
||||
n, err := rw.WriteTo(&b)
|
||||
assert.Equal(t, errBoom, err)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
t.Run("ReadError", func(t *testing.T) {
|
||||
// Test accounting errors
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
|
||||
rw.SetAccounting(func(n int) error {
|
||||
return errBoom
|
||||
})
|
||||
|
||||
dst = make([]byte, 3)
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 3, n)
|
||||
assert.Equal(t, errBoom, err)
|
||||
})
|
||||
|
||||
t.Run("WriteToError", func(t *testing.T) {
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
rw.SetAccounting(func(n int) error {
|
||||
return errBoom
|
||||
})
|
||||
var b bytes.Buffer
|
||||
|
||||
n, err := rw.WriteTo(&b)
|
||||
assert.Equal(t, errBoom, err)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
})
|
||||
})
|
||||
|
||||
}
|
||||
|
@ -363,26 +457,35 @@ func TestRWBoundaryConditions(t *testing.T) {
|
|||
assert.Equal(t, int64(len(data)), nn)
|
||||
}
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
fn func(*RW, []byte, int)
|
||||
}
|
||||
|
||||
// Read and Write the data with a range of block sizes and functions
|
||||
for _, writeFn := range []func(*RW, []byte, int){write, readFrom} {
|
||||
for _, readFn := range []func(*RW, []byte, int){read, writeTo} {
|
||||
for _, size := range sizes {
|
||||
data := buf[:size]
|
||||
for _, chunkSize := range sizes {
|
||||
//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
|
||||
rw := NewRW(rwPool)
|
||||
assert.Equal(t, int64(0), rw.Size())
|
||||
accounted = 0
|
||||
rw.SetAccounting(account)
|
||||
assert.Equal(t, 0, accounted)
|
||||
writeFn(rw, data, chunkSize)
|
||||
assert.Equal(t, int64(size), rw.Size())
|
||||
assert.Equal(t, 0, accounted)
|
||||
readFn(rw, data, chunkSize)
|
||||
assert.NoError(t, rw.Close())
|
||||
assert.Equal(t, size, accounted)
|
||||
}
|
||||
for _, write := range []test{{"Write", write}, {"ReadFrom", readFrom}} {
|
||||
t.Run(write.name, func(t *testing.T) {
|
||||
for _, read := range []test{{"Read", read}, {"WriteTo", writeTo}} {
|
||||
t.Run(read.name, func(t *testing.T) {
|
||||
for _, size := range sizes {
|
||||
data := buf[:size]
|
||||
for _, chunkSize := range sizes {
|
||||
//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
|
||||
rw := NewRW(rwPool)
|
||||
assert.Equal(t, int64(0), rw.Size())
|
||||
accounted = 0
|
||||
rw.SetAccounting(account)
|
||||
assert.Equal(t, 0, accounted)
|
||||
write.fn(rw, data, chunkSize)
|
||||
assert.Equal(t, int64(size), rw.Size())
|
||||
assert.Equal(t, 0, accounted)
|
||||
read.fn(rw, data, chunkSize)
|
||||
assert.NoError(t, rw.Close())
|
||||
assert.Equal(t, size, accounted)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue