From bc986b44b2f1aef9c75d89435d26c465a35d6b7f Mon Sep 17 00:00:00 2001
From: Nick Craig-Wood <nick@craig-wood.com>
Date: Thu, 24 Aug 2023 16:42:09 +0100
Subject: [PATCH] lib/pool: add DelayAccounting() to fix accounting when
 reading hashes

---
 lib/pool/reader_writer.go      |  54 +++++++-
 lib/pool/reader_writer_test.go | 241 +++++++++++++++++++++++----------
 2 files changed, 219 insertions(+), 76 deletions(-)

diff --git a/lib/pool/reader_writer.go b/lib/pool/reader_writer.go
index e7ed535a2..18bd11e8e 100644
--- a/lib/pool/reader_writer.go
+++ b/lib/pool/reader_writer.go
@@ -19,6 +19,8 @@ type RW struct {
 	out        int       // offset we are reading from
 	lastOffset int       // size in last page
 	account    RWAccount // account for a read
+	reads      int       // count how many times the data has been read
+	accountOn  int       // only account on or after this read
 }
 
 var (
@@ -50,10 +52,40 @@ func (rw *RW) SetAccounting(account RWAccount) *RW {
 	return rw
 }
 
+// DelayAccountinger enables an accounting delay
+type DelayAccountinger interface {
+	// DelayAccounting makes sure the accounting function only
+	// gets called on the i-th or later read of the data from this
+	// point (counting from 1).
+	//
+	// This is useful so that we don't account initial reads of
+	// the data e.g. when calculating hashes.
+	//
+	// Set this to 0 to account everything.
+	DelayAccounting(i int)
+}
+
+// DelayAccounting makes sure the accounting function only gets called
+// on the i-th or later read of the data from this point (counting
+// from 1).
+//
+// This is useful so that we don't account initial reads of the data
+// e.g. when calculating hashes.
+//
+// Set this to 0 to account everything.
+func (rw *RW) DelayAccounting(i int) {
+	rw.accountOn = i
+	rw.reads = 0
+}
+
 // Returns the page and offset of i for reading.
 //
 // Ensure there are pages before calling this.
 func (rw *RW) readPage(i int) (page []byte) {
+	// Count a read of the data if we read the first page
+	if i == 0 {
+		rw.reads++
+	}
 	pageNumber := i / rw.pool.bufferSize
 	offset := i % rw.pool.bufferSize
 	page = rw.pages[pageNumber]
@@ -69,7 +101,14 @@ func (rw *RW) accountRead(n int) error {
 	if rw.account == nil {
 		return nil
 	}
-	return rw.account(n)
+	// Don't start accounting until we've reached this many reads
+	//
+	// rw.reads will be 1 the first time this is called
+	// rw.accountOn 2 means start accounting on the 2nd read through
+	if rw.reads >= rw.accountOn {
+		return rw.account(n)
+	}
+	return nil
 }
 
 // Read reads up to len(p) bytes into p. It returns the number of
@@ -227,10 +266,11 @@ func (rw *RW) Size() int64 {
 
 // Check interfaces
 var (
-	_ io.Reader     = (*RW)(nil)
-	_ io.ReaderFrom = (*RW)(nil)
-	_ io.Writer     = (*RW)(nil)
-	_ io.WriterTo   = (*RW)(nil)
-	_ io.Seeker     = (*RW)(nil)
-	_ io.Closer     = (*RW)(nil)
+	_ io.Reader         = (*RW)(nil)
+	_ io.ReaderFrom     = (*RW)(nil)
+	_ io.Writer         = (*RW)(nil)
+	_ io.WriterTo       = (*RW)(nil)
+	_ io.Seeker         = (*RW)(nil)
+	_ io.Closer         = (*RW)(nil)
+	_ DelayAccountinger = (*RW)(nil)
 )
diff --git a/lib/pool/reader_writer_test.go b/lib/pool/reader_writer_test.go
index cd7592e45..e9e02c22e 100644
--- a/lib/pool/reader_writer_test.go
+++ b/lib/pool/reader_writer_test.go
@@ -9,6 +9,7 @@ import (
 
 	"github.com/rclone/rclone/lib/random"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 const blockSize = 4096
@@ -178,71 +179,164 @@ func TestRW(t *testing.T) {
 		assert.Equal(t, testData[7:10], dst)
 	})
 
-	errBoom := errors.New("accounting error")
+	t.Run("Account", func(t *testing.T) {
+		errBoom := errors.New("accounting error")
 
-	t.Run("AccountRead", func(t *testing.T) {
-		// Test accounting errors
-		rw := newRW()
-		defer close(rw)
+		t.Run("Read", func(t *testing.T) {
+			rw := newRW()
+			defer close(rw)
 
-		var total int
-		rw.SetAccounting(func(n int) error {
-			total += n
-			return nil
+			var total int
+			rw.SetAccounting(func(n int) error {
+				total += n
+				return nil
+			})
+
+			dst = make([]byte, 3)
+			n, err = rw.Read(dst)
+			assert.Equal(t, 3, n)
+			assert.NoError(t, err)
+			assert.Equal(t, 3, total)
 		})
 
-		dst = make([]byte, 3)
-		n, err = rw.Read(dst)
-		assert.Equal(t, 3, n)
-		assert.NoError(t, err)
-		assert.Equal(t, 3, total)
-	})
+		t.Run("WriteTo", func(t *testing.T) {
+			rw := newRW()
+			defer close(rw)
+			var b bytes.Buffer
 
-	t.Run("AccountWriteTo", func(t *testing.T) {
-		rw := newRW()
-		defer close(rw)
-		var b bytes.Buffer
+			var total int
+			rw.SetAccounting(func(n int) error {
+				total += n
+				return nil
+			})
 
-		var total int
-		rw.SetAccounting(func(n int) error {
-			total += n
-			return nil
+			n, err := rw.WriteTo(&b)
+			assert.NoError(t, err)
+			assert.Equal(t, 10, total)
+			assert.Equal(t, int64(10), n)
+			assert.Equal(t, testData, b.Bytes())
 		})
 
-		n, err := rw.WriteTo(&b)
-		assert.NoError(t, err)
-		assert.Equal(t, 10, total)
-		assert.Equal(t, int64(10), n)
-		assert.Equal(t, testData, b.Bytes())
-	})
+		t.Run("ReadDelay", func(t *testing.T) {
+			rw := newRW()
+			defer close(rw)
 
-	t.Run("AccountReadError", func(t *testing.T) {
-		// Test accounting errors
-		rw := newRW()
-		defer close(rw)
+			var total int
+			rw.SetAccounting(func(n int) error {
+				total += n
+				return nil
+			})
 
-		rw.SetAccounting(func(n int) error {
-			return errBoom
+			rewind := func() {
+				_, err := rw.Seek(0, io.SeekStart)
+				require.NoError(t, err)
+			}
+
+			rw.DelayAccounting(3)
+
+			dst = make([]byte, 16)
+
+			n, err = rw.Read(dst)
+			assert.Equal(t, 10, n)
+			assert.Equal(t, io.EOF, err)
+			assert.Equal(t, 0, total)
+			rewind()
+
+			n, err = rw.Read(dst)
+			assert.Equal(t, 10, n)
+			assert.Equal(t, io.EOF, err)
+			assert.Equal(t, 0, total)
+			rewind()
+
+			n, err = rw.Read(dst)
+			assert.Equal(t, 10, n)
+			assert.Equal(t, io.EOF, err)
+			assert.Equal(t, 10, total)
+			rewind()
+
+			n, err = rw.Read(dst)
+			assert.Equal(t, 10, n)
+			assert.Equal(t, io.EOF, err)
+			assert.Equal(t, 20, total)
+			rewind()
 		})
 
-		dst = make([]byte, 3)
-		n, err = rw.Read(dst)
-		assert.Equal(t, 3, n)
-		assert.Equal(t, errBoom, err)
-	})
+		t.Run("WriteToDelay", func(t *testing.T) {
+			rw := newRW()
+			defer close(rw)
+			var b bytes.Buffer
 
-	t.Run("AccountWriteToError", func(t *testing.T) {
-		rw := newRW()
-		defer close(rw)
-		rw.SetAccounting(func(n int) error {
-			return errBoom
+			var total int
+			rw.SetAccounting(func(n int) error {
+				total += n
+				return nil
+			})
+
+			rw.DelayAccounting(3)
+
+			rewind := func() {
+				_, err := rw.Seek(0, io.SeekStart)
+				require.NoError(t, err)
+				b.Reset()
+			}
+
+			n, err := rw.WriteTo(&b)
+			assert.NoError(t, err)
+			assert.Equal(t, 0, total)
+			assert.Equal(t, int64(10), n)
+			assert.Equal(t, testData, b.Bytes())
+			rewind()
+
+			n, err = rw.WriteTo(&b)
+			assert.NoError(t, err)
+			assert.Equal(t, 0, total)
+			assert.Equal(t, int64(10), n)
+			assert.Equal(t, testData, b.Bytes())
+			rewind()
+
+			n, err = rw.WriteTo(&b)
+			assert.NoError(t, err)
+			assert.Equal(t, 10, total)
+			assert.Equal(t, int64(10), n)
+			assert.Equal(t, testData, b.Bytes())
+			rewind()
+
+			n, err = rw.WriteTo(&b)
+			assert.NoError(t, err)
+			assert.Equal(t, 20, total)
+			assert.Equal(t, int64(10), n)
+			assert.Equal(t, testData, b.Bytes())
+			rewind()
 		})
-		var b bytes.Buffer
 
-		n, err := rw.WriteTo(&b)
-		assert.Equal(t, errBoom, err)
-		assert.Equal(t, int64(10), n)
-		assert.Equal(t, testData, b.Bytes())
+		t.Run("ReadError", func(t *testing.T) {
+			// Test accounting errors
+			rw := newRW()
+			defer close(rw)
+
+			rw.SetAccounting(func(n int) error {
+				return errBoom
+			})
+
+			dst = make([]byte, 3)
+			n, err = rw.Read(dst)
+			assert.Equal(t, 3, n)
+			assert.Equal(t, errBoom, err)
+		})
+
+		t.Run("WriteToError", func(t *testing.T) {
+			rw := newRW()
+			defer close(rw)
+			rw.SetAccounting(func(n int) error {
+				return errBoom
+			})
+			var b bytes.Buffer
+
+			n, err := rw.WriteTo(&b)
+			assert.Equal(t, errBoom, err)
+			assert.Equal(t, int64(10), n)
+			assert.Equal(t, testData, b.Bytes())
+		})
 	})
 
 }
@@ -363,26 +457,35 @@ func TestRWBoundaryConditions(t *testing.T) {
 		assert.Equal(t, int64(len(data)), nn)
 	}
 
+	type test struct {
+		name string
+		fn   func(*RW, []byte, int)
+	}
+
 	// Read and Write the data with a range of block sizes and functions
-	for _, writeFn := range []func(*RW, []byte, int){write, readFrom} {
-		for _, readFn := range []func(*RW, []byte, int){read, writeTo} {
-			for _, size := range sizes {
-				data := buf[:size]
-				for _, chunkSize := range sizes {
-					//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
-					rw := NewRW(rwPool)
-					assert.Equal(t, int64(0), rw.Size())
-					accounted = 0
-					rw.SetAccounting(account)
-					assert.Equal(t, 0, accounted)
-					writeFn(rw, data, chunkSize)
-					assert.Equal(t, int64(size), rw.Size())
-					assert.Equal(t, 0, accounted)
-					readFn(rw, data, chunkSize)
-					assert.NoError(t, rw.Close())
-					assert.Equal(t, size, accounted)
-				}
+	for _, write := range []test{{"Write", write}, {"ReadFrom", readFrom}} {
+		t.Run(write.name, func(t *testing.T) {
+			for _, read := range []test{{"Read", read}, {"WriteTo", writeTo}} {
+				t.Run(read.name, func(t *testing.T) {
+					for _, size := range sizes {
+						data := buf[:size]
+						for _, chunkSize := range sizes {
+							//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
+							rw := NewRW(rwPool)
+							assert.Equal(t, int64(0), rw.Size())
+							accounted = 0
+							rw.SetAccounting(account)
+							assert.Equal(t, 0, accounted)
+							write.fn(rw, data, chunkSize)
+							assert.Equal(t, int64(size), rw.Size())
+							assert.Equal(t, 0, accounted)
+							read.fn(rw, data, chunkSize)
+							assert.NoError(t, rw.Close())
+							assert.Equal(t, size, accounted)
+						}
+					}
+				})
 			}
-		}
+		})
 	}
 }