From fc57648b7533571a8c5965ddce58efa5a0e76688 Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Thu, 25 Mar 2021 15:35:08 +0000 Subject: [PATCH] lib/rest: fix multipart uploads stopping on context cancel Before this change when the context was cancelled (due to --max-duration for example) this could deadlock when uploading multipart uploads. This change fixes the problem by introducing another go routine to monitor the context and close the pipe with an error when the context errors. --- backend/pcloud/pcloud.go | 2 +- backend/seafile/webapi.go | 2 +- backend/zoho/zoho.go | 2 +- fs/operations/rc_test.go | 5 +++-- lib/rest/rest.go | 17 +++++++++++++++-- 5 files changed, 21 insertions(+), 7 deletions(-) diff --git a/backend/pcloud/pcloud.go b/backend/pcloud/pcloud.go index 36caf7bf3..3cb0d3aed 100644 --- a/backend/pcloud/pcloud.go +++ b/backend/pcloud/pcloud.go @@ -1137,7 +1137,7 @@ func (o *Object) Update(ctx context.Context, in io.Reader, src fs.ObjectInfo, op // opts.Body=0), so upload it as a multipart form POST with // Content-Length set. if size == 0 { - formReader, contentType, overhead, err := rest.MultipartUpload(in, opts.Parameters, "content", leaf) + formReader, contentType, overhead, err := rest.MultipartUpload(ctx, in, opts.Parameters, "content", leaf) if err != nil { return errors.Wrap(err, "failed to make multipart upload for 0 length file") } diff --git a/backend/seafile/webapi.go b/backend/seafile/webapi.go index f91a9ae2b..563cd1896 100644 --- a/backend/seafile/webapi.go +++ b/backend/seafile/webapi.go @@ -682,7 +682,7 @@ func (f *Fs) upload(ctx context.Context, in io.Reader, uploadLink, filePath stri "need_idx_progress": {"true"}, "replace": {"1"}, } - formReader, contentType, _, err := rest.MultipartUpload(in, parameters, "file", f.opt.Enc.FromStandardName(filename)) + formReader, contentType, _, err := rest.MultipartUpload(ctx, in, parameters, "file", f.opt.Enc.FromStandardName(filename)) if err != nil { return nil, errors.Wrap(err, "failed to make multipart upload") } diff --git a/backend/zoho/zoho.go b/backend/zoho/zoho.go index c68c170b3..f8a325469 100644 --- a/backend/zoho/zoho.go +++ b/backend/zoho/zoho.go @@ -647,7 +647,7 @@ func (f *Fs) upload(ctx context.Context, name string, parent string, size int64, params.Set("filename", name) params.Set("parent_id", parent) params.Set("override-name-exist", strconv.FormatBool(true)) - formReader, contentType, overhead, err := rest.MultipartUpload(in, nil, "content", name) + formReader, contentType, overhead, err := rest.MultipartUpload(ctx, in, nil, "content", name) if err != nil { return nil, errors.Wrap(err, "failed to make multipart upload") } diff --git a/fs/operations/rc_test.go b/fs/operations/rc_test.go index bf42a6f65..5420f3148 100644 --- a/fs/operations/rc_test.go +++ b/fs/operations/rc_test.go @@ -450,6 +450,7 @@ func TestRcFsInfo(t *testing.T) { func TestUploadFile(t *testing.T) { r, call := rcNewRun(t, "operations/uploadfile") defer r.Finalise() + ctx := context.Background() testFileName := "test.txt" testFileContent := "Hello World" @@ -460,7 +461,7 @@ func TestUploadFile(t *testing.T) { currentFile, err := os.Open(path.Join(r.LocalName, testFileName)) require.NoError(t, err) - formReader, contentType, _, err := rest.MultipartUpload(currentFile, url.Values{}, "file", testFileName) + formReader, contentType, _, err := rest.MultipartUpload(ctx, currentFile, url.Values{}, "file", testFileName) require.NoError(t, err) httpReq := httptest.NewRequest("POST", "/", formReader) @@ -482,7 +483,7 @@ func TestUploadFile(t *testing.T) { currentFile, err = os.Open(path.Join(r.LocalName, testFileName)) require.NoError(t, err) - formReader, contentType, _, err = rest.MultipartUpload(currentFile, url.Values{}, "file", testFileName) + formReader, contentType, _, err = rest.MultipartUpload(ctx, currentFile, url.Values{}, "file", testFileName) require.NoError(t, err) httpReq = httptest.NewRequest("POST", "/", formReader) diff --git a/lib/rest/rest.go b/lib/rest/rest.go index e72275a73..4490d617a 100644 --- a/lib/rest/rest.go +++ b/lib/rest/rest.go @@ -308,7 +308,7 @@ func (api *Client) Call(ctx context.Context, opts *Opts) (resp *http.Response, e // the int64 returned is the overhead in addition to the file contents, in case Content-Length is required // // NB This doesn't allow setting the content type of the attachment -func MultipartUpload(in io.Reader, params url.Values, contentName, fileName string) (io.ReadCloser, string, int64, error) { +func MultipartUpload(ctx context.Context, in io.Reader, params url.Values, contentName, fileName string) (io.ReadCloser, string, int64, error) { bodyReader, bodyWriter := io.Pipe() writer := multipart.NewWriter(bodyWriter) contentType := writer.FormDataContentType() @@ -343,8 +343,21 @@ func MultipartUpload(in io.Reader, params url.Values, contentName, fileName stri multipartLength := int64(buf.Len()) + // Make sure we close the pipe writer to release the reader on context cancel + quit := make(chan struct{}) + go func() { + select { + case <-quit: + break + case <-ctx.Done(): + _ = bodyWriter.CloseWithError(ctx.Err()) + } + }() + // Pump the data in the background go func() { + defer close(quit) + var err error for key, vals := range params { @@ -452,7 +465,7 @@ func (api *Client) callCodec(ctx context.Context, opts *Opts, request interface{ opts = opts.Copy() var overhead int64 - opts.Body, opts.ContentType, overhead, err = MultipartUpload(opts.Body, params, opts.MultipartContentName, opts.MultipartFileName) + opts.Body, opts.ContentType, overhead, err = MultipartUpload(ctx, opts.Body, params, opts.MultipartContentName, opts.MultipartFileName) if err != nil { return nil, err }