Add plumbing to calculate backend specific file hash for upload

This enables the backends to request the calculation of a
backend-specific hash. For the currently supported backends this will
always be MD5. The hash calculation happens as early as possible, for
pack files this is during assembly of the pack file. That way the hash
would even capture corruptions of the temporary pack file on disk.
This commit is contained in:
Michael Eischer 2020-12-19 12:39:48 +01:00
parent ee2f14eaf0
commit 9aa2eff384
28 changed files with 219 additions and 48 deletions

View file

@ -3,6 +3,7 @@ package azure
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"hash"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -112,6 +113,11 @@ func (be *Backend) Location() string {
return be.Join(be.container.Name, be.prefix) return be.Join(be.container.Name, be.prefix)
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (be *Backend) Hasher() hash.Hash {
return nil
}
// Path returns the path in the bucket that is used for this backend. // Path returns the path in the bucket that is used for this backend.
func (be *Backend) Path() string { func (be *Backend) Path() string {
return be.prefix return be.prefix

View file

@ -172,7 +172,7 @@ func TestUploadLargeFile(t *testing.T) {
t.Logf("hash of %d bytes: %v", len(data), id) t.Logf("hash of %d bytes: %v", len(data), id)
err = be.Save(ctx, h, restic.NewByteReader(data)) err = be.Save(ctx, h, restic.NewByteReader(data, be.Hasher()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -2,6 +2,7 @@ package b2
import ( import (
"context" "context"
"hash"
"io" "io"
"net/http" "net/http"
"path" "path"
@ -137,6 +138,11 @@ func (be *b2Backend) Location() string {
return be.cfg.Bucket return be.cfg.Bucket
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (be *b2Backend) Hasher() hash.Hash {
return nil
}
// IsNotExist returns true if the error is caused by a non-existing file. // IsNotExist returns true if the error is caused by a non-existing file.
func (be *b2Backend) IsNotExist(err error) bool { func (be *b2Backend) IsNotExist(err error) bool {
return b2.IsNotExist(errors.Cause(err)) return b2.IsNotExist(errors.Cause(err))

View file

@ -36,7 +36,7 @@ func TestBackendSaveRetry(t *testing.T) {
retryBackend := NewRetryBackend(be, 10, nil) retryBackend := NewRetryBackend(be, 10, nil)
data := test.Random(23, 5*1024*1024+11241) data := test.Random(23, 5*1024*1024+11241)
err := retryBackend.Save(context.TODO(), restic.Handle{}, restic.NewByteReader(data)) err := retryBackend.Save(context.TODO(), restic.Handle{}, restic.NewByteReader(data, be.Hasher()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -256,7 +256,7 @@ func TestBackendCanceledContext(t *testing.T) {
_, err = retryBackend.Stat(ctx, h) _, err = retryBackend.Stat(ctx, h)
assertIsCanceled(t, err) assertIsCanceled(t, err)
err = retryBackend.Save(ctx, h, restic.NewByteReader([]byte{})) err = retryBackend.Save(ctx, h, restic.NewByteReader([]byte{}, nil))
assertIsCanceled(t, err) assertIsCanceled(t, err)
err = retryBackend.Remove(ctx, h) err = retryBackend.Remove(ctx, h)
assertIsCanceled(t, err) assertIsCanceled(t, err)

View file

@ -2,6 +2,7 @@ package dryrun
import ( import (
"context" "context"
"hash"
"io" "io"
"github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/debug"
@ -58,6 +59,10 @@ func (be *Backend) Close() error {
return be.b.Close() return be.b.Close()
} }
func (be *Backend) Hasher() hash.Hash {
return be.b.Hasher()
}
func (be *Backend) IsNotExist(err error) bool { func (be *Backend) IsNotExist(err error) bool {
return be.b.IsNotExist(err) return be.b.IsNotExist(err)
} }

View file

@ -71,7 +71,7 @@ func TestDry(t *testing.T) {
handle := restic.Handle{Type: restic.PackFile, Name: step.fname} handle := restic.Handle{Type: restic.PackFile, Name: step.fname}
switch step.op { switch step.op {
case "save": case "save":
err = step.be.Save(ctx, handle, restic.NewByteReader([]byte(step.content))) err = step.be.Save(ctx, handle, restic.NewByteReader([]byte(step.content), step.be.Hasher()))
case "test": case "test":
boolRes, err = step.be.Test(ctx, handle) boolRes, err = step.be.Test(ctx, handle)
if boolRes != (step.content != "") { if boolRes != (step.content != "") {

View file

@ -3,6 +3,7 @@ package gs
import ( import (
"context" "context"
"hash"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -188,6 +189,11 @@ func (be *Backend) Location() string {
return be.Join(be.bucketName, be.prefix) return be.Join(be.bucketName, be.prefix)
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (be *Backend) Hasher() hash.Hash {
return nil
}
// Path returns the path in the bucket that is used for this backend. // Path returns the path in the bucket that is used for this backend.
func (be *Backend) Path() string { func (be *Backend) Path() string {
return be.prefix return be.prefix

View file

@ -2,6 +2,7 @@ package local
import ( import (
"context" "context"
"hash"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
@ -77,6 +78,11 @@ func (b *Local) Location() string {
return b.Path return b.Path
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (b *Local) Hasher() hash.Hash {
return nil
}
// IsNotExist returns true if the error is caused by a non existing file. // IsNotExist returns true if the error is caused by a non existing file.
func (b *Local) IsNotExist(err error) bool { func (b *Local) IsNotExist(err error) bool {
return errors.Is(err, os.ErrNotExist) return errors.Is(err, os.ErrNotExist)

View file

@ -3,6 +3,7 @@ package mem
import ( import (
"bytes" "bytes"
"context" "context"
"hash"
"io" "io"
"io/ioutil" "io/ioutil"
"sync" "sync"
@ -214,6 +215,11 @@ func (be *MemoryBackend) Location() string {
return "RAM" return "RAM"
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (be *MemoryBackend) Hasher() hash.Hash {
return nil
}
// Delete removes all data in the backend. // Delete removes all data in the backend.
func (be *MemoryBackend) Delete(ctx context.Context) error { func (be *MemoryBackend) Delete(ctx context.Context) error {
be.m.Lock() be.m.Lock()

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"hash"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -109,6 +110,11 @@ func (b *Backend) Location() string {
return b.url.String() return b.url.String()
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (b *Backend) Hasher() hash.Hash {
return nil
}
// Save stores data in the backend at the handle. // Save stores data in the backend at the handle.
func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error { func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) error {
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {

View file

@ -3,6 +3,7 @@ package s3
import ( import (
"context" "context"
"fmt" "fmt"
"hash"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -250,6 +251,11 @@ func (be *Backend) Location() string {
return be.Join(be.cfg.Bucket, be.cfg.Prefix) return be.Join(be.cfg.Bucket, be.cfg.Prefix)
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (be *Backend) Hasher() hash.Hash {
return nil
}
// Path returns the path in the bucket that is used for this backend. // Path returns the path in the bucket that is used for this backend.
func (be *Backend) Path() string { func (be *Backend) Path() string {
return be.cfg.Prefix return be.cfg.Prefix

View file

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"context" "context"
"fmt" "fmt"
"hash"
"io" "io"
"os" "os"
"os/exec" "os/exec"
@ -240,6 +241,11 @@ func (r *SFTP) Location() string {
return r.p return r.p
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (r *SFTP) Hasher() hash.Hash {
return nil
}
// Join joins the given paths and cleans them afterwards. This always uses // Join joins the given paths and cleans them afterwards. This always uses
// forward slashes, which is required by sftp. // forward slashes, which is required by sftp.
func Join(parts ...string) string { func Join(parts ...string) string {

View file

@ -3,6 +3,7 @@ package swift
import ( import (
"context" "context"
"fmt" "fmt"
"hash"
"io" "io"
"net/http" "net/http"
"path" "path"
@ -115,6 +116,11 @@ func (be *beSwift) Location() string {
return be.container return be.container
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (be *beSwift) Hasher() hash.Hash {
return nil
}
// Load runs fn with a reader that yields the contents of the file at h at the // Load runs fn with a reader that yields the contents of the file at h at the
// given offset. // given offset.
func (be *beSwift) Load(ctx context.Context, h restic.Handle, length int, offset int64, fn func(rd io.Reader) error) error { func (be *beSwift) Load(ctx context.Context, h restic.Handle, length int, offset int64, fn func(rd io.Reader) error) error {

View file

@ -14,7 +14,7 @@ func saveRandomFile(t testing.TB, be restic.Backend, length int) ([]byte, restic
data := test.Random(23, length) data := test.Random(23, length)
id := restic.Hash(data) id := restic.Hash(data)
handle := restic.Handle{Type: restic.PackFile, Name: id.String()} handle := restic.Handle{Type: restic.PackFile, Name: id.String()}
err := be.Save(context.TODO(), handle, restic.NewByteReader(data)) err := be.Save(context.TODO(), handle, restic.NewByteReader(data, be.Hasher()))
if err != nil { if err != nil {
t.Fatalf("Save() error: %+v", err) t.Fatalf("Save() error: %+v", err)
} }
@ -148,7 +148,7 @@ func (s *Suite) BenchmarkSave(t *testing.B) {
id := restic.Hash(data) id := restic.Hash(data)
handle := restic.Handle{Type: restic.PackFile, Name: id.String()} handle := restic.Handle{Type: restic.PackFile, Name: id.String()}
rd := restic.NewByteReader(data) rd := restic.NewByteReader(data, be.Hasher())
t.SetBytes(int64(length)) t.SetBytes(int64(length))
t.ResetTimer() t.ResetTimer()

View file

@ -84,7 +84,7 @@ func (s *Suite) TestConfig(t *testing.T) {
t.Fatalf("did not get expected error for non-existing config") t.Fatalf("did not get expected error for non-existing config")
} }
err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, restic.NewByteReader([]byte(testString))) err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, restic.NewByteReader([]byte(testString), b.Hasher()))
if err != nil { if err != nil {
t.Fatalf("Save() error: %+v", err) t.Fatalf("Save() error: %+v", err)
} }
@ -134,7 +134,7 @@ func (s *Suite) TestLoad(t *testing.T) {
id := restic.Hash(data) id := restic.Hash(data)
handle := restic.Handle{Type: restic.PackFile, Name: id.String()} handle := restic.Handle{Type: restic.PackFile, Name: id.String()}
err = b.Save(context.TODO(), handle, restic.NewByteReader(data)) err = b.Save(context.TODO(), handle, restic.NewByteReader(data, b.Hasher()))
if err != nil { if err != nil {
t.Fatalf("Save() error: %+v", err) t.Fatalf("Save() error: %+v", err)
} }
@ -253,7 +253,7 @@ func (s *Suite) TestList(t *testing.T) {
data := test.Random(rand.Int(), rand.Intn(100)+55) data := test.Random(rand.Int(), rand.Intn(100)+55)
id := restic.Hash(data) id := restic.Hash(data)
h := restic.Handle{Type: restic.PackFile, Name: id.String()} h := restic.Handle{Type: restic.PackFile, Name: id.String()}
err := b.Save(context.TODO(), h, restic.NewByteReader(data)) err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -343,7 +343,7 @@ func (s *Suite) TestListCancel(t *testing.T) {
data := []byte(fmt.Sprintf("random test blob %v", i)) data := []byte(fmt.Sprintf("random test blob %v", i))
id := restic.Hash(data) id := restic.Hash(data)
h := restic.Handle{Type: restic.PackFile, Name: id.String()} h := restic.Handle{Type: restic.PackFile, Name: id.String()}
err := b.Save(context.TODO(), h, restic.NewByteReader(data)) err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -447,6 +447,7 @@ type errorCloser struct {
io.ReadSeeker io.ReadSeeker
l int64 l int64
t testing.TB t testing.TB
h []byte
} }
func (ec errorCloser) Close() error { func (ec errorCloser) Close() error {
@ -458,6 +459,10 @@ func (ec errorCloser) Length() int64 {
return ec.l return ec.l
} }
func (ec errorCloser) Hash() []byte {
return ec.h
}
func (ec errorCloser) Rewind() error { func (ec errorCloser) Rewind() error {
_, err := ec.ReadSeeker.Seek(0, io.SeekStart) _, err := ec.ReadSeeker.Seek(0, io.SeekStart)
return err return err
@ -486,7 +491,7 @@ func (s *Suite) TestSave(t *testing.T) {
Type: restic.PackFile, Type: restic.PackFile,
Name: fmt.Sprintf("%s-%d", id, i), Name: fmt.Sprintf("%s-%d", id, i),
} }
err := b.Save(context.TODO(), h, restic.NewByteReader(data)) err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher()))
test.OK(t, err) test.OK(t, err)
buf, err := backend.LoadAll(context.TODO(), nil, b, h) buf, err := backend.LoadAll(context.TODO(), nil, b, h)
@ -538,7 +543,19 @@ func (s *Suite) TestSave(t *testing.T) {
// wrap the tempfile in an errorCloser, so we can detect if the backend // wrap the tempfile in an errorCloser, so we can detect if the backend
// closes the reader // closes the reader
err = b.Save(context.TODO(), h, errorCloser{t: t, l: int64(length), ReadSeeker: tmpfile}) var beHash []byte
if b.Hasher() != nil {
beHasher := b.Hasher()
// must never fail according to interface
_, _ = beHasher.Write(data)
beHash = beHasher.Sum(nil)
}
err = b.Save(context.TODO(), h, errorCloser{
t: t,
l: int64(length),
ReadSeeker: tmpfile,
h: beHash,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -583,7 +600,7 @@ func (s *Suite) TestSaveError(t *testing.T) {
// test that incomplete uploads fail // test that incomplete uploads fail
h := restic.Handle{Type: restic.PackFile, Name: id.String()} h := restic.Handle{Type: restic.PackFile, Name: id.String()}
err := b.Save(context.TODO(), h, &incompleteByteReader{ByteReader: *restic.NewByteReader(data)}) err := b.Save(context.TODO(), h, &incompleteByteReader{ByteReader: *restic.NewByteReader(data, b.Hasher())})
// try to delete possible leftovers // try to delete possible leftovers
_ = s.delayedRemove(t, b, h) _ = s.delayedRemove(t, b, h)
if err == nil { if err == nil {
@ -610,7 +627,7 @@ func (s *Suite) TestSaveFilenames(t *testing.T) {
for i, test := range filenameTests { for i, test := range filenameTests {
h := restic.Handle{Name: test.name, Type: restic.PackFile} h := restic.Handle{Name: test.name, Type: restic.PackFile}
err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(test.data))) err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(test.data), b.Hasher()))
if err != nil { if err != nil {
t.Errorf("test %d failed: Save() returned %+v", i, err) t.Errorf("test %d failed: Save() returned %+v", i, err)
continue continue
@ -647,7 +664,7 @@ var testStrings = []struct {
func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle { func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle {
id := restic.Hash(data) id := restic.Hash(data)
h := restic.Handle{Name: id.String(), Type: tpe} h := restic.Handle{Name: id.String(), Type: tpe}
err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(data))) err := b.Save(context.TODO(), h, restic.NewByteReader([]byte(data), b.Hasher()))
test.OK(t, err) test.OK(t, err)
return h return h
} }
@ -801,7 +818,7 @@ func (s *Suite) TestBackend(t *testing.T) {
test.Assert(t, !ok, "removed blob still present") test.Assert(t, !ok, "removed blob still present")
// create blob // create blob
err = b.Save(context.TODO(), h, restic.NewByteReader([]byte(ts.data))) err = b.Save(context.TODO(), h, restic.NewByteReader([]byte(ts.data), b.Hasher()))
test.OK(t, err) test.OK(t, err)
// list items // list items

View file

@ -26,7 +26,7 @@ func TestLoadAll(t *testing.T) {
id := restic.Hash(data) id := restic.Hash(data)
h := restic.Handle{Name: id.String(), Type: restic.PackFile} h := restic.Handle{Name: id.String(), Type: restic.PackFile}
err := b.Save(context.TODO(), h, restic.NewByteReader(data)) err := b.Save(context.TODO(), h, restic.NewByteReader(data, b.Hasher()))
rtest.OK(t, err) rtest.OK(t, err)
buf, err := backend.LoadAll(context.TODO(), buf, b, restic.Handle{Type: restic.PackFile, Name: id.String()}) buf, err := backend.LoadAll(context.TODO(), buf, b, restic.Handle{Type: restic.PackFile, Name: id.String()})
@ -47,7 +47,7 @@ func TestLoadAll(t *testing.T) {
func save(t testing.TB, be restic.Backend, buf []byte) restic.Handle { func save(t testing.TB, be restic.Backend, buf []byte) restic.Handle {
id := restic.Hash(buf) id := restic.Hash(buf)
h := restic.Handle{Name: id.String(), Type: restic.PackFile} h := restic.Handle{Name: id.String(), Type: restic.PackFile}
err := be.Save(context.TODO(), h, restic.NewByteReader(buf)) err := be.Save(context.TODO(), h, restic.NewByteReader(buf, be.Hasher()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -32,7 +32,7 @@ func loadAndCompare(t testing.TB, be restic.Backend, h restic.Handle, data []byt
} }
func save(t testing.TB, be restic.Backend, h restic.Handle, data []byte) { func save(t testing.TB, be restic.Backend, h restic.Handle, data []byte) {
err := be.Save(context.TODO(), h, restic.NewByteReader(data)) err := be.Save(context.TODO(), h, restic.NewByteReader(data, be.Hasher()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -16,6 +16,7 @@ import (
"github.com/restic/restic/internal/archiver" "github.com/restic/restic/internal/archiver"
"github.com/restic/restic/internal/checker" "github.com/restic/restic/internal/checker"
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/hashing"
"github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/repository"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/test" "github.com/restic/restic/internal/test"
@ -218,10 +219,16 @@ func TestModifiedIndex(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
}() }()
wr := io.Writer(tmpfile)
var hw *hashing.Writer
if repo.Backend().Hasher() != nil {
hw = hashing.NewWriter(wr, repo.Backend().Hasher())
wr = hw
}
// read the file from the backend // read the file from the backend
err = repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error { err = repo.Backend().Load(context.TODO(), h, 0, 0, func(rd io.Reader) error {
_, err := io.Copy(tmpfile, rd) _, err := io.Copy(wr, rd)
return err return err
}) })
test.OK(t, err) test.OK(t, err)
@ -233,7 +240,11 @@ func TestModifiedIndex(t *testing.T) {
Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd",
} }
rd, err := restic.NewFileReader(tmpfile) var hash []byte
if hw != nil {
hash = hw.Sum(nil)
}
rd, err := restic.NewFileReader(tmpfile, hash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -39,7 +39,7 @@ func TestLimitBackendSave(t *testing.T) {
limiter := NewStaticLimiter(42*1024, 42*1024) limiter := NewStaticLimiter(42*1024, 42*1024)
limbe := LimitBackend(be, limiter) limbe := LimitBackend(be, limiter)
rd := restic.NewByteReader(data) rd := restic.NewByteReader(data, nil)
err := limbe.Save(context.TODO(), testHandle, rd) err := limbe.Save(context.TODO(), testHandle, rd)
rtest.OK(t, err) rtest.OK(t, err)
} }

View file

@ -2,6 +2,7 @@ package mock
import ( import (
"context" "context"
"hash"
"io" "io"
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
@ -20,6 +21,7 @@ type Backend struct {
TestFn func(ctx context.Context, h restic.Handle) (bool, error) TestFn func(ctx context.Context, h restic.Handle) (bool, error)
DeleteFn func(ctx context.Context) error DeleteFn func(ctx context.Context) error
LocationFn func() string LocationFn func() string
HasherFn func() hash.Hash
} }
// NewBackend returns new mock Backend instance // NewBackend returns new mock Backend instance
@ -46,6 +48,15 @@ func (m *Backend) Location() string {
return m.LocationFn() return m.LocationFn()
} }
// Hasher may return a hash function for calculating a content hash for the backend
func (m *Backend) Hasher() hash.Hash {
if m.HasherFn == nil {
return nil
}
return m.HasherFn()
}
// IsNotExist returns true if the error is caused by a missing file. // IsNotExist returns true if the error is caused by a missing file.
func (m *Backend) IsNotExist(err error) bool { func (m *Backend) IsNotExist(err error) bool {
if m.IsNotExistFn == nil { if m.IsNotExistFn == nil {

View file

@ -127,7 +127,7 @@ func TestUnpackReadSeeker(t *testing.T) {
id := restic.Hash(packData) id := restic.Hash(packData)
handle := restic.Handle{Type: restic.PackFile, Name: id.String()} handle := restic.Handle{Type: restic.PackFile, Name: id.String()}
rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData, b.Hasher())))
verifyBlobs(t, bufs, k, restic.ReaderAt(context.TODO(), b, handle), packSize) verifyBlobs(t, bufs, k, restic.ReaderAt(context.TODO(), b, handle), packSize)
} }
@ -140,6 +140,6 @@ func TestShortPack(t *testing.T) {
id := restic.Hash(packData) id := restic.Hash(packData)
handle := restic.Handle{Type: restic.PackFile, Name: id.String()} handle := restic.Handle{Type: restic.PackFile, Name: id.String()}
rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData))) rtest.OK(t, b.Save(context.TODO(), handle, restic.NewByteReader(packData, b.Hasher())))
verifyBlobs(t, bufs, k, restic.ReaderAt(context.TODO(), b, handle), packSize) verifyBlobs(t, bufs, k, restic.ReaderAt(context.TODO(), b, handle), packSize)
} }

View file

@ -279,7 +279,7 @@ func AddKey(ctx context.Context, s *Repository, password, username, hostname str
Name: restic.Hash(buf).String(), Name: restic.Hash(buf).String(),
} }
err = s.be.Save(ctx, h, restic.NewByteReader(buf)) err = s.be.Save(ctx, h, restic.NewByteReader(buf, s.be.Hasher()))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -2,6 +2,8 @@ package repository
import ( import (
"context" "context"
"hash"
"io"
"os" "os"
"sync" "sync"
@ -20,12 +22,14 @@ import (
// Saver implements saving data in a backend. // Saver implements saving data in a backend.
type Saver interface { type Saver interface {
Save(context.Context, restic.Handle, restic.RewindReader) error Save(context.Context, restic.Handle, restic.RewindReader) error
Hasher() hash.Hash
} }
// Packer holds a pack.Packer together with a hash writer. // Packer holds a pack.Packer together with a hash writer.
type Packer struct { type Packer struct {
*pack.Packer *pack.Packer
hw *hashing.Writer hw *hashing.Writer
beHw *hashing.Writer
tmpfile *os.File tmpfile *os.File
} }
@ -71,10 +75,19 @@ func (r *packerManager) findPacker() (packer *Packer, err error) {
return nil, errors.Wrap(err, "fs.TempFile") return nil, errors.Wrap(err, "fs.TempFile")
} }
hw := hashing.NewWriter(tmpfile, sha256.New()) w := io.Writer(tmpfile)
beHasher := r.be.Hasher()
var beHw *hashing.Writer
if beHasher != nil {
beHw = hashing.NewWriter(w, beHasher)
w = beHw
}
hw := hashing.NewWriter(w, sha256.New())
p := pack.NewPacker(r.key, hw) p := pack.NewPacker(r.key, hw)
packer = &Packer{ packer = &Packer{
Packer: p, Packer: p,
beHw: beHw,
hw: hw, hw: hw,
tmpfile: tmpfile, tmpfile: tmpfile,
} }
@ -101,8 +114,11 @@ func (r *Repository) savePacker(ctx context.Context, t restic.BlobType, p *Packe
id := restic.IDFromHash(p.hw.Sum(nil)) id := restic.IDFromHash(p.hw.Sum(nil))
h := restic.Handle{Type: restic.PackFile, Name: id.String()} h := restic.Handle{Type: restic.PackFile, Name: id.String()}
var beHash []byte
rd, err := restic.NewFileReader(p.tmpfile) if p.beHw != nil {
beHash = p.beHw.Sum(nil)
}
rd, err := restic.NewFileReader(p.tmpfile, beHash)
if err != nil { if err != nil {
return err return err
} }

View file

@ -33,11 +33,11 @@ func min(a, b int) int {
return b return b
} }
func saveFile(t testing.TB, be Saver, length int, f *os.File, id restic.ID) { func saveFile(t testing.TB, be Saver, length int, f *os.File, id restic.ID, hash []byte) {
h := restic.Handle{Type: restic.PackFile, Name: id.String()} h := restic.Handle{Type: restic.PackFile, Name: id.String()}
t.Logf("save file %v", h) t.Logf("save file %v", h)
rd, err := restic.NewFileReader(f) rd, err := restic.NewFileReader(f, hash)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -90,7 +90,11 @@ func fillPacks(t testing.TB, rnd *rand.Rand, be Saver, pm *packerManager, buf []
} }
packID := restic.IDFromHash(packer.hw.Sum(nil)) packID := restic.IDFromHash(packer.hw.Sum(nil))
saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) var beHash []byte
if packer.beHw != nil {
beHash = packer.beHw.Sum(nil)
}
saveFile(t, be, int(packer.Size()), packer.tmpfile, packID, beHash)
} }
return bytes return bytes
@ -106,7 +110,11 @@ func flushRemainingPacks(t testing.TB, be Saver, pm *packerManager) (bytes int)
bytes += int(n) bytes += int(n)
packID := restic.IDFromHash(packer.hw.Sum(nil)) packID := restic.IDFromHash(packer.hw.Sum(nil))
saveFile(t, be, int(packer.Size()), packer.tmpfile, packID) var beHash []byte
if packer.beHw != nil {
beHash = packer.beHw.Sum(nil)
}
saveFile(t, be, int(packer.Size()), packer.tmpfile, packID, beHash)
} }
} }

View file

@ -322,7 +322,7 @@ func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []by
} }
h := restic.Handle{Type: t, Name: id.String()} h := restic.Handle{Type: t, Name: id.String()}
err = r.be.Save(ctx, h, restic.NewByteReader(ciphertext)) err = r.be.Save(ctx, h, restic.NewByteReader(ciphertext, r.be.Hasher()))
if err != nil { if err != nil {
debug.Log("error saving blob %v: %v", h, err) debug.Log("error saving blob %v: %v", h, err)
return restic.ID{}, err return restic.ID{}, err

View file

@ -2,6 +2,7 @@ package restic
import ( import (
"context" "context"
"hash"
"io" "io"
) )
@ -17,6 +18,9 @@ type Backend interface {
// repository. // repository.
Location() string Location() string
// Hasher may return a hash function for calculating a content hash for the backend
Hasher() hash.Hash
// Test a boolean value whether a File with the name and type exists. // Test a boolean value whether a File with the name and type exists.
Test(ctx context.Context, h Handle) (bool, error) Test(ctx context.Context, h Handle) (bool, error)

View file

@ -2,6 +2,7 @@ package restic
import ( import (
"bytes" "bytes"
"hash"
"io" "io"
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
@ -18,12 +19,16 @@ type RewindReader interface {
// Length returns the number of bytes that can be read from the Reader // Length returns the number of bytes that can be read from the Reader
// after calling Rewind. // after calling Rewind.
Length() int64 Length() int64
// Hash return a hash of the data if requested by the backed.
Hash() []byte
} }
// ByteReader implements a RewindReader for a byte slice. // ByteReader implements a RewindReader for a byte slice.
type ByteReader struct { type ByteReader struct {
*bytes.Reader *bytes.Reader
Len int64 Len int64
hash []byte
} }
// Rewind restarts the reader from the beginning of the data. // Rewind restarts the reader from the beginning of the data.
@ -38,14 +43,26 @@ func (b *ByteReader) Length() int64 {
return b.Len return b.Len
} }
// Hash return a hash of the data if requested by the backed.
func (b *ByteReader) Hash() []byte {
return b.hash
}
// statically ensure that *ByteReader implements RewindReader. // statically ensure that *ByteReader implements RewindReader.
var _ RewindReader = &ByteReader{} var _ RewindReader = &ByteReader{}
// NewByteReader prepares a ByteReader that can then be used to read buf. // NewByteReader prepares a ByteReader that can then be used to read buf.
func NewByteReader(buf []byte) *ByteReader { func NewByteReader(buf []byte, hasher hash.Hash) *ByteReader {
var hash []byte
if hasher != nil {
// must never fail according to interface
_, _ = hasher.Write(buf)
hash = hasher.Sum(nil)
}
return &ByteReader{ return &ByteReader{
Reader: bytes.NewReader(buf), Reader: bytes.NewReader(buf),
Len: int64(len(buf)), Len: int64(len(buf)),
hash: hash,
} }
} }
@ -55,7 +72,8 @@ var _ RewindReader = &FileReader{}
// FileReader implements a RewindReader for an open file. // FileReader implements a RewindReader for an open file.
type FileReader struct { type FileReader struct {
io.ReadSeeker io.ReadSeeker
Len int64 Len int64
hash []byte
} }
// Rewind seeks to the beginning of the file. // Rewind seeks to the beginning of the file.
@ -69,8 +87,13 @@ func (f *FileReader) Length() int64 {
return f.Len return f.Len
} }
// Hash return a hash of the data if requested by the backed.
func (f *FileReader) Hash() []byte {
return f.hash
}
// NewFileReader wraps f in a *FileReader. // NewFileReader wraps f in a *FileReader.
func NewFileReader(f io.ReadSeeker) (*FileReader, error) { func NewFileReader(f io.ReadSeeker, hash []byte) (*FileReader, error) {
pos, err := f.Seek(0, io.SeekEnd) pos, err := f.Seek(0, io.SeekEnd)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Seek") return nil, errors.Wrap(err, "Seek")
@ -79,6 +102,7 @@ func NewFileReader(f io.ReadSeeker) (*FileReader, error) {
fr := &FileReader{ fr := &FileReader{
ReadSeeker: f, ReadSeeker: f,
Len: pos, Len: pos,
hash: hash,
} }
err = fr.Rewind() err = fr.Rewind()

View file

@ -2,6 +2,8 @@ package restic
import ( import (
"bytes" "bytes"
"crypto/md5"
"hash"
"io" "io"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
@ -15,10 +17,12 @@ import (
func TestByteReader(t *testing.T) { func TestByteReader(t *testing.T) {
buf := []byte("foobar") buf := []byte("foobar")
fn := func() RewindReader { for _, hasher := range []hash.Hash{nil, md5.New()} {
return NewByteReader(buf) fn := func() RewindReader {
return NewByteReader(buf, hasher)
}
testRewindReader(t, fn, buf)
} }
testRewindReader(t, fn, buf)
} }
func TestFileReader(t *testing.T) { func TestFileReader(t *testing.T) {
@ -28,7 +32,7 @@ func TestFileReader(t *testing.T) {
defer cleanup() defer cleanup()
filename := filepath.Join(d, "file-reader-test") filename := filepath.Join(d, "file-reader-test")
err := ioutil.WriteFile(filename, []byte("foobar"), 0600) err := ioutil.WriteFile(filename, buf, 0600)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -45,15 +49,23 @@ func TestFileReader(t *testing.T) {
} }
}() }()
fn := func() RewindReader { for _, hasher := range []hash.Hash{nil, md5.New()} {
rd, err := NewFileReader(f) fn := func() RewindReader {
if err != nil { var hash []byte
t.Fatal(err) if hasher != nil {
// must never fail according to interface
_, _ = hasher.Write(buf)
hash = hasher.Sum(nil)
}
rd, err := NewFileReader(f, hash)
if err != nil {
t.Fatal(err)
}
return rd
} }
return rd
}
testRewindReader(t, fn, buf) testRewindReader(t, fn, buf)
}
} }
func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) { func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) {
@ -104,6 +116,15 @@ func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) {
if rd.Length() != int64(len(data)) { if rd.Length() != int64(len(data)) {
t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length()) t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length())
} }
if rd.Hash() != nil {
hasher := md5.New()
// must never fail according to interface
_, _ = hasher.Write(buf2)
if !bytes.Equal(rd.Hash(), hasher.Sum(nil)) {
t.Fatal("hash does not match data")
}
}
}, },
func(t testing.TB, rd RewindReader, data []byte) { func(t testing.TB, rd RewindReader, data []byte) {
// read first bytes // read first bytes