lib/readers: add Seek method to PatternReader

This commit is contained in:
Nick Craig-Wood 2020-06-13 10:35:36 +01:00
parent 7622506fe2
commit 7d0783aad5
2 changed files with 93 additions and 5 deletions

View file

@ -1,29 +1,59 @@
package readers
import "io"
import (
"io"
"github.com/pkg/errors"
)
// This is the smallest prime less than 256
//
// Using a prime here means we are less likely to hit repeating patterns
const patternReaderModulo = 251
// NewPatternReader creates a reader, that returns a deterministic byte pattern.
// After length bytes are read
func NewPatternReader(length int64) io.Reader {
func NewPatternReader(length int64) io.ReadSeeker {
return &patternReader{
length: length,
}
}
type patternReader struct {
offset int64
length int64
c byte
}
func (r *patternReader) Read(p []byte) (n int, err error) {
for i := range p {
if r.length <= 0 {
if r.offset >= r.length {
return n, io.EOF
}
p[i] = r.c
r.c = (r.c + 1) % 253
r.length--
r.c = (r.c + 1) % patternReaderModulo
r.offset++
n++
}
return
}
// Seek implements the io.Seeker interface.
func (r *patternReader) Seek(offset int64, whence int) (abs int64, err error) {
switch whence {
case io.SeekStart:
abs = offset
case io.SeekCurrent:
abs = r.offset + offset
case io.SeekEnd:
abs = r.length + offset
default:
return 0, errors.New("patternReader: invalid whence")
}
if abs < 0 {
return 0, errors.New("patternReader: negative position")
}
r.offset = abs
r.c = byte(abs % patternReaderModulo)
return abs, nil
}

View file

@ -28,3 +28,61 @@ func TestPatternReader(t *testing.T) {
require.Equal(t, io.EOF, err)
require.Equal(t, 0, n)
}
func TestPatternReaderSeek(t *testing.T) {
r := NewPatternReader(1024)
b, err := ioutil.ReadAll(r)
require.NoError(t, err)
for i := range b {
assert.Equal(t, byte(i%251), b[i])
}
n, err := r.Seek(1, io.SeekStart)
require.NoError(t, err)
assert.Equal(t, int64(1), n)
// pos 1
b2 := make([]byte, 10)
nn, err := r.Read(b2)
require.NoError(t, err)
assert.Equal(t, 10, nn)
assert.Equal(t, b[1:11], b2)
// pos 11
n, err = r.Seek(9, io.SeekCurrent)
require.NoError(t, err)
assert.Equal(t, int64(20), n)
// pos 20
nn, err = r.Read(b2)
require.NoError(t, err)
assert.Equal(t, 10, nn)
assert.Equal(t, b[20:30], b2)
n, err = r.Seek(-24, io.SeekEnd)
require.NoError(t, err)
assert.Equal(t, int64(1000), n)
// pos 1000
nn, err = r.Read(b2)
require.NoError(t, err)
assert.Equal(t, 10, nn)
assert.Equal(t, b[1000:1010], b2)
// Now test errors
n, err = r.Seek(1, 400)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid whence")
assert.Equal(t, int64(0), n)
n, err = r.Seek(-1, io.SeekStart)
require.Error(t, err)
assert.Contains(t, err.Error(), "negative position")
assert.Equal(t, int64(0), n)
}