395 lines
10 KiB
Go
395 lines
10 KiB
Go
package operations
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"testing"
|
|
|
|
"github.com/rclone/rclone/fs"
|
|
"github.com/rclone/rclone/fs/hash"
|
|
"github.com/rclone/rclone/fstest/mockobject"
|
|
"github.com/rclone/rclone/lib/pool"
|
|
"github.com/rclone/rclone/lib/readers"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// check interfaces
|
|
var (
|
|
_ io.ReadSeekCloser = (*ReOpen)(nil)
|
|
_ pool.DelayAccountinger = (*ReOpen)(nil)
|
|
)
|
|
|
|
var errorTestError = errors.New("test error")
|
|
|
|
// this is a wrapper for a mockobject with a custom Open function
|
|
//
|
|
// breaks indicate the number of bytes to read before returning an
|
|
// error
|
|
type reOpenTestObject struct {
|
|
fs.Object
|
|
t *testing.T
|
|
wantStart int64
|
|
breaks []int64
|
|
unknownSize bool
|
|
}
|
|
|
|
// Open opens the file for read. Call Close() on the returned io.ReadCloser
|
|
//
|
|
// This will break after reading the number of bytes in breaks
|
|
func (o *reOpenTestObject) Open(ctx context.Context, options ...fs.OpenOption) (io.ReadCloser, error) {
|
|
gotHash := false
|
|
gotRange := false
|
|
startPos := int64(0)
|
|
for _, option := range options {
|
|
switch x := option.(type) {
|
|
case *fs.HashesOption:
|
|
gotHash = true
|
|
case *fs.RangeOption:
|
|
gotRange = true
|
|
startPos = x.Start
|
|
if o.unknownSize {
|
|
assert.Equal(o.t, int64(-1), x.End)
|
|
}
|
|
case *fs.SeekOption:
|
|
startPos = x.Offset
|
|
}
|
|
}
|
|
assert.Equal(o.t, o.wantStart, startPos)
|
|
// Check if ranging, mustn't have hash if offset != 0
|
|
if gotHash && gotRange {
|
|
assert.Equal(o.t, int64(0), startPos)
|
|
}
|
|
rc, err := o.Object.Open(ctx, options...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(o.breaks) > 0 {
|
|
// Pop a breakpoint off
|
|
N := o.breaks[0]
|
|
o.breaks = o.breaks[1:]
|
|
o.wantStart += N
|
|
// If 0 then return an error immediately
|
|
if N == 0 {
|
|
return nil, errorTestError
|
|
}
|
|
// Read N bytes then an error
|
|
r := io.MultiReader(&io.LimitedReader{R: rc, N: N}, readers.ErrorReader{Err: errorTestError})
|
|
// Wrap with Close in a new readCloser
|
|
rc = readCloser{Reader: r, Closer: rc}
|
|
}
|
|
return rc, nil
|
|
}
|
|
|
|
func TestReOpen(t *testing.T) {
|
|
for _, testName := range []string{"Normal", "WithRangeOption", "WithSeekOption", "UnknownSize"} {
|
|
t.Run(testName, func(t *testing.T) {
|
|
// Contents for the mock object
|
|
var (
|
|
reOpenTestcontents = []byte("0123456789")
|
|
expectedRead = reOpenTestcontents
|
|
rangeOption *fs.RangeOption
|
|
seekOption *fs.SeekOption
|
|
unknownSize = false
|
|
)
|
|
switch testName {
|
|
case "Normal":
|
|
case "WithRangeOption":
|
|
rangeOption = &fs.RangeOption{Start: 1, End: 7} // range is inclusive
|
|
expectedRead = reOpenTestcontents[1:8]
|
|
case "WithSeekOption":
|
|
seekOption = &fs.SeekOption{Offset: 2}
|
|
expectedRead = reOpenTestcontents[2:]
|
|
case "UnknownSize":
|
|
rangeOption = &fs.RangeOption{Start: 1, End: -1}
|
|
expectedRead = reOpenTestcontents[1:]
|
|
unknownSize = true
|
|
default:
|
|
panic("bad test name")
|
|
}
|
|
|
|
// Start the test with the given breaks
|
|
testReOpen := func(breaks []int64, maxRetries int) (*ReOpen, *reOpenTestObject, error) {
|
|
srcOrig := mockobject.New("potato").WithContent(reOpenTestcontents, mockobject.SeekModeNone)
|
|
srcOrig.SetUnknownSize(unknownSize)
|
|
src := &reOpenTestObject{
|
|
Object: srcOrig,
|
|
t: t,
|
|
breaks: breaks,
|
|
unknownSize: unknownSize,
|
|
}
|
|
opts := []fs.OpenOption{}
|
|
if rangeOption == nil && seekOption == nil {
|
|
opts = append(opts, &fs.HashesOption{Hashes: hash.NewHashSet(hash.MD5)})
|
|
}
|
|
if rangeOption != nil {
|
|
opts = append(opts, rangeOption)
|
|
src.wantStart = rangeOption.Start
|
|
}
|
|
if seekOption != nil {
|
|
opts = append(opts, seekOption)
|
|
src.wantStart = seekOption.Offset
|
|
}
|
|
rc, err := NewReOpen(context.Background(), src, maxRetries, opts...)
|
|
return rc, src, err
|
|
}
|
|
|
|
t.Run("Basics", func(t *testing.T) {
|
|
// open
|
|
h, _, err := testReOpen(nil, 10)
|
|
assert.NoError(t, err)
|
|
|
|
// Check contents read correctly
|
|
got, err := io.ReadAll(h)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expectedRead, got)
|
|
|
|
// Check read after end
|
|
var buf = make([]byte, 1)
|
|
n, err := h.Read(buf)
|
|
assert.Equal(t, 0, n)
|
|
assert.Equal(t, io.EOF, err)
|
|
|
|
// Rewind the stream
|
|
_, err = h.Seek(0, io.SeekStart)
|
|
require.NoError(t, err)
|
|
|
|
// Check contents read correctly
|
|
got, err = io.ReadAll(h)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expectedRead, got)
|
|
|
|
// Check close
|
|
assert.NoError(t, h.Close())
|
|
|
|
// Check double close
|
|
assert.Equal(t, errFileClosed, h.Close())
|
|
|
|
// Check read after close
|
|
n, err = h.Read(buf)
|
|
assert.Equal(t, 0, n)
|
|
assert.Equal(t, errFileClosed, err)
|
|
})
|
|
|
|
t.Run("ErrorAtStart", func(t *testing.T) {
|
|
// open with immediate breaking
|
|
h, _, err := testReOpen([]int64{0}, 10)
|
|
assert.Equal(t, errorTestError, err)
|
|
assert.Nil(t, h)
|
|
})
|
|
|
|
t.Run("WithErrors", func(t *testing.T) {
|
|
// open with a few break points but less than the max
|
|
h, _, err := testReOpen([]int64{2, 1, 3}, 10)
|
|
assert.NoError(t, err)
|
|
|
|
// check contents
|
|
got, err := io.ReadAll(h)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expectedRead, got)
|
|
|
|
// check close
|
|
assert.NoError(t, h.Close())
|
|
})
|
|
|
|
t.Run("TooManyErrors", func(t *testing.T) {
|
|
// open with a few break points but >= the max
|
|
h, _, err := testReOpen([]int64{2, 1, 3}, 3)
|
|
assert.NoError(t, err)
|
|
|
|
// check contents
|
|
got, err := io.ReadAll(h)
|
|
assert.Equal(t, errorTestError, err)
|
|
assert.Equal(t, expectedRead[:6], got)
|
|
|
|
// check old error is returned
|
|
var buf = make([]byte, 1)
|
|
n, err := h.Read(buf)
|
|
assert.Equal(t, 0, n)
|
|
assert.Equal(t, errTooManyTries, err)
|
|
|
|
// Check close
|
|
assert.Equal(t, errFileClosed, h.Close())
|
|
})
|
|
|
|
t.Run("Seek", func(t *testing.T) {
|
|
// open
|
|
h, src, err := testReOpen([]int64{2, 1, 3}, 10)
|
|
assert.NoError(t, err)
|
|
|
|
// Seek to end
|
|
pos, err := h.Seek(int64(len(expectedRead)), io.SeekStart)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int64(len(expectedRead)), pos)
|
|
|
|
// Seek to start
|
|
pos, err = h.Seek(0, io.SeekStart)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int64(0), pos)
|
|
|
|
// Should not allow seek past end
|
|
pos, err = h.Seek(int64(len(expectedRead))+1, io.SeekCurrent)
|
|
if !unknownSize {
|
|
assert.Equal(t, errSeekPastEnd, err)
|
|
assert.Equal(t, len(expectedRead), int(pos))
|
|
} else {
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, len(expectedRead)+1, int(pos))
|
|
|
|
// Seek back to start to get tests in sync
|
|
pos, err = h.Seek(0, io.SeekStart)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, int64(0), pos)
|
|
}
|
|
|
|
// Should not allow seek to negative position start
|
|
pos, err = h.Seek(-1, io.SeekCurrent)
|
|
assert.Equal(t, errNegativeSeek, err)
|
|
assert.Equal(t, 0, int(pos))
|
|
|
|
// Should not allow seek with invalid whence
|
|
pos, err = h.Seek(0, 3)
|
|
assert.Equal(t, errInvalidWhence, err)
|
|
assert.Equal(t, 0, int(pos))
|
|
|
|
// check read
|
|
dst := make([]byte, 5)
|
|
n, err := h.Read(dst)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 5, n)
|
|
assert.Equal(t, expectedRead[:5], dst)
|
|
|
|
// Test io.SeekCurrent
|
|
pos, err = h.Seek(-3, io.SeekCurrent)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 2, int(pos))
|
|
|
|
// Reset the start after a seek, taking into account the offset
|
|
setWantStart := func(x int64) {
|
|
src.wantStart = x
|
|
if rangeOption != nil {
|
|
src.wantStart += rangeOption.Start
|
|
} else if seekOption != nil {
|
|
src.wantStart += seekOption.Offset
|
|
}
|
|
}
|
|
|
|
// check read
|
|
setWantStart(2)
|
|
n, err = h.Read(dst)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 5, n)
|
|
assert.Equal(t, expectedRead[2:7], dst)
|
|
|
|
pos, err = h.Seek(-2, io.SeekCurrent)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 5, int(pos))
|
|
|
|
// Test io.SeekEnd
|
|
pos, err = h.Seek(-3, io.SeekEnd)
|
|
if !unknownSize {
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, len(expectedRead)-3, int(pos))
|
|
} else {
|
|
assert.Equal(t, errBadEndSeek, err)
|
|
assert.Equal(t, 0, int(pos))
|
|
|
|
// sync
|
|
pos, err = h.Seek(1, io.SeekCurrent)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 6, int(pos))
|
|
}
|
|
|
|
// check read
|
|
dst = make([]byte, 3)
|
|
setWantStart(int64(len(expectedRead) - 3))
|
|
n, err = h.Read(dst)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, 3, n)
|
|
assert.Equal(t, expectedRead[len(expectedRead)-3:], dst)
|
|
|
|
// check close
|
|
assert.NoError(t, h.Close())
|
|
_, err = h.Seek(0, io.SeekCurrent)
|
|
assert.Equal(t, errFileClosed, err)
|
|
})
|
|
|
|
t.Run("AccountRead", func(t *testing.T) {
|
|
h, _, err := testReOpen(nil, 10)
|
|
assert.NoError(t, err)
|
|
|
|
var total int
|
|
h.SetAccounting(func(n int) error {
|
|
total += n
|
|
return nil
|
|
})
|
|
|
|
dst := make([]byte, 3)
|
|
n, err := h.Read(dst)
|
|
assert.Equal(t, 3, n)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 3, total)
|
|
})
|
|
|
|
t.Run("AccountReadDelay", func(t *testing.T) {
|
|
h, _, err := testReOpen(nil, 10)
|
|
assert.NoError(t, err)
|
|
|
|
var total int
|
|
h.SetAccounting(func(n int) error {
|
|
total += n
|
|
return nil
|
|
})
|
|
|
|
rewind := func() {
|
|
_, err := h.Seek(0, io.SeekStart)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
h.DelayAccounting(3)
|
|
|
|
dst := make([]byte, 16)
|
|
|
|
n, err := h.Read(dst)
|
|
assert.Equal(t, len(expectedRead), n)
|
|
assert.Equal(t, io.EOF, err)
|
|
assert.Equal(t, 0, total)
|
|
rewind()
|
|
|
|
n, err = h.Read(dst)
|
|
assert.Equal(t, len(expectedRead), n)
|
|
assert.Equal(t, io.EOF, err)
|
|
assert.Equal(t, 0, total)
|
|
rewind()
|
|
|
|
n, err = h.Read(dst)
|
|
assert.Equal(t, len(expectedRead), n)
|
|
assert.Equal(t, io.EOF, err)
|
|
assert.Equal(t, len(expectedRead), total)
|
|
rewind()
|
|
|
|
n, err = h.Read(dst)
|
|
assert.Equal(t, len(expectedRead), n)
|
|
assert.Equal(t, io.EOF, err)
|
|
assert.Equal(t, 2*len(expectedRead), total)
|
|
rewind()
|
|
})
|
|
|
|
t.Run("AccountReadError", func(t *testing.T) {
|
|
// Test accounting errors
|
|
h, _, err := testReOpen(nil, 10)
|
|
assert.NoError(t, err)
|
|
|
|
h.SetAccounting(func(n int) error {
|
|
return errorTestError
|
|
})
|
|
|
|
dst := make([]byte, 3)
|
|
n, err := h.Read(dst)
|
|
assert.Equal(t, 3, n)
|
|
assert.Equal(t, errorTestError, err)
|
|
})
|
|
})
|
|
}
|
|
}
|