lib/pool: add SetAccounting to RW
This commit is contained in:
parent
25703ad20e
commit
f4b1a51af6
2 changed files with 116 additions and 5 deletions
|
@ -5,13 +5,20 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
// RWAccount is a function which will be called after every read
|
||||
// from the RW.
|
||||
//
|
||||
// It may return an error which will be passed back to the user.
|
||||
type RWAccount func(n int) error
|
||||
|
||||
// RW contains the state for the read/writer
|
||||
type RW struct {
|
||||
pool *Pool // pool to get pages from
|
||||
pages [][]byte // backing store
|
||||
size int // size written
|
||||
out int // offset we are reading from
|
||||
lastOffset int // size in last page
|
||||
pool *Pool // pool to get pages from
|
||||
pages [][]byte // backing store
|
||||
size int // size written
|
||||
out int // offset we are reading from
|
||||
lastOffset int // size in last page
|
||||
account RWAccount // account for a read
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -34,6 +41,15 @@ func NewRW(pool *Pool) *RW {
|
|||
}
|
||||
}
|
||||
|
||||
// SetAccounting should be provided with a function which will be
|
||||
// called after every read from the RW.
|
||||
//
|
||||
// It may return an error which will be passed back to the user.
|
||||
func (rw *RW) SetAccounting(account RWAccount) *RW {
|
||||
rw.account = account
|
||||
return rw
|
||||
}
|
||||
|
||||
// Returns the page and offset of i for reading.
|
||||
//
|
||||
// Ensure there are pages before calling this.
|
||||
|
@ -48,6 +64,14 @@ func (rw *RW) readPage(i int) (page []byte) {
|
|||
return page[offset:]
|
||||
}
|
||||
|
||||
// account for n bytes being read
|
||||
func (rw *RW) accountRead(n int) error {
|
||||
if rw.account == nil {
|
||||
return nil
|
||||
}
|
||||
return rw.account(n)
|
||||
}
|
||||
|
||||
// Read reads up to len(p) bytes into p. It returns the number of
|
||||
// bytes read (0 <= n <= len(p)) and any error encountered. If some
|
||||
// data is available but not len(p) bytes, Read returns what is
|
||||
|
@ -66,6 +90,10 @@ func (rw *RW) Read(p []byte) (n int, err error) {
|
|||
p = p[nn:]
|
||||
n += nn
|
||||
rw.out += nn
|
||||
err = rw.accountRead(nn)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
@ -89,6 +117,10 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
|
|||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
err = rw.accountRead(nn)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
|
|
@ -177,6 +177,74 @@ func TestRW(t *testing.T) {
|
|||
assert.Equal(t, 3, n)
|
||||
assert.Equal(t, testData[7:10], dst)
|
||||
})
|
||||
|
||||
errBoom := errors.New("accounting error")
|
||||
|
||||
t.Run("AccountRead", func(t *testing.T) {
|
||||
// Test accounting errors
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
|
||||
var total int
|
||||
rw.SetAccounting(func(n int) error {
|
||||
total += n
|
||||
return nil
|
||||
})
|
||||
|
||||
dst = make([]byte, 3)
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 3, n)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, total)
|
||||
})
|
||||
|
||||
t.Run("AccountWriteTo", func(t *testing.T) {
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
var b bytes.Buffer
|
||||
|
||||
var total int
|
||||
rw.SetAccounting(func(n int) error {
|
||||
total += n
|
||||
return nil
|
||||
})
|
||||
|
||||
n, err := rw.WriteTo(&b)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, total)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
})
|
||||
|
||||
t.Run("AccountReadError", func(t *testing.T) {
|
||||
// Test accounting errors
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
|
||||
rw.SetAccounting(func(n int) error {
|
||||
return errBoom
|
||||
})
|
||||
|
||||
dst = make([]byte, 3)
|
||||
n, err = rw.Read(dst)
|
||||
assert.Equal(t, 3, n)
|
||||
assert.Equal(t, errBoom, err)
|
||||
})
|
||||
|
||||
t.Run("AccountWriteToError", func(t *testing.T) {
|
||||
rw := newRW()
|
||||
defer close(rw)
|
||||
rw.SetAccounting(func(n int) error {
|
||||
return errBoom
|
||||
})
|
||||
var b bytes.Buffer
|
||||
|
||||
n, err := rw.WriteTo(&b)
|
||||
assert.Equal(t, errBoom, err)
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, testData, b.Bytes())
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// A reader to read in chunkSize chunks
|
||||
|
@ -220,6 +288,12 @@ func (w *testWriter) Write(p []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
func TestRWBoundaryConditions(t *testing.T) {
|
||||
var accounted int
|
||||
account := func(n int) error {
|
||||
accounted += n
|
||||
return nil
|
||||
}
|
||||
|
||||
maxSize := 3 * blockSize
|
||||
buf := []byte(random.String(maxSize))
|
||||
|
||||
|
@ -298,10 +372,15 @@ func TestRWBoundaryConditions(t *testing.T) {
|
|||
//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
|
||||
rw := NewRW(rwPool)
|
||||
assert.Equal(t, int64(0), rw.Size())
|
||||
accounted = 0
|
||||
rw.SetAccounting(account)
|
||||
assert.Equal(t, 0, accounted)
|
||||
writeFn(rw, data, chunkSize)
|
||||
assert.Equal(t, int64(size), rw.Size())
|
||||
assert.Equal(t, 0, accounted)
|
||||
readFn(rw, data, chunkSize)
|
||||
assert.NoError(t, rw.Close())
|
||||
assert.Equal(t, size, accounted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue