restic/internal/restic/rewind_reader_test.go
2018-03-04 10:40:42 +01:00

154 lines
3 KiB
Go

package restic
import (
"bytes"
"io"
"io/ioutil"
"math/rand"
"os"
"path/filepath"
"testing"
"time"
"github.com/restic/restic/internal/test"
)
func TestByteReader(t *testing.T) {
buf := []byte("foobar")
fn := func() RewindReader {
return NewByteReader(buf)
}
testRewindReader(t, fn, buf)
}
func TestFileReader(t *testing.T) {
buf := []byte("foobar")
d, cleanup := test.TempDir(t)
defer cleanup()
filename := filepath.Join(d, "file-reader-test")
err := ioutil.WriteFile(filename, []byte("foobar"), 0600)
if err != nil {
t.Fatal(err)
}
f, err := os.Open(filename)
if err != nil {
t.Fatal(err)
}
defer func() {
err := f.Close()
if err != nil {
t.Fatal(err)
}
}()
fn := func() RewindReader {
rd, err := NewFileReader(f)
if err != nil {
t.Fatal(err)
}
return rd
}
testRewindReader(t, fn, buf)
}
func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) {
seed := time.Now().UnixNano()
t.Logf("seed is %d", seed)
rnd := rand.New(rand.NewSource(seed))
type ReaderTestFunc func(t testing.TB, r RewindReader, data []byte)
var tests = []ReaderTestFunc{
func(t testing.TB, rd RewindReader, data []byte) {
if rd.Length() != int64(len(data)) {
t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length())
}
buf := make([]byte, len(data))
_, err := io.ReadFull(rd, buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, data) {
t.Fatalf("wrong data returned")
}
if rd.Length() != int64(len(data)) {
t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length())
}
err = rd.Rewind()
if err != nil {
t.Fatal(err)
}
if rd.Length() != int64(len(data)) {
t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length())
}
buf2 := make([]byte, int64(len(data)))
_, err = io.ReadFull(rd, buf2)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf2, data) {
t.Fatalf("wrong data returned")
}
if rd.Length() != int64(len(data)) {
t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length())
}
},
func(t testing.TB, rd RewindReader, data []byte) {
// read first bytes
buf := make([]byte, rnd.Intn(len(data)))
_, err := io.ReadFull(rd, buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, data[:len(buf)]) {
t.Fatalf("wrong data returned")
}
err = rd.Rewind()
if err != nil {
t.Fatal(err)
}
buf2 := make([]byte, rnd.Intn(len(data)))
_, err = io.ReadFull(rd, buf2)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf2, data[:len(buf2)]) {
t.Fatalf("wrong data returned")
}
// read remainder
buf3 := make([]byte, len(data)-len(buf2))
_, err = io.ReadFull(rd, buf3)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf3, data[len(buf2):]) {
t.Fatalf("wrong data returned")
}
},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
rd := fn()
test(t, rd, data)
})
}
}