lib/pool: add SetAccounting to RW

This commit is contained in:
Nick Craig-Wood 2023-08-24 15:28:40 +01:00
parent 25703ad20e
commit f4b1a51af6
2 changed files with 116 additions and 5 deletions

View file

@ -5,6 +5,12 @@ import (
"io" "io"
) )
// RWAccount is a function which will be called after every read
// from the RW.
//
// It may return an error which will be passed back to the user.
type RWAccount func(n int) error
// RW contains the state for the read/writer // RW contains the state for the read/writer
type RW struct { type RW struct {
pool *Pool // pool to get pages from pool *Pool // pool to get pages from
@ -12,6 +18,7 @@ type RW struct {
size int // size written size int // size written
out int // offset we are reading from out int // offset we are reading from
lastOffset int // size in last page lastOffset int // size in last page
account RWAccount // account for a read
} }
var ( var (
@ -34,6 +41,15 @@ func NewRW(pool *Pool) *RW {
} }
} }
// SetAccounting should be provided with a function which will be
// called after every read from the RW.
//
// It may return an error which will be passed back to the user.
func (rw *RW) SetAccounting(account RWAccount) *RW {
rw.account = account
return rw
}
// Returns the page and offset of i for reading. // Returns the page and offset of i for reading.
// //
// Ensure there are pages before calling this. // Ensure there are pages before calling this.
@ -48,6 +64,14 @@ func (rw *RW) readPage(i int) (page []byte) {
return page[offset:] return page[offset:]
} }
// account for n bytes being read
func (rw *RW) accountRead(n int) error {
if rw.account == nil {
return nil
}
return rw.account(n)
}
// Read reads up to len(p) bytes into p. It returns the number of // Read reads up to len(p) bytes into p. It returns the number of
// bytes read (0 <= n <= len(p)) and any error encountered. If some // bytes read (0 <= n <= len(p)) and any error encountered. If some
// data is available but not len(p) bytes, Read returns what is // data is available but not len(p) bytes, Read returns what is
@ -66,6 +90,10 @@ func (rw *RW) Read(p []byte) (n int, err error) {
p = p[nn:] p = p[nn:]
n += nn n += nn
rw.out += nn rw.out += nn
err = rw.accountRead(nn)
if err != nil {
return n, err
}
} }
return n, nil return n, nil
} }
@ -89,6 +117,10 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) {
if err != nil { if err != nil {
return n, err return n, err
} }
err = rw.accountRead(nn)
if err != nil {
return n, err
}
} }
return n, nil return n, nil
} }

View file

@ -177,6 +177,74 @@ func TestRW(t *testing.T) {
assert.Equal(t, 3, n) assert.Equal(t, 3, n)
assert.Equal(t, testData[7:10], dst) assert.Equal(t, testData[7:10], dst)
}) })
errBoom := errors.New("accounting error")
t.Run("AccountRead", func(t *testing.T) {
// Test accounting errors
rw := newRW()
defer close(rw)
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
})
dst = make([]byte, 3)
n, err = rw.Read(dst)
assert.Equal(t, 3, n)
assert.NoError(t, err)
assert.Equal(t, 3, total)
})
t.Run("AccountWriteTo", func(t *testing.T) {
rw := newRW()
defer close(rw)
var b bytes.Buffer
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
})
n, err := rw.WriteTo(&b)
assert.NoError(t, err)
assert.Equal(t, 10, total)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
})
t.Run("AccountReadError", func(t *testing.T) {
// Test accounting errors
rw := newRW()
defer close(rw)
rw.SetAccounting(func(n int) error {
return errBoom
})
dst = make([]byte, 3)
n, err = rw.Read(dst)
assert.Equal(t, 3, n)
assert.Equal(t, errBoom, err)
})
t.Run("AccountWriteToError", func(t *testing.T) {
rw := newRW()
defer close(rw)
rw.SetAccounting(func(n int) error {
return errBoom
})
var b bytes.Buffer
n, err := rw.WriteTo(&b)
assert.Equal(t, errBoom, err)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
})
} }
// A reader to read in chunkSize chunks // A reader to read in chunkSize chunks
@ -220,6 +288,12 @@ func (w *testWriter) Write(p []byte) (n int, err error) {
} }
func TestRWBoundaryConditions(t *testing.T) { func TestRWBoundaryConditions(t *testing.T) {
var accounted int
account := func(n int) error {
accounted += n
return nil
}
maxSize := 3 * blockSize maxSize := 3 * blockSize
buf := []byte(random.String(maxSize)) buf := []byte(random.String(maxSize))
@ -298,10 +372,15 @@ func TestRWBoundaryConditions(t *testing.T) {
//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize) //t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
rw := NewRW(rwPool) rw := NewRW(rwPool)
assert.Equal(t, int64(0), rw.Size()) assert.Equal(t, int64(0), rw.Size())
accounted = 0
rw.SetAccounting(account)
assert.Equal(t, 0, accounted)
writeFn(rw, data, chunkSize) writeFn(rw, data, chunkSize)
assert.Equal(t, int64(size), rw.Size()) assert.Equal(t, int64(size), rw.Size())
assert.Equal(t, 0, accounted)
readFn(rw, data, chunkSize) readFn(rw, data, chunkSize)
assert.NoError(t, rw.Close()) assert.NoError(t, rw.Close())
assert.Equal(t, size, accounted)
} }
} }
} }