accounting: Allow transfers to be canceled with context #3257
This makes all transfers cancelable even if the backend doesn't support context as all transfers are done using the Accounting framework.
This commit is contained in:
parent
421585dd72
commit
122a47fba6
3 changed files with 27 additions and 6 deletions
|
@ -216,6 +216,10 @@ func (acc *Account) averageLoop() {
|
|||
// Check the read before it has happened is valid returning the number
|
||||
// of bytes remaining to read.
|
||||
func (acc *Account) checkReadBefore() (bytesUntilLimit int64, err error) {
|
||||
// Check to see if context is cancelled
|
||||
if err = acc.ctx.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
acc.values.mu.Lock()
|
||||
if acc.values.max >= 0 {
|
||||
bytesUntilLimit = acc.values.max - acc.stats.GetBytes()
|
||||
|
@ -235,7 +239,7 @@ func (acc *Account) checkReadBefore() (bytesUntilLimit int64, err error) {
|
|||
}
|
||||
|
||||
// Check the read call after the read has happened
|
||||
func checkReadAfter(bytesUntilLimit int64, n int, err error) (outN int, outErr error) {
|
||||
func (acc *Account) checkReadAfter(bytesUntilLimit int64, n int, err error) (outN int, outErr error) {
|
||||
bytesUntilLimit -= int64(n)
|
||||
if bytesUntilLimit < 0 {
|
||||
// chop the overage off
|
||||
|
@ -304,7 +308,7 @@ func (acc *Account) read(in io.Reader, p []byte) (n int, err error) {
|
|||
if err == nil {
|
||||
n, err = in.Read(p)
|
||||
acc.accountRead(n)
|
||||
n, err = checkReadAfter(bytesUntilLimit, n, err)
|
||||
n, err = acc.checkReadAfter(bytesUntilLimit, n, err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
@ -333,7 +337,7 @@ func (awt *accountWriteTo) Write(p []byte) (n int, err error) {
|
|||
bytesUntilLimit, err := awt.acc.checkReadBefore()
|
||||
if err == nil {
|
||||
n, err = awt.w.Write(p)
|
||||
n, err = checkReadAfter(bytesUntilLimit, n, err)
|
||||
n, err = awt.acc.checkReadAfter(bytesUntilLimit, n, err)
|
||||
awt.acc.accountRead(n)
|
||||
}
|
||||
return n, err
|
||||
|
@ -361,7 +365,7 @@ func (acc *Account) AccountRead(n int) (err error) {
|
|||
defer acc.mu.Unlock()
|
||||
bytesUntilLimit, err := acc.checkReadBefore()
|
||||
if err == nil {
|
||||
n, err = checkReadAfter(bytesUntilLimit, n, err)
|
||||
n, err = acc.checkReadAfter(bytesUntilLimit, n, err)
|
||||
acc.accountRead(n)
|
||||
}
|
||||
return err
|
||||
|
|
|
@ -312,6 +312,25 @@ func TestAccountMaxTransferWriteTo(t *testing.T) {
|
|||
assert.Equal(t, ErrorMaxTransferLimitReachedFatal, err)
|
||||
}
|
||||
|
||||
func TestAccountReadCtx(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
in := ioutil.NopCloser(bytes.NewBuffer(make([]byte, 100)))
|
||||
stats := NewStats()
|
||||
acc := newAccountSizeName(ctx, stats, in, 1, "test")
|
||||
|
||||
var b = make([]byte, 10)
|
||||
|
||||
n, err := acc.Read(b)
|
||||
assert.Equal(t, 10, n)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cancel()
|
||||
|
||||
n, err = acc.Read(b)
|
||||
assert.Equal(t, 0, n)
|
||||
assert.Equal(t, context.Canceled, err)
|
||||
}
|
||||
|
||||
func TestShortenName(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
in string
|
||||
|
|
|
@ -1040,8 +1040,6 @@ func TestSyncWithMaxDuration(t *testing.T) {
|
|||
startTime := time.Now()
|
||||
err := Sync(context.Background(), r.Fremote, r.Flocal, false)
|
||||
require.Equal(t, context.DeadlineExceeded, errors.Cause(err))
|
||||
err = accounting.GlobalStats().GetLastError()
|
||||
require.NoError(t, err)
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
maxTransferTime := (time.Duration(len(testFiles)) * 60 * time.Second) / time.Duration(bytesPerSecond)
|
||||
|
|
Loading…
Reference in a new issue