lib/rest: fix multipart uploads stopping on context cancel

Before this change when the context was cancelled (due to
--max-duration for example) this could deadlock when uploading
multipart uploads.

This change fixes the problem by introducing another go routine to
monitor the context and close the pipe with an error when the context
errors.
This commit is contained in:
Nick Craig-Wood 2021-03-25 15:35:08 +00:00
parent 8c5c91e68f
commit fc57648b75
5 changed files with 21 additions and 7 deletions

View file

@ -1137,7 +1137,7 @@ func (o *Object) Update(ctx context.Context, in io.Reader, src fs.ObjectInfo, op
// opts.Body=0), so upload it as a multipart form POST with // opts.Body=0), so upload it as a multipart form POST with
// Content-Length set. // Content-Length set.
if size == 0 { if size == 0 {
formReader, contentType, overhead, err := rest.MultipartUpload(in, opts.Parameters, "content", leaf) formReader, contentType, overhead, err := rest.MultipartUpload(ctx, in, opts.Parameters, "content", leaf)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to make multipart upload for 0 length file") return errors.Wrap(err, "failed to make multipart upload for 0 length file")
} }

View file

@ -682,7 +682,7 @@ func (f *Fs) upload(ctx context.Context, in io.Reader, uploadLink, filePath stri
"need_idx_progress": {"true"}, "need_idx_progress": {"true"},
"replace": {"1"}, "replace": {"1"},
} }
formReader, contentType, _, err := rest.MultipartUpload(in, parameters, "file", f.opt.Enc.FromStandardName(filename)) formReader, contentType, _, err := rest.MultipartUpload(ctx, in, parameters, "file", f.opt.Enc.FromStandardName(filename))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to make multipart upload") return nil, errors.Wrap(err, "failed to make multipart upload")
} }

View file

@ -647,7 +647,7 @@ func (f *Fs) upload(ctx context.Context, name string, parent string, size int64,
params.Set("filename", name) params.Set("filename", name)
params.Set("parent_id", parent) params.Set("parent_id", parent)
params.Set("override-name-exist", strconv.FormatBool(true)) params.Set("override-name-exist", strconv.FormatBool(true))
formReader, contentType, overhead, err := rest.MultipartUpload(in, nil, "content", name) formReader, contentType, overhead, err := rest.MultipartUpload(ctx, in, nil, "content", name)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to make multipart upload") return nil, errors.Wrap(err, "failed to make multipart upload")
} }

View file

@ -450,6 +450,7 @@ func TestRcFsInfo(t *testing.T) {
func TestUploadFile(t *testing.T) { func TestUploadFile(t *testing.T) {
r, call := rcNewRun(t, "operations/uploadfile") r, call := rcNewRun(t, "operations/uploadfile")
defer r.Finalise() defer r.Finalise()
ctx := context.Background()
testFileName := "test.txt" testFileName := "test.txt"
testFileContent := "Hello World" testFileContent := "Hello World"
@ -460,7 +461,7 @@ func TestUploadFile(t *testing.T) {
currentFile, err := os.Open(path.Join(r.LocalName, testFileName)) currentFile, err := os.Open(path.Join(r.LocalName, testFileName))
require.NoError(t, err) require.NoError(t, err)
formReader, contentType, _, err := rest.MultipartUpload(currentFile, url.Values{}, "file", testFileName) formReader, contentType, _, err := rest.MultipartUpload(ctx, currentFile, url.Values{}, "file", testFileName)
require.NoError(t, err) require.NoError(t, err)
httpReq := httptest.NewRequest("POST", "/", formReader) httpReq := httptest.NewRequest("POST", "/", formReader)
@ -482,7 +483,7 @@ func TestUploadFile(t *testing.T) {
currentFile, err = os.Open(path.Join(r.LocalName, testFileName)) currentFile, err = os.Open(path.Join(r.LocalName, testFileName))
require.NoError(t, err) require.NoError(t, err)
formReader, contentType, _, err = rest.MultipartUpload(currentFile, url.Values{}, "file", testFileName) formReader, contentType, _, err = rest.MultipartUpload(ctx, currentFile, url.Values{}, "file", testFileName)
require.NoError(t, err) require.NoError(t, err)
httpReq = httptest.NewRequest("POST", "/", formReader) httpReq = httptest.NewRequest("POST", "/", formReader)

View file

@ -308,7 +308,7 @@ func (api *Client) Call(ctx context.Context, opts *Opts) (resp *http.Response, e
// the int64 returned is the overhead in addition to the file contents, in case Content-Length is required // the int64 returned is the overhead in addition to the file contents, in case Content-Length is required
// //
// NB This doesn't allow setting the content type of the attachment // NB This doesn't allow setting the content type of the attachment
func MultipartUpload(in io.Reader, params url.Values, contentName, fileName string) (io.ReadCloser, string, int64, error) { func MultipartUpload(ctx context.Context, in io.Reader, params url.Values, contentName, fileName string) (io.ReadCloser, string, int64, error) {
bodyReader, bodyWriter := io.Pipe() bodyReader, bodyWriter := io.Pipe()
writer := multipart.NewWriter(bodyWriter) writer := multipart.NewWriter(bodyWriter)
contentType := writer.FormDataContentType() contentType := writer.FormDataContentType()
@ -343,8 +343,21 @@ func MultipartUpload(in io.Reader, params url.Values, contentName, fileName stri
multipartLength := int64(buf.Len()) multipartLength := int64(buf.Len())
// Make sure we close the pipe writer to release the reader on context cancel
quit := make(chan struct{})
go func() {
select {
case <-quit:
break
case <-ctx.Done():
_ = bodyWriter.CloseWithError(ctx.Err())
}
}()
// Pump the data in the background // Pump the data in the background
go func() { go func() {
defer close(quit)
var err error var err error
for key, vals := range params { for key, vals := range params {
@ -452,7 +465,7 @@ func (api *Client) callCodec(ctx context.Context, opts *Opts, request interface{
opts = opts.Copy() opts = opts.Copy()
var overhead int64 var overhead int64
opts.Body, opts.ContentType, overhead, err = MultipartUpload(opts.Body, params, opts.MultipartContentName, opts.MultipartFileName) opts.Body, opts.ContentType, overhead, err = MultipartUpload(ctx, opts.Body, params, opts.MultipartContentName, opts.MultipartFileName)
if err != nil { if err != nil {
return nil, err return nil, err
} }