lib/pool: add SetAccounting to RW

This commit is contained in:
Nick Craig-Wood 2023-08-24 15:28:40 +01:00
parent 25703ad20e
commit f4b1a51af6
2 changed files with 116 additions and 5 deletions

View file

@ -5,13 +5,20 @@ import (
"io"
)
// RWAccount is a function which will be called after every read
// from the RW.
//
// It may return an error which will be passed back to the user.
type RWAccount func(n int) error
// RW contains the state for the read/writer
type RW struct {
pool *Pool // pool to get pages from
pages [][]byte // backing store
size int // size written
out int // offset we are reading from
lastOffset int // size in last page
pool *Pool // pool to get pages from
pages [][]byte // backing store
size int // size written
out int // offset we are reading from
lastOffset int // size in last page
account RWAccount // account for a read
}
var (
@ -34,6 +41,15 @@ func NewRW(pool *Pool) *RW {
}
}
// SetAccounting should be provided with a function which will be
// called after every read from the RW.
//
// It may return an error which will be passed back to the user.
func (rw *RW) SetAccounting(account RWAccount) *RW {
rw.account = account
return rw
}
// Returns the page and offset of i for reading.
//
// Ensure there are pages before calling this.
@ -48,6 +64,14 @@ func (rw *RW) readPage(i int) (page []byte) {
return page[offset:]
}
// account for n bytes being read
func (rw *RW) accountRead(n int) error {
if rw.account == nil {
return nil
}
return rw.account(n)
}
// Read reads up to len(p) bytes into p. It returns the number of
// bytes read (0 <= n <= len(p)) and any error encountered. If some
// data is available but not len(p) bytes, Read returns what is
@ -66,6 +90,10 @@ func (rw *RW) Read(p []byte) (n int, err error) {
p = p[nn:]
n += nn
rw.out += nn
err = rw.accountRead(nn)
if err != nil {
return n, err
}
}
return n, nil
}
@ -89,6 +117,10 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
if err != nil {
return n, err
}
err = rw.accountRead(nn)
if err != nil {
return n, err
}
}
return n, nil
}

View file

@ -177,6 +177,74 @@ func TestRW(t *testing.T) {
assert.Equal(t, 3, n)
assert.Equal(t, testData[7:10], dst)
})
errBoom := errors.New("accounting error")
t.Run("AccountRead", func(t *testing.T) {
// Test accounting errors
rw := newRW()
defer close(rw)
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
})
dst = make([]byte, 3)
n, err = rw.Read(dst)
assert.Equal(t, 3, n)
assert.NoError(t, err)
assert.Equal(t, 3, total)
})
t.Run("AccountWriteTo", func(t *testing.T) {
rw := newRW()
defer close(rw)
var b bytes.Buffer
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
})
n, err := rw.WriteTo(&b)
assert.NoError(t, err)
assert.Equal(t, 10, total)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
})
t.Run("AccountReadError", func(t *testing.T) {
// Test accounting errors
rw := newRW()
defer close(rw)
rw.SetAccounting(func(n int) error {
return errBoom
})
dst = make([]byte, 3)
n, err = rw.Read(dst)
assert.Equal(t, 3, n)
assert.Equal(t, errBoom, err)
})
t.Run("AccountWriteToError", func(t *testing.T) {
rw := newRW()
defer close(rw)
rw.SetAccounting(func(n int) error {
return errBoom
})
var b bytes.Buffer
n, err := rw.WriteTo(&b)
assert.Equal(t, errBoom, err)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
})
}
// A reader to read in chunkSize chunks
@ -220,6 +288,12 @@ func (w *testWriter) Write(p []byte) (n int, err error) {
}
func TestRWBoundaryConditions(t *testing.T) {
var accounted int
account := func(n int) error {
accounted += n
return nil
}
maxSize := 3 * blockSize
buf := []byte(random.String(maxSize))
@ -298,10 +372,15 @@ func TestRWBoundaryConditions(t *testing.T) {
//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
rw := NewRW(rwPool)
assert.Equal(t, int64(0), rw.Size())
accounted = 0
rw.SetAccounting(account)
assert.Equal(t, 0, accounted)
writeFn(rw, data, chunkSize)
assert.Equal(t, int64(size), rw.Size())
assert.Equal(t, 0, accounted)
readFn(rw, data, chunkSize)
assert.NoError(t, rw.Close())
assert.Equal(t, size, accounted)
}
}
}