diff --git a/fs/accounting/accounting.go b/fs/accounting/accounting.go index b6d512a5d..8a541c9b1 100644 --- a/fs/accounting/accounting.go +++ b/fs/accounting/accounting.go @@ -171,22 +171,39 @@ func (acc *Account) averageLoop() { } } -// Check the read is valid returning the number of bytes it is over -func (acc *Account) checkRead() (over int64, err error) { +// Check the read before it has happened is valid returning the number +// of bytes remaining to read. +func (acc *Account) checkReadBefore() (bytesUntilLimit int64, err error) { acc.statmu.Lock() if acc.max >= 0 { - over = acc.stats.GetBytes() - acc.max - if over >= 0 { + bytesUntilLimit = acc.max - acc.stats.GetBytes() + if bytesUntilLimit < 0 { acc.statmu.Unlock() - return over, ErrorMaxTransferLimitReachedFatal + return bytesUntilLimit, ErrorMaxTransferLimitReachedFatal } + } else { + bytesUntilLimit = 1 << 62 } // Set start time. if acc.start.IsZero() { acc.start = time.Now() } acc.statmu.Unlock() - return over, nil + return bytesUntilLimit, nil +} + +// Check the read call after the read has happened +func checkReadAfter(bytesUntilLimit int64, n int, err error) (outN int, outErr error) { + bytesUntilLimit -= int64(n) + if bytesUntilLimit < 0 { + // chop the overage off + n += int(bytesUntilLimit) + if n < 0 { + n = 0 + } + err = ErrorMaxTransferLimitReachedFatal + } + return n, err } // ServerSideCopyStart should be called at the start of a server side copy @@ -226,18 +243,11 @@ func (acc *Account) accountRead(n int) { // read bytes from the io.Reader passed in and account them func (acc *Account) read(in io.Reader, p []byte) (n int, err error) { - _, err = acc.checkRead() + bytesUntilLimit, err := acc.checkReadBefore() if err == nil { n, err = in.Read(p) acc.accountRead(n) - if over, checkErr := acc.checkRead(); checkErr == ErrorMaxTransferLimitReachedFatal { - // chop the overage off - n -= int(over) - if n < 0 { - n = 0 - } - err = checkErr - } + n, err = checkReadAfter(bytesUntilLimit, n, err) } return n, err } @@ -263,9 +273,10 @@ type accountWriteTo struct { // // Implementations must not retain p. func (awt *accountWriteTo) Write(p []byte) (n int, err error) { - _, err = awt.acc.checkRead() + bytesUntilLimit, err := awt.acc.checkReadBefore() if err == nil { n, err = awt.w.Write(p) + n, err = checkReadAfter(bytesUntilLimit, n, err) awt.acc.accountRead(n) } return n, err @@ -291,8 +302,9 @@ func (acc *Account) WriteTo(w io.Writer) (n int64, err error) { func (acc *Account) AccountRead(n int) (err error) { acc.mu.Lock() defer acc.mu.Unlock() - _, err = acc.checkRead() + bytesUntilLimit, err := acc.checkReadBefore() if err == nil { + n, err = checkReadAfter(bytesUntilLimit, n, err) acc.accountRead(n) } return err diff --git a/fs/accounting/accounting_test.go b/fs/accounting/accounting_test.go index 7c11f22a3..4ec11fa51 100644 --- a/fs/accounting/accounting_test.go +++ b/fs/accounting/accounting_test.go @@ -12,6 +12,7 @@ import ( "github.com/rclone/rclone/fs" "github.com/rclone/rclone/fs/asyncreader" "github.com/rclone/rclone/fs/fserrors" + "github.com/rclone/rclone/lib/readers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -278,6 +279,27 @@ func TestAccountMaxTransfer(t *testing.T) { assert.NoError(t, err) } +func TestAccountMaxTransferWriteTo(t *testing.T) { + old := fs.Config.MaxTransfer + oldMode := fs.Config.CutoffMode + + fs.Config.MaxTransfer = 15 + defer func() { + fs.Config.MaxTransfer = old + fs.Config.CutoffMode = oldMode + }() + + in := ioutil.NopCloser(readers.NewPatternReader(1024)) + stats := NewStats() + acc := newAccountSizeName(stats, in, 1, "test") + + var b bytes.Buffer + + n, err := acc.WriteTo(&b) + assert.Equal(t, int64(15), n) + assert.Equal(t, ErrorMaxTransferLimitReachedFatal, err) +} + func TestShortenName(t *testing.T) { for _, test := range []struct { in string diff --git a/fs/operations/operations_test.go b/fs/operations/operations_test.go index 8b8cae348..51713a829 100644 --- a/fs/operations/operations_test.go +++ b/fs/operations/operations_test.go @@ -1539,56 +1539,55 @@ func TestCopyFileMaxTransfer(t *testing.T) { ctx := context.Background() const sizeCutoff = 2048 - file1 := r.WriteFile("file1", "file1 contents", t1) - file2 := r.WriteFile("file2", "file2 contents"+random.String(sizeCutoff), t2) - - rfile1 := file1 - rfile1.Path = "sub/file1" - rfile2a := file2 - rfile2a.Path = "sub/file2a" - rfile2b := file2 - rfile2b.Path = "sub/file2b" - rfile2c := file2 - rfile2c.Path = "sub/file2c" + file1 := r.WriteFile("TestCopyFileMaxTransfer/file1", "file1 contents", t1) + file2 := r.WriteFile("TestCopyFileMaxTransfer/file2", "file2 contents"+random.String(sizeCutoff), t2) + file3 := r.WriteFile("TestCopyFileMaxTransfer/file3", "file3 contents"+random.String(sizeCutoff), t2) + file4 := r.WriteFile("TestCopyFileMaxTransfer/file4", "file4 contents"+random.String(sizeCutoff), t2) + // Cutoff mode: Hard fs.Config.MaxTransfer = sizeCutoff fs.Config.CutoffMode = fs.CutoffModeHard - accounting.Stats(ctx).ResetCounters() - err := operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile1.Path, file1.Path) + // file1: Show a small file gets transferred OK + accounting.Stats(ctx).ResetCounters() + err := operations.CopyFile(ctx, r.Fremote, r.Flocal, file1.Path, file1.Path) require.NoError(t, err) - fstest.CheckItems(t, r.Flocal, file1, file2) - fstest.CheckItems(t, r.Fremote, rfile1) + fstest.CheckItems(t, r.Flocal, file1, file2, file3, file4) + fstest.CheckItems(t, r.Fremote, file1) + // file2: show a large file does not get transferred accounting.Stats(ctx).ResetCounters() - - err = operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile2a.Path, file2.Path) - require.NotNil(t, err) + err = operations.CopyFile(ctx, r.Fremote, r.Flocal, file2.Path, file2.Path) + require.NotNil(t, err, "Did not get expected max transfer limit error") assert.Contains(t, err.Error(), "Max transfer limit reached") assert.True(t, fserrors.IsFatalError(err)) - fstest.CheckItems(t, r.Flocal, file1, file2) - fstest.CheckItems(t, r.Fremote, rfile1) + fstest.CheckItems(t, r.Flocal, file1, file2, file3, file4) + fstest.CheckItems(t, r.Fremote, file1) + // Cutoff mode: Cautious fs.Config.CutoffMode = fs.CutoffModeCautious - accounting.Stats(ctx).ResetCounters() - err = operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile2b.Path, file2.Path) + // file3: show a large file does not get transferred + accounting.Stats(ctx).ResetCounters() + err = operations.CopyFile(ctx, r.Fremote, r.Flocal, file3.Path, file3.Path) require.NotNil(t, err) assert.Contains(t, err.Error(), "Max transfer limit reached") assert.True(t, fserrors.IsFatalError(err)) - fstest.CheckItems(t, r.Flocal, file1, file2) - fstest.CheckItems(t, r.Fremote, rfile1) + fstest.CheckItems(t, r.Flocal, file1, file2, file3, file4) + fstest.CheckItems(t, r.Fremote, file1) if strings.HasPrefix(r.Fremote.Name(), "TestChunker") { t.Log("skipping remainder of test for chunker as it involves multiple transfers") return } + // Cutoff mode: Soft fs.Config.CutoffMode = fs.CutoffModeSoft - accounting.Stats(ctx).ResetCounters() - err = operations.CopyFile(ctx, r.Fremote, r.Flocal, rfile2c.Path, file2.Path) + // file4: show a large file does get transferred this time + accounting.Stats(ctx).ResetCounters() + err = operations.CopyFile(ctx, r.Fremote, r.Flocal, file4.Path, file4.Path) require.NoError(t, err) - fstest.CheckItems(t, r.Flocal, file1, file2) - fstest.CheckItems(t, r.Fremote, rfile1, rfile2c) + fstest.CheckItems(t, r.Flocal, file1, file2, file3, file4) + fstest.CheckItems(t, r.Fremote, file1, file4) }