Merge pull request #4625 from MichaelEischer/refactor-streampacks

Refactor repository.StreamPacks
This commit is contained in:
Michael Eischer 2024-01-19 21:48:37 +01:00 committed by GitHub
commit 62111f4379
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 494 additions and 480 deletions

View file

@ -116,7 +116,7 @@ func repairPacks(ctx context.Context, gopts GlobalOptions, repo *repository.Repo
continue
}
err = repository.StreamPack(wgCtx, repo.Backend().Load, repo.Key(), b.PackID, blobs, func(blob restic.BlobHandle, buf []byte, err error) error {
err = repo.LoadBlobsFromPack(wgCtx, b.PackID, blobs, func(blob restic.BlobHandle, buf []byte, err error) error {
if err != nil {
// Fallback path
buf, err = repo.LoadBlob(wgCtx, blob.Type, blob.ID, nil)

View file

@ -10,6 +10,7 @@ import (
"sort"
"sync"
"github.com/klauspost/compress/zstd"
"github.com/minio/sha256-simd"
"github.com/restic/restic/internal/backend"
"github.com/restic/restic/internal/backend/s3"
@ -526,7 +527,7 @@ func (c *Checker) GetPacks() map[restic.ID]int64 {
}
// checkPack reads a pack and checks the integrity of all blobs.
func checkPack(ctx context.Context, r restic.Repository, id restic.ID, blobs []restic.Blob, size int64, bufRd *bufio.Reader) error {
func checkPack(ctx context.Context, r restic.Repository, id restic.ID, blobs []restic.Blob, size int64, bufRd *bufio.Reader, dec *zstd.Decoder) error {
debug.Log("checking pack %v", id.String())
if len(blobs) == 0 {
@ -557,49 +558,44 @@ func checkPack(ctx context.Context, r restic.Repository, id restic.ID, blobs []r
// calculate hash on-the-fly while reading the pack and capture pack header
var hash restic.ID
var hdrBuf []byte
hashingLoader := func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error {
return r.Backend().Load(ctx, h, int(size), 0, func(rd io.Reader) error {
hrd := hashing.NewReader(rd, sha256.New())
bufRd.Reset(hrd)
h := backend.Handle{Type: backend.PackFile, Name: id.String()}
err := r.Backend().Load(ctx, h, int(size), 0, func(rd io.Reader) error {
hrd := hashing.NewReader(rd, sha256.New())
bufRd.Reset(hrd)
// skip to start of first blob, offset == 0 for correct pack files
_, err := bufRd.Discard(int(offset))
if err != nil {
it := repository.NewPackBlobIterator(id, bufRd, 0, blobs, r.Key(), dec)
for {
val, err := it.Next()
if err == repository.ErrPackEOF {
break
} else if err != nil {
return err
}
err = fn(bufRd)
if err != nil {
return err
debug.Log(" check blob %v: %v", val.Handle.ID, val.Handle)
if val.Err != nil {
debug.Log(" error verifying blob %v: %v", val.Handle.ID, err)
errs = append(errs, errors.Errorf("blob %v: %v", val.Handle.ID, err))
}
// skip enough bytes until we reach the possible header start
curPos := length + int(offset)
minHdrStart := int(size) - pack.MaxHeaderSize
if minHdrStart > curPos {
_, err := bufRd.Discard(minHdrStart - curPos)
if err != nil {
return err
}
}
// read remainder, which should be the pack header
hdrBuf, err = io.ReadAll(bufRd)
if err != nil {
return err
}
hash = restic.IDFromHash(hrd.Sum(nil))
return nil
})
}
err := repository.StreamPack(ctx, hashingLoader, r.Key(), id, blobs, func(blob restic.BlobHandle, buf []byte, err error) error {
debug.Log(" check blob %v: %v", blob.ID, blob)
if err != nil {
debug.Log(" error verifying blob %v: %v", blob.ID, err)
errs = append(errs, errors.Errorf("blob %v: %v", blob.ID, err))
}
// skip enough bytes until we reach the possible header start
curPos := lastBlobEnd
minHdrStart := int(size) - pack.MaxHeaderSize
if minHdrStart > curPos {
_, err := bufRd.Discard(minHdrStart - curPos)
if err != nil {
return err
}
}
// read remainder, which should be the pack header
var err error
hdrBuf, err = io.ReadAll(bufRd)
if err != nil {
return err
}
hash = restic.IDFromHash(hrd.Sum(nil))
return nil
})
if err != nil {
@ -670,6 +666,11 @@ func (c *Checker) ReadPacks(ctx context.Context, packs map[restic.ID]int64, p *p
// create a buffer that is large enough to be reused by repository.StreamPack
// this ensures that we can read the pack header later on
bufRd := bufio.NewReaderSize(nil, repository.MaxStreamBufferSize)
dec, err := zstd.NewReader(nil)
if err != nil {
panic(dec)
}
defer dec.Close()
for {
var ps checkTask
var ok bool
@ -683,7 +684,7 @@ func (c *Checker) ReadPacks(ctx context.Context, packs map[restic.ID]int64, p *p
}
}
err := checkPack(ctx, c.repo, ps.id, ps.blobs, ps.size, bufRd)
err := checkPack(ctx, c.repo, ps.id, ps.blobs, ps.size, bufRd, dec)
p.Add(1)
if err == nil {
continue

View file

@ -77,7 +77,7 @@ func repack(ctx context.Context, repo restic.Repository, dstRepo restic.Reposito
worker := func() error {
for t := range downloadQueue {
err := StreamPack(wgCtx, repo.Backend().Load, repo.Key(), t.PackID, t.Blobs, func(blob restic.BlobHandle, buf []byte, err error) error {
err := repo.LoadBlobsFromPack(wgCtx, t.PackID, t.Blobs, func(blob restic.BlobHandle, buf []byte, err error) error {
if err != nil {
var ierr error
// check whether we can get a valid copy somewhere else

View file

@ -875,16 +875,20 @@ func (r *Repository) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte
return newID, known, size, err
}
type BackendLoadFn func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error
type backendLoadFn func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error
// Skip sections with more than 4MB unused blobs
const maxUnusedRange = 4 * 1024 * 1024
// StreamPack loads the listed blobs from the specified pack file. The plaintext blob is passed to
// LoadBlobsFromPack loads the listed blobs from the specified pack file. The plaintext blob is passed to
// the handleBlobFn callback or an error if decryption failed or the blob hash does not match.
// handleBlobFn is never called multiple times for the same blob. If the callback returns an error,
// then StreamPack will abort and not retry it.
func StreamPack(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error {
// handleBlobFn is called at most once for each blob. If the callback returns an error,
// then LoadBlobsFromPack will abort and not retry it.
func (r *Repository) LoadBlobsFromPack(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error {
return streamPack(ctx, r.Backend().Load, r.key, packID, blobs, handleBlobFn)
}
func streamPack(ctx context.Context, beLoad backendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error {
if len(blobs) == 0 {
// nothing to do
return nil
@ -915,7 +919,7 @@ func StreamPack(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, pack
return streamPackPart(ctx, beLoad, key, packID, blobs[lowerIdx:], handleBlobFn)
}
func streamPackPart(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error {
func streamPackPart(ctx context.Context, beLoad backendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error {
h := backend.Handle{Type: restic.PackFile, Name: packID.String(), IsMetadata: false}
dataStart := blobs[0].Offset
@ -940,72 +944,18 @@ func streamPackPart(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key,
if bufferSize > MaxStreamBufferSize {
bufferSize = MaxStreamBufferSize
}
// create reader here to allow reusing the buffered reader from checker.checkData
bufRd := bufio.NewReaderSize(rd, bufferSize)
currentBlobEnd := dataStart
var buf []byte
var decode []byte
for len(blobs) > 0 {
entry := blobs[0]
it := NewPackBlobIterator(packID, bufRd, dataStart, blobs, key, dec)
skipBytes := int(entry.Offset - currentBlobEnd)
if skipBytes < 0 {
return errors.Errorf("overlapping blobs in pack %v", packID)
}
_, err := bufRd.Discard(skipBytes)
if err != nil {
for {
val, err := it.Next()
if err == ErrPackEOF {
break
} else if err != nil {
return err
}
h := restic.BlobHandle{ID: entry.ID, Type: entry.Type}
debug.Log(" process blob %v, skipped %d, %v", h, skipBytes, entry)
if uint(cap(buf)) < entry.Length {
buf = make([]byte, entry.Length)
}
buf = buf[:entry.Length]
n, err := io.ReadFull(bufRd, buf)
if err != nil {
debug.Log(" read error %v", err)
return errors.Wrap(err, "ReadFull")
}
if n != len(buf) {
return errors.Errorf("read blob %v from %v: not enough bytes read, want %v, got %v",
h, packID.Str(), len(buf), n)
}
currentBlobEnd = entry.Offset + entry.Length
if int(entry.Length) <= key.NonceSize() {
debug.Log("%v", blobs)
return errors.Errorf("invalid blob length %v", entry)
}
// decryption errors are likely permanent, give the caller a chance to skip them
nonce, ciphertext := buf[:key.NonceSize()], buf[key.NonceSize():]
plaintext, err := key.Open(ciphertext[:0], nonce, ciphertext, nil)
if err == nil && entry.IsCompressed() {
// DecodeAll will allocate a slice if it is not large enough since it
// knows the decompressed size (because we're using EncodeAll)
decode, err = dec.DecodeAll(plaintext, decode[:0])
plaintext = decode
if err != nil {
err = errors.Errorf("decompressing blob %v failed: %v", h, err)
}
}
if err == nil {
id := restic.Hash(plaintext)
if !id.Equal(entry.ID) {
debug.Log("read blob %v/%v from %v: wrong data returned, hash is %v",
h.Type, h.ID, packID.Str(), id)
err = errors.Errorf("read blob %v from %v: wrong data returned, hash is %v",
h, packID.Str(), id)
}
}
err = handleBlobFn(entry.BlobHandle, plaintext, err)
err = handleBlobFn(val.Handle, val.Plaintext, val.Err)
if err != nil {
cancel()
return backoff.Permanent(err)
@ -1018,6 +968,109 @@ func streamPackPart(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key,
return errors.Wrap(err, "StreamPack")
}
type PackBlobIterator struct {
packID restic.ID
rd *bufio.Reader
currentOffset uint
blobs []restic.Blob
key *crypto.Key
dec *zstd.Decoder
buf []byte
decode []byte
}
type PackBlobValue struct {
Handle restic.BlobHandle
Plaintext []byte
Err error
}
var ErrPackEOF = errors.New("reached EOF of pack file")
func NewPackBlobIterator(packID restic.ID, rd *bufio.Reader, currentOffset uint,
blobs []restic.Blob, key *crypto.Key, dec *zstd.Decoder) *PackBlobIterator {
return &PackBlobIterator{
packID: packID,
rd: rd,
currentOffset: currentOffset,
blobs: blobs,
key: key,
dec: dec,
}
}
// Next returns the next blob, an error or ErrPackEOF if all blobs were read
func (b *PackBlobIterator) Next() (PackBlobValue, error) {
if len(b.blobs) == 0 {
return PackBlobValue{}, ErrPackEOF
}
entry := b.blobs[0]
b.blobs = b.blobs[1:]
skipBytes := int(entry.Offset - b.currentOffset)
if skipBytes < 0 {
return PackBlobValue{}, errors.Errorf("overlapping blobs in pack %v", b.packID)
}
_, err := b.rd.Discard(skipBytes)
if err != nil {
return PackBlobValue{}, err
}
b.currentOffset = entry.Offset
h := restic.BlobHandle{ID: entry.ID, Type: entry.Type}
debug.Log(" process blob %v, skipped %d, %v", h, skipBytes, entry)
if uint(cap(b.buf)) < entry.Length {
b.buf = make([]byte, entry.Length)
}
b.buf = b.buf[:entry.Length]
n, err := io.ReadFull(b.rd, b.buf)
if err != nil {
debug.Log(" read error %v", err)
return PackBlobValue{}, errors.Wrap(err, "ReadFull")
}
if n != len(b.buf) {
return PackBlobValue{}, errors.Errorf("read blob %v from %v: not enough bytes read, want %v, got %v",
h, b.packID.Str(), len(b.buf), n)
}
b.currentOffset = entry.Offset + entry.Length
if int(entry.Length) <= b.key.NonceSize() {
debug.Log("%v", b.blobs)
return PackBlobValue{}, errors.Errorf("invalid blob length %v", entry)
}
// decryption errors are likely permanent, give the caller a chance to skip them
nonce, ciphertext := b.buf[:b.key.NonceSize()], b.buf[b.key.NonceSize():]
plaintext, err := b.key.Open(ciphertext[:0], nonce, ciphertext, nil)
if err == nil && entry.IsCompressed() {
// DecodeAll will allocate a slice if it is not large enough since it
// knows the decompressed size (because we're using EncodeAll)
b.decode, err = b.dec.DecodeAll(plaintext, b.decode[:0])
plaintext = b.decode
if err != nil {
err = errors.Errorf("decompressing blob %v failed: %v", h, err)
}
}
if err == nil {
id := restic.Hash(plaintext)
if !id.Equal(entry.ID) {
debug.Log("read blob %v/%v from %v: wrong data returned, hash is %v",
h.Type, h.ID, b.packID.Str(), id)
err = errors.Errorf("read blob %v from %v: wrong data returned, hash is %v",
h, b.packID.Str(), id)
}
}
return PackBlobValue{entry.BlobHandle, plaintext, err}, nil
}
var zeroChunkOnce sync.Once
var zeroChunkID restic.ID

View file

@ -1,11 +1,21 @@
package repository
import (
"bytes"
"context"
"encoding/json"
"io"
"math/rand"
"sort"
"strings"
"testing"
"github.com/cenkalti/backoff/v4"
"github.com/google/go-cmp/cmp"
"github.com/klauspost/compress/zstd"
"github.com/restic/restic/internal/backend"
"github.com/restic/restic/internal/crypto"
"github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/restic"
rtest "github.com/restic/restic/internal/test"
)
@ -73,3 +83,271 @@ func BenchmarkSortCachedPacksFirst(b *testing.B) {
sortCachedPacksFirst(cache, cpy[:])
}
}
// buildPackfileWithoutHeader returns a manually built pack file without a header.
func buildPackfileWithoutHeader(blobSizes []int, key *crypto.Key, compress bool) (blobs []restic.Blob, packfile []byte) {
opts := []zstd.EOption{
// Set the compression level configured.
zstd.WithEncoderLevel(zstd.SpeedDefault),
// Disable CRC, we have enough checks in place, makes the
// compressed data four bytes shorter.
zstd.WithEncoderCRC(false),
// Set a window of 512kbyte, so we have good lookbehind for usual
// blob sizes.
zstd.WithWindowSize(512 * 1024),
}
enc, err := zstd.NewWriter(nil, opts...)
if err != nil {
panic(err)
}
var offset uint
for i, size := range blobSizes {
plaintext := rtest.Random(800+i, size)
id := restic.Hash(plaintext)
uncompressedLength := uint(0)
if compress {
uncompressedLength = uint(len(plaintext))
plaintext = enc.EncodeAll(plaintext, nil)
}
// we use a deterministic nonce here so the whole process is
// deterministic, last byte is the blob index
var nonce = []byte{
0x15, 0x98, 0xc0, 0xf7, 0xb9, 0x65, 0x97, 0x74,
0x12, 0xdc, 0xd3, 0x62, 0xa9, 0x6e, 0x20, byte(i),
}
before := len(packfile)
packfile = append(packfile, nonce...)
packfile = key.Seal(packfile, nonce, plaintext, nil)
after := len(packfile)
ciphertextLength := after - before
blobs = append(blobs, restic.Blob{
BlobHandle: restic.BlobHandle{
Type: restic.DataBlob,
ID: id,
},
Length: uint(ciphertextLength),
UncompressedLength: uncompressedLength,
Offset: offset,
})
offset = uint(len(packfile))
}
return blobs, packfile
}
func TestStreamPack(t *testing.T) {
TestAllVersions(t, testStreamPack)
}
func testStreamPack(t *testing.T, version uint) {
// always use the same key for deterministic output
const jsonKey = `{"mac":{"k":"eQenuI8adktfzZMuC8rwdA==","r":"k8cfAly2qQSky48CQK7SBA=="},"encrypt":"MKO9gZnRiQFl8mDUurSDa9NMjiu9MUifUrODTHS05wo="}`
var key crypto.Key
err := json.Unmarshal([]byte(jsonKey), &key)
if err != nil {
t.Fatal(err)
}
blobSizes := []int{
5522811,
10,
5231,
18812,
123123,
13522811,
12301,
892242,
28616,
13351,
252287,
188883,
3522811,
18883,
}
var compress bool
switch version {
case 1:
compress = false
case 2:
compress = true
default:
t.Fatal("test does not support repository version", version)
}
packfileBlobs, packfile := buildPackfileWithoutHeader(blobSizes, &key, compress)
loadCalls := 0
shortFirstLoad := false
loadBytes := func(length int, offset int64) []byte {
data := packfile
if offset > int64(len(data)) {
offset = 0
length = 0
}
data = data[offset:]
if length > len(data) {
length = len(data)
}
if shortFirstLoad {
length /= 2
shortFirstLoad = false
}
return data[:length]
}
load := func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error {
data := loadBytes(length, offset)
if shortFirstLoad {
data = data[:len(data)/2]
shortFirstLoad = false
}
loadCalls++
err := fn(bytes.NewReader(data))
if err == nil {
return nil
}
var permanent *backoff.PermanentError
if errors.As(err, &permanent) {
return err
}
// retry loading once
return fn(bytes.NewReader(loadBytes(length, offset)))
}
// first, test regular usage
t.Run("regular", func(t *testing.T) {
tests := []struct {
blobs []restic.Blob
calls int
shortFirstLoad bool
}{
{packfileBlobs[1:2], 1, false},
{packfileBlobs[2:5], 1, false},
{packfileBlobs[2:8], 1, false},
{[]restic.Blob{
packfileBlobs[0],
packfileBlobs[4],
packfileBlobs[2],
}, 1, false},
{[]restic.Blob{
packfileBlobs[0],
packfileBlobs[len(packfileBlobs)-1],
}, 2, false},
{packfileBlobs[:], 1, true},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
gotBlobs := make(map[restic.ID]int)
handleBlob := func(blob restic.BlobHandle, buf []byte, err error) error {
gotBlobs[blob.ID]++
id := restic.Hash(buf)
if !id.Equal(blob.ID) {
t.Fatalf("wrong id %v for blob %s returned", id, blob.ID)
}
return err
}
wantBlobs := make(map[restic.ID]int)
for _, blob := range test.blobs {
wantBlobs[blob.ID] = 1
}
loadCalls = 0
shortFirstLoad = test.shortFirstLoad
err = streamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(wantBlobs, gotBlobs) {
t.Fatal(cmp.Diff(wantBlobs, gotBlobs))
}
rtest.Equals(t, test.calls, loadCalls)
})
}
})
shortFirstLoad = false
// next, test invalid uses, which should return an error
t.Run("invalid", func(t *testing.T) {
tests := []struct {
blobs []restic.Blob
err string
}{
{
// pass one blob several times
blobs: []restic.Blob{
packfileBlobs[3],
packfileBlobs[8],
packfileBlobs[3],
packfileBlobs[4],
},
err: "overlapping blobs in pack",
},
{
// pass something that's not a valid blob in the current pack file
blobs: []restic.Blob{
{
Offset: 123,
Length: 20000,
},
},
err: "ciphertext verification failed",
},
{
// pass a blob that's too small
blobs: []restic.Blob{
{
Offset: 123,
Length: 10,
},
},
err: "invalid blob length",
},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
handleBlob := func(blob restic.BlobHandle, buf []byte, err error) error {
return err
}
err = streamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob)
if err == nil {
t.Fatalf("wanted error %v, got nil", test.err)
}
if !strings.Contains(err.Error(), test.err) {
t.Fatalf("wrong error returned, it should contain %q but was %q", test.err, err)
}
})
}
})
}

View file

@ -4,8 +4,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
@ -15,9 +13,6 @@ import (
"testing"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/go-cmp/cmp"
"github.com/klauspost/compress/zstd"
"github.com/restic/restic/internal/backend"
"github.com/restic/restic/internal/backend/local"
"github.com/restic/restic/internal/crypto"
@ -430,274 +425,6 @@ func testRepositoryIncrementalIndex(t *testing.T, version uint) {
}
// buildPackfileWithoutHeader returns a manually built pack file without a header.
func buildPackfileWithoutHeader(blobSizes []int, key *crypto.Key, compress bool) (blobs []restic.Blob, packfile []byte) {
opts := []zstd.EOption{
// Set the compression level configured.
zstd.WithEncoderLevel(zstd.SpeedDefault),
// Disable CRC, we have enough checks in place, makes the
// compressed data four bytes shorter.
zstd.WithEncoderCRC(false),
// Set a window of 512kbyte, so we have good lookbehind for usual
// blob sizes.
zstd.WithWindowSize(512 * 1024),
}
enc, err := zstd.NewWriter(nil, opts...)
if err != nil {
panic(err)
}
var offset uint
for i, size := range blobSizes {
plaintext := rtest.Random(800+i, size)
id := restic.Hash(plaintext)
uncompressedLength := uint(0)
if compress {
uncompressedLength = uint(len(plaintext))
plaintext = enc.EncodeAll(plaintext, nil)
}
// we use a deterministic nonce here so the whole process is
// deterministic, last byte is the blob index
var nonce = []byte{
0x15, 0x98, 0xc0, 0xf7, 0xb9, 0x65, 0x97, 0x74,
0x12, 0xdc, 0xd3, 0x62, 0xa9, 0x6e, 0x20, byte(i),
}
before := len(packfile)
packfile = append(packfile, nonce...)
packfile = key.Seal(packfile, nonce, plaintext, nil)
after := len(packfile)
ciphertextLength := after - before
blobs = append(blobs, restic.Blob{
BlobHandle: restic.BlobHandle{
Type: restic.DataBlob,
ID: id,
},
Length: uint(ciphertextLength),
UncompressedLength: uncompressedLength,
Offset: offset,
})
offset = uint(len(packfile))
}
return blobs, packfile
}
func TestStreamPack(t *testing.T) {
repository.TestAllVersions(t, testStreamPack)
}
func testStreamPack(t *testing.T, version uint) {
// always use the same key for deterministic output
const jsonKey = `{"mac":{"k":"eQenuI8adktfzZMuC8rwdA==","r":"k8cfAly2qQSky48CQK7SBA=="},"encrypt":"MKO9gZnRiQFl8mDUurSDa9NMjiu9MUifUrODTHS05wo="}`
var key crypto.Key
err := json.Unmarshal([]byte(jsonKey), &key)
if err != nil {
t.Fatal(err)
}
blobSizes := []int{
5522811,
10,
5231,
18812,
123123,
13522811,
12301,
892242,
28616,
13351,
252287,
188883,
3522811,
18883,
}
var compress bool
switch version {
case 1:
compress = false
case 2:
compress = true
default:
t.Fatal("test does not support repository version", version)
}
packfileBlobs, packfile := buildPackfileWithoutHeader(blobSizes, &key, compress)
loadCalls := 0
shortFirstLoad := false
loadBytes := func(length int, offset int64) []byte {
data := packfile
if offset > int64(len(data)) {
offset = 0
length = 0
}
data = data[offset:]
if length > len(data) {
length = len(data)
}
if shortFirstLoad {
length /= 2
shortFirstLoad = false
}
return data[:length]
}
load := func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error {
data := loadBytes(length, offset)
if shortFirstLoad {
data = data[:len(data)/2]
shortFirstLoad = false
}
loadCalls++
err := fn(bytes.NewReader(data))
if err == nil {
return nil
}
var permanent *backoff.PermanentError
if errors.As(err, &permanent) {
return err
}
// retry loading once
return fn(bytes.NewReader(loadBytes(length, offset)))
}
// first, test regular usage
t.Run("regular", func(t *testing.T) {
tests := []struct {
blobs []restic.Blob
calls int
shortFirstLoad bool
}{
{packfileBlobs[1:2], 1, false},
{packfileBlobs[2:5], 1, false},
{packfileBlobs[2:8], 1, false},
{[]restic.Blob{
packfileBlobs[0],
packfileBlobs[4],
packfileBlobs[2],
}, 1, false},
{[]restic.Blob{
packfileBlobs[0],
packfileBlobs[len(packfileBlobs)-1],
}, 2, false},
{packfileBlobs[:], 1, true},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
gotBlobs := make(map[restic.ID]int)
handleBlob := func(blob restic.BlobHandle, buf []byte, err error) error {
gotBlobs[blob.ID]++
id := restic.Hash(buf)
if !id.Equal(blob.ID) {
t.Fatalf("wrong id %v for blob %s returned", id, blob.ID)
}
return err
}
wantBlobs := make(map[restic.ID]int)
for _, blob := range test.blobs {
wantBlobs[blob.ID] = 1
}
loadCalls = 0
shortFirstLoad = test.shortFirstLoad
err = repository.StreamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob)
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(wantBlobs, gotBlobs) {
t.Fatal(cmp.Diff(wantBlobs, gotBlobs))
}
rtest.Equals(t, test.calls, loadCalls)
})
}
})
shortFirstLoad = false
// next, test invalid uses, which should return an error
t.Run("invalid", func(t *testing.T) {
tests := []struct {
blobs []restic.Blob
err string
}{
{
// pass one blob several times
blobs: []restic.Blob{
packfileBlobs[3],
packfileBlobs[8],
packfileBlobs[3],
packfileBlobs[4],
},
err: "overlapping blobs in pack",
},
{
// pass something that's not a valid blob in the current pack file
blobs: []restic.Blob{
{
Offset: 123,
Length: 20000,
},
},
err: "ciphertext verification failed",
},
{
// pass a blob that's too small
blobs: []restic.Blob{
{
Offset: 123,
Length: 10,
},
},
err: "invalid blob length",
},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
handleBlob := func(blob restic.BlobHandle, buf []byte, err error) error {
return err
}
err = repository.StreamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob)
if err == nil {
t.Fatalf("wanted error %v, got nil", test.err)
}
if !strings.Contains(err.Error(), test.err) {
t.Fatalf("wrong error returned, it should contain %q but was %q", test.err, err)
}
})
}
})
}
func TestInvalidCompression(t *testing.T) {
var comp repository.CompressionMode
err := comp.Set("nope")

View file

@ -44,6 +44,7 @@ type Repository interface {
ListPack(context.Context, ID, int64) ([]Blob, uint32, error)
LoadBlob(context.Context, BlobType, ID, []byte) ([]byte, error)
LoadBlobsFromPack(ctx context.Context, packID ID, blobs []Blob, handleBlobFn func(blob BlobHandle, buf []byte, err error) error) error
SaveBlob(context.Context, BlobType, []byte, ID, bool) (ID, bool, int, error)
// StartPackUploader start goroutines to upload new pack files. The errgroup

View file

@ -7,7 +7,6 @@ import (
"golang.org/x/sync/errgroup"
"github.com/restic/restic/internal/crypto"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/repository"
@ -45,11 +44,12 @@ type packInfo struct {
files map[*fileInfo]struct{} // set of files that use blobs from this pack
}
type blobsLoaderFn func(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error
// fileRestorer restores set of files
type fileRestorer struct {
key *crypto.Key
idx func(restic.BlobHandle) []restic.PackedBlob
packLoader repository.BackendLoadFn
idx func(restic.BlobHandle) []restic.PackedBlob
blobsLoader blobsLoaderFn
workerCount int
filesWriter *filesWriter
@ -63,8 +63,7 @@ type fileRestorer struct {
}
func newFileRestorer(dst string,
packLoader repository.BackendLoadFn,
key *crypto.Key,
blobsLoader blobsLoaderFn,
idx func(restic.BlobHandle) []restic.PackedBlob,
connections uint,
sparse bool,
@ -74,9 +73,8 @@ func newFileRestorer(dst string,
workerCount := int(connections)
return &fileRestorer{
key: key,
idx: idx,
packLoader: packLoader,
blobsLoader: blobsLoader,
filesWriter: newFilesWriter(workerCount),
zeroChunk: repository.ZeroChunk(),
sparse: sparse,
@ -310,7 +308,7 @@ func (r *fileRestorer) downloadBlobs(ctx context.Context, packID restic.ID,
for _, entry := range blobs {
blobList = append(blobList, entry.blob)
}
return repository.StreamPack(ctx, r.packLoader, r.key, packID, blobList,
return r.blobsLoader(ctx, packID, blobList,
func(h restic.BlobHandle, blobData []byte, err error) error {
processedBlobs.Insert(h)
blob := blobs[h.ID]

View file

@ -4,14 +4,11 @@ import (
"bytes"
"context"
"fmt"
"io"
"os"
"sort"
"testing"
"github.com/restic/restic/internal/backend"
"github.com/restic/restic/internal/crypto"
"github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/repository"
"github.com/restic/restic/internal/restic"
rtest "github.com/restic/restic/internal/test"
)
@ -27,11 +24,6 @@ type TestFile struct {
}
type TestRepo struct {
key *crypto.Key
// pack names and ids
packsNameToID map[string]restic.ID
packsIDToName map[restic.ID]string
packsIDToData map[restic.ID][]byte
// blobs and files
@ -40,7 +32,7 @@ type TestRepo struct {
filesPathToContent map[string]string
//
loader repository.BackendLoadFn
loader blobsLoaderFn
}
func (i *TestRepo) Lookup(bh restic.BlobHandle) []restic.PackedBlob {
@ -59,16 +51,6 @@ func newTestRepo(content []TestFile) *TestRepo {
blobs map[restic.ID]restic.Blob
}
packs := make(map[string]Pack)
key := crypto.NewRandomKey()
seal := func(data []byte) []byte {
ciphertext := crypto.NewBlobBuffer(len(data))
ciphertext = ciphertext[:0] // truncate the slice
nonce := crypto.NewRandomNonce()
ciphertext = append(ciphertext, nonce...)
return key.Seal(ciphertext, nonce, data, nil)
}
filesPathToContent := make(map[string]string)
for _, file := range content {
@ -86,14 +68,15 @@ func newTestRepo(content []TestFile) *TestRepo {
// calculate blob id and add to the pack as necessary
blobID := restic.Hash([]byte(blob.data))
if _, found := pack.blobs[blobID]; !found {
blobData := seal([]byte(blob.data))
blobData := []byte(blob.data)
pack.blobs[blobID] = restic.Blob{
BlobHandle: restic.BlobHandle{
Type: restic.DataBlob,
ID: blobID,
},
Length: uint(len(blobData)),
Offset: uint(len(pack.data)),
Length: uint(len(blobData)),
UncompressedLength: uint(len(blobData)),
Offset: uint(len(pack.data)),
}
pack.data = append(pack.data, blobData...)
}
@ -104,15 +87,11 @@ func newTestRepo(content []TestFile) *TestRepo {
}
blobs := make(map[restic.ID][]restic.PackedBlob)
packsIDToName := make(map[restic.ID]string)
packsIDToData := make(map[restic.ID][]byte)
packsNameToID := make(map[string]restic.ID)
for _, pack := range packs {
packID := restic.Hash(pack.data)
packsIDToName[packID] = pack.name
packsIDToData[packID] = pack.data
packsNameToID[pack.name] = packID
for blobID, blob := range pack.blobs {
blobs[blobID] = append(blobs[blobID], restic.PackedBlob{Blob: blob, PackID: packID})
}
@ -128,30 +107,44 @@ func newTestRepo(content []TestFile) *TestRepo {
}
repo := &TestRepo{
key: key,
packsIDToName: packsIDToName,
packsIDToData: packsIDToData,
packsNameToID: packsNameToID,
blobs: blobs,
files: files,
filesPathToContent: filesPathToContent,
}
repo.loader = func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error {
packID, err := restic.ParseID(h.Name)
if err != nil {
return err
repo.loader = func(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error {
blobs = append([]restic.Blob{}, blobs...)
sort.Slice(blobs, func(i, j int) bool {
return blobs[i].Offset < blobs[j].Offset
})
for _, blob := range blobs {
found := false
for _, e := range repo.blobs[blob.ID] {
if packID == e.PackID {
found = true
buf := repo.packsIDToData[packID][e.Offset : e.Offset+e.Length]
err := handleBlobFn(e.BlobHandle, buf, nil)
if err != nil {
return err
}
}
}
if !found {
return fmt.Errorf("missing blob: %v", blob)
}
}
rd := bytes.NewReader(repo.packsIDToData[packID][int(offset) : int(offset)+length])
return fn(rd)
return nil
}
return repo
}
func restoreAndVerify(t *testing.T, tempdir string, content []TestFile, files map[string]bool, sparse bool) {
t.Helper()
repo := newTestRepo(content)
r := newFileRestorer(tempdir, repo.loader, repo.key, repo.Lookup, 2, sparse, nil)
r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, sparse, nil)
if files == nil {
r.files = repo.files
@ -170,6 +163,7 @@ func restoreAndVerify(t *testing.T, tempdir string, content []TestFile, files ma
}
func verifyRestore(t *testing.T, r *fileRestorer, repo *TestRepo) {
t.Helper()
for _, file := range r.files {
target := r.targetPath(file.location)
data, err := os.ReadFile(target)
@ -283,62 +277,17 @@ func TestErrorRestoreFiles(t *testing.T) {
loadError := errors.New("load error")
// loader always returns an error
repo.loader = func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error {
repo.loader = func(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error {
return loadError
}
r := newFileRestorer(tempdir, repo.loader, repo.key, repo.Lookup, 2, false, nil)
r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, false, nil)
r.files = repo.files
err := r.restoreFiles(context.TODO())
rtest.Assert(t, errors.Is(err, loadError), "got %v, expected contained error %v", err, loadError)
}
func TestDownloadError(t *testing.T) {
for i := 0; i < 100; i += 10 {
testPartialDownloadError(t, i)
}
}
func testPartialDownloadError(t *testing.T, part int) {
tempdir := rtest.TempDir(t)
content := []TestFile{
{
name: "file1",
blobs: []TestBlob{
{"data1-1", "pack1"},
{"data1-2", "pack1"},
{"data1-3", "pack1"},
},
}}
repo := newTestRepo(content)
// loader always returns an error
loader := repo.loader
repo.loader = func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error {
// only load partial data to exercise fault handling in different places
err := loader(ctx, h, length*part/100, offset, fn)
if err == nil {
return nil
}
fmt.Println("Retry after error", err)
return loader(ctx, h, length, offset, fn)
}
r := newFileRestorer(tempdir, repo.loader, repo.key, repo.Lookup, 2, false, nil)
r.files = repo.files
r.Error = func(s string, e error) error {
// ignore errors as in the `restore` command
fmt.Println("error during restore", s, e)
return nil
}
err := r.restoreFiles(context.TODO())
rtest.OK(t, err)
verifyRestore(t, r, repo)
}
func TestFatalDownloadError(t *testing.T) {
tempdir := rtest.TempDir(t)
content := []TestFile{
@ -361,12 +310,19 @@ func TestFatalDownloadError(t *testing.T) {
repo := newTestRepo(content)
loader := repo.loader
repo.loader = func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error {
// only return half the data to break file2
return loader(ctx, h, length/2, offset, fn)
repo.loader = func(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error {
ctr := 0
return loader(ctx, packID, blobs, func(blob restic.BlobHandle, buf []byte, err error) error {
if ctr < 2 {
ctr++
return handleBlobFn(blob, buf, err)
}
// break file2
return errors.New("failed to load blob")
})
}
r := newFileRestorer(tempdir, repo.loader, repo.key, repo.Lookup, 2, false, nil)
r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, false, nil)
r.files = repo.files
var errors []string

View file

@ -231,7 +231,7 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error {
}
idx := NewHardlinkIndex[string]()
filerestorer := newFileRestorer(dst, res.repo.Backend().Load, res.repo.Key(), res.repo.Index().Lookup,
filerestorer := newFileRestorer(dst, res.repo.LoadBlobsFromPack, res.repo.Index().Lookup,
res.repo.Connections(), res.sparse, res.progress)
filerestorer.Error = res.Error