forked from TrueCloudLab/rclone
lib/readers: add FakeSeeker to adapt io.Reader to io.ReadSeeker #5422
This commit is contained in:
parent
389a29b017
commit
50a0c3482d
2 changed files with 159 additions and 0 deletions
72
lib/readers/fakeseeker.go
Normal file
72
lib/readers/fakeseeker.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package readers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// FakeSeeker adapts an io.Seeker into an io.ReadSeeker
|
||||
type FakeSeeker struct {
|
||||
in io.Reader
|
||||
readErr error
|
||||
length int64
|
||||
offset int64
|
||||
read bool
|
||||
}
|
||||
|
||||
// NewFakeSeeker creates a fake io.ReadSeeker from an io.Reader
|
||||
//
|
||||
// This can be seeked before reading to discover the length passed in.
|
||||
func NewFakeSeeker(in io.Reader, length int64) io.ReadSeeker {
|
||||
if rs, ok := in.(io.ReadSeeker); ok {
|
||||
return rs
|
||||
}
|
||||
return &FakeSeeker{
|
||||
in: in,
|
||||
length: length,
|
||||
}
|
||||
}
|
||||
|
||||
// Seek the stream - possible only before reading
|
||||
func (r *FakeSeeker) Seek(offset int64, whence int) (abs int64, err error) {
|
||||
if r.readErr != nil {
|
||||
return 0, r.readErr
|
||||
}
|
||||
if r.read {
|
||||
return 0, fmt.Errorf("FakeSeeker: can't Seek(%d, %d) after reading", offset, whence)
|
||||
}
|
||||
switch whence {
|
||||
case io.SeekStart:
|
||||
abs = offset
|
||||
case io.SeekCurrent:
|
||||
abs = r.offset + offset
|
||||
case io.SeekEnd:
|
||||
abs = r.length + offset
|
||||
default:
|
||||
return 0, errors.New("FakeSeeker: invalid whence")
|
||||
}
|
||||
if abs < 0 {
|
||||
return 0, errors.New("FakeSeeker: negative position")
|
||||
}
|
||||
r.offset = abs
|
||||
return abs, nil
|
||||
}
|
||||
|
||||
// Read data from the stream. Will give an error if seeked.
|
||||
func (r *FakeSeeker) Read(p []byte) (n int, err error) {
|
||||
if r.readErr != nil {
|
||||
return 0, r.readErr
|
||||
}
|
||||
if !r.read && r.offset != 0 {
|
||||
return 0, errors.New("FakeSeeker: not at start: can't read")
|
||||
}
|
||||
n, err = r.in.Read(p)
|
||||
if n != 0 {
|
||||
r.read = true
|
||||
}
|
||||
if err != nil {
|
||||
r.readErr = err
|
||||
}
|
||||
return n, err
|
||||
}
|
87
lib/readers/fakeseeker_test.go
Normal file
87
lib/readers/fakeseeker_test.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package readers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Check interface
|
||||
var _ io.ReadSeeker = &FakeSeeker{}
|
||||
|
||||
func TestFakeSeeker(t *testing.T) {
|
||||
// Test that passing in an io.ReadSeeker just passes it through
|
||||
bufReader := bytes.NewReader([]byte{1})
|
||||
r := NewFakeSeeker(bufReader, 5)
|
||||
assert.Equal(t, r, bufReader)
|
||||
|
||||
in := bytes.NewBufferString("hello")
|
||||
buf := make([]byte, 16)
|
||||
r = NewFakeSeeker(in, 5)
|
||||
assert.NotEqual(t, r, in)
|
||||
|
||||
// check the seek offset is as passed in
|
||||
checkPos := func(pos int64) {
|
||||
abs, err := r.Seek(0, io.SeekCurrent)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pos, abs)
|
||||
}
|
||||
|
||||
// Test some seeking
|
||||
checkPos(0)
|
||||
|
||||
abs, err := r.Seek(2, io.SeekStart)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), abs)
|
||||
checkPos(2)
|
||||
|
||||
abs, err = r.Seek(-1, io.SeekEnd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(4), abs)
|
||||
checkPos(4)
|
||||
|
||||
// Check can't read if not at start
|
||||
_, err = r.Read(buf)
|
||||
require.ErrorContains(t, err, "not at start")
|
||||
|
||||
// Seek back to start
|
||||
abs, err = r.Seek(-4, io.SeekCurrent)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), abs)
|
||||
checkPos(0)
|
||||
|
||||
_, err = r.Seek(42, 17)
|
||||
require.ErrorContains(t, err, "invalid whence")
|
||||
|
||||
_, err = r.Seek(-1, io.SeekStart)
|
||||
require.ErrorContains(t, err, "negative position")
|
||||
|
||||
// Test reading now seeked back to the start
|
||||
n, err := r.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, n)
|
||||
assert.Equal(t, []byte("hello"), buf[:5])
|
||||
|
||||
// Seeking should give an error now
|
||||
_, err = r.Seek(-1, io.SeekEnd)
|
||||
require.ErrorContains(t, err, "after reading")
|
||||
}
|
||||
|
||||
func TestFakeSeekerError(t *testing.T) {
|
||||
in := bytes.NewBufferString("hello")
|
||||
r := NewFakeSeeker(in, 5)
|
||||
assert.NotEqual(t, r, in)
|
||||
|
||||
buf, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("hello"), buf)
|
||||
|
||||
_, err = r.Read(buf)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
|
||||
_, err = r.Seek(0, io.SeekStart)
|
||||
assert.Equal(t, io.EOF, err)
|
||||
}
|
Loading…
Reference in a new issue