lib/readers: add Seek method to PatternReader
This commit is contained in:
parent
7622506fe2
commit
7d0783aad5
2 changed files with 93 additions and 5 deletions
|
@ -1,29 +1,59 @@
|
||||||
package readers
|
package readers
|
||||||
|
|
||||||
import "io"
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is the smallest prime less than 256
|
||||||
|
//
|
||||||
|
// Using a prime here means we are less likely to hit repeating patterns
|
||||||
|
const patternReaderModulo = 251
|
||||||
|
|
||||||
// NewPatternReader creates a reader, that returns a deterministic byte pattern.
|
// NewPatternReader creates a reader, that returns a deterministic byte pattern.
|
||||||
// After length bytes are read
|
// After length bytes are read
|
||||||
func NewPatternReader(length int64) io.Reader {
|
func NewPatternReader(length int64) io.ReadSeeker {
|
||||||
return &patternReader{
|
return &patternReader{
|
||||||
length: length,
|
length: length,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type patternReader struct {
|
type patternReader struct {
|
||||||
|
offset int64
|
||||||
length int64
|
length int64
|
||||||
c byte
|
c byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *patternReader) Read(p []byte) (n int, err error) {
|
func (r *patternReader) Read(p []byte) (n int, err error) {
|
||||||
for i := range p {
|
for i := range p {
|
||||||
if r.length <= 0 {
|
if r.offset >= r.length {
|
||||||
return n, io.EOF
|
return n, io.EOF
|
||||||
}
|
}
|
||||||
p[i] = r.c
|
p[i] = r.c
|
||||||
r.c = (r.c + 1) % 253
|
r.c = (r.c + 1) % patternReaderModulo
|
||||||
r.length--
|
r.offset++
|
||||||
n++
|
n++
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Seek implements the io.Seeker interface.
|
||||||
|
func (r *patternReader) Seek(offset int64, whence int) (abs int64, err error) {
|
||||||
|
switch whence {
|
||||||
|
case io.SeekStart:
|
||||||
|
abs = offset
|
||||||
|
case io.SeekCurrent:
|
||||||
|
abs = r.offset + offset
|
||||||
|
case io.SeekEnd:
|
||||||
|
abs = r.length + offset
|
||||||
|
default:
|
||||||
|
return 0, errors.New("patternReader: invalid whence")
|
||||||
|
}
|
||||||
|
if abs < 0 {
|
||||||
|
return 0, errors.New("patternReader: negative position")
|
||||||
|
}
|
||||||
|
r.offset = abs
|
||||||
|
r.c = byte(abs % patternReaderModulo)
|
||||||
|
return abs, nil
|
||||||
|
}
|
||||||
|
|
|
@ -28,3 +28,61 @@ func TestPatternReader(t *testing.T) {
|
||||||
require.Equal(t, io.EOF, err)
|
require.Equal(t, io.EOF, err)
|
||||||
require.Equal(t, 0, n)
|
require.Equal(t, 0, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPatternReaderSeek(t *testing.T) {
|
||||||
|
r := NewPatternReader(1024)
|
||||||
|
b, err := ioutil.ReadAll(r)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for i := range b {
|
||||||
|
assert.Equal(t, byte(i%251), b[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := r.Seek(1, io.SeekStart)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), n)
|
||||||
|
|
||||||
|
// pos 1
|
||||||
|
|
||||||
|
b2 := make([]byte, 10)
|
||||||
|
nn, err := r.Read(b2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 10, nn)
|
||||||
|
assert.Equal(t, b[1:11], b2)
|
||||||
|
|
||||||
|
// pos 11
|
||||||
|
|
||||||
|
n, err = r.Seek(9, io.SeekCurrent)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(20), n)
|
||||||
|
|
||||||
|
// pos 20
|
||||||
|
|
||||||
|
nn, err = r.Read(b2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 10, nn)
|
||||||
|
assert.Equal(t, b[20:30], b2)
|
||||||
|
|
||||||
|
n, err = r.Seek(-24, io.SeekEnd)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1000), n)
|
||||||
|
|
||||||
|
// pos 1000
|
||||||
|
|
||||||
|
nn, err = r.Read(b2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 10, nn)
|
||||||
|
assert.Equal(t, b[1000:1010], b2)
|
||||||
|
|
||||||
|
// Now test errors
|
||||||
|
|
||||||
|
n, err = r.Seek(1, 400)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid whence")
|
||||||
|
assert.Equal(t, int64(0), n)
|
||||||
|
|
||||||
|
n, err = r.Seek(-1, io.SeekStart)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "negative position")
|
||||||
|
assert.Equal(t, int64(0), n)
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue