forked from TrueCloudLab/rclone
lib/pool: add SetAccounting to RW
This commit is contained in:
parent
25703ad20e
commit
f4b1a51af6
2 changed files with 116 additions and 5 deletions
|
@ -5,6 +5,12 @@ import (
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RWAccount is a function which will be called after every read
|
||||||
|
// from the RW.
|
||||||
|
//
|
||||||
|
// It may return an error which will be passed back to the user.
|
||||||
|
type RWAccount func(n int) error
|
||||||
|
|
||||||
// RW contains the state for the read/writer
|
// RW contains the state for the read/writer
|
||||||
type RW struct {
|
type RW struct {
|
||||||
pool *Pool // pool to get pages from
|
pool *Pool // pool to get pages from
|
||||||
|
@ -12,6 +18,7 @@ type RW struct {
|
||||||
size int // size written
|
size int // size written
|
||||||
out int // offset we are reading from
|
out int // offset we are reading from
|
||||||
lastOffset int // size in last page
|
lastOffset int // size in last page
|
||||||
|
account RWAccount // account for a read
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -34,6 +41,15 @@ func NewRW(pool *Pool) *RW {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAccounting should be provided with a function which will be
|
||||||
|
// called after every read from the RW.
|
||||||
|
//
|
||||||
|
// It may return an error which will be passed back to the user.
|
||||||
|
func (rw *RW) SetAccounting(account RWAccount) *RW {
|
||||||
|
rw.account = account
|
||||||
|
return rw
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the page and offset of i for reading.
|
// Returns the page and offset of i for reading.
|
||||||
//
|
//
|
||||||
// Ensure there are pages before calling this.
|
// Ensure there are pages before calling this.
|
||||||
|
@ -48,6 +64,14 @@ func (rw *RW) readPage(i int) (page []byte) {
|
||||||
return page[offset:]
|
return page[offset:]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// account for n bytes being read
|
||||||
|
func (rw *RW) accountRead(n int) error {
|
||||||
|
if rw.account == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return rw.account(n)
|
||||||
|
}
|
||||||
|
|
||||||
// Read reads up to len(p) bytes into p. It returns the number of
|
// Read reads up to len(p) bytes into p. It returns the number of
|
||||||
// bytes read (0 <= n <= len(p)) and any error encountered. If some
|
// bytes read (0 <= n <= len(p)) and any error encountered. If some
|
||||||
// data is available but not len(p) bytes, Read returns what is
|
// data is available but not len(p) bytes, Read returns what is
|
||||||
|
@ -66,6 +90,10 @@ func (rw *RW) Read(p []byte) (n int, err error) {
|
||||||
p = p[nn:]
|
p = p[nn:]
|
||||||
n += nn
|
n += nn
|
||||||
rw.out += nn
|
rw.out += nn
|
||||||
|
err = rw.accountRead(nn)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
@ -89,6 +117,10 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
err = rw.accountRead(nn)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -177,6 +177,74 @@ func TestRW(t *testing.T) {
|
||||||
assert.Equal(t, 3, n)
|
assert.Equal(t, 3, n)
|
||||||
assert.Equal(t, testData[7:10], dst)
|
assert.Equal(t, testData[7:10], dst)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
errBoom := errors.New("accounting error")
|
||||||
|
|
||||||
|
t.Run("AccountRead", func(t *testing.T) {
|
||||||
|
// Test accounting errors
|
||||||
|
rw := newRW()
|
||||||
|
defer close(rw)
|
||||||
|
|
||||||
|
var total int
|
||||||
|
rw.SetAccounting(func(n int) error {
|
||||||
|
total += n
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
dst = make([]byte, 3)
|
||||||
|
n, err = rw.Read(dst)
|
||||||
|
assert.Equal(t, 3, n)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, total)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AccountWriteTo", func(t *testing.T) {
|
||||||
|
rw := newRW()
|
||||||
|
defer close(rw)
|
||||||
|
var b bytes.Buffer
|
||||||
|
|
||||||
|
var total int
|
||||||
|
rw.SetAccounting(func(n int) error {
|
||||||
|
total += n
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
n, err := rw.WriteTo(&b)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 10, total)
|
||||||
|
assert.Equal(t, int64(10), n)
|
||||||
|
assert.Equal(t, testData, b.Bytes())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AccountReadError", func(t *testing.T) {
|
||||||
|
// Test accounting errors
|
||||||
|
rw := newRW()
|
||||||
|
defer close(rw)
|
||||||
|
|
||||||
|
rw.SetAccounting(func(n int) error {
|
||||||
|
return errBoom
|
||||||
|
})
|
||||||
|
|
||||||
|
dst = make([]byte, 3)
|
||||||
|
n, err = rw.Read(dst)
|
||||||
|
assert.Equal(t, 3, n)
|
||||||
|
assert.Equal(t, errBoom, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AccountWriteToError", func(t *testing.T) {
|
||||||
|
rw := newRW()
|
||||||
|
defer close(rw)
|
||||||
|
rw.SetAccounting(func(n int) error {
|
||||||
|
return errBoom
|
||||||
|
})
|
||||||
|
var b bytes.Buffer
|
||||||
|
|
||||||
|
n, err := rw.WriteTo(&b)
|
||||||
|
assert.Equal(t, errBoom, err)
|
||||||
|
assert.Equal(t, int64(10), n)
|
||||||
|
assert.Equal(t, testData, b.Bytes())
|
||||||
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// A reader to read in chunkSize chunks
|
// A reader to read in chunkSize chunks
|
||||||
|
@ -220,6 +288,12 @@ func (w *testWriter) Write(p []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRWBoundaryConditions(t *testing.T) {
|
func TestRWBoundaryConditions(t *testing.T) {
|
||||||
|
var accounted int
|
||||||
|
account := func(n int) error {
|
||||||
|
accounted += n
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
maxSize := 3 * blockSize
|
maxSize := 3 * blockSize
|
||||||
buf := []byte(random.String(maxSize))
|
buf := []byte(random.String(maxSize))
|
||||||
|
|
||||||
|
@ -298,10 +372,15 @@ func TestRWBoundaryConditions(t *testing.T) {
|
||||||
//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
|
//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
|
||||||
rw := NewRW(rwPool)
|
rw := NewRW(rwPool)
|
||||||
assert.Equal(t, int64(0), rw.Size())
|
assert.Equal(t, int64(0), rw.Size())
|
||||||
|
accounted = 0
|
||||||
|
rw.SetAccounting(account)
|
||||||
|
assert.Equal(t, 0, accounted)
|
||||||
writeFn(rw, data, chunkSize)
|
writeFn(rw, data, chunkSize)
|
||||||
assert.Equal(t, int64(size), rw.Size())
|
assert.Equal(t, int64(size), rw.Size())
|
||||||
|
assert.Equal(t, 0, accounted)
|
||||||
readFn(rw, data, chunkSize)
|
readFn(rw, data, chunkSize)
|
||||||
assert.NoError(t, rw.Close())
|
assert.NoError(t, rw.Close())
|
||||||
|
assert.Equal(t, size, accounted)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue