diff --git a/fs/accounting/accounting.go b/fs/accounting/accounting.go index 1693ed2e9..b6d512a5d 100644 --- a/fs/accounting/accounting.go +++ b/fs/accounting/accounting.go @@ -249,6 +249,44 @@ func (acc *Account) Read(p []byte) (n int, err error) { return acc.read(acc.in, p) } +// Thin wrapper for w +type accountWriteTo struct { + w io.Writer + acc *Account +} + +// Write writes len(p) bytes from p to the underlying data stream. It +// returns the number of bytes written from p (0 <= n <= len(p)) and +// any error encountered that caused the write to stop early. Write +// must return a non-nil error if it returns n < len(p). Write must +// not modify the slice data, even temporarily. +// +// Implementations must not retain p. +func (awt *accountWriteTo) Write(p []byte) (n int, err error) { + _, err = awt.acc.checkRead() + if err == nil { + n, err = awt.w.Write(p) + awt.acc.accountRead(n) + } + return n, err +} + +// WriteTo writes data to w until there's no more data to write or +// when an error occurs. The return value n is the number of bytes +// written. Any error encountered during the write is also returned. +func (acc *Account) WriteTo(w io.Writer) (n int64, err error) { + acc.mu.Lock() + in := acc.in + acc.mu.Unlock() + wrappedWriter := accountWriteTo{w: w, acc: acc} + if do, ok := in.(io.WriterTo); ok { + n, err = do.WriteTo(&wrappedWriter) + } else { + n, err = io.Copy(&wrappedWriter, in) + } + return +} + // AccountRead account having read n bytes func (acc *Account) AccountRead(n int) (err error) { acc.mu.Lock() diff --git a/fs/accounting/accounting_test.go b/fs/accounting/accounting_test.go index 4b876e1cd..7c11f22a3 100644 --- a/fs/accounting/accounting_test.go +++ b/fs/accounting/accounting_test.go @@ -19,6 +19,7 @@ import ( // Check it satisfies the interfaces var ( _ io.ReadCloser = &Account{} + _ io.WriterTo = &Account{} _ io.Reader = &accountStream{} _ Accounter = &Account{} _ Accounter = &accountStream{} @@ -117,6 +118,46 @@ func TestAccountRead(t *testing.T) { assert.NoError(t, acc.Close()) } +func testAccountWriteTo(t *testing.T, withBuffer bool) { + buf := make([]byte, 2*asyncreader.BufferSize+1) + for i := range buf { + buf[i] = byte(i % 251) + } + in := ioutil.NopCloser(bytes.NewBuffer(buf)) + stats := NewStats() + acc := newAccountSizeName(stats, in, int64(len(buf)), "test") + if withBuffer { + acc = acc.WithBuffer() + } + + assert.True(t, acc.start.IsZero()) + assert.Equal(t, 0, acc.lpBytes) + assert.Equal(t, int64(0), acc.bytes) + assert.Equal(t, int64(0), stats.bytes) + + var out bytes.Buffer + + n, err := acc.WriteTo(&out) + assert.NoError(t, err) + assert.Equal(t, int64(len(buf)), n) + assert.Equal(t, buf, out.Bytes()) + + assert.False(t, acc.start.IsZero()) + assert.Equal(t, len(buf), acc.lpBytes) + assert.Equal(t, int64(len(buf)), acc.bytes) + assert.Equal(t, int64(len(buf)), stats.bytes) + + assert.NoError(t, acc.Close()) +} + +func TestAccountWriteTo(t *testing.T) { + testAccountWriteTo(t, false) +} + +func TestAccountWriteToWithBuffer(t *testing.T) { + testAccountWriteTo(t, true) +} + func TestAccountString(t *testing.T) { in := ioutil.NopCloser(bytes.NewBuffer([]byte{1, 2, 3})) stats := NewStats()