diff --git a/backend/pcloud/pcloud.go b/backend/pcloud/pcloud.go index 36caf7bf3..3cb0d3aed 100644 --- a/backend/pcloud/pcloud.go +++ b/backend/pcloud/pcloud.go @@ -1137,7 +1137,7 @@ func (o *Object) Update(ctx context.Context, in io.Reader, src fs.ObjectInfo, op // opts.Body=0), so upload it as a multipart form POST with // Content-Length set. if size == 0 { - formReader, contentType, overhead, err := rest.MultipartUpload(in, opts.Parameters, "content", leaf) + formReader, contentType, overhead, err := rest.MultipartUpload(ctx, in, opts.Parameters, "content", leaf) if err != nil { return errors.Wrap(err, "failed to make multipart upload for 0 length file") } diff --git a/backend/seafile/webapi.go b/backend/seafile/webapi.go index f91a9ae2b..563cd1896 100644 --- a/backend/seafile/webapi.go +++ b/backend/seafile/webapi.go @@ -682,7 +682,7 @@ func (f *Fs) upload(ctx context.Context, in io.Reader, uploadLink, filePath stri "need_idx_progress": {"true"}, "replace": {"1"}, } - formReader, contentType, _, err := rest.MultipartUpload(in, parameters, "file", f.opt.Enc.FromStandardName(filename)) + formReader, contentType, _, err := rest.MultipartUpload(ctx, in, parameters, "file", f.opt.Enc.FromStandardName(filename)) if err != nil { return nil, errors.Wrap(err, "failed to make multipart upload") } diff --git a/backend/zoho/zoho.go b/backend/zoho/zoho.go index c68c170b3..f8a325469 100644 --- a/backend/zoho/zoho.go +++ b/backend/zoho/zoho.go @@ -647,7 +647,7 @@ func (f *Fs) upload(ctx context.Context, name string, parent string, size int64, params.Set("filename", name) params.Set("parent_id", parent) params.Set("override-name-exist", strconv.FormatBool(true)) - formReader, contentType, overhead, err := rest.MultipartUpload(in, nil, "content", name) + formReader, contentType, overhead, err := rest.MultipartUpload(ctx, in, nil, "content", name) if err != nil { return nil, errors.Wrap(err, "failed to make multipart upload") } diff --git a/fs/operations/rc_test.go b/fs/operations/rc_test.go index bf42a6f65..5420f3148 100644 --- a/fs/operations/rc_test.go +++ b/fs/operations/rc_test.go @@ -450,6 +450,7 @@ func TestRcFsInfo(t *testing.T) { func TestUploadFile(t *testing.T) { r, call := rcNewRun(t, "operations/uploadfile") defer r.Finalise() + ctx := context.Background() testFileName := "test.txt" testFileContent := "Hello World" @@ -460,7 +461,7 @@ func TestUploadFile(t *testing.T) { currentFile, err := os.Open(path.Join(r.LocalName, testFileName)) require.NoError(t, err) - formReader, contentType, _, err := rest.MultipartUpload(currentFile, url.Values{}, "file", testFileName) + formReader, contentType, _, err := rest.MultipartUpload(ctx, currentFile, url.Values{}, "file", testFileName) require.NoError(t, err) httpReq := httptest.NewRequest("POST", "/", formReader) @@ -482,7 +483,7 @@ func TestUploadFile(t *testing.T) { currentFile, err = os.Open(path.Join(r.LocalName, testFileName)) require.NoError(t, err) - formReader, contentType, _, err = rest.MultipartUpload(currentFile, url.Values{}, "file", testFileName) + formReader, contentType, _, err = rest.MultipartUpload(ctx, currentFile, url.Values{}, "file", testFileName) require.NoError(t, err) httpReq = httptest.NewRequest("POST", "/", formReader) diff --git a/lib/rest/rest.go b/lib/rest/rest.go index e72275a73..4490d617a 100644 --- a/lib/rest/rest.go +++ b/lib/rest/rest.go @@ -308,7 +308,7 @@ func (api *Client) Call(ctx context.Context, opts *Opts) (resp *http.Response, e // the int64 returned is the overhead in addition to the file contents, in case Content-Length is required // // NB This doesn't allow setting the content type of the attachment -func MultipartUpload(in io.Reader, params url.Values, contentName, fileName string) (io.ReadCloser, string, int64, error) { +func MultipartUpload(ctx context.Context, in io.Reader, params url.Values, contentName, fileName string) (io.ReadCloser, string, int64, error) { bodyReader, bodyWriter := io.Pipe() writer := multipart.NewWriter(bodyWriter) contentType := writer.FormDataContentType() @@ -343,8 +343,21 @@ func MultipartUpload(in io.Reader, params url.Values, contentName, fileName stri multipartLength := int64(buf.Len()) + // Make sure we close the pipe writer to release the reader on context cancel + quit := make(chan struct{}) + go func() { + select { + case <-quit: + break + case <-ctx.Done(): + _ = bodyWriter.CloseWithError(ctx.Err()) + } + }() + // Pump the data in the background go func() { + defer close(quit) + var err error for key, vals := range params { @@ -452,7 +465,7 @@ func (api *Client) callCodec(ctx context.Context, opts *Opts, request interface{ opts = opts.Copy() var overhead int64 - opts.Body, opts.ContentType, overhead, err = MultipartUpload(opts.Body, params, opts.MultipartContentName, opts.MultipartFileName) + opts.Body, opts.ContentType, overhead, err = MultipartUpload(ctx, opts.Body, params, opts.MultipartContentName, opts.MultipartFileName) if err != nil { return nil, err }