diff --git a/backend/union/entry.go b/backend/union/entry.go index 2f5b38b0d..93184399c 100644 --- a/backend/union/entry.go +++ b/backend/union/entry.go @@ -1,7 +1,6 @@ package union import ( - "bufio" "context" "io" "sync" @@ -67,31 +66,9 @@ func (o *Object) Update(ctx context.Context, in io.Reader, src fs.ObjectInfo, op obj := entries[0].(*upstream.Object) return obj.Update(ctx, in, src, options...) } - // Get multiple reader - readers := make([]io.Reader, len(entries)) - writers := make([]io.Writer, len(entries)) - errs := Errors(make([]error, len(entries)+1)) - for i := range entries { - r, w := io.Pipe() - bw := bufio.NewWriter(w) - readers[i], writers[i] = r, bw - defer func() { - err := w.Close() - if err != nil { - panic(err) - } - }() - } - go func() { - mw := io.MultiWriter(writers...) - es := make([]error, len(writers)+1) - _, es[len(es)-1] = io.Copy(mw, in) - for i, bw := range writers { - es[i] = bw.(*bufio.Writer).Flush() - } - errs[len(entries)] = Errors(es).Err() - }() // Multi-threading + readers, errChan := multiReader(len(entries), in) + errs := Errors(make([]error, len(entries)+1)) multithread(len(entries), func(i int) { if o, ok := entries[i].(*upstream.Object); ok { err := o.Update(ctx, readers[i], src, options...) @@ -100,6 +77,7 @@ func (o *Object) Update(ctx context.Context, in io.Reader, src fs.ObjectInfo, op errs[i] = fs.ErrorNotAFile } }) + errs[len(entries)] = <-errChan return errs.Err() } diff --git a/backend/union/union.go b/backend/union/union.go index 1c178d9ab..17e988956 100644 --- a/backend/union/union.go +++ b/backend/union/union.go @@ -385,6 +385,37 @@ func (f *Fs) DirCacheFlush() { }) } +// Tee in into n outputs +// +// When finished read the error from the channel +func multiReader(n int, in io.Reader) ([]io.Reader, <-chan error) { + readers := make([]io.Reader, n) + pipeWriters := make([]*io.PipeWriter, n) + writers := make([]io.Writer, n) + errChan := make(chan error, 1) + for i := range writers { + r, w := io.Pipe() + bw := bufio.NewWriter(w) + readers[i], pipeWriters[i], writers[i] = r, w, bw + } + go func() { + mw := io.MultiWriter(writers...) + es := make([]error, 2*n+1) + _, copyErr := io.Copy(mw, in) + es[2*n] = copyErr + // Flush the buffers + for i, bw := range writers { + es[i] = bw.(*bufio.Writer).Flush() + } + // Close the underlying pipes + for i, pw := range pipeWriters { + es[2*i] = pw.CloseWithError(copyErr) + } + errChan <- Errors(es).Err() + }() + return readers, errChan +} + func (f *Fs) put(ctx context.Context, in io.Reader, src fs.ObjectInfo, stream bool, options ...fs.OpenOption) (fs.Object, error) { srcPath := src.Remote() upstreams, err := f.create(ctx, srcPath) @@ -412,31 +443,9 @@ func (f *Fs) put(ctx context.Context, in io.Reader, src fs.ObjectInfo, stream bo e, err := f.wrapEntries(u.WrapObject(o)) return e.(*Object), err } - errs := Errors(make([]error, len(upstreams)+1)) - // Get multiple reader - readers := make([]io.Reader, len(upstreams)) - writers := make([]io.Writer, len(upstreams)) - for i := range writers { - r, w := io.Pipe() - bw := bufio.NewWriter(w) - readers[i], writers[i] = r, bw - defer func() { - err := w.Close() - if err != nil { - panic(err) - } - }() - } - go func() { - mw := io.MultiWriter(writers...) - es := make([]error, len(writers)+1) - _, es[len(es)-1] = io.Copy(mw, in) - for i, bw := range writers { - es[i] = bw.(*bufio.Writer).Flush() - } - errs[len(upstreams)] = Errors(es).Err() - }() // Multi-threading + readers, errChan := multiReader(len(upstreams), in) + errs := Errors(make([]error, len(upstreams)+1)) objs := make([]upstream.Entry, len(upstreams)) multithread(len(upstreams), func(i int) { u := upstreams[i] @@ -453,6 +462,7 @@ func (f *Fs) put(ctx context.Context, in io.Reader, src fs.ObjectInfo, stream bo } objs[i] = u.WrapObject(o) }) + errs[len(upstreams)] = <-errChan err = errs.Err() if err != nil { return nil, err diff --git a/backend/union/union_test.go b/backend/union/union_test.go index 98d34a68b..676088220 100644 --- a/backend/union/union_test.go +++ b/backend/union/union_test.go @@ -153,3 +153,29 @@ func TestPolicy2(t *testing.T) { UnimplementableObjectMethods: []string{"MimeType"}, }) } + +func TestPolicy3(t *testing.T) { + if *fstest.RemoteName != "" { + t.Skip("Skipping as -remote set") + } + tempdir1 := filepath.Join(os.TempDir(), "rclone-union-test-policy31") + tempdir2 := filepath.Join(os.TempDir(), "rclone-union-test-policy32") + tempdir3 := filepath.Join(os.TempDir(), "rclone-union-test-policy33") + require.NoError(t, os.MkdirAll(tempdir1, 0744)) + require.NoError(t, os.MkdirAll(tempdir2, 0744)) + require.NoError(t, os.MkdirAll(tempdir3, 0744)) + upstreams := tempdir1 + " " + tempdir2 + " " + tempdir3 + name := "TestUnionPolicy3" + fstests.Run(t, &fstests.Opt{ + RemoteName: name + ":", + ExtraConfig: []fstests.ExtraConfigItem{ + {Name: name, Key: "type", Value: "union"}, + {Name: name, Key: "upstreams", Value: upstreams}, + {Name: name, Key: "action_policy", Value: "all"}, + {Name: name, Key: "create_policy", Value: "all"}, + {Name: name, Key: "search_policy", Value: "all"}, + }, + UnimplementableFsMethods: []string{"OpenWriterAt", "DuplicateFiles"}, + UnimplementableObjectMethods: []string{"MimeType"}, + }) +}