From f4b1a51af6272903ae24f776d7c85ce0fd75ae43 Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Thu, 24 Aug 2023 15:28:40 +0100 Subject: [PATCH] lib/pool: add SetAccounting to RW --- lib/pool/reader_writer.go | 42 +++++++++++++++--- lib/pool/reader_writer_test.go | 79 ++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 5 deletions(-) diff --git a/lib/pool/reader_writer.go b/lib/pool/reader_writer.go index e74acc4e3..e7ed535a2 100644 --- a/lib/pool/reader_writer.go +++ b/lib/pool/reader_writer.go @@ -5,13 +5,20 @@ import ( "io" ) +// RWAccount is a function which will be called after every read +// from the RW. +// +// It may return an error which will be passed back to the user. +type RWAccount func(n int) error + // RW contains the state for the read/writer type RW struct { - pool *Pool // pool to get pages from - pages [][]byte // backing store - size int // size written - out int // offset we are reading from - lastOffset int // size in last page + pool *Pool // pool to get pages from + pages [][]byte // backing store + size int // size written + out int // offset we are reading from + lastOffset int // size in last page + account RWAccount // account for a read } var ( @@ -34,6 +41,15 @@ func NewRW(pool *Pool) *RW { } } +// SetAccounting should be provided with a function which will be +// called after every read from the RW. +// +// It may return an error which will be passed back to the user. +func (rw *RW) SetAccounting(account RWAccount) *RW { + rw.account = account + return rw +} + // Returns the page and offset of i for reading. // // Ensure there are pages before calling this. @@ -48,6 +64,14 @@ func (rw *RW) readPage(i int) (page []byte) { return page[offset:] } +// account for n bytes being read +func (rw *RW) accountRead(n int) error { + if rw.account == nil { + return nil + } + return rw.account(n) +} + // Read reads up to len(p) bytes into p. It returns the number of // bytes read (0 <= n <= len(p)) and any error encountered. If some // data is available but not len(p) bytes, Read returns what is @@ -66,6 +90,10 @@ func (rw *RW) Read(p []byte) (n int, err error) { p = p[nn:] n += nn rw.out += nn + err = rw.accountRead(nn) + if err != nil { + return n, err + } } return n, nil } @@ -89,6 +117,10 @@ func (rw *RW) WriteTo(w io.Writer) (n int64, err error) { if err != nil { return n, err } + err = rw.accountRead(nn) + if err != nil { + return n, err + } } return n, nil } diff --git a/lib/pool/reader_writer_test.go b/lib/pool/reader_writer_test.go index cd4dfd497..cd7592e45 100644 --- a/lib/pool/reader_writer_test.go +++ b/lib/pool/reader_writer_test.go @@ -177,6 +177,74 @@ func TestRW(t *testing.T) { assert.Equal(t, 3, n) assert.Equal(t, testData[7:10], dst) }) + + errBoom := errors.New("accounting error") + + t.Run("AccountRead", func(t *testing.T) { + // Test accounting errors + rw := newRW() + defer close(rw) + + var total int + rw.SetAccounting(func(n int) error { + total += n + return nil + }) + + dst = make([]byte, 3) + n, err = rw.Read(dst) + assert.Equal(t, 3, n) + assert.NoError(t, err) + assert.Equal(t, 3, total) + }) + + t.Run("AccountWriteTo", func(t *testing.T) { + rw := newRW() + defer close(rw) + var b bytes.Buffer + + var total int + rw.SetAccounting(func(n int) error { + total += n + return nil + }) + + n, err := rw.WriteTo(&b) + assert.NoError(t, err) + assert.Equal(t, 10, total) + assert.Equal(t, int64(10), n) + assert.Equal(t, testData, b.Bytes()) + }) + + t.Run("AccountReadError", func(t *testing.T) { + // Test accounting errors + rw := newRW() + defer close(rw) + + rw.SetAccounting(func(n int) error { + return errBoom + }) + + dst = make([]byte, 3) + n, err = rw.Read(dst) + assert.Equal(t, 3, n) + assert.Equal(t, errBoom, err) + }) + + t.Run("AccountWriteToError", func(t *testing.T) { + rw := newRW() + defer close(rw) + rw.SetAccounting(func(n int) error { + return errBoom + }) + var b bytes.Buffer + + n, err := rw.WriteTo(&b) + assert.Equal(t, errBoom, err) + assert.Equal(t, int64(10), n) + assert.Equal(t, testData, b.Bytes()) + }) + } // A reader to read in chunkSize chunks @@ -220,6 +288,12 @@ func (w *testWriter) Write(p []byte) (n int, err error) { } func TestRWBoundaryConditions(t *testing.T) { + var accounted int + account := func(n int) error { + accounted += n + return nil + } + maxSize := 3 * blockSize buf := []byte(random.String(maxSize)) @@ -298,10 +372,15 @@ func TestRWBoundaryConditions(t *testing.T) { //t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize) rw := NewRW(rwPool) assert.Equal(t, int64(0), rw.Size()) + accounted = 0 + rw.SetAccounting(account) + assert.Equal(t, 0, accounted) writeFn(rw, data, chunkSize) assert.Equal(t, int64(size), rw.Size()) + assert.Equal(t, 0, accounted) readFn(rw, data, chunkSize) assert.NoError(t, rw.Close()) + assert.Equal(t, size, accounted) } } }