diff --git a/internal/limiter/limiter.go b/internal/limiter/limiter.go index 410bc7f64..8cbe297fe 100644 --- a/internal/limiter/limiter.go +++ b/internal/limiter/limiter.go @@ -20,6 +20,10 @@ type Limiter interface { // for downloads. Downstream(r io.Reader) io.Reader + // Downstream returns a rate limited reader that is intended to be used + // for downloads. + DownstreamWriter(r io.Writer) io.Writer + // Transport returns an http.RoundTripper limited with the limiter. Transport(http.RoundTripper) http.RoundTripper } diff --git a/internal/limiter/limiter_backend.go b/internal/limiter/limiter_backend.go index b2351a8fd..d074a5a0e 100644 --- a/internal/limiter/limiter_backend.go +++ b/internal/limiter/limiter_backend.go @@ -42,20 +42,34 @@ func (l limitedRewindReader) Read(b []byte) (int, error) { func (r rateLimitedBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64, consumer func(rd io.Reader) error) error { return r.Backend.Load(ctx, h, length, offset, func(rd io.Reader) error { - lrd := limitedReadCloser{ - limited: r.limiter.Downstream(rd), - } - return consumer(lrd) + return consumer(newDownstreamLimitedReadCloser(rd, r.limiter, nil)) }) } type limitedReadCloser struct { + io.Reader original io.ReadCloser - limited io.Reader } -func (l limitedReadCloser) Read(b []byte) (n int, err error) { - return l.limited.Read(b) +type limitedReadWriteToCloser struct { + limitedReadCloser + writerTo io.WriterTo + limiter Limiter +} + +func newDownstreamLimitedReadCloser(rd io.Reader, limiter Limiter, original io.ReadCloser) io.ReadCloser { + lrd := limitedReadCloser{ + Reader: limiter.Downstream(rd), + original: original, + } + if _, ok := rd.(io.WriterTo); ok { + return &limitedReadWriteToCloser{ + limitedReadCloser: lrd, + writerTo: rd.(io.WriterTo), + limiter: limiter, + } + } + return &lrd } func (l limitedReadCloser) Close() error { @@ -65,4 +79,8 @@ func (l limitedReadCloser) Close() error { return l.original.Close() } +func (l limitedReadWriteToCloser) WriteTo(w io.Writer) (int64, error) { + return l.writerTo.WriteTo(l.limiter.DownstreamWriter(w)) +} + var _ restic.Backend = (*rateLimitedBackend)(nil) diff --git a/internal/limiter/static_limiter.go b/internal/limiter/static_limiter.go index 5df7a84da..e9b2b8285 100644 --- a/internal/limiter/static_limiter.go +++ b/internal/limiter/static_limiter.go @@ -46,6 +46,10 @@ func (l staticLimiter) Downstream(r io.Reader) io.Reader { return l.limitReader(r, l.downstream) } +func (l staticLimiter) DownstreamWriter(w io.Writer) io.Writer { + return l.limitWriter(w, l.downstream) +} + type roundTripper func(*http.Request) (*http.Response, error) func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -55,7 +59,7 @@ func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*http.Response, error) { if req.Body != nil { req.Body = limitedReadCloser{ - limited: l.Upstream(req.Body), + Reader: l.Upstream(req.Body), original: req.Body, } } @@ -64,7 +68,7 @@ func (l staticLimiter) roundTripper(rt http.RoundTripper, req *http.Request) (*h if res != nil && res.Body != nil { res.Body = limitedReadCloser{ - limited: l.Downstream(res.Body), + Reader: l.Downstream(res.Body), original: res.Body, } }