repository: split StreamPack implementation

Move the actual decoding of the pack data into a separate iterator.
This commit is contained in:
Michael Eischer 2023-12-31 00:18:41 +01:00
parent 54c5c72e5a
commit fb422497af

View file

@ -882,7 +882,7 @@ const maxUnusedRange = 4 * 1024 * 1024
// StreamPack loads the listed blobs from the specified pack file. The plaintext blob is passed to
// the handleBlobFn callback or an error if decryption failed or the blob hash does not match.
// handleBlobFn is never called multiple times for the same blob. If the callback returns an error,
// handleBlobFn is called at most once for each blob. If the callback returns an error,
// then StreamPack will abort and not retry it.
func StreamPack(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error {
if len(blobs) == 0 {
@ -940,72 +940,18 @@ func streamPackPart(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key,
if bufferSize > MaxStreamBufferSize {
bufferSize = MaxStreamBufferSize
}
// create reader here to allow reusing the buffered reader from checker.checkData
bufRd := bufio.NewReaderSize(rd, bufferSize)
currentBlobEnd := dataStart
var buf []byte
var decode []byte
for len(blobs) > 0 {
entry := blobs[0]
it := NewPackBlobIterator(packID, bufRd, dataStart, blobs, key, dec)
skipBytes := int(entry.Offset - currentBlobEnd)
if skipBytes < 0 {
return errors.Errorf("overlapping blobs in pack %v", packID)
}
_, err := bufRd.Discard(skipBytes)
if err != nil {
for {
val, err := it.Next()
if err == ErrPackEOF {
break
} else if err != nil {
return err
}
h := restic.BlobHandle{ID: entry.ID, Type: entry.Type}
debug.Log(" process blob %v, skipped %d, %v", h, skipBytes, entry)
if uint(cap(buf)) < entry.Length {
buf = make([]byte, entry.Length)
}
buf = buf[:entry.Length]
n, err := io.ReadFull(bufRd, buf)
if err != nil {
debug.Log(" read error %v", err)
return errors.Wrap(err, "ReadFull")
}
if n != len(buf) {
return errors.Errorf("read blob %v from %v: not enough bytes read, want %v, got %v",
h, packID.Str(), len(buf), n)
}
currentBlobEnd = entry.Offset + entry.Length
if int(entry.Length) <= key.NonceSize() {
debug.Log("%v", blobs)
return errors.Errorf("invalid blob length %v", entry)
}
// decryption errors are likely permanent, give the caller a chance to skip them
nonce, ciphertext := buf[:key.NonceSize()], buf[key.NonceSize():]
plaintext, err := key.Open(ciphertext[:0], nonce, ciphertext, nil)
if err == nil && entry.IsCompressed() {
// DecodeAll will allocate a slice if it is not large enough since it
// knows the decompressed size (because we're using EncodeAll)
decode, err = dec.DecodeAll(plaintext, decode[:0])
plaintext = decode
if err != nil {
err = errors.Errorf("decompressing blob %v failed: %v", h, err)
}
}
if err == nil {
id := restic.Hash(plaintext)
if !id.Equal(entry.ID) {
debug.Log("read blob %v/%v from %v: wrong data returned, hash is %v",
h.Type, h.ID, packID.Str(), id)
err = errors.Errorf("read blob %v from %v: wrong data returned, hash is %v",
h, packID.Str(), id)
}
}
err = handleBlobFn(entry.BlobHandle, plaintext, err)
err = handleBlobFn(val.Handle, val.Plaintext, val.Err)
if err != nil {
cancel()
return backoff.Permanent(err)
@ -1018,6 +964,109 @@ func streamPackPart(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key,
return errors.Wrap(err, "StreamPack")
}
type PackBlobIterator struct {
packID restic.ID
rd *bufio.Reader
currentOffset uint
blobs []restic.Blob
key *crypto.Key
dec *zstd.Decoder
buf []byte
decode []byte
}
type PackBlobValue struct {
Handle restic.BlobHandle
Plaintext []byte
Err error
}
var ErrPackEOF = errors.New("reached EOF of pack file")
func NewPackBlobIterator(packID restic.ID, rd *bufio.Reader, currentOffset uint,
blobs []restic.Blob, key *crypto.Key, dec *zstd.Decoder) *PackBlobIterator {
return &PackBlobIterator{
packID: packID,
rd: rd,
currentOffset: currentOffset,
blobs: blobs,
key: key,
dec: dec,
}
}
// Next returns the next blob, an error or ErrPackEOF if all blobs were read
func (b *PackBlobIterator) Next() (PackBlobValue, error) {
if len(b.blobs) == 0 {
return PackBlobValue{}, ErrPackEOF
}
entry := b.blobs[0]
b.blobs = b.blobs[1:]
skipBytes := int(entry.Offset - b.currentOffset)
if skipBytes < 0 {
return PackBlobValue{}, errors.Errorf("overlapping blobs in pack %v", b.packID)
}
_, err := b.rd.Discard(skipBytes)
if err != nil {
return PackBlobValue{}, err
}
b.currentOffset = entry.Offset
h := restic.BlobHandle{ID: entry.ID, Type: entry.Type}
debug.Log(" process blob %v, skipped %d, %v", h, skipBytes, entry)
if uint(cap(b.buf)) < entry.Length {
b.buf = make([]byte, entry.Length)
}
b.buf = b.buf[:entry.Length]
n, err := io.ReadFull(b.rd, b.buf)
if err != nil {
debug.Log(" read error %v", err)
return PackBlobValue{}, errors.Wrap(err, "ReadFull")
}
if n != len(b.buf) {
return PackBlobValue{}, errors.Errorf("read blob %v from %v: not enough bytes read, want %v, got %v",
h, b.packID.Str(), len(b.buf), n)
}
b.currentOffset = entry.Offset + entry.Length
if int(entry.Length) <= b.key.NonceSize() {
debug.Log("%v", b.blobs)
return PackBlobValue{}, errors.Errorf("invalid blob length %v", entry)
}
// decryption errors are likely permanent, give the caller a chance to skip them
nonce, ciphertext := b.buf[:b.key.NonceSize()], b.buf[b.key.NonceSize():]
plaintext, err := b.key.Open(ciphertext[:0], nonce, ciphertext, nil)
if err == nil && entry.IsCompressed() {
// DecodeAll will allocate a slice if it is not large enough since it
// knows the decompressed size (because we're using EncodeAll)
b.decode, err = b.dec.DecodeAll(plaintext, b.decode[:0])
plaintext = b.decode
if err != nil {
err = errors.Errorf("decompressing blob %v failed: %v", h, err)
}
}
if err == nil {
id := restic.Hash(plaintext)
if !id.Equal(entry.ID) {
debug.Log("read blob %v/%v from %v: wrong data returned, hash is %v",
h.Type, h.ID, b.packID.Str(), id)
err = errors.Errorf("read blob %v from %v: wrong data returned, hash is %v",
h, b.packID.Str(), id)
}
}
return PackBlobValue{entry.BlobHandle, plaintext, err}, nil
}
var zeroChunkOnce sync.Once
var zeroChunkID restic.ID