Merge pull request #994 from restic/add-context

Add context.Context to the backend
This commit is contained in:
Alexander Neumann 2017-06-07 19:11:56 +02:00
commit f90ce23f30
91 changed files with 885 additions and 857 deletions

View file

@ -2,6 +2,7 @@ package main
import ( import (
"bufio" "bufio"
"context"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -263,7 +264,7 @@ func readBackupFromStdin(opts BackupOptions, gopts GlobalOptions, args []string)
return err return err
} }
err = repo.LoadIndex() err = repo.LoadIndex(context.TODO())
if err != nil { if err != nil {
return err return err
} }
@ -274,7 +275,7 @@ func readBackupFromStdin(opts BackupOptions, gopts GlobalOptions, args []string)
Hostname: opts.Hostname, Hostname: opts.Hostname,
} }
_, id, err := r.Archive(opts.StdinFilename, os.Stdin, newArchiveStdinProgress(gopts)) _, id, err := r.Archive(context.TODO(), opts.StdinFilename, os.Stdin, newArchiveStdinProgress(gopts))
if err != nil { if err != nil {
return err return err
} }
@ -372,7 +373,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error {
return err return err
} }
err = repo.LoadIndex() err = repo.LoadIndex(context.TODO())
if err != nil { if err != nil {
return err return err
} }
@ -391,7 +392,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error {
// Find last snapshot to set it as parent, if not already set // Find last snapshot to set it as parent, if not already set
if !opts.Force && parentSnapshotID == nil { if !opts.Force && parentSnapshotID == nil {
id, err := restic.FindLatestSnapshot(repo, target, opts.Tags, opts.Hostname) id, err := restic.FindLatestSnapshot(context.TODO(), repo, target, opts.Tags, opts.Hostname)
if err == nil { if err == nil {
parentSnapshotID = &id parentSnapshotID = &id
} else if err != restic.ErrNoSnapshotFound { } else if err != restic.ErrNoSnapshotFound {
@ -489,7 +490,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error {
Warnf("%s\rwarning for %s: %v\n", ClearLine(), dir, err) Warnf("%s\rwarning for %s: %v\n", ClearLine(), dir, err)
} }
_, id, err := arch.Snapshot(newArchiveProgress(gopts, stat), target, opts.Tags, opts.Hostname, parentSnapshotID) _, id, err := arch.Snapshot(context.TODO(), newArchiveProgress(gopts, stat), target, opts.Tags, opts.Hostname, parentSnapshotID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -73,7 +74,7 @@ func runCat(gopts GlobalOptions, args []string) error {
fmt.Println(string(buf)) fmt.Println(string(buf))
return nil return nil
case "index": case "index":
buf, err := repo.LoadAndDecrypt(restic.IndexFile, id) buf, err := repo.LoadAndDecrypt(context.TODO(), restic.IndexFile, id)
if err != nil { if err != nil {
return err return err
} }
@ -83,7 +84,7 @@ func runCat(gopts GlobalOptions, args []string) error {
case "snapshot": case "snapshot":
sn := &restic.Snapshot{} sn := &restic.Snapshot{}
err = repo.LoadJSONUnpacked(restic.SnapshotFile, id, sn) err = repo.LoadJSONUnpacked(context.TODO(), restic.SnapshotFile, id, sn)
if err != nil { if err != nil {
return err return err
} }
@ -98,7 +99,7 @@ func runCat(gopts GlobalOptions, args []string) error {
return nil return nil
case "key": case "key":
h := restic.Handle{Type: restic.KeyFile, Name: id.String()} h := restic.Handle{Type: restic.KeyFile, Name: id.String()}
buf, err := backend.LoadAll(repo.Backend(), h) buf, err := backend.LoadAll(context.TODO(), repo.Backend(), h)
if err != nil { if err != nil {
return err return err
} }
@ -125,7 +126,7 @@ func runCat(gopts GlobalOptions, args []string) error {
fmt.Println(string(buf)) fmt.Println(string(buf))
return nil return nil
case "lock": case "lock":
lock, err := restic.LoadLock(repo, id) lock, err := restic.LoadLock(context.TODO(), repo, id)
if err != nil { if err != nil {
return err return err
} }
@ -141,7 +142,7 @@ func runCat(gopts GlobalOptions, args []string) error {
} }
// load index, handle all the other types // load index, handle all the other types
err = repo.LoadIndex() err = repo.LoadIndex(context.TODO())
if err != nil { if err != nil {
return err return err
} }
@ -149,7 +150,7 @@ func runCat(gopts GlobalOptions, args []string) error {
switch tpe { switch tpe {
case "pack": case "pack":
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
buf, err := backend.LoadAll(repo.Backend(), h) buf, err := backend.LoadAll(context.TODO(), repo.Backend(), h)
if err != nil { if err != nil {
return err return err
} }
@ -171,7 +172,7 @@ func runCat(gopts GlobalOptions, args []string) error {
blob := list[0] blob := list[0]
buf := make([]byte, blob.Length) buf := make([]byte, blob.Length)
n, err := repo.LoadBlob(t, id, buf) n, err := repo.LoadBlob(context.TODO(), t, id, buf)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"time" "time"
@ -92,7 +93,7 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
chkr := checker.New(repo) chkr := checker.New(repo)
Verbosef("Load indexes\n") Verbosef("Load indexes\n")
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
dupFound := false dupFound := false
for _, hint := range hints { for _, hint := range hints {
@ -113,14 +114,11 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
return errors.Fatal("LoadIndex returned errors") return errors.Fatal("LoadIndex returned errors")
} }
done := make(chan struct{})
defer close(done)
errorsFound := false errorsFound := false
errChan := make(chan error) errChan := make(chan error)
Verbosef("Check all packs\n") Verbosef("Check all packs\n")
go chkr.Packs(errChan, done) go chkr.Packs(context.TODO(), errChan)
for err := range errChan { for err := range errChan {
errorsFound = true errorsFound = true
@ -129,7 +127,7 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
Verbosef("Check snapshots, trees and blobs\n") Verbosef("Check snapshots, trees and blobs\n")
errChan = make(chan error) errChan = make(chan error)
go chkr.Structure(errChan, done) go chkr.Structure(context.TODO(), errChan)
for err := range errChan { for err := range errChan {
errorsFound = true errorsFound = true
@ -156,7 +154,7 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
p := newReadProgress(gopts, restic.Stat{Blobs: chkr.CountPacks()}) p := newReadProgress(gopts, restic.Stat{Blobs: chkr.CountPacks()})
errChan := make(chan error) errChan := make(chan error)
go chkr.ReadData(p, errChan, done) go chkr.ReadData(context.TODO(), p, errChan)
for err := range errChan { for err := range errChan {
errorsFound = true errorsFound = true

View file

@ -1,8 +1,9 @@
// +build debug // xbuild debug
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -44,11 +45,8 @@ func prettyPrintJSON(wr io.Writer, item interface{}) error {
} }
func debugPrintSnapshots(repo *repository.Repository, wr io.Writer) error { func debugPrintSnapshots(repo *repository.Repository, wr io.Writer) error {
done := make(chan struct{}) for id := range repo.List(context.TODO(), restic.SnapshotFile) {
defer close(done) snapshot, err := restic.LoadSnapshot(context.TODO(), repo, id)
for id := range repo.List(restic.SnapshotFile, done) {
snapshot, err := restic.LoadSnapshot(repo, id)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "LoadSnapshot(%v): %v", id.Str(), err) fmt.Fprintf(os.Stderr, "LoadSnapshot(%v): %v", id.Str(), err)
continue continue
@ -83,15 +81,12 @@ type Blob struct {
} }
func printPacks(repo *repository.Repository, wr io.Writer) error { func printPacks(repo *repository.Repository, wr io.Writer) error {
done := make(chan struct{}) f := func(ctx context.Context, job worker.Job) (interface{}, error) {
defer close(done)
f := func(job worker.Job, done <-chan struct{}) (interface{}, error) {
name := job.Data.(string) name := job.Data.(string)
h := restic.Handle{Type: restic.DataFile, Name: name} h := restic.Handle{Type: restic.DataFile, Name: name}
blobInfo, err := repo.Backend().Stat(h) blobInfo, err := repo.Backend().Stat(ctx, h)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -106,10 +101,10 @@ func printPacks(repo *repository.Repository, wr io.Writer) error {
jobCh := make(chan worker.Job) jobCh := make(chan worker.Job)
resCh := make(chan worker.Job) resCh := make(chan worker.Job)
wp := worker.New(dumpPackWorkers, f, jobCh, resCh) wp := worker.New(context.TODO(), dumpPackWorkers, f, jobCh, resCh)
go func() { go func() {
for name := range repo.Backend().List(restic.DataFile, done) { for name := range repo.Backend().List(context.TODO(), restic.DataFile) {
jobCh <- worker.Job{Data: name} jobCh <- worker.Job{Data: name}
} }
close(jobCh) close(jobCh)
@ -146,13 +141,10 @@ func printPacks(repo *repository.Repository, wr io.Writer) error {
} }
func dumpIndexes(repo restic.Repository) error { func dumpIndexes(repo restic.Repository) error {
done := make(chan struct{}) for id := range repo.List(context.TODO(), restic.IndexFile) {
defer close(done)
for id := range repo.List(restic.IndexFile, done) {
fmt.Printf("index_id: %v\n", id) fmt.Printf("index_id: %v\n", id)
idx, err := repository.LoadIndex(repo, id) idx, err := repository.LoadIndex(context.TODO(), repo, id)
if err != nil { if err != nil {
return err return err
} }
@ -184,7 +176,7 @@ func runDump(gopts GlobalOptions, args []string) error {
} }
} }
err = repo.LoadIndex() err = repo.LoadIndex(context.TODO())
if err != nil { if err != nil {
return err return err
} }

View file

@ -187,7 +187,7 @@ func (f *Finder) findInTree(treeID restic.ID, prefix string) error {
debug.Log("%v checking tree %v\n", prefix, treeID.Str()) debug.Log("%v checking tree %v\n", prefix, treeID.Str())
tree, err := f.repo.LoadTree(treeID) tree, err := f.repo.LoadTree(context.TODO(), treeID)
if err != nil { if err != nil {
return err return err
} }
@ -283,7 +283,7 @@ func runFind(opts FindOptions, gopts GlobalOptions, args []string) error {
} }
} }
if err = repo.LoadIndex(); err != nil { if err = repo.LoadIndex(context.TODO()); err != nil {
return err return err
} }

View file

@ -97,7 +97,7 @@ func runForget(opts ForgetOptions, gopts GlobalOptions, args []string) error {
// When explicit snapshots args are given, remove them immediately. // When explicit snapshots args are given, remove them immediately.
if !opts.DryRun { if !opts.DryRun {
h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()} h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()}
if err = repo.Backend().Remove(h); err != nil { if err = repo.Backend().Remove(context.TODO(), h); err != nil {
return err return err
} }
Verbosef("removed snapshot %v\n", sn.ID().Str()) Verbosef("removed snapshot %v\n", sn.ID().Str())
@ -167,7 +167,7 @@ func runForget(opts ForgetOptions, gopts GlobalOptions, args []string) error {
if !opts.DryRun { if !opts.DryRun {
for _, sn := range remove { for _, sn := range remove {
h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()} h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()}
err = repo.Backend().Remove(h) err = repo.Backend().Remove(context.TODO(), h)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"restic/errors" "restic/errors"
"restic/repository" "restic/repository"
@ -43,7 +44,7 @@ func runInit(gopts GlobalOptions, args []string) error {
s := repository.New(be) s := repository.New(be)
err = s.Init(gopts.password) err = s.Init(context.TODO(), gopts.password)
if err != nil { if err != nil {
return errors.Fatalf("create key in backend at %s failed: %v\n", gopts.Repo, err) return errors.Fatalf("create key in backend at %s failed: %v\n", gopts.Repo, err)
} }

View file

@ -30,8 +30,8 @@ func listKeys(ctx context.Context, s *repository.Repository) error {
tab.Header = fmt.Sprintf(" %-10s %-10s %-10s %s", "ID", "User", "Host", "Created") tab.Header = fmt.Sprintf(" %-10s %-10s %-10s %s", "ID", "User", "Host", "Created")
tab.RowFormat = "%s%-10s %-10s %-10s %s" tab.RowFormat = "%s%-10s %-10s %-10s %s"
for id := range s.List(restic.KeyFile, ctx.Done()) { for id := range s.List(ctx, restic.KeyFile) {
k, err := repository.LoadKey(s, id.String()) k, err := repository.LoadKey(ctx, s, id.String())
if err != nil { if err != nil {
Warnf("LoadKey() failed: %v\n", err) Warnf("LoadKey() failed: %v\n", err)
continue continue
@ -69,7 +69,7 @@ func addKey(gopts GlobalOptions, repo *repository.Repository) error {
return err return err
} }
id, err := repository.AddKey(repo, pw, repo.Key()) id, err := repository.AddKey(context.TODO(), repo, pw, repo.Key())
if err != nil { if err != nil {
return errors.Fatalf("creating new key failed: %v\n", err) return errors.Fatalf("creating new key failed: %v\n", err)
} }
@ -85,7 +85,7 @@ func deleteKey(repo *repository.Repository, name string) error {
} }
h := restic.Handle{Type: restic.KeyFile, Name: name} h := restic.Handle{Type: restic.KeyFile, Name: name}
err := repo.Backend().Remove(h) err := repo.Backend().Remove(context.TODO(), h)
if err != nil { if err != nil {
return err return err
} }
@ -100,13 +100,13 @@ func changePassword(gopts GlobalOptions, repo *repository.Repository) error {
return err return err
} }
id, err := repository.AddKey(repo, pw, repo.Key()) id, err := repository.AddKey(context.TODO(), repo, pw, repo.Key())
if err != nil { if err != nil {
return errors.Fatalf("creating new key failed: %v\n", err) return errors.Fatalf("creating new key failed: %v\n", err)
} }
h := restic.Handle{Type: restic.KeyFile, Name: repo.KeyName()} h := restic.Handle{Type: restic.KeyFile, Name: repo.KeyName()}
err = repo.Backend().Remove(h) err = repo.Backend().Remove(context.TODO(), h)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"restic" "restic"
"restic/errors" "restic/errors"
@ -55,7 +56,7 @@ func runList(opts GlobalOptions, args []string) error {
case "locks": case "locks":
t = restic.LockFile t = restic.LockFile
case "blobs": case "blobs":
idx, err := index.Load(repo, nil) idx, err := index.Load(context.TODO(), repo, nil)
if err != nil { if err != nil {
return err return err
} }
@ -71,7 +72,7 @@ func runList(opts GlobalOptions, args []string) error {
return errors.Fatal("invalid type") return errors.Fatal("invalid type")
} }
for id := range repo.List(t, nil) { for id := range repo.List(context.TODO(), t) {
Printf("%s\n", id) Printf("%s\n", id)
} }

View file

@ -46,7 +46,7 @@ func init() {
} }
func printTree(repo *repository.Repository, id *restic.ID, prefix string) error { func printTree(repo *repository.Repository, id *restic.ID, prefix string) error {
tree, err := repo.LoadTree(*id) tree, err := repo.LoadTree(context.TODO(), *id)
if err != nil { if err != nil {
return err return err
} }
@ -74,7 +74,7 @@ func runLs(opts LsOptions, gopts GlobalOptions, args []string) error {
return err return err
} }
if err = repo.LoadIndex(); err != nil { if err = repo.LoadIndex(context.TODO()); err != nil {
return err return err
} }

View file

@ -4,6 +4,7 @@
package main package main
import ( import (
"context"
"os" "os"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -64,7 +65,7 @@ func mount(opts MountOptions, gopts GlobalOptions, mountpoint string) error {
return err return err
} }
err = repo.LoadIndex() err = repo.LoadIndex(context.TODO())
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"restic" "restic"
"restic/debug" "restic/debug"
@ -76,14 +75,13 @@ func runPrune(gopts GlobalOptions) error {
} }
func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
err := repo.LoadIndex() ctx := gopts.ctx
err := repo.LoadIndex(ctx)
if err != nil { if err != nil {
return err return err
} }
ctx, cancel := context.WithCancel(gopts.ctx)
defer cancel()
var stats struct { var stats struct {
blobs int blobs int
packs int packs int
@ -92,14 +90,14 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
} }
Verbosef("counting files in repo\n") Verbosef("counting files in repo\n")
for range repo.List(restic.DataFile, ctx.Done()) { for range repo.List(ctx, restic.DataFile) {
stats.packs++ stats.packs++
} }
Verbosef("building new index for repo\n") Verbosef("building new index for repo\n")
bar := newProgressMax(!gopts.Quiet, uint64(stats.packs), "packs") bar := newProgressMax(!gopts.Quiet, uint64(stats.packs), "packs")
idx, err := index.New(repo, bar) idx, err := index.New(ctx, repo, bar)
if err != nil { if err != nil {
return err return err
} }
@ -135,7 +133,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
Verbosef("load all snapshots\n") Verbosef("load all snapshots\n")
// find referenced blobs // find referenced blobs
snapshots, err := restic.LoadAllSnapshots(repo) snapshots, err := restic.LoadAllSnapshots(ctx, repo)
if err != nil { if err != nil {
return err return err
} }
@ -152,7 +150,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
for _, sn := range snapshots { for _, sn := range snapshots {
debug.Log("process snapshot %v", sn.ID().Str()) debug.Log("process snapshot %v", sn.ID().Str())
err = restic.FindUsedBlobs(repo, *sn.Tree, usedBlobs, seenBlobs) err = restic.FindUsedBlobs(ctx, repo, *sn.Tree, usedBlobs, seenBlobs)
if err != nil { if err != nil {
return err return err
} }
@ -217,7 +215,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
if len(rewritePacks) != 0 { if len(rewritePacks) != 0 {
bar = newProgressMax(!gopts.Quiet, uint64(len(rewritePacks)), "packs rewritten") bar = newProgressMax(!gopts.Quiet, uint64(len(rewritePacks)), "packs rewritten")
bar.Start() bar.Start()
err = repository.Repack(repo, rewritePacks, usedBlobs, bar) err = repository.Repack(ctx, repo, rewritePacks, usedBlobs, bar)
if err != nil { if err != nil {
return err return err
} }
@ -229,7 +227,7 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error {
bar.Start() bar.Start()
for packID := range removePacks { for packID := range removePacks {
h := restic.Handle{Type: restic.DataFile, Name: packID.String()} h := restic.Handle{Type: restic.DataFile, Name: packID.String()}
err = repo.Backend().Remove(h) err = repo.Backend().Remove(ctx, h)
if err != nil { if err != nil {
Warnf("unable to remove file %v from the repository\n", packID.Str()) Warnf("unable to remove file %v from the repository\n", packID.Str())
} }

View file

@ -45,12 +45,12 @@ func rebuildIndex(ctx context.Context, repo restic.Repository) error {
Verbosef("counting files in repo\n") Verbosef("counting files in repo\n")
var packs uint64 var packs uint64
for range repo.List(restic.DataFile, ctx.Done()) { for range repo.List(ctx, restic.DataFile) {
packs++ packs++
} }
bar := newProgressMax(!globalOptions.Quiet, packs, "packs") bar := newProgressMax(!globalOptions.Quiet, packs, "packs")
idx, err := index.New(repo, bar) idx, err := index.New(ctx, repo, bar)
if err != nil { if err != nil {
return err return err
} }
@ -58,11 +58,11 @@ func rebuildIndex(ctx context.Context, repo restic.Repository) error {
Verbosef("finding old index files\n") Verbosef("finding old index files\n")
var supersedes restic.IDs var supersedes restic.IDs
for id := range repo.List(restic.IndexFile, ctx.Done()) { for id := range repo.List(ctx, restic.IndexFile) {
supersedes = append(supersedes, id) supersedes = append(supersedes, id)
} }
id, err := idx.Save(repo, supersedes) id, err := idx.Save(ctx, repo, supersedes)
if err != nil { if err != nil {
return err return err
} }
@ -72,7 +72,7 @@ func rebuildIndex(ctx context.Context, repo restic.Repository) error {
Verbosef("remove %d old index files\n", len(supersedes)) Verbosef("remove %d old index files\n", len(supersedes))
for _, id := range supersedes { for _, id := range supersedes {
if err := repo.Backend().Remove(restic.Handle{ if err := repo.Backend().Remove(ctx, restic.Handle{
Type: restic.IndexFile, Type: restic.IndexFile,
Name: id.String(), Name: id.String(),
}); err != nil { }); err != nil {

View file

@ -50,6 +50,8 @@ func init() {
} }
func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error { func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error {
ctx := gopts.ctx
if len(args) != 1 { if len(args) != 1 {
return errors.Fatal("no snapshot ID specified") return errors.Fatal("no snapshot ID specified")
} }
@ -79,7 +81,7 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error {
} }
} }
err = repo.LoadIndex() err = repo.LoadIndex(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -87,7 +89,7 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error {
var id restic.ID var id restic.ID
if snapshotIDString == "latest" { if snapshotIDString == "latest" {
id, err = restic.FindLatestSnapshot(repo, opts.Paths, opts.Tags, opts.Host) id, err = restic.FindLatestSnapshot(ctx, repo, opts.Paths, opts.Tags, opts.Host)
if err != nil { if err != nil {
Exitf(1, "latest snapshot for criteria not found: %v Paths:%v Host:%v", err, opts.Paths, opts.Host) Exitf(1, "latest snapshot for criteria not found: %v Paths:%v Host:%v", err, opts.Paths, opts.Host)
} }
@ -136,7 +138,7 @@ func runRestore(opts RestoreOptions, gopts GlobalOptions, args []string) error {
Verbosef("restoring %s to %s\n", res.Snapshot(), opts.Target) Verbosef("restoring %s to %s\n", res.Snapshot(), opts.Target)
err = res.RestoreTo(opts.Target) err = res.RestoreTo(ctx, opts.Target)
if totalErrors > 0 { if totalErrors > 0 {
Printf("There were %d errors\n", totalErrors) Printf("There were %d errors\n", totalErrors)
} }

View file

@ -76,7 +76,7 @@ func changeTags(repo *repository.Repository, sn *restic.Snapshot, setTags, addTa
} }
// Save the new snapshot. // Save the new snapshot.
id, err := repo.SaveJSONUnpacked(restic.SnapshotFile, sn) id, err := repo.SaveJSONUnpacked(context.TODO(), restic.SnapshotFile, sn)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -89,7 +89,7 @@ func changeTags(repo *repository.Repository, sn *restic.Snapshot, setTags, addTa
// Remove the old snapshot. // Remove the old snapshot.
h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()} h := restic.Handle{Type: restic.SnapshotFile, Name: sn.ID().String()}
if err = repo.Backend().Remove(h); err != nil { if err = repo.Backend().Remove(context.TODO(), h); err != nil {
return false, err return false, err
} }

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"restic" "restic"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -41,7 +42,7 @@ func runUnlock(opts UnlockOptions, gopts GlobalOptions) error {
fn = restic.RemoveAllLocks fn = restic.RemoveAllLocks
} }
err = fn(repo) err = fn(context.TODO(), repo)
if err != nil { if err != nil {
return err return err
} }

View file

@ -22,7 +22,7 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos
// Process all snapshot IDs given as arguments. // Process all snapshot IDs given as arguments.
for _, s := range snapshotIDs { for _, s := range snapshotIDs {
if s == "latest" { if s == "latest" {
id, err = restic.FindLatestSnapshot(repo, paths, tags, host) id, err = restic.FindLatestSnapshot(ctx, repo, paths, tags, host)
if err != nil { if err != nil {
Warnf("Ignoring %q, no snapshot matched given filter (Paths:%v Tags:%v Host:%v)\n", s, paths, tags, host) Warnf("Ignoring %q, no snapshot matched given filter (Paths:%v Tags:%v Host:%v)\n", s, paths, tags, host)
usedFilter = true usedFilter = true
@ -44,7 +44,7 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos
} }
for _, id := range ids.Uniq() { for _, id := range ids.Uniq() {
sn, err := restic.LoadSnapshot(repo, id) sn, err := restic.LoadSnapshot(ctx, repo, id)
if err != nil { if err != nil {
Warnf("Ignoring %q, could not load snapshot: %v\n", id, err) Warnf("Ignoring %q, could not load snapshot: %v\n", id, err)
continue continue
@ -58,8 +58,8 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos
return return
} }
for id := range repo.List(restic.SnapshotFile, ctx.Done()) { for id := range repo.List(ctx, restic.SnapshotFile) {
sn, err := restic.LoadSnapshot(repo, id) sn, err := restic.LoadSnapshot(ctx, repo, id)
if err != nil { if err != nil {
Warnf("Ignoring %q, could not load snapshot: %v\n", id, err) Warnf("Ignoring %q, could not load snapshot: %v\n", id, err)
continue continue

View file

@ -310,7 +310,7 @@ func OpenRepository(opts GlobalOptions) (*repository.Repository, error) {
} }
} }
err = s.SearchKey(opts.password, maxKeys) err = s.SearchKey(context.TODO(), opts.password, maxKeys)
if err != nil { if err != nil {
return nil, errors.Fatalf("unable to open repo: %v", err) return nil, errors.Fatalf("unable to open repo: %v", err)
} }
@ -440,7 +440,7 @@ func open(s string, opts options.Options) (restic.Backend, error) {
} }
// check if config is there // check if config is there
fi, err := be.Stat(restic.Handle{Type: restic.ConfigFile}) fi, err := be.Stat(context.TODO(), restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return nil, errors.Fatalf("unable to open config file: %v\nIs there a repository at the following location?\n%v", err, s) return nil, errors.Fatalf("unable to open config file: %v\nIs there a repository at the following location?\n%v", err, s)
} }

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@ -32,7 +33,7 @@ func lockRepository(repo *repository.Repository, exclusive bool) (*restic.Lock,
lockFn = restic.NewExclusiveLock lockFn = restic.NewExclusiveLock
} }
lock, err := lockFn(repo) lock, err := lockFn(context.TODO(), repo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,7 +76,7 @@ func refreshLocks(wg *sync.WaitGroup, done <-chan struct{}) {
debug.Log("refreshing locks") debug.Log("refreshing locks")
globalLocks.Lock() globalLocks.Lock()
for _, lock := range globalLocks.locks { for _, lock := range globalLocks.locks {
err := lock.Refresh() err := lock.Refresh(context.TODO())
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "unable to refresh lock: %v\n", err) fmt.Fprintf(os.Stderr, "unable to refresh lock: %v\n", err)
} }

View file

@ -1,6 +1,7 @@
package archiver package archiver
import ( import (
"context"
"io" "io"
"restic" "restic"
"restic/debug" "restic/debug"
@ -20,7 +21,7 @@ type Reader struct {
} }
// Archive reads data from the reader and saves it to the repo. // Archive reads data from the reader and saves it to the repo.
func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic.Snapshot, restic.ID, error) { func (r *Reader) Archive(ctx context.Context, name string, rd io.Reader, p *restic.Progress) (*restic.Snapshot, restic.ID, error) {
if name == "" { if name == "" {
return nil, restic.ID{}, errors.New("no filename given") return nil, restic.ID{}, errors.New("no filename given")
} }
@ -53,7 +54,7 @@ func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic
id := restic.Hash(chunk.Data) id := restic.Hash(chunk.Data)
if !repo.Index().Has(id, restic.DataBlob) { if !repo.Index().Has(id, restic.DataBlob) {
_, err := repo.SaveBlob(restic.DataBlob, chunk.Data, id) _, err := repo.SaveBlob(ctx, restic.DataBlob, chunk.Data, id)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
@ -87,14 +88,14 @@ func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic
}, },
} }
treeID, err := repo.SaveTree(tree) treeID, err := repo.SaveTree(ctx, tree)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
sn.Tree = &treeID sn.Tree = &treeID
debug.Log("tree saved as %v", treeID.Str()) debug.Log("tree saved as %v", treeID.Str())
id, err := repo.SaveJSONUnpacked(restic.SnapshotFile, sn) id, err := repo.SaveJSONUnpacked(ctx, restic.SnapshotFile, sn)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
@ -106,7 +107,7 @@ func (r *Reader) Archive(name string, rd io.Reader, p *restic.Progress) (*restic
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
err = repo.SaveIndex() err = repo.SaveIndex(ctx)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }

View file

@ -2,6 +2,7 @@ package archiver
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"io" "io"
"math/rand" "math/rand"
@ -12,7 +13,7 @@ import (
) )
func loadBlob(t *testing.T, repo restic.Repository, id restic.ID, buf []byte) int { func loadBlob(t *testing.T, repo restic.Repository, id restic.ID, buf []byte) int {
n, err := repo.LoadBlob(restic.DataBlob, id, buf) n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf)
if err != nil { if err != nil {
t.Fatalf("LoadBlob(%v) returned error %v", id, err) t.Fatalf("LoadBlob(%v) returned error %v", id, err)
} }
@ -21,7 +22,7 @@ func loadBlob(t *testing.T, repo restic.Repository, id restic.ID, buf []byte) in
} }
func checkSavedFile(t *testing.T, repo restic.Repository, treeID restic.ID, name string, rd io.Reader) { func checkSavedFile(t *testing.T, repo restic.Repository, treeID restic.ID, name string, rd io.Reader) {
tree, err := repo.LoadTree(treeID) tree, err := repo.LoadTree(context.TODO(), treeID)
if err != nil { if err != nil {
t.Fatalf("LoadTree() returned error %v", err) t.Fatalf("LoadTree() returned error %v", err)
} }
@ -85,7 +86,7 @@ func TestArchiveReader(t *testing.T) {
Tags: []string{"test"}, Tags: []string{"test"},
} }
sn, id, err := r.Archive("fakefile", f, nil) sn, id, err := r.Archive(context.TODO(), "fakefile", f, nil)
if err != nil { if err != nil {
t.Fatalf("ArchiveReader() returned error %v", err) t.Fatalf("ArchiveReader() returned error %v", err)
} }
@ -111,7 +112,7 @@ func TestArchiveReaderNull(t *testing.T) {
Tags: []string{"test"}, Tags: []string{"test"},
} }
sn, id, err := r.Archive("fakefile", bytes.NewReader(nil), nil) sn, id, err := r.Archive(context.TODO(), "fakefile", bytes.NewReader(nil), nil)
if err != nil { if err != nil {
t.Fatalf("ArchiveReader() returned error %v", err) t.Fatalf("ArchiveReader() returned error %v", err)
} }
@ -132,11 +133,8 @@ func (e errReader) Read([]byte) (int, error) {
} }
func countSnapshots(t testing.TB, repo restic.Repository) int { func countSnapshots(t testing.TB, repo restic.Repository) int {
done := make(chan struct{})
defer close(done)
snapshots := 0 snapshots := 0
for range repo.List(restic.SnapshotFile, done) { for range repo.List(context.TODO(), restic.SnapshotFile) {
snapshots++ snapshots++
} }
return snapshots return snapshots
@ -152,7 +150,7 @@ func TestArchiveReaderError(t *testing.T) {
Tags: []string{"test"}, Tags: []string{"test"},
} }
sn, id, err := r.Archive("fakefile", errReader("error returned by reading stdin"), nil) sn, id, err := r.Archive(context.TODO(), "fakefile", errReader("error returned by reading stdin"), nil)
if err == nil { if err == nil {
t.Errorf("expected error not returned") t.Errorf("expected error not returned")
} }
@ -195,7 +193,7 @@ func BenchmarkArchiveReader(t *testing.B) {
t.ResetTimer() t.ResetTimer()
for i := 0; i < t.N; i++ { for i := 0; i < t.N; i++ {
_, _, err := r.Archive("fakefile", bytes.NewReader(buf), nil) _, _, err := r.Archive(context.TODO(), "fakefile", bytes.NewReader(buf), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -1,6 +1,7 @@
package archiver package archiver
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -92,7 +93,7 @@ func (arch *Archiver) isKnownBlob(id restic.ID, t restic.BlobType) bool {
} }
// Save stores a blob read from rd in the repository. // Save stores a blob read from rd in the repository.
func (arch *Archiver) Save(t restic.BlobType, data []byte, id restic.ID) error { func (arch *Archiver) Save(ctx context.Context, t restic.BlobType, data []byte, id restic.ID) error {
debug.Log("Save(%v, %v)\n", t, id.Str()) debug.Log("Save(%v, %v)\n", t, id.Str())
if arch.isKnownBlob(id, restic.DataBlob) { if arch.isKnownBlob(id, restic.DataBlob) {
@ -100,7 +101,7 @@ func (arch *Archiver) Save(t restic.BlobType, data []byte, id restic.ID) error {
return nil return nil
} }
_, err := arch.repo.SaveBlob(t, data, id) _, err := arch.repo.SaveBlob(ctx, t, data, id)
if err != nil { if err != nil {
debug.Log("Save(%v, %v): error %v\n", t, id.Str(), err) debug.Log("Save(%v, %v): error %v\n", t, id.Str(), err)
return err return err
@ -111,7 +112,7 @@ func (arch *Archiver) Save(t restic.BlobType, data []byte, id restic.ID) error {
} }
// SaveTreeJSON stores a tree in the repository. // SaveTreeJSON stores a tree in the repository.
func (arch *Archiver) SaveTreeJSON(tree *restic.Tree) (restic.ID, error) { func (arch *Archiver) SaveTreeJSON(ctx context.Context, tree *restic.Tree) (restic.ID, error) {
data, err := json.Marshal(tree) data, err := json.Marshal(tree)
if err != nil { if err != nil {
return restic.ID{}, errors.Wrap(err, "Marshal") return restic.ID{}, errors.Wrap(err, "Marshal")
@ -124,7 +125,7 @@ func (arch *Archiver) SaveTreeJSON(tree *restic.Tree) (restic.ID, error) {
return id, nil return id, nil
} }
return arch.repo.SaveBlob(restic.TreeBlob, data, id) return arch.repo.SaveBlob(ctx, restic.TreeBlob, data, id)
} }
func (arch *Archiver) reloadFileIfChanged(node *restic.Node, file fs.File) (*restic.Node, error) { func (arch *Archiver) reloadFileIfChanged(node *restic.Node, file fs.File) (*restic.Node, error) {
@ -153,11 +154,11 @@ type saveResult struct {
bytes uint64 bytes uint64
} }
func (arch *Archiver) saveChunk(chunk chunker.Chunk, p *restic.Progress, token struct{}, file fs.File, resultChannel chan<- saveResult) { func (arch *Archiver) saveChunk(ctx context.Context, chunk chunker.Chunk, p *restic.Progress, token struct{}, file fs.File, resultChannel chan<- saveResult) {
defer freeBuf(chunk.Data) defer freeBuf(chunk.Data)
id := restic.Hash(chunk.Data) id := restic.Hash(chunk.Data)
err := arch.Save(restic.DataBlob, chunk.Data, id) err := arch.Save(ctx, restic.DataBlob, chunk.Data, id)
// TODO handle error // TODO handle error
if err != nil { if err != nil {
panic(err) panic(err)
@ -206,7 +207,7 @@ func updateNodeContent(node *restic.Node, results []saveResult) error {
// SaveFile stores the content of the file on the backend as a Blob by calling // SaveFile stores the content of the file on the backend as a Blob by calling
// Save for each chunk. // Save for each chunk.
func (arch *Archiver) SaveFile(p *restic.Progress, node *restic.Node) (*restic.Node, error) { func (arch *Archiver) SaveFile(ctx context.Context, p *restic.Progress, node *restic.Node) (*restic.Node, error) {
file, err := fs.Open(node.Path) file, err := fs.Open(node.Path)
defer file.Close() defer file.Close()
if err != nil { if err != nil {
@ -234,7 +235,7 @@ func (arch *Archiver) SaveFile(p *restic.Progress, node *restic.Node) (*restic.N
} }
resCh := make(chan saveResult, 1) resCh := make(chan saveResult, 1)
go arch.saveChunk(chunk, p, <-arch.blobToken, file, resCh) go arch.saveChunk(ctx, chunk, p, <-arch.blobToken, file, resCh)
resultChannels = append(resultChannels, resCh) resultChannels = append(resultChannels, resCh)
} }
@ -247,7 +248,7 @@ func (arch *Archiver) SaveFile(p *restic.Progress, node *restic.Node) (*restic.N
return node, err return node, err
} }
func (arch *Archiver) fileWorker(wg *sync.WaitGroup, p *restic.Progress, done <-chan struct{}, entCh <-chan pipe.Entry) { func (arch *Archiver) fileWorker(ctx context.Context, wg *sync.WaitGroup, p *restic.Progress, entCh <-chan pipe.Entry) {
defer func() { defer func() {
debug.Log("done") debug.Log("done")
wg.Done() wg.Done()
@ -305,7 +306,7 @@ func (arch *Archiver) fileWorker(wg *sync.WaitGroup, p *restic.Progress, done <-
// otherwise read file normally // otherwise read file normally
if node.Type == "file" && len(node.Content) == 0 { if node.Type == "file" && len(node.Content) == 0 {
debug.Log(" read and save %v", e.Path()) debug.Log(" read and save %v", e.Path())
node, err = arch.SaveFile(p, node) node, err = arch.SaveFile(ctx, p, node)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error for %v: %v\n", node.Path, err) fmt.Fprintf(os.Stderr, "error for %v: %v\n", node.Path, err)
arch.Warn(e.Path(), nil, err) arch.Warn(e.Path(), nil, err)
@ -322,14 +323,14 @@ func (arch *Archiver) fileWorker(wg *sync.WaitGroup, p *restic.Progress, done <-
debug.Log(" processed %v, %d blobs", e.Path(), len(node.Content)) debug.Log(" processed %v, %d blobs", e.Path(), len(node.Content))
e.Result() <- node e.Result() <- node
p.Report(restic.Stat{Files: 1}) p.Report(restic.Stat{Files: 1})
case <-done: case <-ctx.Done():
// pipeline was cancelled // pipeline was cancelled
return return
} }
} }
} }
func (arch *Archiver) dirWorker(wg *sync.WaitGroup, p *restic.Progress, done <-chan struct{}, dirCh <-chan pipe.Dir) { func (arch *Archiver) dirWorker(ctx context.Context, wg *sync.WaitGroup, p *restic.Progress, dirCh <-chan pipe.Dir) {
debug.Log("start") debug.Log("start")
defer func() { defer func() {
debug.Log("done") debug.Log("done")
@ -398,7 +399,7 @@ func (arch *Archiver) dirWorker(wg *sync.WaitGroup, p *restic.Progress, done <-c
node.Error = err.Error() node.Error = err.Error()
} }
id, err := arch.SaveTreeJSON(tree) id, err := arch.SaveTreeJSON(ctx, tree)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -415,7 +416,7 @@ func (arch *Archiver) dirWorker(wg *sync.WaitGroup, p *restic.Progress, done <-c
if dir.Path() != "" { if dir.Path() != "" {
p.Report(restic.Stat{Dirs: 1}) p.Report(restic.Stat{Dirs: 1})
} }
case <-done: case <-ctx.Done():
// pipeline was cancelled // pipeline was cancelled
return return
} }
@ -427,7 +428,7 @@ type archivePipe struct {
New <-chan pipe.Job New <-chan pipe.Job
} }
func copyJobs(done <-chan struct{}, in <-chan pipe.Job, out chan<- pipe.Job) { func copyJobs(ctx context.Context, in <-chan pipe.Job, out chan<- pipe.Job) {
var ( var (
// disable sending on the outCh until we received a job // disable sending on the outCh until we received a job
outCh chan<- pipe.Job outCh chan<- pipe.Job
@ -439,7 +440,7 @@ func copyJobs(done <-chan struct{}, in <-chan pipe.Job, out chan<- pipe.Job) {
for { for {
select { select {
case <-done: case <-ctx.Done():
return return
case job, ok = <-inCh: case job, ok = <-inCh:
if !ok { if !ok {
@ -462,7 +463,7 @@ type archiveJob struct {
new pipe.Job new pipe.Job
} }
func (a *archivePipe) compare(done <-chan struct{}, out chan<- pipe.Job) { func (a *archivePipe) compare(ctx context.Context, out chan<- pipe.Job) {
defer func() { defer func() {
close(out) close(out)
debug.Log("done") debug.Log("done")
@ -488,7 +489,7 @@ func (a *archivePipe) compare(done <-chan struct{}, out chan<- pipe.Job) {
out <- archiveJob{new: newJob}.Copy() out <- archiveJob{new: newJob}.Copy()
} }
copyJobs(done, a.New, out) copyJobs(ctx, a.New, out)
return return
} }
@ -585,7 +586,7 @@ func (j archiveJob) Copy() pipe.Job {
const saveIndexTime = 30 * time.Second const saveIndexTime = 30 * time.Second
// saveIndexes regularly queries the master index for full indexes and saves them. // saveIndexes regularly queries the master index for full indexes and saves them.
func (arch *Archiver) saveIndexes(wg *sync.WaitGroup, done <-chan struct{}) { func (arch *Archiver) saveIndexes(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
ticker := time.NewTicker(saveIndexTime) ticker := time.NewTicker(saveIndexTime)
@ -593,11 +594,11 @@ func (arch *Archiver) saveIndexes(wg *sync.WaitGroup, done <-chan struct{}) {
for { for {
select { select {
case <-done: case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
debug.Log("saving full indexes") debug.Log("saving full indexes")
err := arch.repo.SaveFullIndex() err := arch.repo.SaveFullIndex(ctx)
if err != nil { if err != nil {
debug.Log("save indexes returned an error: %v", err) debug.Log("save indexes returned an error: %v", err)
fmt.Fprintf(os.Stderr, "error saving preliminary index: %v\n", err) fmt.Fprintf(os.Stderr, "error saving preliminary index: %v\n", err)
@ -634,7 +635,7 @@ func (p baseNameSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// Snapshot creates a snapshot of the given paths. If parentrestic.ID is set, this is // Snapshot creates a snapshot of the given paths. If parentrestic.ID is set, this is
// used to compare the files to the ones archived at the time this snapshot was // used to compare the files to the ones archived at the time this snapshot was
// taken. // taken.
func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostname string, parentID *restic.ID) (*restic.Snapshot, restic.ID, error) { func (arch *Archiver) Snapshot(ctx context.Context, p *restic.Progress, paths, tags []string, hostname string, parentID *restic.ID) (*restic.Snapshot, restic.ID, error) {
paths = unique(paths) paths = unique(paths)
sort.Sort(baseNameSlice(paths)) sort.Sort(baseNameSlice(paths))
@ -643,7 +644,6 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
debug.RunHook("Archiver.Snapshot", nil) debug.RunHook("Archiver.Snapshot", nil)
// signal the whole pipeline to stop // signal the whole pipeline to stop
done := make(chan struct{})
var err error var err error
p.Start() p.Start()
@ -663,14 +663,14 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
sn.Parent = parentID sn.Parent = parentID
// load parent snapshot // load parent snapshot
parent, err := restic.LoadSnapshot(arch.repo, *parentID) parent, err := restic.LoadSnapshot(ctx, arch.repo, *parentID)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
// start walker on old tree // start walker on old tree
ch := make(chan walk.TreeJob) ch := make(chan walk.TreeJob)
go walk.Tree(arch.repo, *parent.Tree, done, ch) go walk.Tree(ctx, arch.repo, *parent.Tree, ch)
jobs.Old = ch jobs.Old = ch
} else { } else {
// use closed channel // use closed channel
@ -683,13 +683,13 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
pipeCh := make(chan pipe.Job) pipeCh := make(chan pipe.Job)
resCh := make(chan pipe.Result, 1) resCh := make(chan pipe.Result, 1)
go func() { go func() {
pipe.Walk(paths, arch.SelectFilter, done, pipeCh, resCh) pipe.Walk(ctx, paths, arch.SelectFilter, pipeCh, resCh)
debug.Log("pipe.Walk done") debug.Log("pipe.Walk done")
}() }()
jobs.New = pipeCh jobs.New = pipeCh
ch := make(chan pipe.Job) ch := make(chan pipe.Job)
go jobs.compare(done, ch) go jobs.compare(ctx, ch)
var wg sync.WaitGroup var wg sync.WaitGroup
entCh := make(chan pipe.Entry) entCh := make(chan pipe.Entry)
@ -708,22 +708,22 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
// run workers // run workers
for i := 0; i < maxConcurrency; i++ { for i := 0; i < maxConcurrency; i++ {
wg.Add(2) wg.Add(2)
go arch.fileWorker(&wg, p, done, entCh) go arch.fileWorker(ctx, &wg, p, entCh)
go arch.dirWorker(&wg, p, done, dirCh) go arch.dirWorker(ctx, &wg, p, dirCh)
} }
// run index saver // run index saver
var wgIndexSaver sync.WaitGroup var wgIndexSaver sync.WaitGroup
stopIndexSaver := make(chan struct{}) indexCtx, indexCancel := context.WithCancel(ctx)
wgIndexSaver.Add(1) wgIndexSaver.Add(1)
go arch.saveIndexes(&wgIndexSaver, stopIndexSaver) go arch.saveIndexes(indexCtx, &wgIndexSaver)
// wait for all workers to terminate // wait for all workers to terminate
debug.Log("wait for workers") debug.Log("wait for workers")
wg.Wait() wg.Wait()
// stop index saver // stop index saver
close(stopIndexSaver) indexCancel()
wgIndexSaver.Wait() wgIndexSaver.Wait()
debug.Log("workers terminated") debug.Log("workers terminated")
@ -740,7 +740,7 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
sn.Tree = root.Subtree sn.Tree = root.Subtree
// load top-level tree again to see if it is empty // load top-level tree again to see if it is empty
toptree, err := arch.repo.LoadTree(*root.Subtree) toptree, err := arch.repo.LoadTree(ctx, *root.Subtree)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }
@ -750,7 +750,7 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
} }
// save index // save index
err = arch.repo.SaveIndex() err = arch.repo.SaveIndex(ctx)
if err != nil { if err != nil {
debug.Log("error saving index: %v", err) debug.Log("error saving index: %v", err)
return nil, restic.ID{}, err return nil, restic.ID{}, err
@ -759,7 +759,7 @@ func (arch *Archiver) Snapshot(p *restic.Progress, paths, tags []string, hostnam
debug.Log("saved indexes") debug.Log("saved indexes")
// save snapshot // save snapshot
id, err := arch.repo.SaveJSONUnpacked(restic.SnapshotFile, sn) id, err := arch.repo.SaveJSONUnpacked(ctx, restic.SnapshotFile, sn)
if err != nil { if err != nil {
return nil, restic.ID{}, err return nil, restic.ID{}, err
} }

View file

@ -1,6 +1,7 @@
package archiver_test package archiver_test
import ( import (
"context"
"crypto/rand" "crypto/rand"
"io" "io"
mrand "math/rand" mrand "math/rand"
@ -39,33 +40,33 @@ func randomID() restic.ID {
func forgetfulBackend() restic.Backend { func forgetfulBackend() restic.Backend {
be := &mock.Backend{} be := &mock.Backend{}
be.TestFn = func(h restic.Handle) (bool, error) { be.TestFn = func(ctx context.Context, h restic.Handle) (bool, error) {
return false, nil return false, nil
} }
be.LoadFn = func(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { be.LoadFn = func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
return nil, errors.New("not found") return nil, errors.New("not found")
} }
be.SaveFn = func(h restic.Handle, rd io.Reader) error { be.SaveFn = func(ctx context.Context, h restic.Handle, rd io.Reader) error {
return nil return nil
} }
be.StatFn = func(h restic.Handle) (restic.FileInfo, error) { be.StatFn = func(ctx context.Context, h restic.Handle) (restic.FileInfo, error) {
return restic.FileInfo{}, errors.New("not found") return restic.FileInfo{}, errors.New("not found")
} }
be.RemoveFn = func(h restic.Handle) error { be.RemoveFn = func(ctx context.Context, h restic.Handle) error {
return nil return nil
} }
be.ListFn = func(t restic.FileType, done <-chan struct{}) <-chan string { be.ListFn = func(ctx context.Context, t restic.FileType) <-chan string {
ch := make(chan string) ch := make(chan string)
close(ch) close(ch)
return ch return ch
} }
be.DeleteFn = func() error { be.DeleteFn = func(ctx context.Context) error {
return nil return nil
} }
@ -80,7 +81,7 @@ func testArchiverDuplication(t *testing.T) {
repo := repository.New(forgetfulBackend()) repo := repository.New(forgetfulBackend())
err = repo.Init("foo") err = repo.Init(context.TODO(), "foo")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -108,7 +109,7 @@ func testArchiverDuplication(t *testing.T) {
buf := make([]byte, 50) buf := make([]byte, 50)
err := arch.Save(restic.DataBlob, buf, id) err := arch.Save(context.TODO(), restic.DataBlob, buf, id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -127,7 +128,7 @@ func testArchiverDuplication(t *testing.T) {
case <-done: case <-done:
return return
case <-ticker.C: case <-ticker.C:
err := repo.SaveFullIndex() err := repo.SaveFullIndex(context.TODO())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -1,6 +1,7 @@
package archiver package archiver
import ( import (
"context"
"os" "os"
"testing" "testing"
@ -83,10 +84,10 @@ func (j testPipeJob) Error() error { return j.err }
func (j testPipeJob) Info() os.FileInfo { return j.fi } func (j testPipeJob) Info() os.FileInfo { return j.fi }
func (j testPipeJob) Result() chan<- pipe.Result { return j.res } func (j testPipeJob) Result() chan<- pipe.Result { return j.res }
func testTreeWalker(done <-chan struct{}, out chan<- walk.TreeJob) { func testTreeWalker(ctx context.Context, out chan<- walk.TreeJob) {
for _, e := range treeJobs { for _, e := range treeJobs {
select { select {
case <-done: case <-ctx.Done():
return return
case out <- walk.TreeJob{Path: e}: case out <- walk.TreeJob{Path: e}:
} }
@ -95,10 +96,10 @@ func testTreeWalker(done <-chan struct{}, out chan<- walk.TreeJob) {
close(out) close(out)
} }
func testPipeWalker(done <-chan struct{}, out chan<- pipe.Job) { func testPipeWalker(ctx context.Context, out chan<- pipe.Job) {
for _, e := range pipeJobs { for _, e := range pipeJobs {
select { select {
case <-done: case <-ctx.Done():
return return
case out <- testPipeJob{path: e}: case out <- testPipeJob{path: e}:
} }
@ -108,19 +109,19 @@ func testPipeWalker(done <-chan struct{}, out chan<- pipe.Job) {
} }
func TestArchivePipe(t *testing.T) { func TestArchivePipe(t *testing.T) {
done := make(chan struct{}) ctx := context.TODO()
treeCh := make(chan walk.TreeJob) treeCh := make(chan walk.TreeJob)
pipeCh := make(chan pipe.Job) pipeCh := make(chan pipe.Job)
go testTreeWalker(done, treeCh) go testTreeWalker(ctx, treeCh)
go testPipeWalker(done, pipeCh) go testPipeWalker(ctx, pipeCh)
p := archivePipe{Old: treeCh, New: pipeCh} p := archivePipe{Old: treeCh, New: pipeCh}
ch := make(chan pipe.Job) ch := make(chan pipe.Job)
go p.compare(done, ch) go p.compare(ctx, ch)
i := 0 i := 0
for job := range ch { for job := range ch {

View file

@ -2,6 +2,7 @@ package archiver_test
import ( import (
"bytes" "bytes"
"context"
"io" "io"
"testing" "testing"
"time" "time"
@ -104,7 +105,7 @@ func archiveDirectory(b testing.TB) {
arch := archiver.New(repo) arch := archiver.New(repo)
_, id, err := arch.Snapshot(nil, []string{BenchArchiveDirectory}, nil, "localhost", nil) _, id, err := arch.Snapshot(context.TODO(), nil, []string{BenchArchiveDirectory}, nil, "localhost", nil)
OK(b, err) OK(b, err)
b.Logf("snapshot archived as %v", id) b.Logf("snapshot archived as %v", id)
@ -129,7 +130,7 @@ func BenchmarkArchiveDirectory(b *testing.B) {
} }
func countPacks(repo restic.Repository, t restic.FileType) (n uint) { func countPacks(repo restic.Repository, t restic.FileType) (n uint) {
for range repo.Backend().List(t, nil) { for range repo.Backend().List(context.TODO(), t) {
n++ n++
} }
@ -234,7 +235,7 @@ func testParallelSaveWithDuplication(t *testing.T, seed int) {
id := restic.Hash(c.Data) id := restic.Hash(c.Data)
time.Sleep(time.Duration(id[0])) time.Sleep(time.Duration(id[0]))
err := arch.Save(restic.DataBlob, c.Data, id) err := arch.Save(context.TODO(), restic.DataBlob, c.Data, id)
<-barrier <-barrier
errChan <- err errChan <- err
}(c, errChan) }(c, errChan)
@ -246,7 +247,7 @@ func testParallelSaveWithDuplication(t *testing.T, seed int) {
} }
OK(t, repo.Flush()) OK(t, repo.Flush())
OK(t, repo.SaveIndex()) OK(t, repo.SaveIndex(context.TODO()))
chkr := createAndInitChecker(t, repo) chkr := createAndInitChecker(t, repo)
assertNoUnreferencedPacks(t, chkr) assertNoUnreferencedPacks(t, chkr)
@ -271,7 +272,7 @@ func getRandomData(seed int, size int) []chunker.Chunk {
func createAndInitChecker(t *testing.T, repo restic.Repository) *checker.Checker { func createAndInitChecker(t *testing.T, repo restic.Repository) *checker.Checker {
chkr := checker.New(repo) chkr := checker.New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) > 0 { if len(errs) > 0 {
t.Fatalf("expected no errors, got %v: %v", len(errs), errs) t.Fatalf("expected no errors, got %v: %v", len(errs), errs)
} }
@ -284,11 +285,8 @@ func createAndInitChecker(t *testing.T, repo restic.Repository) *checker.Checker
} }
func assertNoUnreferencedPacks(t *testing.T, chkr *checker.Checker) { func assertNoUnreferencedPacks(t *testing.T, chkr *checker.Checker) {
done := make(chan struct{})
defer close(done)
errChan := make(chan error) errChan := make(chan error)
go chkr.Packs(errChan, done) go chkr.Packs(context.TODO(), errChan)
for err := range errChan { for err := range errChan {
OK(t, err) OK(t, err)
@ -301,7 +299,7 @@ func TestArchiveEmptySnapshot(t *testing.T) {
arch := archiver.New(repo) arch := archiver.New(repo)
sn, id, err := arch.Snapshot(nil, []string{"file-does-not-exist-123123213123", "file2-does-not-exist-too-123123123"}, nil, "localhost", nil) sn, id, err := arch.Snapshot(context.TODO(), nil, []string{"file-does-not-exist-123123213123", "file2-does-not-exist-too-123123123"}, nil, "localhost", nil)
if err == nil { if err == nil {
t.Errorf("expected error for empty snapshot, got nil") t.Errorf("expected error for empty snapshot, got nil")
} }

View file

@ -1,6 +1,7 @@
package archiver package archiver
import ( import (
"context"
"restic" "restic"
"testing" "testing"
) )
@ -8,7 +9,7 @@ import (
// TestSnapshot creates a new snapshot of path. // TestSnapshot creates a new snapshot of path.
func TestSnapshot(t testing.TB, repo restic.Repository, path string, parent *restic.ID) *restic.Snapshot { func TestSnapshot(t testing.TB, repo restic.Repository, path string, parent *restic.ID) *restic.Snapshot {
arch := New(repo) arch := New(repo)
sn, _, err := arch.Snapshot(nil, []string{path}, []string{"test"}, "localhost", parent) sn, _, err := arch.Snapshot(context.TODO(), nil, []string{path}, []string{"test"}, "localhost", parent)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -1,6 +1,9 @@
package restic package restic
import "io" import (
"context"
"io"
)
// Backend is used to store and access data. // Backend is used to store and access data.
type Backend interface { type Backend interface {
@ -9,30 +12,30 @@ type Backend interface {
Location() string Location() string
// Test a boolean value whether a File with the name and type exists. // Test a boolean value whether a File with the name and type exists.
Test(h Handle) (bool, error) Test(ctx context.Context, h Handle) (bool, error)
// Remove removes a File with type t and name. // Remove removes a File with type t and name.
Remove(h Handle) error Remove(ctx context.Context, h Handle) error
// Close the backend // Close the backend
Close() error Close() error
// Save stores the data in the backend under the given handle. // Save stores the data in the backend under the given handle.
Save(h Handle, rd io.Reader) error Save(ctx context.Context, h Handle, rd io.Reader) error
// Load returns a reader that yields the contents of the file at h at the // Load returns a reader that yields the contents of the file at h at the
// given offset. If length is larger than zero, only a portion of the file // given offset. If length is larger than zero, only a portion of the file
// is returned. rd must be closed after use. If an error is returned, the // is returned. rd must be closed after use. If an error is returned, the
// ReadCloser must be nil. // ReadCloser must be nil.
Load(h Handle, length int, offset int64) (io.ReadCloser, error) Load(ctx context.Context, h Handle, length int, offset int64) (io.ReadCloser, error)
// Stat returns information about the File identified by h. // Stat returns information about the File identified by h.
Stat(h Handle) (FileInfo, error) Stat(ctx context.Context, h Handle) (FileInfo, error)
// List returns a channel that yields all names of files of type t in an // List returns a channel that yields all names of files of type t in an
// arbitrary order. A goroutine is started for this. If the channel done is // arbitrary order. A goroutine is started for this, which is stopped when
// closed, sending stops. // ctx is cancelled.
List(t FileType, done <-chan struct{}) <-chan string List(ctx context.Context, t FileType) <-chan string
} }
// FileInfo is returned by Stat() and contains information about a file in the // FileInfo is returned by Stat() and contains information about a file in the

View file

@ -23,6 +23,9 @@ type b2Backend struct {
sem *backend.Semaphore sem *backend.Semaphore
} }
// ensure statically that *b2Backend implements restic.Backend.
var _ restic.Backend = &b2Backend{}
func newClient(ctx context.Context, cfg Config) (*b2.Client, error) { func newClient(ctx context.Context, cfg Config) (*b2.Client, error) {
opts := []b2.ClientOption{b2.Transport(backend.Transport())} opts := []b2.ClientOption{b2.Transport(backend.Transport())}
@ -96,7 +99,7 @@ func Create(cfg Config) (restic.Backend, error) {
sem: backend.NewSemaphore(cfg.Connections), sem: backend.NewSemaphore(cfg.Connections),
} }
present, err := be.Test(restic.Handle{Type: restic.ConfigFile}) present, err := be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -140,7 +143,7 @@ func (wr *wrapReader) Close() error {
// Load returns the data stored in the backend for h at the given offset // Load returns the data stored in the backend for h at the given offset
// and saves it in p. Load has the same semantics as io.ReaderAt. // and saves it in p. Load has the same semantics as io.ReaderAt.
func (be *b2Backend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (be *b2Backend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
debug.Log("Load %v, length %v, offset %v from %v", h, length, offset, be.Filename(h)) debug.Log("Load %v, length %v, offset %v from %v", h, length, offset, be.Filename(h))
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return nil, err return nil, err
@ -154,7 +157,7 @@ func (be *b2Backend) Load(h restic.Handle, length int, offset int64) (io.ReadClo
return nil, errors.Errorf("invalid length %d", length) return nil, errors.Errorf("invalid length %d", length)
} }
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(ctx)
be.sem.GetToken() be.sem.GetToken()
@ -191,8 +194,8 @@ func (be *b2Backend) Load(h restic.Handle, length int, offset int64) (io.ReadClo
} }
// Save stores data in the backend at the handle. // Save stores data in the backend at the handle.
func (be *b2Backend) Save(h restic.Handle, rd io.Reader) (err error) { func (be *b2Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
@ -225,12 +228,9 @@ func (be *b2Backend) Save(h restic.Handle, rd io.Reader) (err error) {
} }
// Stat returns information about a blob. // Stat returns information about a blob.
func (be *b2Backend) Stat(h restic.Handle) (bi restic.FileInfo, err error) { func (be *b2Backend) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInfo, err error) {
debug.Log("Stat %v", h) debug.Log("Stat %v", h)
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
be.sem.GetToken() be.sem.GetToken()
defer be.sem.ReleaseToken() defer be.sem.ReleaseToken()
@ -245,12 +245,9 @@ func (be *b2Backend) Stat(h restic.Handle) (bi restic.FileInfo, err error) {
} }
// Test returns true if a blob of the given type and name exists in the backend. // Test returns true if a blob of the given type and name exists in the backend.
func (be *b2Backend) Test(h restic.Handle) (bool, error) { func (be *b2Backend) Test(ctx context.Context, h restic.Handle) (bool, error) {
debug.Log("Test %v", h) debug.Log("Test %v", h)
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
be.sem.GetToken() be.sem.GetToken()
defer be.sem.ReleaseToken() defer be.sem.ReleaseToken()
@ -265,12 +262,9 @@ func (be *b2Backend) Test(h restic.Handle) (bool, error) {
} }
// Remove removes the blob with the given name and type. // Remove removes the blob with the given name and type.
func (be *b2Backend) Remove(h restic.Handle) error { func (be *b2Backend) Remove(ctx context.Context, h restic.Handle) error {
debug.Log("Remove %v", h) debug.Log("Remove %v", h)
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
be.sem.GetToken() be.sem.GetToken()
defer be.sem.ReleaseToken() defer be.sem.ReleaseToken()
@ -281,11 +275,11 @@ func (be *b2Backend) Remove(h restic.Handle) error {
// List returns a channel that yields all names of blobs of type t. A // List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending // goroutine is started for this. If the channel done is closed, sending
// stops. // stops.
func (be *b2Backend) List(t restic.FileType, done <-chan struct{}) <-chan string { func (be *b2Backend) List(ctx context.Context, t restic.FileType) <-chan string {
debug.Log("List %v", t) debug.Log("List %v", t)
ch := make(chan string) ch := make(chan string)
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(ctx)
be.sem.GetToken() be.sem.GetToken()
@ -315,7 +309,7 @@ func (be *b2Backend) List(t restic.FileType, done <-chan struct{}) <-chan string
select { select {
case ch <- m: case ch <- m:
case <-done: case <-ctx.Done():
return return
} }
} }
@ -330,13 +324,10 @@ func (be *b2Backend) List(t restic.FileType, done <-chan struct{}) <-chan string
} }
// Remove keys for a specified backend type. // Remove keys for a specified backend type.
func (be *b2Backend) removeKeys(t restic.FileType) error { func (be *b2Backend) removeKeys(ctx context.Context, t restic.FileType) error {
debug.Log("removeKeys %v", t) debug.Log("removeKeys %v", t)
for key := range be.List(ctx, t) {
done := make(chan struct{}) err := be.Remove(ctx, restic.Handle{Type: t, Name: key})
defer close(done)
for key := range be.List(t, done) {
err := be.Remove(restic.Handle{Type: t, Name: key})
if err != nil { if err != nil {
return err return err
} }
@ -345,7 +336,7 @@ func (be *b2Backend) removeKeys(t restic.FileType) error {
} }
// Delete removes all restic keys in the bucket. It will not remove the bucket itself. // Delete removes all restic keys in the bucket. It will not remove the bucket itself.
func (be *b2Backend) Delete() error { func (be *b2Backend) Delete(ctx context.Context) error {
alltypes := []restic.FileType{ alltypes := []restic.FileType{
restic.DataFile, restic.DataFile,
restic.KeyFile, restic.KeyFile,
@ -354,12 +345,12 @@ func (be *b2Backend) Delete() error {
restic.IndexFile} restic.IndexFile}
for _, t := range alltypes { for _, t := range alltypes {
err := be.removeKeys(t) err := be.removeKeys(ctx, t)
if err != nil { if err != nil {
return nil return nil
} }
} }
err := be.Remove(restic.Handle{Type: restic.ConfigFile}) err := be.Remove(ctx, restic.Handle{Type: restic.ConfigFile})
if err != nil && b2.IsNotExist(errors.Cause(err)) { if err != nil && b2.IsNotExist(errors.Cause(err)) {
err = nil err = nil
} }

View file

@ -1,6 +1,7 @@
package b2_test package b2_test
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"testing" "testing"
@ -52,7 +53,7 @@ func newB2TestSuite(t testing.TB) *test.Suite {
return err return err
} }
if err := be.(restic.Deleter).Delete(); err != nil { if err := be.(restic.Deleter).Delete(context.TODO()); err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package local package local
import ( import (
"context"
"path/filepath" "path/filepath"
"restic" "restic"
. "restic/test" . "restic/test"
@ -47,7 +48,7 @@ func TestLayout(t *testing.T) {
} }
datafiles := make(map[string]bool) datafiles := make(map[string]bool)
for id := range be.List(restic.DataFile, nil) { for id := range be.List(context.TODO(), restic.DataFile) {
datafiles[id] = false datafiles[id] = false
} }

View file

@ -1,6 +1,7 @@
package local package local
import ( import (
"context"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
@ -75,7 +76,7 @@ func (b *Local) Location() string {
} }
// Save stores data in the backend at the handle. // Save stores data in the backend at the handle.
func (b *Local) Save(h restic.Handle, rd io.Reader) (err error) { func (b *Local) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
debug.Log("Save %v", h) debug.Log("Save %v", h)
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return err return err
@ -100,7 +101,7 @@ func (b *Local) Save(h restic.Handle, rd io.Reader) (err error) {
return errors.Wrap(err, "MkdirAll") return errors.Wrap(err, "MkdirAll")
} }
return b.Save(h, rd) return b.Save(ctx, h, rd)
} }
if err != nil { if err != nil {
@ -110,12 +111,12 @@ func (b *Local) Save(h restic.Handle, rd io.Reader) (err error) {
// save data, then sync // save data, then sync
_, err = io.Copy(f, rd) _, err = io.Copy(f, rd)
if err != nil { if err != nil {
f.Close() _ = f.Close()
return errors.Wrap(err, "Write") return errors.Wrap(err, "Write")
} }
if err = f.Sync(); err != nil { if err = f.Sync(); err != nil {
f.Close() _ = f.Close()
return errors.Wrap(err, "Sync") return errors.Wrap(err, "Sync")
} }
@ -136,7 +137,7 @@ func (b *Local) Save(h restic.Handle, rd io.Reader) (err error) {
// Load returns a reader that yields the contents of the file at h at the // Load returns a reader that yields the contents of the file at h at the
// given offset. If length is nonzero, only a portion of the file is // given offset. If length is nonzero, only a portion of the file is
// returned. rd must be closed after use. // returned. rd must be closed after use.
func (b *Local) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (b *Local) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
debug.Log("Load %v, length %v, offset %v", h, length, offset) debug.Log("Load %v, length %v, offset %v", h, length, offset)
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return nil, err return nil, err
@ -154,7 +155,7 @@ func (b *Local) Load(h restic.Handle, length int, offset int64) (io.ReadCloser,
if offset > 0 { if offset > 0 {
_, err = f.Seek(offset, 0) _, err = f.Seek(offset, 0)
if err != nil { if err != nil {
f.Close() _ = f.Close()
return nil, err return nil, err
} }
} }
@ -167,7 +168,7 @@ func (b *Local) Load(h restic.Handle, length int, offset int64) (io.ReadCloser,
} }
// Stat returns information about a blob. // Stat returns information about a blob.
func (b *Local) Stat(h restic.Handle) (restic.FileInfo, error) { func (b *Local) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) {
debug.Log("Stat %v", h) debug.Log("Stat %v", h)
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return restic.FileInfo{}, err return restic.FileInfo{}, err
@ -182,7 +183,7 @@ func (b *Local) Stat(h restic.Handle) (restic.FileInfo, error) {
} }
// Test returns true if a blob of the given type and name exists in the backend. // Test returns true if a blob of the given type and name exists in the backend.
func (b *Local) Test(h restic.Handle) (bool, error) { func (b *Local) Test(ctx context.Context, h restic.Handle) (bool, error) {
debug.Log("Test %v", h) debug.Log("Test %v", h)
_, err := fs.Stat(b.Filename(h)) _, err := fs.Stat(b.Filename(h))
if err != nil { if err != nil {
@ -196,7 +197,7 @@ func (b *Local) Test(h restic.Handle) (bool, error) {
} }
// Remove removes the blob with the given name and type. // Remove removes the blob with the given name and type.
func (b *Local) Remove(h restic.Handle) error { func (b *Local) Remove(ctx context.Context, h restic.Handle) error {
debug.Log("Remove %v", h) debug.Log("Remove %v", h)
fn := b.Filename(h) fn := b.Filename(h)
@ -214,9 +215,8 @@ func isFile(fi os.FileInfo) bool {
} }
// List returns a channel that yields all names of blobs of type t. A // List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending // goroutine is started for this.
// stops. func (b *Local) List(ctx context.Context, t restic.FileType) <-chan string {
func (b *Local) List(t restic.FileType, done <-chan struct{}) <-chan string {
debug.Log("List %v", t) debug.Log("List %v", t)
ch := make(chan string) ch := make(chan string)
@ -235,7 +235,7 @@ func (b *Local) List(t restic.FileType, done <-chan struct{}) <-chan string {
select { select {
case ch <- filepath.Base(path): case ch <- filepath.Base(path):
case <-done: case <-ctx.Done():
return err return err
} }

View file

@ -2,6 +2,7 @@ package mem
import ( import (
"bytes" "bytes"
"context"
"io" "io"
"io/ioutil" "io/ioutil"
"restic" "restic"
@ -37,7 +38,7 @@ func New() *MemoryBackend {
} }
// Test returns whether a file exists. // Test returns whether a file exists.
func (be *MemoryBackend) Test(h restic.Handle) (bool, error) { func (be *MemoryBackend) Test(ctx context.Context, h restic.Handle) (bool, error) {
be.m.Lock() be.m.Lock()
defer be.m.Unlock() defer be.m.Unlock()
@ -51,7 +52,7 @@ func (be *MemoryBackend) Test(h restic.Handle) (bool, error) {
} }
// Save adds new Data to the backend. // Save adds new Data to the backend.
func (be *MemoryBackend) Save(h restic.Handle, rd io.Reader) error { func (be *MemoryBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return err return err
} }
@ -81,7 +82,7 @@ func (be *MemoryBackend) Save(h restic.Handle, rd io.Reader) error {
// Load returns a reader that yields the contents of the file at h at the // Load returns a reader that yields the contents of the file at h at the
// given offset. If length is nonzero, only a portion of the file is // given offset. If length is nonzero, only a portion of the file is
// returned. rd must be closed after use. // returned. rd must be closed after use.
func (be *MemoryBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (be *MemoryBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return nil, err return nil, err
} }
@ -117,7 +118,7 @@ func (be *MemoryBackend) Load(h restic.Handle, length int, offset int64) (io.Rea
} }
// Stat returns information about a file in the backend. // Stat returns information about a file in the backend.
func (be *MemoryBackend) Stat(h restic.Handle) (restic.FileInfo, error) { func (be *MemoryBackend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) {
be.m.Lock() be.m.Lock()
defer be.m.Unlock() defer be.m.Unlock()
@ -140,7 +141,7 @@ func (be *MemoryBackend) Stat(h restic.Handle) (restic.FileInfo, error) {
} }
// Remove deletes a file from the backend. // Remove deletes a file from the backend.
func (be *MemoryBackend) Remove(h restic.Handle) error { func (be *MemoryBackend) Remove(ctx context.Context, h restic.Handle) error {
be.m.Lock() be.m.Lock()
defer be.m.Unlock() defer be.m.Unlock()
@ -156,7 +157,7 @@ func (be *MemoryBackend) Remove(h restic.Handle) error {
} }
// List returns a channel which yields entries from the backend. // List returns a channel which yields entries from the backend.
func (be *MemoryBackend) List(t restic.FileType, done <-chan struct{}) <-chan string { func (be *MemoryBackend) List(ctx context.Context, t restic.FileType) <-chan string {
be.m.Lock() be.m.Lock()
defer be.m.Unlock() defer be.m.Unlock()
@ -177,7 +178,7 @@ func (be *MemoryBackend) List(t restic.FileType, done <-chan struct{}) <-chan st
for _, id := range ids { for _, id := range ids {
select { select {
case ch <- id: case ch <- id:
case <-done: case <-ctx.Done():
return return
} }
} }
@ -192,7 +193,7 @@ func (be *MemoryBackend) Location() string {
} }
// Delete removes all data in the backend. // Delete removes all data in the backend.
func (be *MemoryBackend) Delete() error { func (be *MemoryBackend) Delete(ctx context.Context) error {
be.m.Lock() be.m.Lock()
defer be.m.Unlock() defer be.m.Unlock()

View file

@ -1,6 +1,7 @@
package mem_test package mem_test
import ( import (
"context"
"restic" "restic"
"testing" "testing"
@ -25,7 +26,7 @@ func newTestSuite() *test.Suite {
Create: func(cfg interface{}) (restic.Backend, error) { Create: func(cfg interface{}) (restic.Backend, error) {
c := cfg.(*memConfig) c := cfg.(*memConfig)
if c.be != nil { if c.be != nil {
ok, err := c.be.Test(restic.Handle{Type: restic.ConfigFile}) ok, err := c.be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,6 +1,7 @@
package rest package rest
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -11,6 +12,8 @@ import (
"restic" "restic"
"strings" "strings"
"golang.org/x/net/context/ctxhttp"
"restic/debug" "restic/debug"
"restic/errors" "restic/errors"
@ -25,7 +28,7 @@ var _ restic.Backend = &restBackend{}
type restBackend struct { type restBackend struct {
url *url.URL url *url.URL
connChan chan struct{} connChan chan struct{}
client http.Client client *http.Client
backend.Layout backend.Layout
} }
@ -36,7 +39,7 @@ func Open(cfg Config) (restic.Backend, error) {
connChan <- struct{}{} connChan <- struct{}{}
} }
client := http.Client{Transport: backend.Transport()} client := &http.Client{Transport: backend.Transport()}
// use url without trailing slash for layout // use url without trailing slash for layout
url := cfg.URL.String() url := cfg.URL.String()
@ -61,7 +64,7 @@ func Create(cfg Config) (restic.Backend, error) {
return nil, err return nil, err
} }
_, err = be.Stat(restic.Handle{Type: restic.ConfigFile}) _, err = be.Stat(context.TODO(), restic.Handle{Type: restic.ConfigFile})
if err == nil { if err == nil {
return nil, errors.Fatal("config file already exists") return nil, errors.Fatal("config file already exists")
} }
@ -99,22 +102,25 @@ func (b *restBackend) Location() string {
} }
// Save stores data in the backend at the handle. // Save stores data in the backend at the handle.
func (b *restBackend) Save(h restic.Handle, rd io.Reader) (err error) { func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return err return err
} }
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// make sure that client.Post() cannot close the reader by wrapping it in // make sure that client.Post() cannot close the reader by wrapping it in
// backend.Closer, which has a noop method. // backend.Closer, which has a noop method.
rd = backend.Closer{Reader: rd} rd = backend.Closer{Reader: rd}
<-b.connChan <-b.connChan
resp, err := b.client.Post(b.Filename(h), "binary/octet-stream", rd) resp, err := ctxhttp.Post(ctx, b.client, b.Filename(h), "binary/octet-stream", rd)
b.connChan <- struct{}{} b.connChan <- struct{}{}
if resp != nil { if resp != nil {
defer func() { defer func() {
io.Copy(ioutil.Discard, resp.Body) _, _ = io.Copy(ioutil.Discard, resp.Body)
e := resp.Body.Close() e := resp.Body.Close()
if err == nil { if err == nil {
@ -137,7 +143,7 @@ func (b *restBackend) Save(h restic.Handle, rd io.Reader) (err error) {
// Load returns a reader that yields the contents of the file at h at the // Load returns a reader that yields the contents of the file at h at the
// given offset. If length is nonzero, only a portion of the file is // given offset. If length is nonzero, only a portion of the file is
// returned. rd must be closed after use. // returned. rd must be closed after use.
func (b *restBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (b *restBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
debug.Log("Load %v, length %v, offset %v", h, length, offset) debug.Log("Load %v, length %v, offset %v", h, length, offset)
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return nil, err return nil, err
@ -164,20 +170,19 @@ func (b *restBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCl
debug.Log("Load(%v) send range %v", h, byteRange) debug.Log("Load(%v) send range %v", h, byteRange)
<-b.connChan <-b.connChan
resp, err := b.client.Do(req) resp, err := ctxhttp.Do(ctx, b.client, req)
b.connChan <- struct{}{} b.connChan <- struct{}{}
if err != nil { if err != nil {
if resp != nil { if resp != nil {
io.Copy(ioutil.Discard, resp.Body) _, _ = io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close() _ = resp.Body.Close()
} }
return nil, errors.Wrap(err, "client.Do") return nil, errors.Wrap(err, "client.Do")
} }
if resp.StatusCode != 200 && resp.StatusCode != 206 { if resp.StatusCode != 200 && resp.StatusCode != 206 {
io.Copy(ioutil.Discard, resp.Body) _ = resp.Body.Close()
resp.Body.Close()
return nil, errors.Errorf("unexpected HTTP response (%v): %v", resp.StatusCode, resp.Status) return nil, errors.Errorf("unexpected HTTP response (%v): %v", resp.StatusCode, resp.Status)
} }
@ -185,19 +190,19 @@ func (b *restBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCl
} }
// Stat returns information about a blob. // Stat returns information about a blob.
func (b *restBackend) Stat(h restic.Handle) (restic.FileInfo, error) { func (b *restBackend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) {
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return restic.FileInfo{}, err return restic.FileInfo{}, err
} }
<-b.connChan <-b.connChan
resp, err := b.client.Head(b.Filename(h)) resp, err := ctxhttp.Head(ctx, b.client, b.Filename(h))
b.connChan <- struct{}{} b.connChan <- struct{}{}
if err != nil { if err != nil {
return restic.FileInfo{}, errors.Wrap(err, "client.Head") return restic.FileInfo{}, errors.Wrap(err, "client.Head")
} }
io.Copy(ioutil.Discard, resp.Body) _, _ = io.Copy(ioutil.Discard, resp.Body)
if err = resp.Body.Close(); err != nil { if err = resp.Body.Close(); err != nil {
return restic.FileInfo{}, errors.Wrap(err, "Close") return restic.FileInfo{}, errors.Wrap(err, "Close")
} }
@ -218,8 +223,8 @@ func (b *restBackend) Stat(h restic.Handle) (restic.FileInfo, error) {
} }
// Test returns true if a blob of the given type and name exists in the backend. // Test returns true if a blob of the given type and name exists in the backend.
func (b *restBackend) Test(h restic.Handle) (bool, error) { func (b *restBackend) Test(ctx context.Context, h restic.Handle) (bool, error) {
_, err := b.Stat(h) _, err := b.Stat(ctx, h)
if err != nil { if err != nil {
return false, nil return false, nil
} }
@ -228,7 +233,7 @@ func (b *restBackend) Test(h restic.Handle) (bool, error) {
} }
// Remove removes the blob with the given name and type. // Remove removes the blob with the given name and type.
func (b *restBackend) Remove(h restic.Handle) error { func (b *restBackend) Remove(ctx context.Context, h restic.Handle) error {
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return err return err
} }
@ -238,7 +243,7 @@ func (b *restBackend) Remove(h restic.Handle) error {
return errors.Wrap(err, "http.NewRequest") return errors.Wrap(err, "http.NewRequest")
} }
<-b.connChan <-b.connChan
resp, err := b.client.Do(req) resp, err := ctxhttp.Do(ctx, b.client, req)
b.connChan <- struct{}{} b.connChan <- struct{}{}
if err != nil { if err != nil {
@ -249,14 +254,18 @@ func (b *restBackend) Remove(h restic.Handle) error {
return errors.Errorf("blob not removed, server response: %v (%v)", resp.Status, resp.StatusCode) return errors.Errorf("blob not removed, server response: %v (%v)", resp.Status, resp.StatusCode)
} }
io.Copy(ioutil.Discard, resp.Body) _, err = io.Copy(ioutil.Discard, resp.Body)
return resp.Body.Close() if err != nil {
return errors.Wrap(err, "Copy")
}
return errors.Wrap(resp.Body.Close(), "Close")
} }
// List returns a channel that yields all names of blobs of type t. A // List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending // goroutine is started for this. If the channel done is closed, sending
// stops. // stops.
func (b *restBackend) List(t restic.FileType, done <-chan struct{}) <-chan string { func (b *restBackend) List(ctx context.Context, t restic.FileType) <-chan string {
ch := make(chan string) ch := make(chan string)
url := b.Dirname(restic.Handle{Type: t}) url := b.Dirname(restic.Handle{Type: t})
@ -265,12 +274,12 @@ func (b *restBackend) List(t restic.FileType, done <-chan struct{}) <-chan strin
} }
<-b.connChan <-b.connChan
resp, err := b.client.Get(url) resp, err := ctxhttp.Get(ctx, b.client, url)
b.connChan <- struct{}{} b.connChan <- struct{}{}
if resp != nil { if resp != nil {
defer func() { defer func() {
io.Copy(ioutil.Discard, resp.Body) _, _ = io.Copy(ioutil.Discard, resp.Body)
e := resp.Body.Close() e := resp.Body.Close()
if err == nil { if err == nil {
@ -296,7 +305,7 @@ func (b *restBackend) List(t restic.FileType, done <-chan struct{}) <-chan strin
for _, m := range list { for _, m := range list {
select { select {
case ch <- m: case ch <- m:
case <-done: case <-ctx.Done():
return return
} }
} }

View file

@ -1,6 +1,7 @@
package s3 package s3
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -31,6 +32,9 @@ type s3 struct {
backend.Layout backend.Layout
} }
// make sure that *s3 implements backend.Backend
var _ restic.Backend = &s3{}
const defaultLayout = "s3legacy" const defaultLayout = "s3legacy"
// Open opens the S3 backend at bucket and region. The bucket is created if it // Open opens the S3 backend at bucket and region. The bucket is created if it
@ -202,7 +206,7 @@ func (wr preventCloser) Close() error {
} }
// Save stores data in the backend at the handle. // Save stores data in the backend at the handle.
func (be *s3) Save(h restic.Handle, rd io.Reader) (err error) { func (be *s3) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return err return err
} }
@ -259,7 +263,7 @@ func (wr wrapReader) Close() error {
// Load returns a reader that yields the contents of the file at h at the // Load returns a reader that yields the contents of the file at h at the
// given offset. If length is nonzero, only a portion of the file is // given offset. If length is nonzero, only a portion of the file is
// returned. rd must be closed after use. // returned. rd must be closed after use.
func (be *s3) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (be *s3) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
debug.Log("Load %v, length %v, offset %v from %v", h, length, offset, be.Filename(h)) debug.Log("Load %v, length %v, offset %v from %v", h, length, offset, be.Filename(h))
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return nil, err return nil, err
@ -307,7 +311,7 @@ func (be *s3) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, er
} }
// Stat returns information about a blob. // Stat returns information about a blob.
func (be *s3) Stat(h restic.Handle) (bi restic.FileInfo, err error) { func (be *s3) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInfo, err error) {
debug.Log("%v", h) debug.Log("%v", h)
objName := be.Filename(h) objName := be.Filename(h)
@ -337,7 +341,7 @@ func (be *s3) Stat(h restic.Handle) (bi restic.FileInfo, err error) {
} }
// Test returns true if a blob of the given type and name exists in the backend. // Test returns true if a blob of the given type and name exists in the backend.
func (be *s3) Test(h restic.Handle) (bool, error) { func (be *s3) Test(ctx context.Context, h restic.Handle) (bool, error) {
found := false found := false
objName := be.Filename(h) objName := be.Filename(h)
_, err := be.client.StatObject(be.bucketname, objName) _, err := be.client.StatObject(be.bucketname, objName)
@ -350,7 +354,7 @@ func (be *s3) Test(h restic.Handle) (bool, error) {
} }
// Remove removes the blob with the given name and type. // Remove removes the blob with the given name and type.
func (be *s3) Remove(h restic.Handle) error { func (be *s3) Remove(ctx context.Context, h restic.Handle) error {
objName := be.Filename(h) objName := be.Filename(h)
err := be.client.RemoveObject(be.bucketname, objName) err := be.client.RemoveObject(be.bucketname, objName)
debug.Log("Remove(%v) at %v -> err %v", h, objName, err) debug.Log("Remove(%v) at %v -> err %v", h, objName, err)
@ -360,7 +364,7 @@ func (be *s3) Remove(h restic.Handle) error {
// List returns a channel that yields all names of blobs of type t. A // List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending // goroutine is started for this. If the channel done is closed, sending
// stops. // stops.
func (be *s3) List(t restic.FileType, done <-chan struct{}) <-chan string { func (be *s3) List(ctx context.Context, t restic.FileType) <-chan string {
debug.Log("listing %v", t) debug.Log("listing %v", t)
ch := make(chan string) ch := make(chan string)
@ -371,7 +375,7 @@ func (be *s3) List(t restic.FileType, done <-chan struct{}) <-chan string {
prefix += "/" prefix += "/"
} }
listresp := be.client.ListObjects(be.bucketname, prefix, true, done) listresp := be.client.ListObjects(be.bucketname, prefix, true, ctx.Done())
go func() { go func() {
defer close(ch) defer close(ch)
@ -383,7 +387,7 @@ func (be *s3) List(t restic.FileType, done <-chan struct{}) <-chan string {
select { select {
case ch <- path.Base(m): case ch <- path.Base(m):
case <-done: case <-ctx.Done():
return return
} }
} }
@ -393,11 +397,9 @@ func (be *s3) List(t restic.FileType, done <-chan struct{}) <-chan string {
} }
// Remove keys for a specified backend type. // Remove keys for a specified backend type.
func (be *s3) removeKeys(t restic.FileType) error { func (be *s3) removeKeys(ctx context.Context, t restic.FileType) error {
done := make(chan struct{}) for key := range be.List(ctx, restic.DataFile) {
defer close(done) err := be.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key})
for key := range be.List(restic.DataFile, done) {
err := be.Remove(restic.Handle{Type: restic.DataFile, Name: key})
if err != nil { if err != nil {
return err return err
} }
@ -407,7 +409,7 @@ func (be *s3) removeKeys(t restic.FileType) error {
} }
// Delete removes all restic keys in the bucket. It will not remove the bucket itself. // Delete removes all restic keys in the bucket. It will not remove the bucket itself.
func (be *s3) Delete() error { func (be *s3) Delete(ctx context.Context) error {
alltypes := []restic.FileType{ alltypes := []restic.FileType{
restic.DataFile, restic.DataFile,
restic.KeyFile, restic.KeyFile,
@ -416,13 +418,13 @@ func (be *s3) Delete() error {
restic.IndexFile} restic.IndexFile}
for _, t := range alltypes { for _, t := range alltypes {
err := be.removeKeys(t) err := be.removeKeys(ctx, t)
if err != nil { if err != nil {
return nil return nil
} }
} }
return be.Remove(restic.Handle{Type: restic.ConfigFile}) return be.Remove(ctx, restic.Handle{Type: restic.ConfigFile})
} }
// Close does nothing // Close does nothing

View file

@ -134,7 +134,7 @@ func newMinioTestSuite(ctx context.Context, t testing.TB) *test.Suite {
return nil, err return nil, err
} }
exists, err := be.Test(restic.Handle{Type: restic.ConfigFile}) exists, err := be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -228,7 +228,7 @@ func newS3TestSuite(t testing.TB) *test.Suite {
return nil, err return nil, err
} }
exists, err := be.Test(restic.Handle{Type: restic.ConfigFile}) exists, err := be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -255,7 +255,7 @@ func newS3TestSuite(t testing.TB) *test.Suite {
return err return err
} }
if err := be.(restic.Deleter).Delete(); err != nil { if err := be.(restic.Deleter).Delete(context.TODO()); err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package sftp_test package sftp_test
import ( import (
"context"
"fmt" "fmt"
"path/filepath" "path/filepath"
"restic" "restic"
@ -54,7 +55,7 @@ func TestLayout(t *testing.T) {
} }
datafiles := make(map[string]bool) datafiles := make(map[string]bool)
for id := range be.List(restic.DataFile, nil) { for id := range be.List(context.TODO(), restic.DataFile) {
datafiles[id] = false datafiles[id] = false
} }

View file

@ -2,6 +2,7 @@ package sftp
import ( import (
"bufio" "bufio"
"context"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -262,7 +263,7 @@ func Join(parts ...string) string {
} }
// Save stores data in the backend at the handle. // Save stores data in the backend at the handle.
func (r *SFTP) Save(h restic.Handle, rd io.Reader) (err error) { func (r *SFTP) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
debug.Log("Save %v", h) debug.Log("Save %v", h)
if err := r.clientError(); err != nil { if err := r.clientError(); err != nil {
return err return err
@ -283,7 +284,7 @@ func (r *SFTP) Save(h restic.Handle, rd io.Reader) (err error) {
return errors.Wrap(err, "MkdirAll") return errors.Wrap(err, "MkdirAll")
} }
return r.Save(h, rd) return r.Save(ctx, h, rd)
} }
if err != nil { if err != nil {
@ -315,7 +316,7 @@ func (r *SFTP) Save(h restic.Handle, rd io.Reader) (err error) {
// Load returns a reader that yields the contents of the file at h at the // Load returns a reader that yields the contents of the file at h at the
// given offset. If length is nonzero, only a portion of the file is // given offset. If length is nonzero, only a portion of the file is
// returned. rd must be closed after use. // returned. rd must be closed after use.
func (r *SFTP) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (r *SFTP) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
debug.Log("Load %v, length %v, offset %v", h, length, offset) debug.Log("Load %v, length %v, offset %v", h, length, offset)
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return nil, err return nil, err
@ -346,7 +347,7 @@ func (r *SFTP) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, e
} }
// Stat returns information about a blob. // Stat returns information about a blob.
func (r *SFTP) Stat(h restic.Handle) (restic.FileInfo, error) { func (r *SFTP) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) {
debug.Log("Stat(%v)", h) debug.Log("Stat(%v)", h)
if err := r.clientError(); err != nil { if err := r.clientError(); err != nil {
return restic.FileInfo{}, err return restic.FileInfo{}, err
@ -365,7 +366,7 @@ func (r *SFTP) Stat(h restic.Handle) (restic.FileInfo, error) {
} }
// Test returns true if a blob of the given type and name exists in the backend. // Test returns true if a blob of the given type and name exists in the backend.
func (r *SFTP) Test(h restic.Handle) (bool, error) { func (r *SFTP) Test(ctx context.Context, h restic.Handle) (bool, error) {
debug.Log("Test(%v)", h) debug.Log("Test(%v)", h)
if err := r.clientError(); err != nil { if err := r.clientError(); err != nil {
return false, err return false, err
@ -384,7 +385,7 @@ func (r *SFTP) Test(h restic.Handle) (bool, error) {
} }
// Remove removes the content stored at name. // Remove removes the content stored at name.
func (r *SFTP) Remove(h restic.Handle) error { func (r *SFTP) Remove(ctx context.Context, h restic.Handle) error {
debug.Log("Remove(%v)", h) debug.Log("Remove(%v)", h)
if err := r.clientError(); err != nil { if err := r.clientError(); err != nil {
return err return err
@ -396,7 +397,7 @@ func (r *SFTP) Remove(h restic.Handle) error {
// List returns a channel that yields all names of blobs of type t. A // List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending // goroutine is started for this. If the channel done is closed, sending
// stops. // stops.
func (r *SFTP) List(t restic.FileType, done <-chan struct{}) <-chan string { func (r *SFTP) List(ctx context.Context, t restic.FileType) <-chan string {
debug.Log("List %v", t) debug.Log("List %v", t)
ch := make(chan string) ch := make(chan string)
@ -416,7 +417,7 @@ func (r *SFTP) List(t restic.FileType, done <-chan struct{}) <-chan string {
select { select {
case ch <- path.Base(walker.Path()): case ch <- path.Base(walker.Path()):
case <-done: case <-ctx.Done():
return return
} }
} }

View file

@ -1,6 +1,7 @@
package swift package swift
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -27,6 +28,9 @@ type beSwift struct {
backend.Layout backend.Layout
} }
// ensure statically that *beSwift implements restic.Backend.
var _ restic.Backend = &beSwift{}
// Open opens the swift backend at a container in region. The container is // Open opens the swift backend at a container in region. The container is
// created if it does not exist yet. // created if it does not exist yet.
func Open(cfg Config) (restic.Backend, error) { func Open(cfg Config) (restic.Backend, error) {
@ -120,7 +124,7 @@ func (be *beSwift) Location() string {
// Load returns a reader that yields the contents of the file at h at the // Load returns a reader that yields the contents of the file at h at the
// given offset. If length is nonzero, only a portion of the file is // given offset. If length is nonzero, only a portion of the file is
// returned. rd must be closed after use. // returned. rd must be closed after use.
func (be *beSwift) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (be *beSwift) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
debug.Log("Load %v, length %v, offset %v", h, length, offset) debug.Log("Load %v, length %v, offset %v", h, length, offset)
if err := h.Valid(); err != nil { if err := h.Valid(); err != nil {
return nil, err return nil, err
@ -164,7 +168,7 @@ func (be *beSwift) Load(h restic.Handle, length int, offset int64) (io.ReadClose
} }
// Save stores data in the backend at the handle. // Save stores data in the backend at the handle.
func (be *beSwift) Save(h restic.Handle, rd io.Reader) (err error) { func (be *beSwift) Save(ctx context.Context, h restic.Handle, rd io.Reader) (err error) {
if err = h.Valid(); err != nil { if err = h.Valid(); err != nil {
return err return err
} }
@ -201,7 +205,7 @@ func (be *beSwift) Save(h restic.Handle, rd io.Reader) (err error) {
} }
// Stat returns information about a blob. // Stat returns information about a blob.
func (be *beSwift) Stat(h restic.Handle) (bi restic.FileInfo, err error) { func (be *beSwift) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInfo, err error) {
debug.Log("%v", h) debug.Log("%v", h)
objName := be.Filename(h) objName := be.Filename(h)
@ -216,7 +220,7 @@ func (be *beSwift) Stat(h restic.Handle) (bi restic.FileInfo, err error) {
} }
// Test returns true if a blob of the given type and name exists in the backend. // Test returns true if a blob of the given type and name exists in the backend.
func (be *beSwift) Test(h restic.Handle) (bool, error) { func (be *beSwift) Test(ctx context.Context, h restic.Handle) (bool, error) {
objName := be.Filename(h) objName := be.Filename(h)
switch _, _, err := be.conn.Object(be.container, objName); err { switch _, _, err := be.conn.Object(be.container, objName); err {
case nil: case nil:
@ -231,7 +235,7 @@ func (be *beSwift) Test(h restic.Handle) (bool, error) {
} }
// Remove removes the blob with the given name and type. // Remove removes the blob with the given name and type.
func (be *beSwift) Remove(h restic.Handle) error { func (be *beSwift) Remove(ctx context.Context, h restic.Handle) error {
objName := be.Filename(h) objName := be.Filename(h)
err := be.conn.ObjectDelete(be.container, objName) err := be.conn.ObjectDelete(be.container, objName)
debug.Log("Remove(%v) -> err %v", h, err) debug.Log("Remove(%v) -> err %v", h, err)
@ -241,7 +245,7 @@ func (be *beSwift) Remove(h restic.Handle) error {
// List returns a channel that yields all names of blobs of type t. A // List returns a channel that yields all names of blobs of type t. A
// goroutine is started for this. If the channel done is closed, sending // goroutine is started for this. If the channel done is closed, sending
// stops. // stops.
func (be *beSwift) List(t restic.FileType, done <-chan struct{}) <-chan string { func (be *beSwift) List(ctx context.Context, t restic.FileType) <-chan string {
debug.Log("listing %v", t) debug.Log("listing %v", t)
ch := make(chan string) ch := make(chan string)
@ -264,7 +268,7 @@ func (be *beSwift) List(t restic.FileType, done <-chan struct{}) <-chan string {
select { select {
case ch <- m: case ch <- m:
case <-done: case <-ctx.Done():
return nil, io.EOF return nil, io.EOF
} }
} }
@ -280,11 +284,9 @@ func (be *beSwift) List(t restic.FileType, done <-chan struct{}) <-chan string {
} }
// Remove keys for a specified backend type. // Remove keys for a specified backend type.
func (be *beSwift) removeKeys(t restic.FileType) error { func (be *beSwift) removeKeys(ctx context.Context, t restic.FileType) error {
done := make(chan struct{}) for key := range be.List(ctx, t) {
defer close(done) err := be.Remove(ctx, restic.Handle{Type: t, Name: key})
for key := range be.List(t, done) {
err := be.Remove(restic.Handle{Type: t, Name: key})
if err != nil { if err != nil {
return err return err
} }
@ -304,7 +306,7 @@ func (be *beSwift) IsNotExist(err error) bool {
// Delete removes all restic objects in the container. // Delete removes all restic objects in the container.
// It will not remove the container itself. // It will not remove the container itself.
func (be *beSwift) Delete() error { func (be *beSwift) Delete(ctx context.Context) error {
alltypes := []restic.FileType{ alltypes := []restic.FileType{
restic.DataFile, restic.DataFile,
restic.KeyFile, restic.KeyFile,
@ -313,13 +315,13 @@ func (be *beSwift) Delete() error {
restic.IndexFile} restic.IndexFile}
for _, t := range alltypes { for _, t := range alltypes {
err := be.removeKeys(t) err := be.removeKeys(ctx, t)
if err != nil { if err != nil {
return nil return nil
} }
} }
err := be.Remove(restic.Handle{Type: restic.ConfigFile}) err := be.Remove(ctx, restic.Handle{Type: restic.ConfigFile})
if err != nil && !be.IsNotExist(err) { if err != nil && !be.IsNotExist(err) {
return err return err
} }

View file

@ -1,6 +1,7 @@
package swift_test package swift_test
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"restic" "restic"
@ -44,7 +45,7 @@ func newSwiftTestSuite(t testing.TB) *test.Suite {
return nil, err return nil, err
} }
exists, err := be.Test(restic.Handle{Type: restic.ConfigFile}) exists, err := be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,7 +72,7 @@ func newSwiftTestSuite(t testing.TB) *test.Suite {
return err return err
} }
if err := be.(restic.Deleter).Delete(); err != nil { if err := be.(restic.Deleter).Delete(context.TODO()); err != nil {
return err return err
} }

View file

@ -2,6 +2,7 @@ package test
import ( import (
"bytes" "bytes"
"context"
"io" "io"
"restic" "restic"
"restic/test" "restic/test"
@ -12,14 +13,14 @@ func saveRandomFile(t testing.TB, be restic.Backend, length int) ([]byte, restic
data := test.Random(23, length) data := test.Random(23, length)
id := restic.Hash(data) id := restic.Hash(data)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()} handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
if err := be.Save(handle, bytes.NewReader(data)); err != nil { if err := be.Save(context.TODO(), handle, bytes.NewReader(data)); err != nil {
t.Fatalf("Save() error: %+v", err) t.Fatalf("Save() error: %+v", err)
} }
return data, handle return data, handle
} }
func remove(t testing.TB, be restic.Backend, h restic.Handle) { func remove(t testing.TB, be restic.Backend, h restic.Handle) {
if err := be.Remove(h); err != nil { if err := be.Remove(context.TODO(), h); err != nil {
t.Fatalf("Remove() returned error: %v", err) t.Fatalf("Remove() returned error: %v", err)
} }
} }
@ -40,7 +41,7 @@ func (s *Suite) BenchmarkLoadFile(t *testing.B) {
t.ResetTimer() t.ResetTimer()
for i := 0; i < t.N; i++ { for i := 0; i < t.N; i++ {
rd, err := be.Load(handle, 0, 0) rd, err := be.Load(context.TODO(), handle, 0, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -82,7 +83,7 @@ func (s *Suite) BenchmarkLoadPartialFile(t *testing.B) {
t.ResetTimer() t.ResetTimer()
for i := 0; i < t.N; i++ { for i := 0; i < t.N; i++ {
rd, err := be.Load(handle, testLength, 0) rd, err := be.Load(context.TODO(), handle, testLength, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -126,7 +127,7 @@ func (s *Suite) BenchmarkLoadPartialFileOffset(t *testing.B) {
t.ResetTimer() t.ResetTimer()
for i := 0; i < t.N; i++ { for i := 0; i < t.N; i++ {
rd, err := be.Load(handle, testLength, int64(testOffset)) rd, err := be.Load(context.TODO(), handle, testLength, int64(testOffset))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -171,11 +172,11 @@ func (s *Suite) BenchmarkSave(t *testing.B) {
t.Fatal(err) t.Fatal(err)
} }
if err := be.Save(handle, rd); err != nil { if err := be.Save(context.TODO(), handle, rd); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := be.Remove(handle); err != nil { if err := be.Remove(context.TODO(), handle); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View file

@ -2,6 +2,7 @@ package test
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -34,7 +35,7 @@ func (s *Suite) TestCreateWithConfig(t *testing.T) {
// remove a config if present // remove a config if present
cfgHandle := restic.Handle{Type: restic.ConfigFile} cfgHandle := restic.Handle{Type: restic.ConfigFile}
cfgPresent, err := b.Test(cfgHandle) cfgPresent, err := b.Test(context.TODO(), cfgHandle)
if err != nil { if err != nil {
t.Fatalf("unable to test for config: %+v", err) t.Fatalf("unable to test for config: %+v", err)
} }
@ -53,7 +54,7 @@ func (s *Suite) TestCreateWithConfig(t *testing.T) {
} }
// remove config // remove config
err = b.Remove(restic.Handle{Type: restic.ConfigFile, Name: ""}) err = b.Remove(context.TODO(), restic.Handle{Type: restic.ConfigFile, Name: ""})
if err != nil { if err != nil {
t.Fatalf("unexpected error removing config: %+v", err) t.Fatalf("unexpected error removing config: %+v", err)
} }
@ -78,12 +79,12 @@ func (s *Suite) TestConfig(t *testing.T) {
var testString = "Config" var testString = "Config"
// create config and read it back // create config and read it back
_, err := backend.LoadAll(b, restic.Handle{Type: restic.ConfigFile}) _, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.ConfigFile})
if err == nil { if err == nil {
t.Fatalf("did not get expected error for non-existing config") t.Fatalf("did not get expected error for non-existing config")
} }
err = b.Save(restic.Handle{Type: restic.ConfigFile}, strings.NewReader(testString)) err = b.Save(context.TODO(), restic.Handle{Type: restic.ConfigFile}, strings.NewReader(testString))
if err != nil { if err != nil {
t.Fatalf("Save() error: %+v", err) t.Fatalf("Save() error: %+v", err)
} }
@ -92,7 +93,7 @@ func (s *Suite) TestConfig(t *testing.T) {
// same config // same config
for _, name := range []string{"", "foo", "bar", "0000000000000000000000000000000000000000000000000000000000000000"} { for _, name := range []string{"", "foo", "bar", "0000000000000000000000000000000000000000000000000000000000000000"} {
h := restic.Handle{Type: restic.ConfigFile, Name: name} h := restic.Handle{Type: restic.ConfigFile, Name: name}
buf, err := backend.LoadAll(b, h) buf, err := backend.LoadAll(context.TODO(), b, h)
if err != nil { if err != nil {
t.Fatalf("unable to read config with name %q: %+v", name, err) t.Fatalf("unable to read config with name %q: %+v", name, err)
} }
@ -113,12 +114,12 @@ func (s *Suite) TestLoad(t *testing.T) {
b := s.open(t) b := s.open(t)
defer s.close(t, b) defer s.close(t, b)
rd, err := b.Load(restic.Handle{}, 0, 0) rd, err := b.Load(context.TODO(), restic.Handle{}, 0, 0)
if err == nil { if err == nil {
t.Fatalf("Load() did not return an error for invalid handle") t.Fatalf("Load() did not return an error for invalid handle")
} }
if rd != nil { if rd != nil {
rd.Close() _ = rd.Close()
} }
err = testLoad(b, restic.Handle{Type: restic.DataFile, Name: "foobar"}, 0, 0) err = testLoad(b, restic.Handle{Type: restic.DataFile, Name: "foobar"}, 0, 0)
@ -132,14 +133,14 @@ func (s *Suite) TestLoad(t *testing.T) {
id := restic.Hash(data) id := restic.Hash(data)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()} handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
err = b.Save(handle, bytes.NewReader(data)) err = b.Save(context.TODO(), handle, bytes.NewReader(data))
if err != nil { if err != nil {
t.Fatalf("Save() error: %+v", err) t.Fatalf("Save() error: %+v", err)
} }
t.Logf("saved %d bytes as %v", length, handle) t.Logf("saved %d bytes as %v", length, handle)
rd, err = b.Load(handle, 100, -1) rd, err = b.Load(context.TODO(), handle, 100, -1)
if err == nil { if err == nil {
t.Fatalf("Load() returned no error for negative offset!") t.Fatalf("Load() returned no error for negative offset!")
} }
@ -174,7 +175,7 @@ func (s *Suite) TestLoad(t *testing.T) {
d = d[:l] d = d[:l]
} }
rd, err := b.Load(handle, getlen, int64(o)) rd, err := b.Load(context.TODO(), handle, getlen, int64(o))
if err != nil { if err != nil {
t.Logf("Load, l %v, o %v, len(d) %v, getlen %v", l, o, len(d), getlen) t.Logf("Load, l %v, o %v, len(d) %v, getlen %v", l, o, len(d), getlen)
t.Errorf("Load(%d, %d) returned unexpected error: %+v", l, o, err) t.Errorf("Load(%d, %d) returned unexpected error: %+v", l, o, err)
@ -235,7 +236,7 @@ func (s *Suite) TestLoad(t *testing.T) {
} }
} }
test.OK(t, b.Remove(handle)) test.OK(t, b.Remove(context.TODO(), handle))
} }
type errorCloser struct { type errorCloser struct {
@ -276,10 +277,10 @@ func (s *Suite) TestSave(t *testing.T) {
Type: restic.DataFile, Type: restic.DataFile,
Name: fmt.Sprintf("%s-%d", id, i), Name: fmt.Sprintf("%s-%d", id, i),
} }
err := b.Save(h, bytes.NewReader(data)) err := b.Save(context.TODO(), h, bytes.NewReader(data))
test.OK(t, err) test.OK(t, err)
buf, err := backend.LoadAll(b, h) buf, err := backend.LoadAll(context.TODO(), b, h)
test.OK(t, err) test.OK(t, err)
if len(buf) != len(data) { if len(buf) != len(data) {
t.Fatalf("number of bytes does not match, want %v, got %v", len(data), len(buf)) t.Fatalf("number of bytes does not match, want %v, got %v", len(data), len(buf))
@ -289,14 +290,14 @@ func (s *Suite) TestSave(t *testing.T) {
t.Fatalf("data not equal") t.Fatalf("data not equal")
} }
fi, err := b.Stat(h) fi, err := b.Stat(context.TODO(), h)
test.OK(t, err) test.OK(t, err)
if fi.Size != int64(len(data)) { if fi.Size != int64(len(data)) {
t.Fatalf("Stat() returned different size, want %q, got %d", len(data), fi.Size) t.Fatalf("Stat() returned different size, want %q, got %d", len(data), fi.Size)
} }
err = b.Remove(h) err = b.Remove(context.TODO(), h)
if err != nil { if err != nil {
t.Fatalf("error removing item: %+v", err) t.Fatalf("error removing item: %+v", err)
} }
@ -324,12 +325,12 @@ func (s *Suite) TestSave(t *testing.T) {
// wrap the tempfile in an errorCloser, so we can detect if the backend // wrap the tempfile in an errorCloser, so we can detect if the backend
// closes the reader // closes the reader
err = b.Save(h, errorCloser{t: t, size: int64(length), Reader: tmpfile}) err = b.Save(context.TODO(), h, errorCloser{t: t, size: int64(length), Reader: tmpfile})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = b.Remove(h) err = b.Remove(context.TODO(), h)
if err != nil { if err != nil {
t.Fatalf("error removing item: %+v", err) t.Fatalf("error removing item: %+v", err)
} }
@ -339,7 +340,7 @@ func (s *Suite) TestSave(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = b.Save(h, tmpfile) err = b.Save(context.TODO(), h, tmpfile)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -348,7 +349,7 @@ func (s *Suite) TestSave(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = b.Remove(h) err = b.Remove(context.TODO(), h)
if err != nil { if err != nil {
t.Fatalf("error removing item: %+v", err) t.Fatalf("error removing item: %+v", err)
} }
@ -377,13 +378,13 @@ func (s *Suite) TestSaveFilenames(t *testing.T) {
for i, test := range filenameTests { for i, test := range filenameTests {
h := restic.Handle{Name: test.name, Type: restic.DataFile} h := restic.Handle{Name: test.name, Type: restic.DataFile}
err := b.Save(h, strings.NewReader(test.data)) err := b.Save(context.TODO(), h, strings.NewReader(test.data))
if err != nil { if err != nil {
t.Errorf("test %d failed: Save() returned %+v", i, err) t.Errorf("test %d failed: Save() returned %+v", i, err)
continue continue
} }
buf, err := backend.LoadAll(b, h) buf, err := backend.LoadAll(context.TODO(), b, h)
if err != nil { if err != nil {
t.Errorf("test %d failed: Load() returned %+v", i, err) t.Errorf("test %d failed: Load() returned %+v", i, err)
continue continue
@ -393,7 +394,7 @@ func (s *Suite) TestSaveFilenames(t *testing.T) {
t.Errorf("test %d: returned wrong bytes", i) t.Errorf("test %d: returned wrong bytes", i)
} }
err = b.Remove(h) err = b.Remove(context.TODO(), h)
if err != nil { if err != nil {
t.Errorf("test %d failed: Remove() returned %+v", i, err) t.Errorf("test %d failed: Remove() returned %+v", i, err)
continue continue
@ -414,14 +415,14 @@ var testStrings = []struct {
func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle { func store(t testing.TB, b restic.Backend, tpe restic.FileType, data []byte) restic.Handle {
id := restic.Hash(data) id := restic.Hash(data)
h := restic.Handle{Name: id.String(), Type: tpe} h := restic.Handle{Name: id.String(), Type: tpe}
err := b.Save(h, bytes.NewReader(data)) err := b.Save(context.TODO(), h, bytes.NewReader(data))
test.OK(t, err) test.OK(t, err)
return h return h
} }
// testLoad loads a blob (but discards its contents). // testLoad loads a blob (but discards its contents).
func testLoad(b restic.Backend, h restic.Handle, length int, offset int64) error { func testLoad(b restic.Backend, h restic.Handle, length int, offset int64) error {
rd, err := b.Load(h, 0, 0) rd, err := b.Load(context.TODO(), h, 0, 0)
if err != nil { if err != nil {
return err return err
} }
@ -437,14 +438,14 @@ func testLoad(b restic.Backend, h restic.Handle, length int, offset int64) error
func delayedRemove(b restic.Backend, h restic.Handle) error { func delayedRemove(b restic.Backend, h restic.Handle) error {
// Some backend (swift, I'm looking at you) may implement delayed // Some backend (swift, I'm looking at you) may implement delayed
// removal of data. Let's wait a bit if this happens. // removal of data. Let's wait a bit if this happens.
err := b.Remove(h) err := b.Remove(context.TODO(), h)
if err != nil { if err != nil {
return err return err
} }
found, err := b.Test(h) found, err := b.Test(context.TODO(), h)
for i := 0; found && i < 20; i++ { for i := 0; found && i < 20; i++ {
found, err = b.Test(h) found, err = b.Test(context.TODO(), h)
if found { if found {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
@ -468,12 +469,12 @@ func (s *Suite) TestBackend(t *testing.T) {
// test if blob is already in repository // test if blob is already in repository
h := restic.Handle{Type: tpe, Name: id.String()} h := restic.Handle{Type: tpe, Name: id.String()}
ret, err := b.Test(h) ret, err := b.Test(context.TODO(), h)
test.OK(t, err) test.OK(t, err)
test.Assert(t, !ret, "blob was found to exist before creating") test.Assert(t, !ret, "blob was found to exist before creating")
// try to stat a not existing blob // try to stat a not existing blob
_, err = b.Stat(h) _, err = b.Stat(context.TODO(), h)
test.Assert(t, err != nil, "blob data could be extracted before creation") test.Assert(t, err != nil, "blob data could be extracted before creation")
// try to read not existing blob // try to read not existing blob
@ -481,7 +482,7 @@ func (s *Suite) TestBackend(t *testing.T) {
test.Assert(t, err != nil, "blob could be read before creation") test.Assert(t, err != nil, "blob could be read before creation")
// try to get string out, should fail // try to get string out, should fail
ret, err = b.Test(h) ret, err = b.Test(context.TODO(), h)
test.OK(t, err) test.OK(t, err)
test.Assert(t, !ret, "id %q was found (but should not have)", ts.id) test.Assert(t, !ret, "id %q was found (but should not have)", ts.id)
} }
@ -492,7 +493,7 @@ func (s *Suite) TestBackend(t *testing.T) {
// test Load() // test Load()
h := restic.Handle{Type: tpe, Name: ts.id} h := restic.Handle{Type: tpe, Name: ts.id}
buf, err := backend.LoadAll(b, h) buf, err := backend.LoadAll(context.TODO(), b, h)
test.OK(t, err) test.OK(t, err)
test.Equals(t, ts.data, string(buf)) test.Equals(t, ts.data, string(buf))
@ -502,7 +503,7 @@ func (s *Suite) TestBackend(t *testing.T) {
length := end - start length := end - start
buf2 := make([]byte, length) buf2 := make([]byte, length)
rd, err := b.Load(h, len(buf2), int64(start)) rd, err := b.Load(context.TODO(), h, len(buf2), int64(start))
test.OK(t, err) test.OK(t, err)
n, err := io.ReadFull(rd, buf2) n, err := io.ReadFull(rd, buf2)
test.OK(t, err) test.OK(t, err)
@ -522,7 +523,7 @@ func (s *Suite) TestBackend(t *testing.T) {
// create blob // create blob
h := restic.Handle{Type: tpe, Name: ts.id} h := restic.Handle{Type: tpe, Name: ts.id}
err := b.Save(h, strings.NewReader(ts.data)) err := b.Save(context.TODO(), h, strings.NewReader(ts.data))
test.Assert(t, err != nil, "expected error for %v, got %v", h, err) test.Assert(t, err != nil, "expected error for %v, got %v", h, err)
// remove and recreate // remove and recreate
@ -530,12 +531,12 @@ func (s *Suite) TestBackend(t *testing.T) {
test.OK(t, err) test.OK(t, err)
// test that the blob is gone // test that the blob is gone
ok, err := b.Test(h) ok, err := b.Test(context.TODO(), h)
test.OK(t, err) test.OK(t, err)
test.Assert(t, !ok, "removed blob still present") test.Assert(t, !ok, "removed blob still present")
// create blob // create blob
err = b.Save(h, strings.NewReader(ts.data)) err = b.Save(context.TODO(), h, strings.NewReader(ts.data))
test.OK(t, err) test.OK(t, err)
// list items // list items
@ -549,7 +550,7 @@ func (s *Suite) TestBackend(t *testing.T) {
list := restic.IDs{} list := restic.IDs{}
for s := range b.List(tpe, nil) { for s := range b.List(context.TODO(), tpe) {
list = append(list, restic.TestParseID(s)) list = append(list, restic.TestParseID(s))
} }
@ -572,13 +573,13 @@ func (s *Suite) TestBackend(t *testing.T) {
h := restic.Handle{Type: tpe, Name: id.String()} h := restic.Handle{Type: tpe, Name: id.String()}
found, err := b.Test(h) found, err := b.Test(context.TODO(), h)
test.OK(t, err) test.OK(t, err)
test.Assert(t, found, fmt.Sprintf("id %q not found", id)) test.Assert(t, found, fmt.Sprintf("id %q not found", id))
test.OK(t, delayedRemove(b, h)) test.OK(t, delayedRemove(b, h))
found, err = b.Test(h) found, err = b.Test(context.TODO(), h)
test.OK(t, err) test.OK(t, err)
test.Assert(t, !found, fmt.Sprintf("id %q not found after removal", id)) test.Assert(t, !found, fmt.Sprintf("id %q not found after removal", id))
} }
@ -600,7 +601,7 @@ func (s *Suite) TestDelete(t *testing.T) {
return return
} }
err := be.Delete() err := be.Delete(context.TODO())
if err != nil { if err != nil {
t.Fatalf("error deleting backend: %+v", err) t.Fatalf("error deleting backend: %+v", err)
} }

View file

@ -1,6 +1,7 @@
package test_test package test_test
import ( import (
"context"
"restic" "restic"
"restic/errors" "restic/errors"
"testing" "testing"
@ -26,7 +27,7 @@ func newTestSuite(t testing.TB) *test.Suite {
Create: func(cfg interface{}) (restic.Backend, error) { Create: func(cfg interface{}) (restic.Backend, error) {
c := cfg.(*memConfig) c := cfg.(*memConfig)
if c.be != nil { if c.be != nil {
ok, err := c.be.Test(restic.Handle{Type: restic.ConfigFile}) ok, err := c.be.Test(context.TODO(), restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,14 +1,15 @@
package backend package backend
import ( import (
"context"
"io" "io"
"io/ioutil" "io/ioutil"
"restic" "restic"
) )
// LoadAll reads all data stored in the backend for the handle. // LoadAll reads all data stored in the backend for the handle.
func LoadAll(be restic.Backend, h restic.Handle) (buf []byte, err error) { func LoadAll(ctx context.Context, be restic.Backend, h restic.Handle) (buf []byte, err error) {
rd, err := be.Load(h, 0, 0) rd, err := be.Load(ctx, h, 0, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -2,6 +2,7 @@ package backend_test
import ( import (
"bytes" "bytes"
"context"
"math/rand" "math/rand"
"restic" "restic"
"testing" "testing"
@ -21,10 +22,10 @@ func TestLoadAll(t *testing.T) {
data := Random(23+i, rand.Intn(MiB)+500*KiB) data := Random(23+i, rand.Intn(MiB)+500*KiB)
id := restic.Hash(data) id := restic.Hash(data)
err := b.Save(restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data))
OK(t, err) OK(t, err)
buf, err := backend.LoadAll(b, restic.Handle{Type: restic.DataFile, Name: id.String()}) buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()})
OK(t, err) OK(t, err)
if len(buf) != len(data) { if len(buf) != len(data) {
@ -46,10 +47,10 @@ func TestLoadSmallBuffer(t *testing.T) {
data := Random(23+i, rand.Intn(MiB)+500*KiB) data := Random(23+i, rand.Intn(MiB)+500*KiB)
id := restic.Hash(data) id := restic.Hash(data)
err := b.Save(restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data))
OK(t, err) OK(t, err)
buf, err := backend.LoadAll(b, restic.Handle{Type: restic.DataFile, Name: id.String()}) buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()})
OK(t, err) OK(t, err)
if len(buf) != len(data) { if len(buf) != len(data) {
@ -71,10 +72,10 @@ func TestLoadLargeBuffer(t *testing.T) {
data := Random(23+i, rand.Intn(MiB)+500*KiB) data := Random(23+i, rand.Intn(MiB)+500*KiB)
id := restic.Hash(data) id := restic.Hash(data)
err := b.Save(restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data)) err := b.Save(context.TODO(), restic.Handle{Name: id.String(), Type: restic.DataFile}, bytes.NewReader(data))
OK(t, err) OK(t, err)
buf, err := backend.LoadAll(b, restic.Handle{Type: restic.DataFile, Name: id.String()}) buf, err := backend.LoadAll(context.TODO(), b, restic.Handle{Type: restic.DataFile, Name: id.String()})
OK(t, err) OK(t, err)
if len(buf) != len(data) { if len(buf) != len(data) {

View file

@ -1,6 +1,9 @@
package restic package restic
import "restic/errors" import (
"context"
"restic/errors"
)
// ErrNoIDPrefixFound is returned by Find() when no ID for the given prefix // ErrNoIDPrefixFound is returned by Find() when no ID for the given prefix
// could be found. // could be found.
@ -14,13 +17,10 @@ var ErrMultipleIDMatches = errors.New("multiple IDs with prefix found")
// start with prefix. If none is found, nil and ErrNoIDPrefixFound is returned. // start with prefix. If none is found, nil and ErrNoIDPrefixFound is returned.
// If more than one is found, nil and ErrMultipleIDMatches is returned. // If more than one is found, nil and ErrMultipleIDMatches is returned.
func Find(be Lister, t FileType, prefix string) (string, error) { func Find(be Lister, t FileType, prefix string) (string, error) {
done := make(chan struct{})
defer close(done)
match := "" match := ""
// TODO: optimize by sorting list etc. // TODO: optimize by sorting list etc.
for name := range be.List(t, done) { for name := range be.List(context.TODO(), t) {
if prefix == name[:len(prefix)] { if prefix == name[:len(prefix)] {
if match == "" { if match == "" {
match = name match = name
@ -42,12 +42,9 @@ const minPrefixLength = 8
// PrefixLength returns the number of bytes required so that all prefixes of // PrefixLength returns the number of bytes required so that all prefixes of
// all names of type t are unique. // all names of type t are unique.
func PrefixLength(be Lister, t FileType) (int, error) { func PrefixLength(be Lister, t FileType) (int, error) {
done := make(chan struct{})
defer close(done)
// load all IDs of the given type // load all IDs of the given type
list := make([]string, 0, 100) list := make([]string, 0, 100)
for name := range be.List(t, done) { for name := range be.List(context.TODO(), t) {
list = append(list, name) list = append(list, name)
} }

View file

@ -1,15 +1,16 @@
package restic package restic
import ( import (
"context"
"testing" "testing"
) )
type mockBackend struct { type mockBackend struct {
list func(FileType, <-chan struct{}) <-chan string list func(context.Context, FileType) <-chan string
} }
func (m mockBackend) List(t FileType, done <-chan struct{}) <-chan string { func (m mockBackend) List(ctx context.Context, t FileType) <-chan string {
return m.list(t, done) return m.list(ctx, t)
} }
var samples = IDs{ var samples = IDs{
@ -27,14 +28,14 @@ func TestPrefixLength(t *testing.T) {
list := samples list := samples
m := mockBackend{} m := mockBackend{}
m.list = func(t FileType, done <-chan struct{}) <-chan string { m.list = func(ctx context.Context, t FileType) <-chan string {
ch := make(chan string) ch := make(chan string)
go func() { go func() {
defer close(ch) defer close(ch)
for _, id := range list { for _, id := range list {
select { select {
case ch <- id.String(): case ch <- id.String():
case <-done: case <-ctx.Done():
return return
} }
} }

View file

@ -1,6 +1,7 @@
package checker package checker
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"io" "io"
@ -76,7 +77,7 @@ func (err ErrOldIndexFormat) Error() string {
} }
// LoadIndex loads all index files. // LoadIndex loads all index files.
func (c *Checker) LoadIndex() (hints []error, errs []error) { func (c *Checker) LoadIndex(ctx context.Context) (hints []error, errs []error) {
debug.Log("Start") debug.Log("Start")
type indexRes struct { type indexRes struct {
Index *repository.Index Index *repository.Index
@ -86,21 +87,21 @@ func (c *Checker) LoadIndex() (hints []error, errs []error) {
indexCh := make(chan indexRes) indexCh := make(chan indexRes)
worker := func(id restic.ID, done <-chan struct{}) error { worker := func(ctx context.Context, id restic.ID) error {
debug.Log("worker got index %v", id) debug.Log("worker got index %v", id)
idx, err := repository.LoadIndexWithDecoder(c.repo, id, repository.DecodeIndex) idx, err := repository.LoadIndexWithDecoder(ctx, c.repo, id, repository.DecodeIndex)
if errors.Cause(err) == repository.ErrOldIndexFormat { if errors.Cause(err) == repository.ErrOldIndexFormat {
debug.Log("index %v has old format", id.Str()) debug.Log("index %v has old format", id.Str())
hints = append(hints, ErrOldIndexFormat{id}) hints = append(hints, ErrOldIndexFormat{id})
idx, err = repository.LoadIndexWithDecoder(c.repo, id, repository.DecodeOldIndex) idx, err = repository.LoadIndexWithDecoder(ctx, c.repo, id, repository.DecodeOldIndex)
} }
err = errors.Wrapf(err, "error loading index %v", id.Str()) err = errors.Wrapf(err, "error loading index %v", id.Str())
select { select {
case indexCh <- indexRes{Index: idx, ID: id.String(), err: err}: case indexCh <- indexRes{Index: idx, ID: id.String(), err: err}:
case <-done: case <-ctx.Done():
} }
return nil return nil
@ -109,7 +110,7 @@ func (c *Checker) LoadIndex() (hints []error, errs []error) {
go func() { go func() {
defer close(indexCh) defer close(indexCh)
debug.Log("start loading indexes in parallel") debug.Log("start loading indexes in parallel")
err := repository.FilesInParallel(c.repo.Backend(), restic.IndexFile, defaultParallelism, err := repository.FilesInParallel(ctx, c.repo.Backend(), restic.IndexFile, defaultParallelism,
repository.ParallelWorkFuncParseID(worker)) repository.ParallelWorkFuncParseID(worker))
debug.Log("loading indexes finished, error: %v", err) debug.Log("loading indexes finished, error: %v", err)
if err != nil { if err != nil {
@ -183,7 +184,7 @@ func (e PackError) Error() string {
return "pack " + e.ID.String() + ": " + e.Err.Error() return "pack " + e.ID.String() + ": " + e.Err.Error()
} }
func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan<- error, wg *sync.WaitGroup, done <-chan struct{}) { func packIDTester(ctx context.Context, repo restic.Repository, inChan <-chan restic.ID, errChan chan<- error, wg *sync.WaitGroup) {
debug.Log("worker start") debug.Log("worker start")
defer debug.Log("worker done") defer debug.Log("worker done")
@ -191,7 +192,7 @@ func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan<
for id := range inChan { for id := range inChan {
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
ok, err := repo.Backend().Test(h) ok, err := repo.Backend().Test(ctx, h)
if err != nil { if err != nil {
err = PackError{ID: id, Err: err} err = PackError{ID: id, Err: err}
} else { } else {
@ -203,7 +204,7 @@ func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan<
if err != nil { if err != nil {
debug.Log("error checking for pack %s: %v", id.Str(), err) debug.Log("error checking for pack %s: %v", id.Str(), err)
select { select {
case <-done: case <-ctx.Done():
return return
case errChan <- err: case errChan <- err:
} }
@ -218,7 +219,7 @@ func packIDTester(repo restic.Repository, inChan <-chan restic.ID, errChan chan<
// Packs checks that all packs referenced in the index are still available and // Packs checks that all packs referenced in the index are still available and
// there are no packs that aren't in an index. errChan is closed after all // there are no packs that aren't in an index. errChan is closed after all
// packs have been checked. // packs have been checked.
func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) { func (c *Checker) Packs(ctx context.Context, errChan chan<- error) {
defer close(errChan) defer close(errChan)
debug.Log("checking for %d packs", len(c.packs)) debug.Log("checking for %d packs", len(c.packs))
@ -229,7 +230,7 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) {
IDChan := make(chan restic.ID) IDChan := make(chan restic.ID)
for i := 0; i < defaultParallelism; i++ { for i := 0; i < defaultParallelism; i++ {
workerWG.Add(1) workerWG.Add(1)
go packIDTester(c.repo, IDChan, errChan, &workerWG, done) go packIDTester(ctx, c.repo, IDChan, errChan, &workerWG)
} }
for id := range c.packs { for id := range c.packs {
@ -242,12 +243,12 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) {
workerWG.Wait() workerWG.Wait()
debug.Log("workers terminated") debug.Log("workers terminated")
for id := range c.repo.List(restic.DataFile, done) { for id := range c.repo.List(ctx, restic.DataFile) {
debug.Log("check data blob %v", id.Str()) debug.Log("check data blob %v", id.Str())
if !seenPacks.Has(id) { if !seenPacks.Has(id) {
c.orphanedPacks = append(c.orphanedPacks, id) c.orphanedPacks = append(c.orphanedPacks, id)
select { select {
case <-done: case <-ctx.Done():
return return
case errChan <- PackError{ID: id, Orphaned: true, Err: errors.New("not referenced in any index")}: case errChan <- PackError{ID: id, Orphaned: true, Err: errors.New("not referenced in any index")}:
} }
@ -277,8 +278,8 @@ func (e Error) Error() string {
return e.Err.Error() return e.Err.Error()
} }
func loadTreeFromSnapshot(repo restic.Repository, id restic.ID) (restic.ID, error) { func loadTreeFromSnapshot(ctx context.Context, repo restic.Repository, id restic.ID) (restic.ID, error) {
sn, err := restic.LoadSnapshot(repo, id) sn, err := restic.LoadSnapshot(ctx, repo, id)
if err != nil { if err != nil {
debug.Log("error loading snapshot %v: %v", id.Str(), err) debug.Log("error loading snapshot %v: %v", id.Str(), err)
return restic.ID{}, err return restic.ID{}, err
@ -293,7 +294,7 @@ func loadTreeFromSnapshot(repo restic.Repository, id restic.ID) (restic.ID, erro
} }
// loadSnapshotTreeIDs loads all snapshots from backend and returns the tree IDs. // loadSnapshotTreeIDs loads all snapshots from backend and returns the tree IDs.
func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) { func loadSnapshotTreeIDs(ctx context.Context, repo restic.Repository) (restic.IDs, []error) {
var trees struct { var trees struct {
IDs restic.IDs IDs restic.IDs
sync.Mutex sync.Mutex
@ -304,7 +305,7 @@ func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) {
sync.Mutex sync.Mutex
} }
snapshotWorker := func(strID string, done <-chan struct{}) error { snapshotWorker := func(ctx context.Context, strID string) error {
id, err := restic.ParseID(strID) id, err := restic.ParseID(strID)
if err != nil { if err != nil {
return err return err
@ -312,7 +313,7 @@ func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) {
debug.Log("load snapshot %v", id.Str()) debug.Log("load snapshot %v", id.Str())
treeID, err := loadTreeFromSnapshot(repo, id) treeID, err := loadTreeFromSnapshot(ctx, repo, id)
if err != nil { if err != nil {
errs.Lock() errs.Lock()
errs.errs = append(errs.errs, err) errs.errs = append(errs.errs, err)
@ -328,7 +329,7 @@ func loadSnapshotTreeIDs(repo restic.Repository) (restic.IDs, []error) {
return nil return nil
} }
err := repository.FilesInParallel(repo.Backend(), restic.SnapshotFile, defaultParallelism, snapshotWorker) err := repository.FilesInParallel(ctx, repo.Backend(), restic.SnapshotFile, defaultParallelism, snapshotWorker)
if err != nil { if err != nil {
errs.errs = append(errs.errs, err) errs.errs = append(errs.errs, err)
} }
@ -353,9 +354,9 @@ type treeJob struct {
} }
// loadTreeWorker loads trees from repo and sends them to out. // loadTreeWorker loads trees from repo and sends them to out.
func loadTreeWorker(repo restic.Repository, func loadTreeWorker(ctx context.Context, repo restic.Repository,
in <-chan restic.ID, out chan<- treeJob, in <-chan restic.ID, out chan<- treeJob,
done <-chan struct{}, wg *sync.WaitGroup) { wg *sync.WaitGroup) {
defer func() { defer func() {
debug.Log("exiting") debug.Log("exiting")
@ -371,7 +372,7 @@ func loadTreeWorker(repo restic.Repository,
outCh = nil outCh = nil
for { for {
select { select {
case <-done: case <-ctx.Done():
return return
case treeID, ok := <-inCh: case treeID, ok := <-inCh:
@ -380,7 +381,7 @@ func loadTreeWorker(repo restic.Repository,
} }
debug.Log("load tree %v", treeID.Str()) debug.Log("load tree %v", treeID.Str())
tree, err := repo.LoadTree(treeID) tree, err := repo.LoadTree(ctx, treeID)
debug.Log("load tree %v (%v) returned err: %v", tree, treeID.Str(), err) debug.Log("load tree %v (%v) returned err: %v", tree, treeID.Str(), err)
job = treeJob{ID: treeID, error: err, Tree: tree} job = treeJob{ID: treeID, error: err, Tree: tree}
outCh = out outCh = out
@ -395,7 +396,7 @@ func loadTreeWorker(repo restic.Repository,
} }
// checkTreeWorker checks the trees received and sends out errors to errChan. // checkTreeWorker checks the trees received and sends out errors to errChan.
func (c *Checker) checkTreeWorker(in <-chan treeJob, out chan<- error, done <-chan struct{}, wg *sync.WaitGroup) { func (c *Checker) checkTreeWorker(ctx context.Context, in <-chan treeJob, out chan<- error, wg *sync.WaitGroup) {
defer func() { defer func() {
debug.Log("exiting") debug.Log("exiting")
wg.Done() wg.Done()
@ -410,7 +411,7 @@ func (c *Checker) checkTreeWorker(in <-chan treeJob, out chan<- error, done <-ch
outCh = nil outCh = nil
for { for {
select { select {
case <-done: case <-ctx.Done():
debug.Log("done channel closed, exiting") debug.Log("done channel closed, exiting")
return return
@ -458,7 +459,7 @@ func (c *Checker) checkTreeWorker(in <-chan treeJob, out chan<- error, done <-ch
} }
} }
func filterTrees(backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob, done <-chan struct{}) { func filterTrees(ctx context.Context, backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob) {
defer func() { defer func() {
debug.Log("closing output channels") debug.Log("closing output channels")
close(loaderChan) close(loaderChan)
@ -489,7 +490,7 @@ func filterTrees(backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan tree
} }
select { select {
case <-done: case <-ctx.Done():
return return
case loadCh <- nextTreeID: case loadCh <- nextTreeID:
@ -549,15 +550,15 @@ func filterTrees(backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan tree
// Structure checks that for all snapshots all referenced data blobs and // Structure checks that for all snapshots all referenced data blobs and
// subtrees are available in the index. errChan is closed after all trees have // subtrees are available in the index. errChan is closed after all trees have
// been traversed. // been traversed.
func (c *Checker) Structure(errChan chan<- error, done <-chan struct{}) { func (c *Checker) Structure(ctx context.Context, errChan chan<- error) {
defer close(errChan) defer close(errChan)
trees, errs := loadSnapshotTreeIDs(c.repo) trees, errs := loadSnapshotTreeIDs(ctx, c.repo)
debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs)) debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs))
for _, err := range errs { for _, err := range errs {
select { select {
case <-done: case <-ctx.Done():
return return
case errChan <- err: case errChan <- err:
} }
@ -570,11 +571,11 @@ func (c *Checker) Structure(errChan chan<- error, done <-chan struct{}) {
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < defaultParallelism; i++ { for i := 0; i < defaultParallelism; i++ {
wg.Add(2) wg.Add(2)
go loadTreeWorker(c.repo, treeIDChan, treeJobChan1, done, &wg) go loadTreeWorker(ctx, c.repo, treeIDChan, treeJobChan1, &wg)
go c.checkTreeWorker(treeJobChan2, errChan, done, &wg) go c.checkTreeWorker(ctx, treeJobChan2, errChan, &wg)
} }
filterTrees(trees, treeIDChan, treeJobChan1, treeJobChan2, done) filterTrees(ctx, trees, treeIDChan, treeJobChan1, treeJobChan2)
wg.Wait() wg.Wait()
} }
@ -659,11 +660,11 @@ func (c *Checker) CountPacks() uint64 {
} }
// checkPack reads a pack and checks the integrity of all blobs. // checkPack reads a pack and checks the integrity of all blobs.
func checkPack(r restic.Repository, id restic.ID) error { func checkPack(ctx context.Context, r restic.Repository, id restic.ID) error {
debug.Log("checking pack %v", id.Str()) debug.Log("checking pack %v", id.Str())
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
rd, err := r.Backend().Load(h, 0, 0) rd, err := r.Backend().Load(ctx, h, 0, 0)
if err != nil { if err != nil {
return err return err
} }
@ -748,7 +749,7 @@ func checkPack(r restic.Repository, id restic.ID) error {
} }
// ReadData loads all data from the repository and checks the integrity. // ReadData loads all data from the repository and checks the integrity.
func (c *Checker) ReadData(p *restic.Progress, errChan chan<- error, done <-chan struct{}) { func (c *Checker) ReadData(ctx context.Context, p *restic.Progress, errChan chan<- error) {
defer close(errChan) defer close(errChan)
p.Start() p.Start()
@ -761,7 +762,7 @@ func (c *Checker) ReadData(p *restic.Progress, errChan chan<- error, done <-chan
var ok bool var ok bool
select { select {
case <-done: case <-ctx.Done():
return return
case id, ok = <-in: case id, ok = <-in:
if !ok { if !ok {
@ -769,21 +770,21 @@ func (c *Checker) ReadData(p *restic.Progress, errChan chan<- error, done <-chan
} }
} }
err := checkPack(c.repo, id) err := checkPack(ctx, c.repo, id)
p.Report(restic.Stat{Blobs: 1}) p.Report(restic.Stat{Blobs: 1})
if err == nil { if err == nil {
continue continue
} }
select { select {
case <-done: case <-ctx.Done():
return return
case errChan <- err: case errChan <- err:
} }
} }
} }
ch := c.repo.List(restic.DataFile, done) ch := c.repo.List(ctx, restic.DataFile)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < defaultParallelism; i++ { for i := 0; i < defaultParallelism; i++ {

View file

@ -1,6 +1,7 @@
package checker_test package checker_test
import ( import (
"context"
"io" "io"
"math/rand" "math/rand"
"path/filepath" "path/filepath"
@ -16,13 +17,13 @@ import (
var checkerTestData = filepath.Join("testdata", "checker-test-repo.tar.gz") var checkerTestData = filepath.Join("testdata", "checker-test-repo.tar.gz")
func collectErrors(f func(chan<- error, <-chan struct{})) (errs []error) { func collectErrors(ctx context.Context, f func(context.Context, chan<- error)) (errs []error) {
done := make(chan struct{}) ctx, cancel := context.WithCancel(ctx)
defer close(done) defer cancel()
errChan := make(chan error) errChan := make(chan error)
go f(errChan, done) go f(ctx, errChan)
for err := range errChan { for err := range errChan {
errs = append(errs, err) errs = append(errs, err)
@ -32,17 +33,18 @@ func collectErrors(f func(chan<- error, <-chan struct{})) (errs []error) {
} }
func checkPacks(chkr *checker.Checker) []error { func checkPacks(chkr *checker.Checker) []error {
return collectErrors(chkr.Packs) return collectErrors(context.TODO(), chkr.Packs)
} }
func checkStruct(chkr *checker.Checker) []error { func checkStruct(chkr *checker.Checker) []error {
return collectErrors(chkr.Structure) return collectErrors(context.TODO(), chkr.Structure)
} }
func checkData(chkr *checker.Checker) []error { func checkData(chkr *checker.Checker) []error {
return collectErrors( return collectErrors(
func(errCh chan<- error, done <-chan struct{}) { context.TODO(),
chkr.ReadData(nil, errCh, done) func(ctx context.Context, errCh chan<- error) {
chkr.ReadData(ctx, nil, errCh)
}, },
) )
} }
@ -54,7 +56,7 @@ func TestCheckRepo(t *testing.T) {
repo := repository.TestOpenLocal(t, repodir) repo := repository.TestOpenLocal(t, repodir)
chkr := checker.New(repo) chkr := checker.New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) > 0 { if len(errs) > 0 {
t.Fatalf("expected no errors, got %v: %v", len(errs), errs) t.Fatalf("expected no errors, got %v: %v", len(errs), errs)
} }
@ -77,10 +79,10 @@ func TestMissingPack(t *testing.T) {
Type: restic.DataFile, Type: restic.DataFile,
Name: "657f7fb64f6a854fff6fe9279998ee09034901eded4e6db9bcee0e59745bbce6", Name: "657f7fb64f6a854fff6fe9279998ee09034901eded4e6db9bcee0e59745bbce6",
} }
test.OK(t, repo.Backend().Remove(packHandle)) test.OK(t, repo.Backend().Remove(context.TODO(), packHandle))
chkr := checker.New(repo) chkr := checker.New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) > 0 { if len(errs) > 0 {
t.Fatalf("expected no errors, got %v: %v", len(errs), errs) t.Fatalf("expected no errors, got %v: %v", len(errs), errs)
} }
@ -113,10 +115,10 @@ func TestUnreferencedPack(t *testing.T) {
Type: restic.IndexFile, Type: restic.IndexFile,
Name: "3f1abfcb79c6f7d0a3be517d2c83c8562fba64ef2c8e9a3544b4edaf8b5e3b44", Name: "3f1abfcb79c6f7d0a3be517d2c83c8562fba64ef2c8e9a3544b4edaf8b5e3b44",
} }
test.OK(t, repo.Backend().Remove(indexHandle)) test.OK(t, repo.Backend().Remove(context.TODO(), indexHandle))
chkr := checker.New(repo) chkr := checker.New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) > 0 { if len(errs) > 0 {
t.Fatalf("expected no errors, got %v: %v", len(errs), errs) t.Fatalf("expected no errors, got %v: %v", len(errs), errs)
} }
@ -147,7 +149,7 @@ func TestUnreferencedBlobs(t *testing.T) {
Type: restic.SnapshotFile, Type: restic.SnapshotFile,
Name: "51d249d28815200d59e4be7b3f21a157b864dc343353df9d8e498220c2499b02", Name: "51d249d28815200d59e4be7b3f21a157b864dc343353df9d8e498220c2499b02",
} }
test.OK(t, repo.Backend().Remove(snapshotHandle)) test.OK(t, repo.Backend().Remove(context.TODO(), snapshotHandle))
unusedBlobsBySnapshot := restic.IDs{ unusedBlobsBySnapshot := restic.IDs{
restic.TestParseID("58c748bbe2929fdf30c73262bd8313fe828f8925b05d1d4a87fe109082acb849"), restic.TestParseID("58c748bbe2929fdf30c73262bd8313fe828f8925b05d1d4a87fe109082acb849"),
@ -161,7 +163,7 @@ func TestUnreferencedBlobs(t *testing.T) {
sort.Sort(unusedBlobsBySnapshot) sort.Sort(unusedBlobsBySnapshot)
chkr := checker.New(repo) chkr := checker.New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) > 0 { if len(errs) > 0 {
t.Fatalf("expected no errors, got %v: %v", len(errs), errs) t.Fatalf("expected no errors, got %v: %v", len(errs), errs)
} }
@ -192,7 +194,7 @@ func TestModifiedIndex(t *testing.T) {
Type: restic.IndexFile, Type: restic.IndexFile,
Name: "90f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", Name: "90f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd",
} }
f, err := repo.Backend().Load(h, 0, 0) f, err := repo.Backend().Load(context.TODO(), h, 0, 0)
test.OK(t, err) test.OK(t, err)
// save the index again with a modified name so that the hash doesn't match // save the index again with a modified name so that the hash doesn't match
@ -201,13 +203,13 @@ func TestModifiedIndex(t *testing.T) {
Type: restic.IndexFile, Type: restic.IndexFile,
Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd", Name: "80f838b4ac28735fda8644fe6a08dbc742e57aaf81b30977b4fefa357010eafd",
} }
err = repo.Backend().Save(h2, f) err = repo.Backend().Save(context.TODO(), h2, f)
test.OK(t, err) test.OK(t, err)
test.OK(t, f.Close()) test.OK(t, f.Close())
chkr := checker.New(repo) chkr := checker.New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) == 0 { if len(errs) == 0 {
t.Fatalf("expected errors not found") t.Fatalf("expected errors not found")
} }
@ -230,7 +232,7 @@ func TestDuplicatePacksInIndex(t *testing.T) {
repo := repository.TestOpenLocal(t, repodir) repo := repository.TestOpenLocal(t, repodir)
chkr := checker.New(repo) chkr := checker.New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(hints) == 0 { if len(hints) == 0 {
t.Fatalf("did not get expected checker hints for duplicate packs in indexes") t.Fatalf("did not get expected checker hints for duplicate packs in indexes")
} }
@ -259,8 +261,8 @@ type errorBackend struct {
ProduceErrors bool ProduceErrors bool
} }
func (b errorBackend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (b errorBackend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
rd, err := b.Backend.Load(h, length, offset) rd, err := b.Backend.Load(ctx, h, length, offset)
if err != nil { if err != nil {
return rd, err return rd, err
} }
@ -303,17 +305,17 @@ func TestCheckerModifiedData(t *testing.T) {
defer cleanup() defer cleanup()
arch := archiver.New(repo) arch := archiver.New(repo)
_, id, err := arch.Snapshot(nil, []string{"."}, nil, "localhost", nil) _, id, err := arch.Snapshot(context.TODO(), nil, []string{"."}, nil, "localhost", nil)
test.OK(t, err) test.OK(t, err)
t.Logf("archived as %v", id.Str()) t.Logf("archived as %v", id.Str())
beError := &errorBackend{Backend: repo.Backend()} beError := &errorBackend{Backend: repo.Backend()}
checkRepo := repository.New(beError) checkRepo := repository.New(beError)
test.OK(t, checkRepo.SearchKey(test.TestPassword, 5)) test.OK(t, checkRepo.SearchKey(context.TODO(), test.TestPassword, 5))
chkr := checker.New(checkRepo) chkr := checker.New(checkRepo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) > 0 { if len(errs) > 0 {
t.Fatalf("expected no errors, got %v: %v", len(errs), errs) t.Fatalf("expected no errors, got %v: %v", len(errs), errs)
} }
@ -349,7 +351,7 @@ func BenchmarkChecker(t *testing.B) {
repo := repository.TestOpenLocal(t, repodir) repo := repository.TestOpenLocal(t, repodir)
chkr := checker.New(repo) chkr := checker.New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) > 0 { if len(errs) > 0 {
t.Fatalf("expected no errors, got %v: %v", len(errs), errs) t.Fatalf("expected no errors, got %v: %v", len(errs), errs)
} }

View file

@ -1,6 +1,7 @@
package checker package checker
import ( import (
"context"
"restic" "restic"
"testing" "testing"
) )
@ -9,7 +10,7 @@ import (
func TestCheckRepo(t testing.TB, repo restic.Repository) { func TestCheckRepo(t testing.TB, repo restic.Repository) {
chkr := New(repo) chkr := New(repo)
hints, errs := chkr.LoadIndex() hints, errs := chkr.LoadIndex(context.TODO())
if len(errs) != 0 { if len(errs) != 0 {
t.Fatalf("errors loading index: %v", errs) t.Fatalf("errors loading index: %v", errs)
} }
@ -18,12 +19,9 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) {
t.Fatalf("errors loading index: %v", hints) t.Fatalf("errors loading index: %v", hints)
} }
done := make(chan struct{})
defer close(done)
// packs // packs
errChan := make(chan error) errChan := make(chan error)
go chkr.Packs(errChan, done) go chkr.Packs(context.TODO(), errChan)
for err := range errChan { for err := range errChan {
t.Error(err) t.Error(err)
@ -31,7 +29,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) {
// structure // structure
errChan = make(chan error) errChan = make(chan error)
go chkr.Structure(errChan, done) go chkr.Structure(context.TODO(), errChan)
for err := range errChan { for err := range errChan {
t.Error(err) t.Error(err)
@ -45,7 +43,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) {
// read data // read data
errChan = make(chan error) errChan = make(chan error)
go chkr.ReadData(nil, errChan, done) go chkr.ReadData(context.TODO(), nil, errChan)
for err := range errChan { for err := range errChan {
t.Error(err) t.Error(err)

View file

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"testing" "testing"
"restic/errors" "restic/errors"
@ -23,7 +24,7 @@ const RepoVersion = 1
// JSONUnpackedLoader loads unpacked JSON. // JSONUnpackedLoader loads unpacked JSON.
type JSONUnpackedLoader interface { type JSONUnpackedLoader interface {
LoadJSONUnpacked(FileType, ID, interface{}) error LoadJSONUnpacked(context.Context, FileType, ID, interface{}) error
} }
// CreateConfig creates a config file with a randomly selected polynomial and // CreateConfig creates a config file with a randomly selected polynomial and
@ -57,12 +58,12 @@ func TestCreateConfig(t testing.TB, pol chunker.Pol) (cfg Config) {
} }
// LoadConfig returns loads, checks and returns the config for a repository. // LoadConfig returns loads, checks and returns the config for a repository.
func LoadConfig(r JSONUnpackedLoader) (Config, error) { func LoadConfig(ctx context.Context, r JSONUnpackedLoader) (Config, error) {
var ( var (
cfg Config cfg Config
) )
err := r.LoadJSONUnpacked(ConfigFile, ID{}, &cfg) err := r.LoadJSONUnpacked(ctx, ConfigFile, ID{}, &cfg)
if err != nil { if err != nil {
return Config{}, err return Config{}, err
} }

View file

@ -1,6 +1,7 @@
package restic_test package restic_test
import ( import (
"context"
"restic" "restic"
"testing" "testing"
@ -13,10 +14,10 @@ func (s saver) SaveJSONUnpacked(t restic.FileType, arg interface{}) (restic.ID,
return s(t, arg) return s(t, arg)
} }
type loader func(restic.FileType, restic.ID, interface{}) error type loader func(context.Context, restic.FileType, restic.ID, interface{}) error
func (l loader) LoadJSONUnpacked(t restic.FileType, id restic.ID, arg interface{}) error { func (l loader) LoadJSONUnpacked(ctx context.Context, t restic.FileType, id restic.ID, arg interface{}) error {
return l(t, id, arg) return l(ctx, t, id, arg)
} }
func TestConfig(t *testing.T) { func TestConfig(t *testing.T) {
@ -36,7 +37,7 @@ func TestConfig(t *testing.T) {
_, err = saver(save).SaveJSONUnpacked(restic.ConfigFile, cfg1) _, err = saver(save).SaveJSONUnpacked(restic.ConfigFile, cfg1)
load := func(tpe restic.FileType, id restic.ID, arg interface{}) error { load := func(ctx context.Context, tpe restic.FileType, id restic.ID, arg interface{}) error {
Assert(t, tpe == restic.ConfigFile, Assert(t, tpe == restic.ConfigFile,
"wrong backend type: got %v, wanted %v", "wrong backend type: got %v, wanted %v",
tpe, restic.ConfigFile) tpe, restic.ConfigFile)
@ -46,7 +47,7 @@ func TestConfig(t *testing.T) {
return nil return nil
} }
cfg2, err := restic.LoadConfig(loader(load)) cfg2, err := restic.LoadConfig(context.TODO(), loader(load))
OK(t, err) OK(t, err)
Assert(t, cfg1 == cfg2, Assert(t, cfg1 == cfg2,

View file

@ -1,12 +1,14 @@
package restic package restic
import "context"
// FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data // FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data
// blobs) to the set blobs. The tree blobs in the `seen` BlobSet will not be visited // blobs) to the set blobs. The tree blobs in the `seen` BlobSet will not be visited
// again. // again.
func FindUsedBlobs(repo Repository, treeID ID, blobs BlobSet, seen BlobSet) error { func FindUsedBlobs(ctx context.Context, repo Repository, treeID ID, blobs BlobSet, seen BlobSet) error {
blobs.Insert(BlobHandle{ID: treeID, Type: TreeBlob}) blobs.Insert(BlobHandle{ID: treeID, Type: TreeBlob})
tree, err := repo.LoadTree(treeID) tree, err := repo.LoadTree(ctx, treeID)
if err != nil { if err != nil {
return err return err
} }
@ -26,7 +28,7 @@ func FindUsedBlobs(repo Repository, treeID ID, blobs BlobSet, seen BlobSet) erro
seen.Insert(h) seen.Insert(h)
err := FindUsedBlobs(repo, subtreeID, blobs, seen) err := FindUsedBlobs(ctx, repo, subtreeID, blobs, seen)
if err != nil { if err != nil {
return err return err
} }

View file

@ -2,6 +2,7 @@ package restic_test
import ( import (
"bufio" "bufio"
"context"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
@ -92,7 +93,7 @@ func TestFindUsedBlobs(t *testing.T) {
for i, sn := range snapshots { for i, sn := range snapshots {
usedBlobs := restic.NewBlobSet() usedBlobs := restic.NewBlobSet()
err := restic.FindUsedBlobs(repo, *sn.Tree, usedBlobs, restic.NewBlobSet()) err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, usedBlobs, restic.NewBlobSet())
if err != nil { if err != nil {
t.Errorf("FindUsedBlobs returned error: %v", err) t.Errorf("FindUsedBlobs returned error: %v", err)
continue continue
@ -128,7 +129,7 @@ func BenchmarkFindUsedBlobs(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
seen := restic.NewBlobSet() seen := restic.NewBlobSet()
blobs := restic.NewBlobSet() blobs := restic.NewBlobSet()
err := restic.FindUsedBlobs(repo, *sn.Tree, blobs, seen) err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, blobs, seen)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }

View file

@ -26,9 +26,9 @@ type dir struct {
ownerIsRoot bool ownerIsRoot bool
} }
func newDir(repo restic.Repository, node *restic.Node, ownerIsRoot bool) (*dir, error) { func newDir(ctx context.Context, repo restic.Repository, node *restic.Node, ownerIsRoot bool) (*dir, error) {
debug.Log("new dir for %v (%v)", node.Name, node.Subtree.Str()) debug.Log("new dir for %v (%v)", node.Name, node.Subtree.Str())
tree, err := repo.LoadTree(*node.Subtree) tree, err := repo.LoadTree(ctx, *node.Subtree)
if err != nil { if err != nil {
debug.Log(" error loading tree %v: %v", node.Subtree.Str(), err) debug.Log(" error loading tree %v: %v", node.Subtree.Str(), err)
return nil, err return nil, err
@ -49,7 +49,7 @@ func newDir(repo restic.Repository, node *restic.Node, ownerIsRoot bool) (*dir,
// replaceSpecialNodes replaces nodes with name "." and "/" by their contents. // replaceSpecialNodes replaces nodes with name "." and "/" by their contents.
// Otherwise, the node is returned. // Otherwise, the node is returned.
func replaceSpecialNodes(repo restic.Repository, node *restic.Node) ([]*restic.Node, error) { func replaceSpecialNodes(ctx context.Context, repo restic.Repository, node *restic.Node) ([]*restic.Node, error) {
if node.Type != "dir" || node.Subtree == nil { if node.Type != "dir" || node.Subtree == nil {
return []*restic.Node{node}, nil return []*restic.Node{node}, nil
} }
@ -58,7 +58,7 @@ func replaceSpecialNodes(repo restic.Repository, node *restic.Node) ([]*restic.N
return []*restic.Node{node}, nil return []*restic.Node{node}, nil
} }
tree, err := repo.LoadTree(*node.Subtree) tree, err := repo.LoadTree(ctx, *node.Subtree)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -66,16 +66,16 @@ func replaceSpecialNodes(repo restic.Repository, node *restic.Node) ([]*restic.N
return tree.Nodes, nil return tree.Nodes, nil
} }
func newDirFromSnapshot(repo restic.Repository, snapshot SnapshotWithId, ownerIsRoot bool) (*dir, error) { func newDirFromSnapshot(ctx context.Context, repo restic.Repository, snapshot SnapshotWithId, ownerIsRoot bool) (*dir, error) {
debug.Log("new dir for snapshot %v (%v)", snapshot.ID.Str(), snapshot.Tree.Str()) debug.Log("new dir for snapshot %v (%v)", snapshot.ID.Str(), snapshot.Tree.Str())
tree, err := repo.LoadTree(*snapshot.Tree) tree, err := repo.LoadTree(ctx, *snapshot.Tree)
if err != nil { if err != nil {
debug.Log(" loadTree(%v) failed: %v", snapshot.ID.Str(), err) debug.Log(" loadTree(%v) failed: %v", snapshot.ID.Str(), err)
return nil, err return nil, err
} }
items := make(map[string]*restic.Node) items := make(map[string]*restic.Node)
for _, n := range tree.Nodes { for _, n := range tree.Nodes {
nodes, err := replaceSpecialNodes(repo, n) nodes, err := replaceSpecialNodes(ctx, repo, n)
if err != nil { if err != nil {
debug.Log(" replaceSpecialNodes(%v) failed: %v", n, err) debug.Log(" replaceSpecialNodes(%v) failed: %v", n, err)
return nil, err return nil, err
@ -167,7 +167,7 @@ func (d *dir) Lookup(ctx context.Context, name string) (fs.Node, error) {
} }
switch node.Type { switch node.Type {
case "dir": case "dir":
return newDir(d.repo, node, d.ownerIsRoot) return newDir(ctx, d.repo, node, d.ownerIsRoot)
case "file": case "file":
return newFile(d.repo, node, d.ownerIsRoot) return newFile(d.repo, node, d.ownerIsRoot)
case "symlink": case "symlink":

View file

@ -9,6 +9,8 @@ import (
"restic" "restic"
"restic/debug" "restic/debug"
scontext "context"
"bazil.org/fuse" "bazil.org/fuse"
"bazil.org/fuse/fs" "bazil.org/fuse/fs"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -25,7 +27,7 @@ var _ = fs.HandleReleaser(&file{})
// for fuse operations. // for fuse operations.
type BlobLoader interface { type BlobLoader interface {
LookupBlobSize(restic.ID, restic.BlobType) (uint, error) LookupBlobSize(restic.ID, restic.BlobType) (uint, error)
LoadBlob(restic.BlobType, restic.ID, []byte) (int, error) LoadBlob(scontext.Context, restic.BlobType, restic.ID, []byte) (int, error)
} }
type file struct { type file struct {
@ -88,7 +90,7 @@ func (f *file) Attr(ctx context.Context, a *fuse.Attr) error {
} }
func (f *file) getBlobAt(i int) (blob []byte, err error) { func (f *file) getBlobAt(ctx context.Context, i int) (blob []byte, err error) {
debug.Log("getBlobAt(%v, %v)", f.node.Name, i) debug.Log("getBlobAt(%v, %v)", f.node.Name, i)
if f.blobs[i] != nil { if f.blobs[i] != nil {
return f.blobs[i], nil return f.blobs[i], nil
@ -100,7 +102,7 @@ func (f *file) getBlobAt(i int) (blob []byte, err error) {
} }
buf := restic.NewBlobBuffer(f.sizes[i]) buf := restic.NewBlobBuffer(f.sizes[i])
n, err := f.repo.LoadBlob(restic.DataBlob, f.node.Content[i], buf) n, err := f.repo.LoadBlob(ctx, restic.DataBlob, f.node.Content[i], buf)
if err != nil { if err != nil {
debug.Log("LoadBlob(%v, %v) failed: %v", f.node.Name, f.node.Content[i], err) debug.Log("LoadBlob(%v, %v) failed: %v", f.node.Name, f.node.Content[i], err)
return nil, err return nil, err
@ -137,7 +139,7 @@ func (f *file) Read(ctx context.Context, req *fuse.ReadRequest, resp *fuse.ReadR
readBytes := 0 readBytes := 0
remainingBytes := req.Size remainingBytes := req.Size
for i := startContent; remainingBytes > 0 && i < len(f.sizes); i++ { for i := startContent; remainingBytes > 0 && i < len(f.sizes); i++ {
blob, err := f.getBlobAt(i) blob, err := f.getBlobAt(ctx, i)
if err != nil { if err != nil {
return err return err
} }

View file

@ -34,9 +34,7 @@ func testRead(t testing.TB, f *file, offset, length int, data []byte) {
} }
func firstSnapshotID(t testing.TB, repo restic.Repository) (first restic.ID) { func firstSnapshotID(t testing.TB, repo restic.Repository) (first restic.ID) {
done := make(chan struct{}) for id := range repo.List(context.TODO(), restic.SnapshotFile) {
defer close(done)
for id := range repo.List(restic.SnapshotFile, done) {
if first.IsNull() { if first.IsNull() {
first = id first = id
} }
@ -46,13 +44,13 @@ func firstSnapshotID(t testing.TB, repo restic.Repository) (first restic.ID) {
func loadFirstSnapshot(t testing.TB, repo restic.Repository) *restic.Snapshot { func loadFirstSnapshot(t testing.TB, repo restic.Repository) *restic.Snapshot {
id := firstSnapshotID(t, repo) id := firstSnapshotID(t, repo)
sn, err := restic.LoadSnapshot(repo, id) sn, err := restic.LoadSnapshot(context.TODO(), repo, id)
OK(t, err) OK(t, err)
return sn return sn
} }
func loadTree(t testing.TB, repo restic.Repository, id restic.ID) *restic.Tree { func loadTree(t testing.TB, repo restic.Repository, id restic.ID) *restic.Tree {
tree, err := repo.LoadTree(id) tree, err := repo.LoadTree(context.TODO(), id)
OK(t, err) OK(t, err)
return tree return tree
} }
@ -87,7 +85,7 @@ func TestFuseFile(t *testing.T) {
filesize += uint64(size) filesize += uint64(size)
buf := restic.NewBlobBuffer(int(size)) buf := restic.NewBlobBuffer(int(size))
n, err := repo.LoadBlob(restic.DataBlob, id, buf) n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf)
OK(t, err) OK(t, err)
if uint(n) != size { if uint(n) != size {

View file

@ -73,14 +73,14 @@ func (sn *SnapshotsDir) updateCache(ctx context.Context) error {
sn.Lock() sn.Lock()
defer sn.Unlock() defer sn.Unlock()
for id := range sn.repo.List(restic.SnapshotFile, ctx.Done()) { for id := range sn.repo.List(ctx, restic.SnapshotFile) {
if sn.processed.Has(id) { if sn.processed.Has(id) {
debug.Log("skipping snapshot %v, already in list", id.Str()) debug.Log("skipping snapshot %v, already in list", id.Str())
continue continue
} }
debug.Log("found snapshot id %v", id.Str()) debug.Log("found snapshot id %v", id.Str())
snapshot, err := restic.LoadSnapshot(sn.repo, id) snapshot, err := restic.LoadSnapshot(ctx, sn.repo, id)
if err != nil { if err != nil {
return err return err
} }
@ -158,5 +158,5 @@ func (sn *SnapshotsDir) Lookup(ctx context.Context, name string) (fs.Node, error
} }
} }
return newDirFromSnapshot(sn.repo, snapshot, sn.ownerIsRoot) return newDirFromSnapshot(ctx, sn.repo, snapshot, sn.ownerIsRoot)
} }

View file

@ -2,6 +2,7 @@
package index package index
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"restic" "restic"
@ -33,15 +34,12 @@ func newIndex() *Index {
} }
// New creates a new index for repo from scratch. // New creates a new index for repo from scratch.
func New(repo restic.Repository, p *restic.Progress) (*Index, error) { func New(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Index, error) {
done := make(chan struct{})
defer close(done)
p.Start() p.Start()
defer p.Done() defer p.Done()
ch := make(chan worker.Job) ch := make(chan worker.Job)
go list.AllPacks(repo, ch, done) go list.AllPacks(ctx, repo, ch)
idx := newIndex() idx := newIndex()
@ -84,11 +82,11 @@ type indexJSON struct {
Packs []*packJSON `json:"packs"` Packs []*packJSON `json:"packs"`
} }
func loadIndexJSON(repo restic.Repository, id restic.ID) (*indexJSON, error) { func loadIndexJSON(ctx context.Context, repo restic.Repository, id restic.ID) (*indexJSON, error) {
debug.Log("process index %v\n", id.Str()) debug.Log("process index %v\n", id.Str())
var idx indexJSON var idx indexJSON
err := repo.LoadJSONUnpacked(restic.IndexFile, id, &idx) err := repo.LoadJSONUnpacked(ctx, restic.IndexFile, id, &idx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -97,25 +95,22 @@ func loadIndexJSON(repo restic.Repository, id restic.ID) (*indexJSON, error) {
} }
// Load creates an index by loading all index files from the repo. // Load creates an index by loading all index files from the repo.
func Load(repo restic.Repository, p *restic.Progress) (*Index, error) { func Load(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Index, error) {
debug.Log("loading indexes") debug.Log("loading indexes")
p.Start() p.Start()
defer p.Done() defer p.Done()
done := make(chan struct{})
defer close(done)
supersedes := make(map[restic.ID]restic.IDSet) supersedes := make(map[restic.ID]restic.IDSet)
results := make(map[restic.ID]map[restic.ID]Pack) results := make(map[restic.ID]map[restic.ID]Pack)
index := newIndex() index := newIndex()
for id := range repo.List(restic.IndexFile, done) { for id := range repo.List(ctx, restic.IndexFile) {
p.Report(restic.Stat{Blobs: 1}) p.Report(restic.Stat{Blobs: 1})
debug.Log("Load index %v", id.Str()) debug.Log("Load index %v", id.Str())
idx, err := loadIndexJSON(repo, id) idx, err := loadIndexJSON(ctx, repo, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -250,17 +245,17 @@ func (idx *Index) FindBlob(h restic.BlobHandle) (result []Location, err error) {
} }
// Save writes the complete index to the repo. // Save writes the complete index to the repo.
func (idx *Index) Save(repo restic.Repository, supersedes restic.IDs) (restic.ID, error) { func (idx *Index) Save(ctx context.Context, repo restic.Repository, supersedes restic.IDs) (restic.ID, error) {
packs := make(map[restic.ID][]restic.Blob, len(idx.Packs)) packs := make(map[restic.ID][]restic.Blob, len(idx.Packs))
for id, p := range idx.Packs { for id, p := range idx.Packs {
packs[id] = p.Entries packs[id] = p.Entries
} }
return Save(repo, packs, supersedes) return Save(ctx, repo, packs, supersedes)
} }
// Save writes a new index containing the given packs. // Save writes a new index containing the given packs.
func Save(repo restic.Repository, packs map[restic.ID][]restic.Blob, supersedes restic.IDs) (restic.ID, error) { func Save(ctx context.Context, repo restic.Repository, packs map[restic.ID][]restic.Blob, supersedes restic.IDs) (restic.ID, error) {
idx := &indexJSON{ idx := &indexJSON{
Supersedes: supersedes, Supersedes: supersedes,
Packs: make([]*packJSON, 0, len(packs)), Packs: make([]*packJSON, 0, len(packs)),
@ -285,5 +280,5 @@ func Save(repo restic.Repository, packs map[restic.ID][]restic.Blob, supersedes
idx.Packs = append(idx.Packs, p) idx.Packs = append(idx.Packs, p)
} }
return repo.SaveJSONUnpacked(restic.IndexFile, idx) return repo.SaveJSONUnpacked(ctx, restic.IndexFile, idx)
} }

View file

@ -1,6 +1,7 @@
package index package index
import ( import (
"context"
"math/rand" "math/rand"
"restic" "restic"
"restic/checker" "restic/checker"
@ -26,7 +27,7 @@ func createFilledRepo(t testing.TB, snapshots int, dup float32) (restic.Reposito
} }
func validateIndex(t testing.TB, repo restic.Repository, idx *Index) { func validateIndex(t testing.TB, repo restic.Repository, idx *Index) {
for id := range repo.List(restic.DataFile, nil) { for id := range repo.List(context.TODO(), restic.DataFile) {
p, ok := idx.Packs[id] p, ok := idx.Packs[id]
if !ok { if !ok {
t.Errorf("pack %v missing from index", id.Str()) t.Errorf("pack %v missing from index", id.Str())
@ -42,7 +43,7 @@ func TestIndexNew(t *testing.T) {
repo, cleanup := createFilledRepo(t, 3, 0) repo, cleanup := createFilledRepo(t, 3, 0)
defer cleanup() defer cleanup()
idx, err := New(repo, nil) idx, err := New(context.TODO(), repo, nil)
if err != nil { if err != nil {
t.Fatalf("New() returned error %v", err) t.Fatalf("New() returned error %v", err)
} }
@ -58,7 +59,7 @@ func TestIndexLoad(t *testing.T) {
repo, cleanup := createFilledRepo(t, 3, 0) repo, cleanup := createFilledRepo(t, 3, 0)
defer cleanup() defer cleanup()
loadIdx, err := Load(repo, nil) loadIdx, err := Load(context.TODO(), repo, nil)
if err != nil { if err != nil {
t.Fatalf("Load() returned error %v", err) t.Fatalf("Load() returned error %v", err)
} }
@ -69,7 +70,7 @@ func TestIndexLoad(t *testing.T) {
validateIndex(t, repo, loadIdx) validateIndex(t, repo, loadIdx)
newIdx, err := New(repo, nil) newIdx, err := New(context.TODO(), repo, nil)
if err != nil { if err != nil {
t.Fatalf("New() returned error %v", err) t.Fatalf("New() returned error %v", err)
} }
@ -133,7 +134,7 @@ func BenchmarkIndexNew(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
idx, err := New(repo, nil) idx, err := New(context.TODO(), repo, nil)
if err != nil { if err != nil {
b.Fatalf("New() returned error %v", err) b.Fatalf("New() returned error %v", err)
@ -150,7 +151,7 @@ func BenchmarkIndexSave(b *testing.B) {
repo, cleanup := repository.TestRepository(b) repo, cleanup := repository.TestRepository(b)
defer cleanup() defer cleanup()
idx, err := New(repo, nil) idx, err := New(context.TODO(), repo, nil)
test.OK(b, err) test.OK(b, err)
for i := 0; i < 8000; i++ { for i := 0; i < 8000; i++ {
@ -170,7 +171,7 @@ func BenchmarkIndexSave(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
id, err := idx.Save(repo, nil) id, err := idx.Save(context.TODO(), repo, nil)
if err != nil { if err != nil {
b.Fatalf("New() returned error %v", err) b.Fatalf("New() returned error %v", err)
} }
@ -183,7 +184,7 @@ func TestIndexDuplicateBlobs(t *testing.T) {
repo, cleanup := createFilledRepo(t, 3, 0.01) repo, cleanup := createFilledRepo(t, 3, 0.01)
defer cleanup() defer cleanup()
idx, err := New(repo, nil) idx, err := New(context.TODO(), repo, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -202,7 +203,7 @@ func TestIndexDuplicateBlobs(t *testing.T) {
} }
func loadIndex(t testing.TB, repo restic.Repository) *Index { func loadIndex(t testing.TB, repo restic.Repository) *Index {
idx, err := Load(repo, nil) idx, err := Load(context.TODO(), repo, nil)
if err != nil { if err != nil {
t.Fatalf("Load() returned error %v", err) t.Fatalf("Load() returned error %v", err)
} }
@ -225,7 +226,7 @@ func TestSave(t *testing.T) {
t.Logf("save %d/%d packs in a new index\n", len(packs), len(idx.Packs)) t.Logf("save %d/%d packs in a new index\n", len(packs), len(idx.Packs))
id, err := Save(repo, packs, idx.IndexIDs.List()) id, err := Save(context.TODO(), repo, packs, idx.IndexIDs.List())
if err != nil { if err != nil {
t.Fatalf("unable to save new index: %v", err) t.Fatalf("unable to save new index: %v", err)
} }
@ -235,7 +236,7 @@ func TestSave(t *testing.T) {
for id := range idx.IndexIDs { for id := range idx.IndexIDs {
t.Logf("remove index %v", id.Str()) t.Logf("remove index %v", id.Str())
h := restic.Handle{Type: restic.IndexFile, Name: id.String()} h := restic.Handle{Type: restic.IndexFile, Name: id.String()}
err = repo.Backend().Remove(h) err = repo.Backend().Remove(context.TODO(), h)
if err != nil { if err != nil {
t.Errorf("error removing index %v: %v", id, err) t.Errorf("error removing index %v: %v", id, err)
} }
@ -267,7 +268,7 @@ func TestIndexSave(t *testing.T) {
idx := loadIndex(t, repo) idx := loadIndex(t, repo)
id, err := idx.Save(repo, idx.IndexIDs.List()) id, err := idx.Save(context.TODO(), repo, idx.IndexIDs.List())
if err != nil { if err != nil {
t.Fatalf("unable to save new index: %v", err) t.Fatalf("unable to save new index: %v", err)
} }
@ -277,7 +278,7 @@ func TestIndexSave(t *testing.T) {
for id := range idx.IndexIDs { for id := range idx.IndexIDs {
t.Logf("remove index %v", id.Str()) t.Logf("remove index %v", id.Str())
h := restic.Handle{Type: restic.IndexFile, Name: id.String()} h := restic.Handle{Type: restic.IndexFile, Name: id.String()}
err = repo.Backend().Remove(h) err = repo.Backend().Remove(context.TODO(), h)
if err != nil { if err != nil {
t.Errorf("error removing index %v: %v", id, err) t.Errorf("error removing index %v: %v", id, err)
} }
@ -287,7 +288,7 @@ func TestIndexSave(t *testing.T) {
t.Logf("load new index with %d packs", len(idx2.Packs)) t.Logf("load new index with %d packs", len(idx2.Packs))
checker := checker.New(repo) checker := checker.New(repo)
hints, errs := checker.LoadIndex() hints, errs := checker.LoadIndex(context.TODO())
for _, h := range hints { for _, h := range hints {
t.Logf("hint: %v\n", h) t.Logf("hint: %v\n", h)
} }
@ -301,15 +302,12 @@ func TestIndexAddRemovePack(t *testing.T) {
repo, cleanup := createFilledRepo(t, 3, 0) repo, cleanup := createFilledRepo(t, 3, 0)
defer cleanup() defer cleanup()
idx, err := Load(repo, nil) idx, err := Load(context.TODO(), repo, nil)
if err != nil { if err != nil {
t.Fatalf("Load() returned error %v", err) t.Fatalf("Load() returned error %v", err)
} }
done := make(chan struct{}) packID := <-repo.List(context.TODO(), restic.DataFile)
defer close(done)
packID := <-repo.List(restic.DataFile, done)
t.Logf("selected pack %v", packID.Str()) t.Logf("selected pack %v", packID.Str())
@ -367,7 +365,7 @@ func TestIndexLoadDocReference(t *testing.T) {
repo, cleanup := repository.TestRepository(t) repo, cleanup := repository.TestRepository(t)
defer cleanup() defer cleanup()
id, err := repo.SaveUnpacked(restic.IndexFile, docExample) id, err := repo.SaveUnpacked(context.TODO(), restic.IndexFile, docExample)
if err != nil { if err != nil {
t.Fatalf("SaveUnpacked() returned error %v", err) t.Fatalf("SaveUnpacked() returned error %v", err)
} }

View file

@ -1,6 +1,7 @@
package list package list
import ( import (
"context"
"restic" "restic"
"restic/worker" "restic/worker"
) )
@ -9,8 +10,8 @@ const listPackWorkers = 10
// Lister combines lists packs in a repo and blobs in a pack. // Lister combines lists packs in a repo and blobs in a pack.
type Lister interface { type Lister interface {
List(restic.FileType, <-chan struct{}) <-chan restic.ID List(context.Context, restic.FileType) <-chan restic.ID
ListPack(restic.ID) ([]restic.Blob, int64, error) ListPack(context.Context, restic.ID) ([]restic.Blob, int64, error)
} }
// Result is returned in the channel from LoadBlobsFromAllPacks. // Result is returned in the channel from LoadBlobsFromAllPacks.
@ -36,10 +37,10 @@ func (l Result) Entries() []restic.Blob {
} }
// AllPacks sends the contents of all packs to ch. // AllPacks sends the contents of all packs to ch.
func AllPacks(repo Lister, ch chan<- worker.Job, done <-chan struct{}) { func AllPacks(ctx context.Context, repo Lister, ch chan<- worker.Job) {
f := func(job worker.Job, done <-chan struct{}) (interface{}, error) { f := func(ctx context.Context, job worker.Job) (interface{}, error) {
packID := job.Data.(restic.ID) packID := job.Data.(restic.ID)
entries, size, err := repo.ListPack(packID) entries, size, err := repo.ListPack(ctx, packID)
return Result{ return Result{
packID: packID, packID: packID,
@ -49,14 +50,14 @@ func AllPacks(repo Lister, ch chan<- worker.Job, done <-chan struct{}) {
} }
jobCh := make(chan worker.Job) jobCh := make(chan worker.Job)
wp := worker.New(listPackWorkers, f, jobCh, ch) wp := worker.New(ctx, listPackWorkers, f, jobCh, ch)
go func() { go func() {
defer close(jobCh) defer close(jobCh)
for id := range repo.List(restic.DataFile, done) { for id := range repo.List(ctx, restic.DataFile) {
select { select {
case jobCh <- worker.Job{Data: id}: case jobCh <- worker.Job{Data: id}:
case <-done: case <-ctx.Done():
return return
} }
} }

View file

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
@ -58,15 +59,15 @@ func IsAlreadyLocked(err error) bool {
// NewLock returns a new, non-exclusive lock for the repository. If an // NewLock returns a new, non-exclusive lock for the repository. If an
// exclusive lock is already held by another process, ErrAlreadyLocked is // exclusive lock is already held by another process, ErrAlreadyLocked is
// returned. // returned.
func NewLock(repo Repository) (*Lock, error) { func NewLock(ctx context.Context, repo Repository) (*Lock, error) {
return newLock(repo, false) return newLock(ctx, repo, false)
} }
// NewExclusiveLock returns a new, exclusive lock for the repository. If // NewExclusiveLock returns a new, exclusive lock for the repository. If
// another lock (normal and exclusive) is already held by another process, // another lock (normal and exclusive) is already held by another process,
// ErrAlreadyLocked is returned. // ErrAlreadyLocked is returned.
func NewExclusiveLock(repo Repository) (*Lock, error) { func NewExclusiveLock(ctx context.Context, repo Repository) (*Lock, error) {
return newLock(repo, true) return newLock(ctx, repo, true)
} }
var waitBeforeLockCheck = 200 * time.Millisecond var waitBeforeLockCheck = 200 * time.Millisecond
@ -77,7 +78,7 @@ func TestSetLockTimeout(t testing.TB, d time.Duration) {
waitBeforeLockCheck = d waitBeforeLockCheck = d
} }
func newLock(repo Repository, excl bool) (*Lock, error) { func newLock(ctx context.Context, repo Repository, excl bool) (*Lock, error) {
lock := &Lock{ lock := &Lock{
Time: time.Now(), Time: time.Now(),
PID: os.Getpid(), PID: os.Getpid(),
@ -94,11 +95,11 @@ func newLock(repo Repository, excl bool) (*Lock, error) {
return nil, err return nil, err
} }
if err = lock.checkForOtherLocks(); err != nil { if err = lock.checkForOtherLocks(ctx); err != nil {
return nil, err return nil, err
} }
lockID, err := lock.createLock() lockID, err := lock.createLock(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -107,7 +108,7 @@ func newLock(repo Repository, excl bool) (*Lock, error) {
time.Sleep(waitBeforeLockCheck) time.Sleep(waitBeforeLockCheck)
if err = lock.checkForOtherLocks(); err != nil { if err = lock.checkForOtherLocks(ctx); err != nil {
lock.Unlock() lock.Unlock()
return nil, err return nil, err
} }
@ -132,8 +133,8 @@ func (l *Lock) fillUserInfo() error {
// if there are any other locks, regardless if exclusive or not. If a // if there are any other locks, regardless if exclusive or not. If a
// non-exclusive lock is to be created, an error is only returned when an // non-exclusive lock is to be created, an error is only returned when an
// exclusive lock is found. // exclusive lock is found.
func (l *Lock) checkForOtherLocks() error { func (l *Lock) checkForOtherLocks(ctx context.Context) error {
return eachLock(l.repo, func(id ID, lock *Lock, err error) error { return eachLock(ctx, l.repo, func(id ID, lock *Lock, err error) error {
if l.lockID != nil && id.Equal(*l.lockID) { if l.lockID != nil && id.Equal(*l.lockID) {
return nil return nil
} }
@ -155,12 +156,9 @@ func (l *Lock) checkForOtherLocks() error {
}) })
} }
func eachLock(repo Repository, f func(ID, *Lock, error) error) error { func eachLock(ctx context.Context, repo Repository, f func(ID, *Lock, error) error) error {
done := make(chan struct{}) for id := range repo.List(ctx, LockFile) {
defer close(done) lock, err := LoadLock(ctx, repo, id)
for id := range repo.List(LockFile, done) {
lock, err := LoadLock(repo, id)
err = f(id, lock, err) err = f(id, lock, err)
if err != nil { if err != nil {
return err return err
@ -171,8 +169,8 @@ func eachLock(repo Repository, f func(ID, *Lock, error) error) error {
} }
// createLock acquires the lock by creating a file in the repository. // createLock acquires the lock by creating a file in the repository.
func (l *Lock) createLock() (ID, error) { func (l *Lock) createLock(ctx context.Context) (ID, error) {
id, err := l.repo.SaveJSONUnpacked(LockFile, l) id, err := l.repo.SaveJSONUnpacked(ctx, LockFile, l)
if err != nil { if err != nil {
return ID{}, err return ID{}, err
} }
@ -186,7 +184,7 @@ func (l *Lock) Unlock() error {
return nil return nil
} }
return l.repo.Backend().Remove(Handle{Type: LockFile, Name: l.lockID.String()}) return l.repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: l.lockID.String()})
} }
var staleTimeout = 30 * time.Minute var staleTimeout = 30 * time.Minute
@ -227,14 +225,14 @@ func (l *Lock) Stale() bool {
// Refresh refreshes the lock by creating a new file in the backend with a new // Refresh refreshes the lock by creating a new file in the backend with a new
// timestamp. Afterwards the old lock is removed. // timestamp. Afterwards the old lock is removed.
func (l *Lock) Refresh() error { func (l *Lock) Refresh(ctx context.Context) error {
debug.Log("refreshing lock %v", l.lockID.Str()) debug.Log("refreshing lock %v", l.lockID.Str())
id, err := l.createLock() id, err := l.createLock(ctx)
if err != nil { if err != nil {
return err return err
} }
err = l.repo.Backend().Remove(Handle{Type: LockFile, Name: l.lockID.String()}) err = l.repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: l.lockID.String()})
if err != nil { if err != nil {
return err return err
} }
@ -270,9 +268,9 @@ func init() {
} }
// LoadLock loads and unserializes a lock from a repository. // LoadLock loads and unserializes a lock from a repository.
func LoadLock(repo Repository, id ID) (*Lock, error) { func LoadLock(ctx context.Context, repo Repository, id ID) (*Lock, error) {
lock := &Lock{} lock := &Lock{}
if err := repo.LoadJSONUnpacked(LockFile, id, lock); err != nil { if err := repo.LoadJSONUnpacked(ctx, LockFile, id, lock); err != nil {
return nil, err return nil, err
} }
lock.lockID = &id lock.lockID = &id
@ -281,15 +279,15 @@ func LoadLock(repo Repository, id ID) (*Lock, error) {
} }
// RemoveStaleLocks deletes all locks detected as stale from the repository. // RemoveStaleLocks deletes all locks detected as stale from the repository.
func RemoveStaleLocks(repo Repository) error { func RemoveStaleLocks(ctx context.Context, repo Repository) error {
return eachLock(repo, func(id ID, lock *Lock, err error) error { return eachLock(ctx, repo, func(id ID, lock *Lock, err error) error {
// ignore locks that cannot be loaded // ignore locks that cannot be loaded
if err != nil { if err != nil {
return nil return nil
} }
if lock.Stale() { if lock.Stale() {
return repo.Backend().Remove(Handle{Type: LockFile, Name: id.String()}) return repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: id.String()})
} }
return nil return nil
@ -297,8 +295,8 @@ func RemoveStaleLocks(repo Repository) error {
} }
// RemoveAllLocks removes all locks forcefully. // RemoveAllLocks removes all locks forcefully.
func RemoveAllLocks(repo Repository) error { func RemoveAllLocks(ctx context.Context, repo Repository) error {
return eachLock(repo, func(id ID, lock *Lock, err error) error { return eachLock(ctx, repo, func(id ID, lock *Lock, err error) error {
return repo.Backend().Remove(Handle{Type: LockFile, Name: id.String()}) return repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: id.String()})
}) })
} }

View file

@ -1,6 +1,7 @@
package restic_test package restic_test
import ( import (
"context"
"os" "os"
"testing" "testing"
"time" "time"
@ -14,7 +15,7 @@ func TestLock(t *testing.T) {
repo, cleanup := repository.TestRepository(t) repo, cleanup := repository.TestRepository(t)
defer cleanup() defer cleanup()
lock, err := restic.NewLock(repo) lock, err := restic.NewLock(context.TODO(), repo)
OK(t, err) OK(t, err)
OK(t, lock.Unlock()) OK(t, lock.Unlock())
@ -24,7 +25,7 @@ func TestDoubleUnlock(t *testing.T) {
repo, cleanup := repository.TestRepository(t) repo, cleanup := repository.TestRepository(t)
defer cleanup() defer cleanup()
lock, err := restic.NewLock(repo) lock, err := restic.NewLock(context.TODO(), repo)
OK(t, err) OK(t, err)
OK(t, lock.Unlock()) OK(t, lock.Unlock())
@ -38,10 +39,10 @@ func TestMultipleLock(t *testing.T) {
repo, cleanup := repository.TestRepository(t) repo, cleanup := repository.TestRepository(t)
defer cleanup() defer cleanup()
lock1, err := restic.NewLock(repo) lock1, err := restic.NewLock(context.TODO(), repo)
OK(t, err) OK(t, err)
lock2, err := restic.NewLock(repo) lock2, err := restic.NewLock(context.TODO(), repo)
OK(t, err) OK(t, err)
OK(t, lock1.Unlock()) OK(t, lock1.Unlock())
@ -52,7 +53,7 @@ func TestLockExclusive(t *testing.T) {
repo, cleanup := repository.TestRepository(t) repo, cleanup := repository.TestRepository(t)
defer cleanup() defer cleanup()
elock, err := restic.NewExclusiveLock(repo) elock, err := restic.NewExclusiveLock(context.TODO(), repo)
OK(t, err) OK(t, err)
OK(t, elock.Unlock()) OK(t, elock.Unlock())
} }
@ -61,10 +62,10 @@ func TestLockOnExclusiveLockedRepo(t *testing.T) {
repo, cleanup := repository.TestRepository(t) repo, cleanup := repository.TestRepository(t)
defer cleanup() defer cleanup()
elock, err := restic.NewExclusiveLock(repo) elock, err := restic.NewExclusiveLock(context.TODO(), repo)
OK(t, err) OK(t, err)
lock, err := restic.NewLock(repo) lock, err := restic.NewLock(context.TODO(), repo)
Assert(t, err != nil, Assert(t, err != nil,
"create normal lock with exclusively locked repo didn't return an error") "create normal lock with exclusively locked repo didn't return an error")
Assert(t, restic.IsAlreadyLocked(err), Assert(t, restic.IsAlreadyLocked(err),
@ -78,10 +79,10 @@ func TestExclusiveLockOnLockedRepo(t *testing.T) {
repo, cleanup := repository.TestRepository(t) repo, cleanup := repository.TestRepository(t)
defer cleanup() defer cleanup()
elock, err := restic.NewLock(repo) elock, err := restic.NewLock(context.TODO(), repo)
OK(t, err) OK(t, err)
lock, err := restic.NewExclusiveLock(repo) lock, err := restic.NewExclusiveLock(context.TODO(), repo)
Assert(t, err != nil, Assert(t, err != nil,
"create normal lock with exclusively locked repo didn't return an error") "create normal lock with exclusively locked repo didn't return an error")
Assert(t, restic.IsAlreadyLocked(err), Assert(t, restic.IsAlreadyLocked(err),
@ -98,12 +99,12 @@ func createFakeLock(repo restic.Repository, t time.Time, pid int) (restic.ID, er
} }
newLock := &restic.Lock{Time: t, PID: pid, Hostname: hostname} newLock := &restic.Lock{Time: t, PID: pid, Hostname: hostname}
return repo.SaveJSONUnpacked(restic.LockFile, &newLock) return repo.SaveJSONUnpacked(context.TODO(), restic.LockFile, &newLock)
} }
func removeLock(repo restic.Repository, id restic.ID) error { func removeLock(repo restic.Repository, id restic.ID) error {
h := restic.Handle{Type: restic.LockFile, Name: id.String()} h := restic.Handle{Type: restic.LockFile, Name: id.String()}
return repo.Backend().Remove(h) return repo.Backend().Remove(context.TODO(), h)
} }
var staleLockTests = []struct { var staleLockTests = []struct {
@ -164,7 +165,7 @@ func TestLockStale(t *testing.T) {
func lockExists(repo restic.Repository, t testing.TB, id restic.ID) bool { func lockExists(repo restic.Repository, t testing.TB, id restic.ID) bool {
h := restic.Handle{Type: restic.LockFile, Name: id.String()} h := restic.Handle{Type: restic.LockFile, Name: id.String()}
exists, err := repo.Backend().Test(h) exists, err := repo.Backend().Test(context.TODO(), h)
OK(t, err) OK(t, err)
return exists return exists
@ -183,7 +184,7 @@ func TestLockWithStaleLock(t *testing.T) {
id3, err := createFakeLock(repo, time.Now().Add(-time.Minute), os.Getpid()+500000) id3, err := createFakeLock(repo, time.Now().Add(-time.Minute), os.Getpid()+500000)
OK(t, err) OK(t, err)
OK(t, restic.RemoveStaleLocks(repo)) OK(t, restic.RemoveStaleLocks(context.TODO(), repo))
Assert(t, lockExists(repo, t, id1) == false, Assert(t, lockExists(repo, t, id1) == false,
"stale lock still exists after RemoveStaleLocks was called") "stale lock still exists after RemoveStaleLocks was called")
@ -208,7 +209,7 @@ func TestRemoveAllLocks(t *testing.T) {
id3, err := createFakeLock(repo, time.Now().Add(-time.Minute), os.Getpid()+500000) id3, err := createFakeLock(repo, time.Now().Add(-time.Minute), os.Getpid()+500000)
OK(t, err) OK(t, err)
OK(t, restic.RemoveAllLocks(repo)) OK(t, restic.RemoveAllLocks(context.TODO(), repo))
Assert(t, lockExists(repo, t, id1) == false, Assert(t, lockExists(repo, t, id1) == false,
"lock still exists after RemoveAllLocks was called") "lock still exists after RemoveAllLocks was called")
@ -222,21 +223,21 @@ func TestLockRefresh(t *testing.T) {
repo, cleanup := repository.TestRepository(t) repo, cleanup := repository.TestRepository(t)
defer cleanup() defer cleanup()
lock, err := restic.NewLock(repo) lock, err := restic.NewLock(context.TODO(), repo)
OK(t, err) OK(t, err)
var lockID *restic.ID var lockID *restic.ID
for id := range repo.List(restic.LockFile, nil) { for id := range repo.List(context.TODO(), restic.LockFile) {
if lockID != nil { if lockID != nil {
t.Error("more than one lock found") t.Error("more than one lock found")
} }
lockID = &id lockID = &id
} }
OK(t, lock.Refresh()) OK(t, lock.Refresh(context.TODO()))
var lockID2 *restic.ID var lockID2 *restic.ID
for id := range repo.List(restic.LockFile, nil) { for id := range repo.List(context.TODO(), restic.LockFile) {
if lockID2 != nil { if lockID2 != nil {
t.Error("more than one lock found") t.Error("more than one lock found")
} }

View file

@ -1,6 +1,7 @@
package mock package mock
import ( import (
"context"
"io" "io"
"restic" "restic"
@ -10,13 +11,13 @@ import (
// Backend implements a mock backend. // Backend implements a mock backend.
type Backend struct { type Backend struct {
CloseFn func() error CloseFn func() error
SaveFn func(h restic.Handle, rd io.Reader) error SaveFn func(ctx context.Context, h restic.Handle, rd io.Reader) error
LoadFn func(h restic.Handle, length int, offset int64) (io.ReadCloser, error) LoadFn func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error)
StatFn func(h restic.Handle) (restic.FileInfo, error) StatFn func(ctx context.Context, h restic.Handle) (restic.FileInfo, error)
ListFn func(restic.FileType, <-chan struct{}) <-chan string ListFn func(ctx context.Context, t restic.FileType) <-chan string
RemoveFn func(h restic.Handle) error RemoveFn func(ctx context.Context, h restic.Handle) error
TestFn func(h restic.Handle) (bool, error) TestFn func(ctx context.Context, h restic.Handle) (bool, error)
DeleteFn func() error DeleteFn func(ctx context.Context) error
LocationFn func() string LocationFn func() string
} }
@ -39,68 +40,68 @@ func (m *Backend) Location() string {
} }
// Save data in the backend. // Save data in the backend.
func (m *Backend) Save(h restic.Handle, rd io.Reader) error { func (m *Backend) Save(ctx context.Context, h restic.Handle, rd io.Reader) error {
if m.SaveFn == nil { if m.SaveFn == nil {
return errors.New("not implemented") return errors.New("not implemented")
} }
return m.SaveFn(h, rd) return m.SaveFn(ctx, h, rd)
} }
// Load loads data from the backend. // Load loads data from the backend.
func (m *Backend) Load(h restic.Handle, length int, offset int64) (io.ReadCloser, error) { func (m *Backend) Load(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) {
if m.LoadFn == nil { if m.LoadFn == nil {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
return m.LoadFn(h, length, offset) return m.LoadFn(ctx, h, length, offset)
} }
// Stat an object in the backend. // Stat an object in the backend.
func (m *Backend) Stat(h restic.Handle) (restic.FileInfo, error) { func (m *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) {
if m.StatFn == nil { if m.StatFn == nil {
return restic.FileInfo{}, errors.New("not implemented") return restic.FileInfo{}, errors.New("not implemented")
} }
return m.StatFn(h) return m.StatFn(ctx, h)
} }
// List items of type t. // List items of type t.
func (m *Backend) List(t restic.FileType, done <-chan struct{}) <-chan string { func (m *Backend) List(ctx context.Context, t restic.FileType) <-chan string {
if m.ListFn == nil { if m.ListFn == nil {
ch := make(chan string) ch := make(chan string)
close(ch) close(ch)
return ch return ch
} }
return m.ListFn(t, done) return m.ListFn(ctx, t)
} }
// Remove data from the backend. // Remove data from the backend.
func (m *Backend) Remove(h restic.Handle) error { func (m *Backend) Remove(ctx context.Context, h restic.Handle) error {
if m.RemoveFn == nil { if m.RemoveFn == nil {
return errors.New("not implemented") return errors.New("not implemented")
} }
return m.RemoveFn(h) return m.RemoveFn(ctx, h)
} }
// Test for the existence of a specific item. // Test for the existence of a specific item.
func (m *Backend) Test(h restic.Handle) (bool, error) { func (m *Backend) Test(ctx context.Context, h restic.Handle) (bool, error) {
if m.TestFn == nil { if m.TestFn == nil {
return false, errors.New("not implemented") return false, errors.New("not implemented")
} }
return m.TestFn(h) return m.TestFn(ctx, h)
} }
// Delete all data. // Delete all data.
func (m *Backend) Delete() error { func (m *Backend) Delete(ctx context.Context) error {
if m.DeleteFn == nil { if m.DeleteFn == nil {
return errors.New("not implemented") return errors.New("not implemented")
} }
return m.DeleteFn() return m.DeleteFn(ctx)
} }
// Make sure that Backend implements the backend interface. // Make sure that Backend implements the backend interface.

View file

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -116,7 +117,7 @@ func (node Node) GetExtendedAttribute(a string) []byte {
} }
// CreateAt creates the node at the given path and restores all the meta data. // CreateAt creates the node at the given path and restores all the meta data.
func (node *Node) CreateAt(path string, repo Repository, idx *HardlinkIndex) error { func (node *Node) CreateAt(ctx context.Context, path string, repo Repository, idx *HardlinkIndex) error {
debug.Log("create node %v at %v", node.Name, path) debug.Log("create node %v at %v", node.Name, path)
switch node.Type { switch node.Type {
@ -125,7 +126,7 @@ func (node *Node) CreateAt(path string, repo Repository, idx *HardlinkIndex) err
return err return err
} }
case "file": case "file":
if err := node.createFileAt(path, repo, idx); err != nil { if err := node.createFileAt(ctx, path, repo, idx); err != nil {
return err return err
} }
case "symlink": case "symlink":
@ -228,7 +229,7 @@ func (node Node) createDirAt(path string) error {
return nil return nil
} }
func (node Node) createFileAt(path string, repo Repository, idx *HardlinkIndex) error { func (node Node) createFileAt(ctx context.Context, path string, repo Repository, idx *HardlinkIndex) error {
if node.Links > 1 && idx.Has(node.Inode, node.DeviceID) { if node.Links > 1 && idx.Has(node.Inode, node.DeviceID) {
if err := fs.Remove(path); !os.IsNotExist(err) { if err := fs.Remove(path); !os.IsNotExist(err) {
return errors.Wrap(err, "RemoveCreateHardlink") return errors.Wrap(err, "RemoveCreateHardlink")
@ -259,7 +260,7 @@ func (node Node) createFileAt(path string, repo Repository, idx *HardlinkIndex)
buf = NewBlobBuffer(int(size)) buf = NewBlobBuffer(int(size))
} }
n, err := repo.LoadBlob(DataBlob, id, buf) n, err := repo.LoadBlob(ctx, DataBlob, id, buf)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package restic_test package restic_test
import ( import (
"context"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -180,7 +181,7 @@ func TestNodeRestoreAt(t *testing.T) {
for _, test := range nodeTests { for _, test := range nodeTests {
nodePath := filepath.Join(tempdir, test.Name) nodePath := filepath.Join(tempdir, test.Name)
OK(t, test.CreateAt(nodePath, nil, idx)) OK(t, test.CreateAt(context.TODO(), nodePath, nil, idx))
if test.Type == "symlink" && runtime.GOOS == "windows" { if test.Type == "symlink" && runtime.GOOS == "windows" {
continue continue

View file

@ -2,6 +2,7 @@ package pack_test
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
@ -126,7 +127,7 @@ func TestUnpackReadSeeker(t *testing.T) {
id := restic.Hash(packData) id := restic.Hash(packData)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()} handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
OK(t, b.Save(handle, bytes.NewReader(packData))) OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData)))
verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize)
} }
@ -139,6 +140,6 @@ func TestShortPack(t *testing.T) {
id := restic.Hash(packData) id := restic.Hash(packData)
handle := restic.Handle{Type: restic.DataFile, Name: id.String()} handle := restic.Handle{Type: restic.DataFile, Name: id.String()}
OK(t, b.Save(handle, bytes.NewReader(packData))) OK(t, b.Save(context.TODO(), handle, bytes.NewReader(packData)))
verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize) verifyBlobs(t, bufs, k, restic.ReaderAt(b, handle), packSize)
} }

View file

@ -1,6 +1,7 @@
package pipe package pipe
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -78,7 +79,7 @@ func readDirNames(dirname string) ([]string, error) {
// dirs). If false is returned, files are ignored and dirs are not even walked. // dirs). If false is returned, files are ignored and dirs are not even walked.
type SelectFunc func(item string, fi os.FileInfo) bool type SelectFunc func(item string, fi os.FileInfo) bool
func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs chan<- Job, res chan<- Result) (excluded bool) { func walk(ctx context.Context, basedir, dir string, selectFunc SelectFunc, jobs chan<- Job, res chan<- Result) (excluded bool) {
debug.Log("start on %q, basedir %q", dir, basedir) debug.Log("start on %q, basedir %q", dir, basedir)
relpath, err := filepath.Rel(basedir, dir) relpath, err := filepath.Rel(basedir, dir)
@ -92,7 +93,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
debug.Log("error for %v: %v, res %p", dir, err, res) debug.Log("error for %v: %v, res %p", dir, err, res)
select { select {
case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}: case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}:
case <-done: case <-ctx.Done():
} }
return return
} }
@ -107,7 +108,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
debug.Log("sending file job for %v, res %p", dir, res) debug.Log("sending file job for %v, res %p", dir, res)
select { select {
case jobs <- Entry{info: info, basedir: basedir, path: relpath, result: res}: case jobs <- Entry{info: info, basedir: basedir, path: relpath, result: res}:
case <-done: case <-ctx.Done():
} }
return return
} }
@ -117,7 +118,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
if err != nil { if err != nil {
debug.Log("Readdirnames(%v) returned error: %v, res %p", dir, err, res) debug.Log("Readdirnames(%v) returned error: %v, res %p", dir, err, res)
select { select {
case <-done: case <-ctx.Done():
case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}: case jobs <- Dir{basedir: basedir, path: relpath, info: info, error: err, result: res}:
} }
return return
@ -146,7 +147,7 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
debug.Log("sending file job for %v, err %v, res %p", subpath, err, res) debug.Log("sending file job for %v, err %v, res %p", subpath, err, res)
select { select {
case jobs <- Entry{info: fi, error: statErr, basedir: basedir, path: filepath.Join(relpath, name), result: ch}: case jobs <- Entry{info: fi, error: statErr, basedir: basedir, path: filepath.Join(relpath, name), result: ch}:
case <-done: case <-ctx.Done():
return return
} }
continue continue
@ -156,13 +157,13 @@ func walk(basedir, dir string, selectFunc SelectFunc, done <-chan struct{}, jobs
// between walk and open // between walk and open
debug.RunHook("pipe.walk2", filepath.Join(relpath, name)) debug.RunHook("pipe.walk2", filepath.Join(relpath, name))
walk(basedir, subpath, selectFunc, done, jobs, ch) walk(ctx, basedir, subpath, selectFunc, jobs, ch)
} }
debug.Log("sending dirjob for %q, basedir %q, res %p", dir, basedir, res) debug.Log("sending dirjob for %q, basedir %q, res %p", dir, basedir, res)
select { select {
case jobs <- Dir{basedir: basedir, path: relpath, info: info, Entries: entries, result: res}: case jobs <- Dir{basedir: basedir, path: relpath, info: info, Entries: entries, result: res}:
case <-done: case <-ctx.Done():
} }
return return
@ -191,7 +192,7 @@ func cleanupPath(path string) ([]string, error) {
// Walk sends a Job for each file and directory it finds below the paths. When // Walk sends a Job for each file and directory it finds below the paths. When
// the channel done is closed, processing stops. // the channel done is closed, processing stops.
func Walk(walkPaths []string, selectFunc SelectFunc, done chan struct{}, jobs chan<- Job, res chan<- Result) { func Walk(ctx context.Context, walkPaths []string, selectFunc SelectFunc, jobs chan<- Job, res chan<- Result) {
var paths []string var paths []string
for _, p := range walkPaths { for _, p := range walkPaths {
@ -215,7 +216,7 @@ func Walk(walkPaths []string, selectFunc SelectFunc, done chan struct{}, jobs ch
for _, path := range paths { for _, path := range paths {
debug.Log("start walker for %v", path) debug.Log("start walker for %v", path)
ch := make(chan Result, 1) ch := make(chan Result, 1)
excluded := walk(filepath.Dir(path), path, selectFunc, done, jobs, ch) excluded := walk(ctx, filepath.Dir(path), path, selectFunc, jobs, ch)
if excluded { if excluded {
debug.Log("walker for %v done, it was excluded by the filter", path) debug.Log("walker for %v done, it was excluded by the filter", path)
@ -228,7 +229,7 @@ func Walk(walkPaths []string, selectFunc SelectFunc, done chan struct{}, jobs ch
debug.Log("sending root node, res %p", res) debug.Log("sending root node, res %p", res)
select { select {
case <-done: case <-ctx.Done():
return return
case jobs <- Dir{Entries: entries, result: res}: case jobs <- Dir{Entries: entries, result: res}:
} }

View file

@ -1,6 +1,7 @@
package pipe_test package pipe_test
import ( import (
"context"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -127,7 +128,7 @@ func TestPipelineWalkerWithSplit(t *testing.T) {
}() }()
resCh := make(chan pipe.Result, 1) resCh := make(chan pipe.Result, 1)
pipe.Walk([]string{TestWalkerPath}, acceptAll, done, jobs, resCh) pipe.Walk(context.TODO(), []string{TestWalkerPath}, acceptAll, jobs, resCh)
// wait for all workers to terminate // wait for all workers to terminate
wg.Wait() wg.Wait()
@ -146,6 +147,9 @@ func TestPipelineWalker(t *testing.T) {
t.Skipf("walkerpath not set, skipping TestPipelineWalker") t.Skipf("walkerpath not set, skipping TestPipelineWalker")
} }
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
var err error var err error
if !filepath.IsAbs(TestWalkerPath) { if !filepath.IsAbs(TestWalkerPath) {
TestWalkerPath, err = filepath.Abs(TestWalkerPath) TestWalkerPath, err = filepath.Abs(TestWalkerPath)
@ -164,7 +168,7 @@ func TestPipelineWalker(t *testing.T) {
after := stats{} after := stats{}
m := sync.Mutex{} m := sync.Mutex{}
worker := func(wg *sync.WaitGroup, done <-chan struct{}, jobs <-chan pipe.Job) { worker := func(ctx context.Context, wg *sync.WaitGroup, jobs <-chan pipe.Job) {
defer wg.Done() defer wg.Done()
for { for {
select { select {
@ -195,7 +199,7 @@ func TestPipelineWalker(t *testing.T) {
j.Result() <- true j.Result() <- true
} }
case <-done: case <-ctx.Done():
// pipeline was cancelled // pipeline was cancelled
return return
} }
@ -203,16 +207,15 @@ func TestPipelineWalker(t *testing.T) {
} }
var wg sync.WaitGroup var wg sync.WaitGroup
done := make(chan struct{})
jobs := make(chan pipe.Job) jobs := make(chan pipe.Job)
for i := 0; i < maxWorkers; i++ { for i := 0; i < maxWorkers; i++ {
wg.Add(1) wg.Add(1)
go worker(&wg, done, jobs) go worker(ctx, &wg, jobs)
} }
resCh := make(chan pipe.Result, 1) resCh := make(chan pipe.Result, 1)
pipe.Walk([]string{TestWalkerPath}, acceptAll, done, jobs, resCh) pipe.Walk(ctx, []string{TestWalkerPath}, acceptAll, jobs, resCh)
// wait for all workers to terminate // wait for all workers to terminate
wg.Wait() wg.Wait()
@ -286,11 +289,12 @@ func TestPipeWalkerError(t *testing.T) {
OK(t, os.RemoveAll(testdir)) OK(t, os.RemoveAll(testdir))
}) })
done := make(chan struct{}) ctx, cancel := context.WithCancel(context.TODO())
ch := make(chan pipe.Job) ch := make(chan pipe.Job)
resCh := make(chan pipe.Result, 1) resCh := make(chan pipe.Result, 1)
go pipe.Walk([]string{dir}, acceptAll, done, ch, resCh) go pipe.Walk(ctx, []string{dir}, acceptAll, ch, resCh)
i := 0 i := 0
for job := range ch { for job := range ch {
@ -321,7 +325,7 @@ func TestPipeWalkerError(t *testing.T) {
t.Errorf("expected %d jobs, got %d", len(testjobs), i) t.Errorf("expected %d jobs, got %d", len(testjobs), i)
} }
close(done) cancel()
Assert(t, ranHook, "hook did not run") Assert(t, ranHook, "hook did not run")
OK(t, os.RemoveAll(dir)) OK(t, os.RemoveAll(dir))
@ -335,7 +339,7 @@ func BenchmarkPipelineWalker(b *testing.B) {
var max time.Duration var max time.Duration
m := sync.Mutex{} m := sync.Mutex{}
fileWorker := func(wg *sync.WaitGroup, done <-chan struct{}, ch <-chan pipe.Entry) { fileWorker := func(ctx context.Context, wg *sync.WaitGroup, ch <-chan pipe.Entry) {
defer wg.Done() defer wg.Done()
for { for {
select { select {
@ -349,14 +353,14 @@ func BenchmarkPipelineWalker(b *testing.B) {
//time.Sleep(10 * time.Millisecond) //time.Sleep(10 * time.Millisecond)
e.Result() <- true e.Result() <- true
case <-done: case <-ctx.Done():
// pipeline was cancelled // pipeline was cancelled
return return
} }
} }
} }
dirWorker := func(wg *sync.WaitGroup, done <-chan struct{}, ch <-chan pipe.Dir) { dirWorker := func(ctx context.Context, wg *sync.WaitGroup, ch <-chan pipe.Dir) {
defer wg.Done() defer wg.Done()
for { for {
select { select {
@ -381,16 +385,18 @@ func BenchmarkPipelineWalker(b *testing.B) {
m.Unlock() m.Unlock()
dir.Result() <- true dir.Result() <- true
case <-done: case <-ctx.Done():
// pipeline was cancelled // pipeline was cancelled
return return
} }
} }
} }
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
max = 0 max = 0
done := make(chan struct{})
entCh := make(chan pipe.Entry, 200) entCh := make(chan pipe.Entry, 200)
dirCh := make(chan pipe.Dir, 200) dirCh := make(chan pipe.Dir, 200)
@ -398,8 +404,8 @@ func BenchmarkPipelineWalker(b *testing.B) {
b.Logf("starting %d workers", maxWorkers) b.Logf("starting %d workers", maxWorkers)
for i := 0; i < maxWorkers; i++ { for i := 0; i < maxWorkers; i++ {
wg.Add(2) wg.Add(2)
go dirWorker(&wg, done, dirCh) go dirWorker(ctx, &wg, dirCh)
go fileWorker(&wg, done, entCh) go fileWorker(ctx, &wg, entCh)
} }
jobs := make(chan pipe.Job, 200) jobs := make(chan pipe.Job, 200)
@ -412,7 +418,7 @@ func BenchmarkPipelineWalker(b *testing.B) {
}() }()
resCh := make(chan pipe.Result, 1) resCh := make(chan pipe.Result, 1)
pipe.Walk([]string{TestWalkerPath}, acceptAll, done, jobs, resCh) pipe.Walk(ctx, []string{TestWalkerPath}, acceptAll, jobs, resCh)
// wait for all workers to terminate // wait for all workers to terminate
wg.Wait() wg.Wait()
@ -429,6 +435,9 @@ func TestPipelineWalkerMultiple(t *testing.T) {
t.Skipf("walkerpath not set, skipping TestPipelineWalker") t.Skipf("walkerpath not set, skipping TestPipelineWalker")
} }
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
paths, err := filepath.Glob(filepath.Join(TestWalkerPath, "*")) paths, err := filepath.Glob(filepath.Join(TestWalkerPath, "*"))
OK(t, err) OK(t, err)
@ -441,7 +450,7 @@ func TestPipelineWalkerMultiple(t *testing.T) {
after := stats{} after := stats{}
m := sync.Mutex{} m := sync.Mutex{}
worker := func(wg *sync.WaitGroup, done <-chan struct{}, jobs <-chan pipe.Job) { worker := func(ctx context.Context, wg *sync.WaitGroup, jobs <-chan pipe.Job) {
defer wg.Done() defer wg.Done()
for { for {
select { select {
@ -472,7 +481,7 @@ func TestPipelineWalkerMultiple(t *testing.T) {
j.Result() <- true j.Result() <- true
} }
case <-done: case <-ctx.Done():
// pipeline was cancelled // pipeline was cancelled
return return
} }
@ -480,16 +489,15 @@ func TestPipelineWalkerMultiple(t *testing.T) {
} }
var wg sync.WaitGroup var wg sync.WaitGroup
done := make(chan struct{})
jobs := make(chan pipe.Job) jobs := make(chan pipe.Job)
for i := 0; i < maxWorkers; i++ { for i := 0; i < maxWorkers; i++ {
wg.Add(1) wg.Add(1)
go worker(&wg, done, jobs) go worker(ctx, &wg, jobs)
} }
resCh := make(chan pipe.Result, 1) resCh := make(chan pipe.Result, 1)
pipe.Walk(paths, acceptAll, done, jobs, resCh) pipe.Walk(ctx, paths, acceptAll, jobs, resCh)
// wait for all workers to terminate // wait for all workers to terminate
wg.Wait() wg.Wait()
@ -547,9 +555,6 @@ func testPipeWalkerRootWithPath(path string, t *testing.T) {
t.Logf("paths in %v (pattern %q) expanded to %v items", path, pattern, len(rootPaths)) t.Logf("paths in %v (pattern %q) expanded to %v items", path, pattern, len(rootPaths))
done := make(chan struct{})
defer close(done)
jobCh := make(chan pipe.Job) jobCh := make(chan pipe.Job)
var jobs []pipe.Job var jobs []pipe.Job
@ -571,7 +576,7 @@ func testPipeWalkerRootWithPath(path string, t *testing.T) {
} }
resCh := make(chan pipe.Result, 1) resCh := make(chan pipe.Result, 1)
pipe.Walk([]string{path}, filter, done, jobCh, resCh) pipe.Walk(context.TODO(), []string{path}, filter, jobCh, resCh)
wg.Wait() wg.Wait()

View file

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"io" "io"
"restic/debug" "restic/debug"
) )
@ -11,7 +12,7 @@ type backendReaderAt struct {
} }
func (brd backendReaderAt) ReadAt(p []byte, offset int64) (n int, err error) { func (brd backendReaderAt) ReadAt(p []byte, offset int64) (n int, err error) {
return ReadAt(brd.be, brd.h, offset, p) return ReadAt(context.TODO(), brd.be, brd.h, offset, p)
} }
// ReaderAt returns an io.ReaderAt for a file in the backend. // ReaderAt returns an io.ReaderAt for a file in the backend.
@ -20,9 +21,9 @@ func ReaderAt(be Backend, h Handle) io.ReaderAt {
} }
// ReadAt reads from the backend handle h at the given position. // ReadAt reads from the backend handle h at the given position.
func ReadAt(be Backend, h Handle, offset int64, p []byte) (n int, err error) { func ReadAt(ctx context.Context, be Backend, h Handle, offset int64, p []byte) (n int, err error) {
debug.Log("ReadAt(%v) at %v, len %v", h, offset, len(p)) debug.Log("ReadAt(%v) at %v, len %v", h, offset, len(p))
rd, err := be.Load(h, len(p), offset) rd, err := be.Load(ctx, h, len(p), offset)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View file

@ -1,6 +1,9 @@
package restic package restic
import "restic/crypto" import (
"context"
"restic/crypto"
)
// Repository stores data in a backend. It provides high-level functions and // Repository stores data in a backend. It provides high-level functions and
// transparently encrypts/decrypts data. // transparently encrypts/decrypts data.
@ -14,40 +17,40 @@ type Repository interface {
SetIndex(Index) SetIndex(Index)
Index() Index Index() Index
SaveFullIndex() error SaveFullIndex(context.Context) error
SaveIndex() error SaveIndex(context.Context) error
LoadIndex() error LoadIndex(context.Context) error
Config() Config Config() Config
LookupBlobSize(ID, BlobType) (uint, error) LookupBlobSize(ID, BlobType) (uint, error)
List(FileType, <-chan struct{}) <-chan ID List(context.Context, FileType) <-chan ID
ListPack(ID) ([]Blob, int64, error) ListPack(context.Context, ID) ([]Blob, int64, error)
Flush() error Flush() error
SaveUnpacked(FileType, []byte) (ID, error) SaveUnpacked(context.Context, FileType, []byte) (ID, error)
SaveJSONUnpacked(FileType, interface{}) (ID, error) SaveJSONUnpacked(context.Context, FileType, interface{}) (ID, error)
LoadJSONUnpacked(FileType, ID, interface{}) error LoadJSONUnpacked(context.Context, FileType, ID, interface{}) error
LoadAndDecrypt(FileType, ID) ([]byte, error) LoadAndDecrypt(context.Context, FileType, ID) ([]byte, error)
LoadBlob(BlobType, ID, []byte) (int, error) LoadBlob(context.Context, BlobType, ID, []byte) (int, error)
SaveBlob(BlobType, []byte, ID) (ID, error) SaveBlob(context.Context, BlobType, []byte, ID) (ID, error)
LoadTree(ID) (*Tree, error) LoadTree(context.Context, ID) (*Tree, error)
SaveTree(t *Tree) (ID, error) SaveTree(context.Context, *Tree) (ID, error)
} }
// Deleter removes all data stored in a backend/repo. // Deleter removes all data stored in a backend/repo.
type Deleter interface { type Deleter interface {
Delete() error Delete(context.Context) error
} }
// Lister allows listing files in a backend. // Lister allows listing files in a backend.
type Lister interface { type Lister interface {
List(FileType, <-chan struct{}) <-chan string List(context.Context, FileType) <-chan string
} }
// Index keeps track of the blobs are stored within files. // Index keeps track of the blobs are stored within files.

View file

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"restic" "restic"
@ -519,10 +520,10 @@ func DecodeOldIndex(buf []byte) (idx *Index, err error) {
} }
// LoadIndexWithDecoder loads the index and decodes it with fn. // LoadIndexWithDecoder loads the index and decodes it with fn.
func LoadIndexWithDecoder(repo restic.Repository, id restic.ID, fn func([]byte) (*Index, error)) (idx *Index, err error) { func LoadIndexWithDecoder(ctx context.Context, repo restic.Repository, id restic.ID, fn func([]byte) (*Index, error)) (idx *Index, err error) {
debug.Log("Loading index %v", id.Str()) debug.Log("Loading index %v", id.Str())
buf, err := repo.LoadAndDecrypt(restic.IndexFile, id) buf, err := repo.LoadAndDecrypt(ctx, restic.IndexFile, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -2,6 +2,7 @@ package repository
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -58,12 +59,12 @@ var (
// createMasterKey creates a new master key in the given backend and encrypts // createMasterKey creates a new master key in the given backend and encrypts
// it with the password. // it with the password.
func createMasterKey(s *Repository, password string) (*Key, error) { func createMasterKey(s *Repository, password string) (*Key, error) {
return AddKey(s, password, nil) return AddKey(context.TODO(), s, password, nil)
} }
// OpenKey tries do decrypt the key specified by name with the given password. // OpenKey tries do decrypt the key specified by name with the given password.
func OpenKey(s *Repository, name string, password string) (*Key, error) { func OpenKey(ctx context.Context, s *Repository, name string, password string) (*Key, error) {
k, err := LoadKey(s, name) k, err := LoadKey(ctx, s, name)
if err != nil { if err != nil {
debug.Log("LoadKey(%v) returned error %v", name[:12], err) debug.Log("LoadKey(%v) returned error %v", name[:12], err)
return nil, err return nil, err
@ -113,19 +114,17 @@ func OpenKey(s *Repository, name string, password string) (*Key, error) {
// given password. If none could be found, ErrNoKeyFound is returned. When // given password. If none could be found, ErrNoKeyFound is returned. When
// maxKeys is reached, ErrMaxKeysReached is returned. When setting maxKeys to // maxKeys is reached, ErrMaxKeysReached is returned. When setting maxKeys to
// zero, all keys in the repo are checked. // zero, all keys in the repo are checked.
func SearchKey(s *Repository, password string, maxKeys int) (*Key, error) { func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int) (*Key, error) {
checked := 0 checked := 0
// try at most maxKeysForSearch keys in repo // try at most maxKeysForSearch keys in repo
done := make(chan struct{}) for name := range s.Backend().List(ctx, restic.KeyFile) {
defer close(done)
for name := range s.Backend().List(restic.KeyFile, done) {
if maxKeys > 0 && checked > maxKeys { if maxKeys > 0 && checked > maxKeys {
return nil, ErrMaxKeysReached return nil, ErrMaxKeysReached
} }
debug.Log("trying key %v", name[:12]) debug.Log("trying key %v", name[:12])
key, err := OpenKey(s, name, password) key, err := OpenKey(ctx, s, name, password)
if err != nil { if err != nil {
debug.Log("key %v returned error %v", name[:12], err) debug.Log("key %v returned error %v", name[:12], err)
@ -145,9 +144,9 @@ func SearchKey(s *Repository, password string, maxKeys int) (*Key, error) {
} }
// LoadKey loads a key from the backend. // LoadKey loads a key from the backend.
func LoadKey(s *Repository, name string) (k *Key, err error) { func LoadKey(ctx context.Context, s *Repository, name string) (k *Key, err error) {
h := restic.Handle{Type: restic.KeyFile, Name: name} h := restic.Handle{Type: restic.KeyFile, Name: name}
data, err := backend.LoadAll(s.be, h) data, err := backend.LoadAll(ctx, s.be, h)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -162,7 +161,7 @@ func LoadKey(s *Repository, name string) (k *Key, err error) {
} }
// AddKey adds a new key to an already existing repository. // AddKey adds a new key to an already existing repository.
func AddKey(s *Repository, password string, template *crypto.Key) (*Key, error) { func AddKey(ctx context.Context, s *Repository, password string, template *crypto.Key) (*Key, error) {
// make sure we have valid KDF parameters // make sure we have valid KDF parameters
if KDFParams == nil { if KDFParams == nil {
p, err := crypto.Calibrate(KDFTimeout, KDFMemory) p, err := crypto.Calibrate(KDFTimeout, KDFMemory)
@ -233,7 +232,7 @@ func AddKey(s *Repository, password string, template *crypto.Key) (*Key, error)
Name: restic.Hash(buf).String(), Name: restic.Hash(buf).String(),
} }
err = s.be.Save(h, bytes.NewReader(buf)) err = s.be.Save(ctx, h, bytes.NewReader(buf))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"io" "io"
"os" "os"
@ -18,7 +19,7 @@ import (
// Saver implements saving data in a backend. // Saver implements saving data in a backend.
type Saver interface { type Saver interface {
Save(restic.Handle, io.Reader) error Save(context.Context, restic.Handle, io.Reader) error
} }
// Packer holds a pack.Packer together with a hash writer. // Packer holds a pack.Packer together with a hash writer.
@ -118,7 +119,7 @@ func (r *Repository) savePacker(p *Packer) error {
id := restic.IDFromHash(p.hw.Sum(nil)) id := restic.IDFromHash(p.hw.Sum(nil))
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
err = r.be.Save(h, p.tmpfile) err = r.be.Save(context.TODO(), h, p.tmpfile)
if err != nil { if err != nil {
debug.Log("Save(%v) error: %v", h, err) debug.Log("Save(%v) error: %v", h, err)
return err return err

View file

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"io" "io"
"math/rand" "math/rand"
"os" "os"
@ -52,7 +53,7 @@ func saveFile(t testing.TB, be Saver, f *os.File, id restic.ID) {
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
t.Logf("save file %v", h) t.Logf("save file %v", h)
if err := be.Save(h, f); err != nil { if err := be.Save(context.TODO(), h, f); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -145,7 +146,7 @@ func BenchmarkPackerManager(t *testing.B) {
rnd := newRandReader(rand.NewSource(23)) rnd := newRandReader(rand.NewSource(23))
be := &mock.Backend{ be := &mock.Backend{
SaveFn: func(restic.Handle, io.Reader) error { return nil }, SaveFn: func(context.Context, restic.Handle, io.Reader) error { return nil },
} }
blobBuf := make([]byte, maxBlobSize) blobBuf := make([]byte, maxBlobSize)

View file

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"restic" "restic"
"sync" "sync"
@ -18,24 +19,19 @@ func closeIfOpen(ch chan struct{}) {
} }
// ParallelWorkFunc gets one file ID to work on. If an error is returned, // ParallelWorkFunc gets one file ID to work on. If an error is returned,
// processing stops. If done is closed, the function should return. // processing stops. When the contect is cancelled the function should return.
type ParallelWorkFunc func(id string, done <-chan struct{}) error type ParallelWorkFunc func(ctx context.Context, id string) error
// ParallelIDWorkFunc gets one restic.ID to work on. If an error is returned, // ParallelIDWorkFunc gets one restic.ID to work on. If an error is returned,
// processing stops. If done is closed, the function should return. // processing stops. When the context is cancelled the function should return.
type ParallelIDWorkFunc func(id restic.ID, done <-chan struct{}) error type ParallelIDWorkFunc func(ctx context.Context, id restic.ID) error
// FilesInParallel runs n workers of f in parallel, on the IDs that // FilesInParallel runs n workers of f in parallel, on the IDs that
// repo.List(t) yield. If f returns an error, the process is aborted and the // repo.List(t) yield. If f returns an error, the process is aborted and the
// first error is returned. // first error is returned.
func FilesInParallel(repo restic.Lister, t restic.FileType, n uint, f ParallelWorkFunc) error { func FilesInParallel(ctx context.Context, repo restic.Lister, t restic.FileType, n uint, f ParallelWorkFunc) error {
done := make(chan struct{})
defer closeIfOpen(done)
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
ch := repo.List(ctx, t)
ch := repo.List(t, done)
errors := make(chan error, n) errors := make(chan error, n)
for i := 0; uint(i) < n; i++ { for i := 0; uint(i) < n; i++ {
@ -50,13 +46,12 @@ func FilesInParallel(repo restic.Lister, t restic.FileType, n uint, f ParallelWo
return return
} }
err := f(id, done) err := f(ctx, id)
if err != nil { if err != nil {
closeIfOpen(done)
errors <- err errors <- err
return return
} }
case <-done: case <-ctx.Done():
return return
} }
} }
@ -79,13 +74,13 @@ func FilesInParallel(repo restic.Lister, t restic.FileType, n uint, f ParallelWo
// function that takes a string. Filenames that do not parse as a restic.ID // function that takes a string. Filenames that do not parse as a restic.ID
// are ignored. // are ignored.
func ParallelWorkFuncParseID(f ParallelIDWorkFunc) ParallelWorkFunc { func ParallelWorkFuncParseID(f ParallelIDWorkFunc) ParallelWorkFunc {
return func(s string, done <-chan struct{}) error { return func(ctx context.Context, s string) error {
id, err := restic.ParseID(s) id, err := restic.ParseID(s)
if err != nil { if err != nil {
debug.Log("invalid ID %q: %v", id, err) debug.Log("invalid ID %q: %v", id, err)
return err return err
} }
return f(id, done) return f(ctx, id)
} }
} }

View file

@ -1,6 +1,7 @@
package repository_test package repository_test
import ( import (
"context"
"math/rand" "math/rand"
"restic" "restic"
"testing" "testing"
@ -73,7 +74,7 @@ var lister = testIDs{
"34dd044c228727f2226a0c9c06a3e5ceb5e30e31cb7854f8fa1cde846b395a58", "34dd044c228727f2226a0c9c06a3e5ceb5e30e31cb7854f8fa1cde846b395a58",
} }
func (tests testIDs) List(t restic.FileType, done <-chan struct{}) <-chan string { func (tests testIDs) List(ctx context.Context, t restic.FileType) <-chan string {
ch := make(chan string) ch := make(chan string)
go func() { go func() {
@ -83,7 +84,7 @@ func (tests testIDs) List(t restic.FileType, done <-chan struct{}) <-chan string
for _, id := range tests { for _, id := range tests {
select { select {
case ch <- id: case ch <- id:
case <-done: case <-ctx.Done():
return return
} }
} }
@ -94,13 +95,13 @@ func (tests testIDs) List(t restic.FileType, done <-chan struct{}) <-chan string
} }
func TestFilesInParallel(t *testing.T) { func TestFilesInParallel(t *testing.T) {
f := func(id string, done <-chan struct{}) error { f := func(ctx context.Context, id string) error {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
return nil return nil
} }
for n := uint(1); n < 5; n++ { for n := uint(1); n < 5; n++ {
err := repository.FilesInParallel(lister, restic.DataFile, n*100, f) err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f)
OK(t, err) OK(t, err)
} }
} }
@ -109,7 +110,7 @@ var errTest = errors.New("test error")
func TestFilesInParallelWithError(t *testing.T) { func TestFilesInParallelWithError(t *testing.T) {
f := func(id string, done <-chan struct{}) error { f := func(ctx context.Context, id string) error {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
if rand.Float32() < 0.01 { if rand.Float32() < 0.01 {
@ -120,7 +121,7 @@ func TestFilesInParallelWithError(t *testing.T) {
} }
for n := uint(1); n < 5; n++ { for n := uint(1); n < 5; n++ {
err := repository.FilesInParallel(lister, restic.DataFile, n*100, f) err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f)
Equals(t, errTest, err) Equals(t, errTest, err)
} }
} }

View file

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"io" "io"
"restic" "restic"
@ -17,7 +18,7 @@ import (
// these packs. Each pack is loaded and the blobs listed in keepBlobs is saved // these packs. Each pack is loaded and the blobs listed in keepBlobs is saved
// into a new pack. Afterwards, the packs are removed. This operation requires // into a new pack. Afterwards, the packs are removed. This operation requires
// an exclusive lock on the repo. // an exclusive lock on the repo.
func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *restic.Progress) (err error) { func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *restic.Progress) (err error) {
debug.Log("repacking %d packs while keeping %d blobs", len(packs), len(keepBlobs)) debug.Log("repacking %d packs while keeping %d blobs", len(packs), len(keepBlobs))
for packID := range packs { for packID := range packs {
@ -29,7 +30,7 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet
return errors.Wrap(err, "TempFile") return errors.Wrap(err, "TempFile")
} }
beRd, err := repo.Backend().Load(h, 0, 0) beRd, err := repo.Backend().Load(ctx, h, 0, 0)
if err != nil { if err != nil {
return err return err
} }
@ -100,7 +101,7 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet
h, tempfile.Name(), id) h, tempfile.Name(), id)
} }
_, err = repo.SaveBlob(entry.Type, buf, entry.ID) _, err = repo.SaveBlob(ctx, entry.Type, buf, entry.ID)
if err != nil { if err != nil {
return err return err
} }
@ -128,7 +129,7 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet
for packID := range packs { for packID := range packs {
h := restic.Handle{Type: restic.DataFile, Name: packID.String()} h := restic.Handle{Type: restic.DataFile, Name: packID.String()}
err := repo.Backend().Remove(h) err := repo.Backend().Remove(ctx, h)
if err != nil { if err != nil {
debug.Log("error removing pack %v: %v", packID.Str(), err) debug.Log("error removing pack %v: %v", packID.Str(), err)
return err return err

View file

@ -1,6 +1,7 @@
package repository_test package repository_test
import ( import (
"context"
"io" "io"
"math/rand" "math/rand"
"restic" "restic"
@ -47,7 +48,7 @@ func createRandomBlobs(t testing.TB, repo restic.Repository, blobs int, pData fl
continue continue
} }
_, err := repo.SaveBlob(tpe, buf, id) _, err := repo.SaveBlob(context.TODO(), tpe, buf, id)
if err != nil { if err != nil {
t.Fatalf("SaveFrom() error %v", err) t.Fatalf("SaveFrom() error %v", err)
} }
@ -67,16 +68,13 @@ func createRandomBlobs(t testing.TB, repo restic.Repository, blobs int, pData fl
// selectBlobs splits the list of all blobs randomly into two lists. A blob // selectBlobs splits the list of all blobs randomly into two lists. A blob
// will be contained in the firstone ith probability p. // will be contained in the firstone ith probability p.
func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 restic.BlobSet) { func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 restic.BlobSet) {
done := make(chan struct{})
defer close(done)
list1 = restic.NewBlobSet() list1 = restic.NewBlobSet()
list2 = restic.NewBlobSet() list2 = restic.NewBlobSet()
blobs := restic.NewBlobSet() blobs := restic.NewBlobSet()
for id := range repo.List(restic.DataFile, done) { for id := range repo.List(context.TODO(), restic.DataFile) {
entries, _, err := repo.ListPack(id) entries, _, err := repo.ListPack(context.TODO(), id)
if err != nil { if err != nil {
t.Fatalf("error listing pack %v: %v", id, err) t.Fatalf("error listing pack %v: %v", id, err)
} }
@ -102,11 +100,8 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2
} }
func listPacks(t *testing.T, repo restic.Repository) restic.IDSet { func listPacks(t *testing.T, repo restic.Repository) restic.IDSet {
done := make(chan struct{})
defer close(done)
list := restic.NewIDSet() list := restic.NewIDSet()
for id := range repo.List(restic.DataFile, done) { for id := range repo.List(context.TODO(), restic.DataFile) {
list.Insert(id) list.Insert(id)
} }
@ -132,35 +127,36 @@ func findPacksForBlobs(t *testing.T, repo restic.Repository, blobs restic.BlobSe
} }
func repack(t *testing.T, repo restic.Repository, packs restic.IDSet, blobs restic.BlobSet) { func repack(t *testing.T, repo restic.Repository, packs restic.IDSet, blobs restic.BlobSet) {
err := repository.Repack(repo, packs, blobs, nil) err := repository.Repack(context.TODO(), repo, packs, blobs, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
func saveIndex(t *testing.T, repo restic.Repository) { func saveIndex(t *testing.T, repo restic.Repository) {
if err := repo.SaveIndex(); err != nil { if err := repo.SaveIndex(context.TODO()); err != nil {
t.Fatalf("repo.SaveIndex() %v", err) t.Fatalf("repo.SaveIndex() %v", err)
} }
} }
func rebuildIndex(t *testing.T, repo restic.Repository) { func rebuildIndex(t *testing.T, repo restic.Repository) {
idx, err := index.New(repo, nil) idx, err := index.New(context.TODO(), repo, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
for id := range repo.List(restic.IndexFile, nil) { for id := range repo.List(context.TODO(), restic.IndexFile) {
err = repo.Backend().Remove(restic.Handle{ h := restic.Handle{
Type: restic.IndexFile, Type: restic.IndexFile,
Name: id.String(), Name: id.String(),
}) }
err = repo.Backend().Remove(context.TODO(), h)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
_, err = idx.Save(repo, nil) _, err = idx.Save(context.TODO(), repo, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -168,7 +164,7 @@ func rebuildIndex(t *testing.T, repo restic.Repository) {
func reloadIndex(t *testing.T, repo restic.Repository) { func reloadIndex(t *testing.T, repo restic.Repository) {
repo.SetIndex(repository.NewMasterIndex()) repo.SetIndex(repository.NewMasterIndex())
if err := repo.LoadIndex(); err != nil { if err := repo.LoadIndex(context.TODO()); err != nil {
t.Fatalf("error loading new index: %v", err) t.Fatalf("error loading new index: %v", err)
} }
} }

View file

@ -2,6 +2,7 @@ package repository
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -50,11 +51,11 @@ func (r *Repository) PrefixLength(t restic.FileType) (int, error) {
// LoadAndDecrypt loads and decrypts data identified by t and id from the // LoadAndDecrypt loads and decrypts data identified by t and id from the
// backend. // backend.
func (r *Repository) LoadAndDecrypt(t restic.FileType, id restic.ID) ([]byte, error) { func (r *Repository) LoadAndDecrypt(ctx context.Context, t restic.FileType, id restic.ID) ([]byte, error) {
debug.Log("load %v with id %v", t, id.Str()) debug.Log("load %v with id %v", t, id.Str())
h := restic.Handle{Type: t, Name: id.String()} h := restic.Handle{Type: t, Name: id.String()}
buf, err := backend.LoadAll(r.be, h) buf, err := backend.LoadAll(ctx, r.be, h)
if err != nil { if err != nil {
debug.Log("error loading %v: %v", h, err) debug.Log("error loading %v: %v", h, err)
return nil, err return nil, err
@ -76,7 +77,7 @@ func (r *Repository) LoadAndDecrypt(t restic.FileType, id restic.ID) ([]byte, er
// loadBlob tries to load and decrypt content identified by t and id from a // loadBlob tries to load and decrypt content identified by t and id from a
// pack from the backend, the result is stored in plaintextBuf, which must be // pack from the backend, the result is stored in plaintextBuf, which must be
// large enough to hold the complete blob. // large enough to hold the complete blob.
func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []byte) (int, error) { func (r *Repository) loadBlob(ctx context.Context, id restic.ID, t restic.BlobType, plaintextBuf []byte) (int, error) {
debug.Log("load %v with id %v (buf len %v, cap %d)", t, id.Str(), len(plaintextBuf), cap(plaintextBuf)) debug.Log("load %v with id %v (buf len %v, cap %d)", t, id.Str(), len(plaintextBuf), cap(plaintextBuf))
// lookup packs // lookup packs
@ -103,7 +104,7 @@ func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []by
plaintextBuf = plaintextBuf[:blob.Length] plaintextBuf = plaintextBuf[:blob.Length]
n, err := restic.ReadAt(r.be, h, int64(blob.Offset), plaintextBuf) n, err := restic.ReadAt(ctx, r.be, h, int64(blob.Offset), plaintextBuf)
if err != nil { if err != nil {
debug.Log("error loading blob %v: %v", blob, err) debug.Log("error loading blob %v: %v", blob, err)
lastError = err lastError = err
@ -143,8 +144,8 @@ func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []by
// LoadJSONUnpacked decrypts the data and afterwards calls json.Unmarshal on // LoadJSONUnpacked decrypts the data and afterwards calls json.Unmarshal on
// the item. // the item.
func (r *Repository) LoadJSONUnpacked(t restic.FileType, id restic.ID, item interface{}) (err error) { func (r *Repository) LoadJSONUnpacked(ctx context.Context, t restic.FileType, id restic.ID, item interface{}) (err error) {
buf, err := r.LoadAndDecrypt(t, id) buf, err := r.LoadAndDecrypt(ctx, t, id)
if err != nil { if err != nil {
return err return err
} }
@ -159,7 +160,7 @@ func (r *Repository) LookupBlobSize(id restic.ID, tpe restic.BlobType) (uint, er
// SaveAndEncrypt encrypts data and stores it to the backend as type t. If data // SaveAndEncrypt encrypts data and stores it to the backend as type t. If data
// is small enough, it will be packed together with other small blobs. // is small enough, it will be packed together with other small blobs.
func (r *Repository) SaveAndEncrypt(t restic.BlobType, data []byte, id *restic.ID) (restic.ID, error) { func (r *Repository) SaveAndEncrypt(ctx context.Context, t restic.BlobType, data []byte, id *restic.ID) (restic.ID, error) {
if id == nil { if id == nil {
// compute plaintext hash // compute plaintext hash
hashedID := restic.Hash(data) hashedID := restic.Hash(data)
@ -204,19 +205,19 @@ func (r *Repository) SaveAndEncrypt(t restic.BlobType, data []byte, id *restic.I
// SaveJSONUnpacked serialises item as JSON and encrypts and saves it in the // SaveJSONUnpacked serialises item as JSON and encrypts and saves it in the
// backend as type t, without a pack. It returns the storage hash. // backend as type t, without a pack. It returns the storage hash.
func (r *Repository) SaveJSONUnpacked(t restic.FileType, item interface{}) (restic.ID, error) { func (r *Repository) SaveJSONUnpacked(ctx context.Context, t restic.FileType, item interface{}) (restic.ID, error) {
debug.Log("save new blob %v", t) debug.Log("save new blob %v", t)
plaintext, err := json.Marshal(item) plaintext, err := json.Marshal(item)
if err != nil { if err != nil {
return restic.ID{}, errors.Wrap(err, "json.Marshal") return restic.ID{}, errors.Wrap(err, "json.Marshal")
} }
return r.SaveUnpacked(t, plaintext) return r.SaveUnpacked(ctx, t, plaintext)
} }
// SaveUnpacked encrypts data and stores it in the backend. Returned is the // SaveUnpacked encrypts data and stores it in the backend. Returned is the
// storage hash. // storage hash.
func (r *Repository) SaveUnpacked(t restic.FileType, p []byte) (id restic.ID, err error) { func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []byte) (id restic.ID, err error) {
ciphertext := restic.NewBlobBuffer(len(p)) ciphertext := restic.NewBlobBuffer(len(p))
ciphertext, err = r.Encrypt(ciphertext, p) ciphertext, err = r.Encrypt(ciphertext, p)
if err != nil { if err != nil {
@ -226,7 +227,7 @@ func (r *Repository) SaveUnpacked(t restic.FileType, p []byte) (id restic.ID, er
id = restic.Hash(ciphertext) id = restic.Hash(ciphertext)
h := restic.Handle{Type: t, Name: id.String()} h := restic.Handle{Type: t, Name: id.String()}
err = r.be.Save(h, bytes.NewReader(ciphertext)) err = r.be.Save(ctx, h, bytes.NewReader(ciphertext))
if err != nil { if err != nil {
debug.Log("error saving blob %v: %v", h, err) debug.Log("error saving blob %v: %v", h, err)
return restic.ID{}, err return restic.ID{}, err
@ -269,7 +270,7 @@ func (r *Repository) SetIndex(i restic.Index) {
} }
// SaveIndex saves an index in the repository. // SaveIndex saves an index in the repository.
func SaveIndex(repo restic.Repository, index *Index) (restic.ID, error) { func SaveIndex(ctx context.Context, repo restic.Repository, index *Index) (restic.ID, error) {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
err := index.Finalize(buf) err := index.Finalize(buf)
@ -277,15 +278,15 @@ func SaveIndex(repo restic.Repository, index *Index) (restic.ID, error) {
return restic.ID{}, err return restic.ID{}, err
} }
return repo.SaveUnpacked(restic.IndexFile, buf.Bytes()) return repo.SaveUnpacked(ctx, restic.IndexFile, buf.Bytes())
} }
// saveIndex saves all indexes in the backend. // saveIndex saves all indexes in the backend.
func (r *Repository) saveIndex(indexes ...*Index) error { func (r *Repository) saveIndex(ctx context.Context, indexes ...*Index) error {
for i, idx := range indexes { for i, idx := range indexes {
debug.Log("Saving index %d", i) debug.Log("Saving index %d", i)
sid, err := SaveIndex(r, idx) sid, err := SaveIndex(ctx, r, idx)
if err != nil { if err != nil {
return err return err
} }
@ -297,34 +298,34 @@ func (r *Repository) saveIndex(indexes ...*Index) error {
} }
// SaveIndex saves all new indexes in the backend. // SaveIndex saves all new indexes in the backend.
func (r *Repository) SaveIndex() error { func (r *Repository) SaveIndex(ctx context.Context) error {
return r.saveIndex(r.idx.NotFinalIndexes()...) return r.saveIndex(ctx, r.idx.NotFinalIndexes()...)
} }
// SaveFullIndex saves all full indexes in the backend. // SaveFullIndex saves all full indexes in the backend.
func (r *Repository) SaveFullIndex() error { func (r *Repository) SaveFullIndex(ctx context.Context) error {
return r.saveIndex(r.idx.FullIndexes()...) return r.saveIndex(ctx, r.idx.FullIndexes()...)
} }
const loadIndexParallelism = 20 const loadIndexParallelism = 20
// LoadIndex loads all index files from the backend in parallel and stores them // LoadIndex loads all index files from the backend in parallel and stores them
// in the master index. The first error that occurred is returned. // in the master index. The first error that occurred is returned.
func (r *Repository) LoadIndex() error { func (r *Repository) LoadIndex(ctx context.Context) error {
debug.Log("Loading index") debug.Log("Loading index")
errCh := make(chan error, 1) errCh := make(chan error, 1)
indexes := make(chan *Index) indexes := make(chan *Index)
worker := func(id restic.ID, done <-chan struct{}) error { worker := func(ctx context.Context, id restic.ID) error {
idx, err := LoadIndex(r, id) idx, err := LoadIndex(ctx, r, id)
if err != nil { if err != nil {
return err return err
} }
select { select {
case indexes <- idx: case indexes <- idx:
case <-done: case <-ctx.Done():
} }
return nil return nil
@ -332,7 +333,7 @@ func (r *Repository) LoadIndex() error {
go func() { go func() {
defer close(indexes) defer close(indexes)
errCh <- FilesInParallel(r.be, restic.IndexFile, loadIndexParallelism, errCh <- FilesInParallel(ctx, r.be, restic.IndexFile, loadIndexParallelism,
ParallelWorkFuncParseID(worker)) ParallelWorkFuncParseID(worker))
}() }()
@ -348,15 +349,15 @@ func (r *Repository) LoadIndex() error {
} }
// LoadIndex loads the index id from backend and returns it. // LoadIndex loads the index id from backend and returns it.
func LoadIndex(repo restic.Repository, id restic.ID) (*Index, error) { func LoadIndex(ctx context.Context, repo restic.Repository, id restic.ID) (*Index, error) {
idx, err := LoadIndexWithDecoder(repo, id, DecodeIndex) idx, err := LoadIndexWithDecoder(ctx, repo, id, DecodeIndex)
if err == nil { if err == nil {
return idx, nil return idx, nil
} }
if errors.Cause(err) == ErrOldIndexFormat { if errors.Cause(err) == ErrOldIndexFormat {
fmt.Fprintf(os.Stderr, "index %v has old format\n", id.Str()) fmt.Fprintf(os.Stderr, "index %v has old format\n", id.Str())
return LoadIndexWithDecoder(repo, id, DecodeOldIndex) return LoadIndexWithDecoder(ctx, repo, id, DecodeOldIndex)
} }
return nil, err return nil, err
@ -364,8 +365,8 @@ func LoadIndex(repo restic.Repository, id restic.ID) (*Index, error) {
// SearchKey finds a key with the supplied password, afterwards the config is // SearchKey finds a key with the supplied password, afterwards the config is
// read and parsed. It tries at most maxKeys key files in the repo. // read and parsed. It tries at most maxKeys key files in the repo.
func (r *Repository) SearchKey(password string, maxKeys int) error { func (r *Repository) SearchKey(ctx context.Context, password string, maxKeys int) error {
key, err := SearchKey(r, password, maxKeys) key, err := SearchKey(ctx, r, password, maxKeys)
if err != nil { if err != nil {
return err return err
} }
@ -373,14 +374,14 @@ func (r *Repository) SearchKey(password string, maxKeys int) error {
r.key = key.master r.key = key.master
r.packerManager.key = key.master r.packerManager.key = key.master
r.keyName = key.Name() r.keyName = key.Name()
r.cfg, err = restic.LoadConfig(r) r.cfg, err = restic.LoadConfig(ctx, r)
return err return err
} }
// Init creates a new master key with the supplied password, initializes and // Init creates a new master key with the supplied password, initializes and
// saves the repository config. // saves the repository config.
func (r *Repository) Init(password string) error { func (r *Repository) Init(ctx context.Context, password string) error {
has, err := r.be.Test(restic.Handle{Type: restic.ConfigFile}) has, err := r.be.Test(ctx, restic.Handle{Type: restic.ConfigFile})
if err != nil { if err != nil {
return err return err
} }
@ -393,12 +394,12 @@ func (r *Repository) Init(password string) error {
return err return err
} }
return r.init(password, cfg) return r.init(ctx, password, cfg)
} }
// init creates a new master key with the supplied password and uses it to save // init creates a new master key with the supplied password and uses it to save
// the config into the repo. // the config into the repo.
func (r *Repository) init(password string, cfg restic.Config) error { func (r *Repository) init(ctx context.Context, password string, cfg restic.Config) error {
key, err := createMasterKey(r, password) key, err := createMasterKey(r, password)
if err != nil { if err != nil {
return err return err
@ -408,7 +409,7 @@ func (r *Repository) init(password string, cfg restic.Config) error {
r.packerManager.key = key.master r.packerManager.key = key.master
r.keyName = key.Name() r.keyName = key.Name()
r.cfg = cfg r.cfg = cfg
_, err = r.SaveJSONUnpacked(restic.ConfigFile, cfg) _, err = r.SaveJSONUnpacked(ctx, restic.ConfigFile, cfg)
return err return err
} }
@ -443,15 +444,15 @@ func (r *Repository) KeyName() string {
} }
// List returns a channel that yields all IDs of type t in the backend. // List returns a channel that yields all IDs of type t in the backend.
func (r *Repository) List(t restic.FileType, done <-chan struct{}) <-chan restic.ID { func (r *Repository) List(ctx context.Context, t restic.FileType) <-chan restic.ID {
out := make(chan restic.ID) out := make(chan restic.ID)
go func() { go func() {
defer close(out) defer close(out)
for strID := range r.be.List(t, done) { for strID := range r.be.List(ctx, t) {
if id, err := restic.ParseID(strID); err == nil { if id, err := restic.ParseID(strID); err == nil {
select { select {
case out <- id: case out <- id:
case <-done: case <-ctx.Done():
return return
} }
} }
@ -462,10 +463,10 @@ func (r *Repository) List(t restic.FileType, done <-chan struct{}) <-chan restic
// ListPack returns the list of blobs saved in the pack id and the length of // ListPack returns the list of blobs saved in the pack id and the length of
// the file as stored in the backend. // the file as stored in the backend.
func (r *Repository) ListPack(id restic.ID) ([]restic.Blob, int64, error) { func (r *Repository) ListPack(ctx context.Context, id restic.ID) ([]restic.Blob, int64, error) {
h := restic.Handle{Type: restic.DataFile, Name: id.String()} h := restic.Handle{Type: restic.DataFile, Name: id.String()}
blobInfo, err := r.Backend().Stat(h) blobInfo, err := r.Backend().Stat(ctx, h)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -480,9 +481,9 @@ func (r *Repository) ListPack(id restic.ID) ([]restic.Blob, int64, error) {
// Delete calls backend.Delete() if implemented, and returns an error // Delete calls backend.Delete() if implemented, and returns an error
// otherwise. // otherwise.
func (r *Repository) Delete() error { func (r *Repository) Delete(ctx context.Context) error {
if b, ok := r.be.(restic.Deleter); ok { if b, ok := r.be.(restic.Deleter); ok {
return b.Delete() return b.Delete(ctx)
} }
return errors.New("Delete() called for backend that does not implement this method") return errors.New("Delete() called for backend that does not implement this method")
@ -496,7 +497,7 @@ func (r *Repository) Close() error {
// LoadBlob loads a blob of type t from the repository to the buffer. buf must // LoadBlob loads a blob of type t from the repository to the buffer. buf must
// be large enough to hold the encrypted blob, since it is used as scratch // be large enough to hold the encrypted blob, since it is used as scratch
// space. // space.
func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int, error) { func (r *Repository) LoadBlob(ctx context.Context, t restic.BlobType, id restic.ID, buf []byte) (int, error) {
debug.Log("load blob %v into buf (len %v, cap %v)", id.Str(), len(buf), cap(buf)) debug.Log("load blob %v into buf (len %v, cap %v)", id.Str(), len(buf), cap(buf))
size, err := r.idx.LookupSize(id, t) size, err := r.idx.LookupSize(id, t)
if err != nil { if err != nil {
@ -507,7 +508,7 @@ func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int,
return 0, errors.Errorf("buffer is too small for data blob (%d < %d)", cap(buf), restic.CiphertextLength(int(size))) return 0, errors.Errorf("buffer is too small for data blob (%d < %d)", cap(buf), restic.CiphertextLength(int(size)))
} }
n, err := r.loadBlob(id, t, buf) n, err := r.loadBlob(ctx, id, t, buf)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -520,16 +521,16 @@ func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int,
// SaveBlob saves a blob of type t into the repository. If id is the null id, it // SaveBlob saves a blob of type t into the repository. If id is the null id, it
// will be computed and returned. // will be computed and returned.
func (r *Repository) SaveBlob(t restic.BlobType, buf []byte, id restic.ID) (restic.ID, error) { func (r *Repository) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID) (restic.ID, error) {
var i *restic.ID var i *restic.ID
if !id.IsNull() { if !id.IsNull() {
i = &id i = &id
} }
return r.SaveAndEncrypt(t, buf, i) return r.SaveAndEncrypt(ctx, t, buf, i)
} }
// LoadTree loads a tree from the repository. // LoadTree loads a tree from the repository.
func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) { func (r *Repository) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) {
debug.Log("load tree %v", id.Str()) debug.Log("load tree %v", id.Str())
size, err := r.idx.LookupSize(id, restic.TreeBlob) size, err := r.idx.LookupSize(id, restic.TreeBlob)
@ -540,7 +541,7 @@ func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) {
debug.Log("size is %d, create buffer", size) debug.Log("size is %d, create buffer", size)
buf := restic.NewBlobBuffer(int(size)) buf := restic.NewBlobBuffer(int(size))
n, err := r.loadBlob(id, restic.TreeBlob, buf) n, err := r.loadBlob(ctx, id, restic.TreeBlob, buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -558,7 +559,7 @@ func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) {
// SaveTree stores a tree into the repository and returns the ID. The ID is // SaveTree stores a tree into the repository and returns the ID. The ID is
// checked against the index. The tree is only stored when the index does not // checked against the index. The tree is only stored when the index does not
// contain the ID. // contain the ID.
func (r *Repository) SaveTree(t *restic.Tree) (restic.ID, error) { func (r *Repository) SaveTree(ctx context.Context, t *restic.Tree) (restic.ID, error) {
buf, err := json.Marshal(t) buf, err := json.Marshal(t)
if err != nil { if err != nil {
return restic.ID{}, errors.Wrap(err, "MarshalJSON") return restic.ID{}, errors.Wrap(err, "MarshalJSON")
@ -573,6 +574,6 @@ func (r *Repository) SaveTree(t *restic.Tree) (restic.ID, error) {
return id, nil return id, nil
} }
_, err = r.SaveBlob(restic.TreeBlob, buf, id) _, err = r.SaveBlob(ctx, restic.TreeBlob, buf, id)
return id, err return id, err
} }

View file

@ -2,6 +2,7 @@ package repository_test
import ( import (
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"io" "io"
"math/rand" "math/rand"
@ -31,7 +32,7 @@ func TestSave(t *testing.T) {
id := restic.Hash(data) id := restic.Hash(data)
// save // save
sid, err := repo.SaveBlob(restic.DataBlob, data, restic.ID{}) sid, err := repo.SaveBlob(context.TODO(), restic.DataBlob, data, restic.ID{})
OK(t, err) OK(t, err)
Equals(t, id, sid) Equals(t, id, sid)
@ -41,7 +42,7 @@ func TestSave(t *testing.T) {
// read back // read back
buf := restic.NewBlobBuffer(size) buf := restic.NewBlobBuffer(size)
n, err := repo.LoadBlob(restic.DataBlob, id, buf) n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf)
OK(t, err) OK(t, err)
Equals(t, len(buf), n) Equals(t, len(buf), n)
@ -67,7 +68,7 @@ func TestSaveFrom(t *testing.T) {
id := restic.Hash(data) id := restic.Hash(data)
// save // save
id2, err := repo.SaveBlob(restic.DataBlob, data, id) id2, err := repo.SaveBlob(context.TODO(), restic.DataBlob, data, id)
OK(t, err) OK(t, err)
Equals(t, id, id2) Equals(t, id, id2)
@ -75,7 +76,7 @@ func TestSaveFrom(t *testing.T) {
// read back // read back
buf := restic.NewBlobBuffer(size) buf := restic.NewBlobBuffer(size)
n, err := repo.LoadBlob(restic.DataBlob, id, buf) n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf)
OK(t, err) OK(t, err)
Equals(t, len(buf), n) Equals(t, len(buf), n)
@ -106,7 +107,7 @@ func BenchmarkSaveAndEncrypt(t *testing.B) {
for i := 0; i < t.N; i++ { for i := 0; i < t.N; i++ {
// save // save
_, err = repo.SaveBlob(restic.DataBlob, data, id) _, err = repo.SaveBlob(context.TODO(), restic.DataBlob, data, id)
OK(t, err) OK(t, err)
} }
} }
@ -123,7 +124,7 @@ func TestLoadTree(t *testing.T) {
sn := archiver.TestSnapshot(t, repo, BenchArchiveDirectory, nil) sn := archiver.TestSnapshot(t, repo, BenchArchiveDirectory, nil)
OK(t, repo.Flush()) OK(t, repo.Flush())
_, err := repo.LoadTree(*sn.Tree) _, err := repo.LoadTree(context.TODO(), *sn.Tree)
OK(t, err) OK(t, err)
} }
@ -142,7 +143,7 @@ func BenchmarkLoadTree(t *testing.B) {
t.ResetTimer() t.ResetTimer()
for i := 0; i < t.N; i++ { for i := 0; i < t.N; i++ {
_, err := repo.LoadTree(*sn.Tree) _, err := repo.LoadTree(context.TODO(), *sn.Tree)
OK(t, err) OK(t, err)
} }
} }
@ -156,14 +157,14 @@ func TestLoadBlob(t *testing.T) {
_, err := io.ReadFull(rnd, buf) _, err := io.ReadFull(rnd, buf)
OK(t, err) OK(t, err)
id, err := repo.SaveBlob(restic.DataBlob, buf, restic.ID{}) id, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{})
OK(t, err) OK(t, err)
OK(t, repo.Flush()) OK(t, repo.Flush())
// first, test with buffers that are too small // first, test with buffers that are too small
for _, testlength := range []int{length - 20, length, restic.CiphertextLength(length) - 1} { for _, testlength := range []int{length - 20, length, restic.CiphertextLength(length) - 1} {
buf = make([]byte, 0, testlength) buf = make([]byte, 0, testlength)
n, err := repo.LoadBlob(restic.DataBlob, id, buf) n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf)
if err == nil { if err == nil {
t.Errorf("LoadBlob() did not return an error for a buffer that is too small to hold the blob") t.Errorf("LoadBlob() did not return an error for a buffer that is too small to hold the blob")
continue continue
@ -179,7 +180,7 @@ func TestLoadBlob(t *testing.T) {
base := restic.CiphertextLength(length) base := restic.CiphertextLength(length)
for _, testlength := range []int{base, base + 7, base + 15, base + 1000} { for _, testlength := range []int{base, base + 7, base + 15, base + 1000} {
buf = make([]byte, 0, testlength) buf = make([]byte, 0, testlength)
n, err := repo.LoadBlob(restic.DataBlob, id, buf) n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf)
if err != nil { if err != nil {
t.Errorf("LoadBlob() returned an error for buffer size %v: %v", testlength, err) t.Errorf("LoadBlob() returned an error for buffer size %v: %v", testlength, err)
continue continue
@ -201,7 +202,7 @@ func BenchmarkLoadBlob(b *testing.B) {
_, err := io.ReadFull(rnd, buf) _, err := io.ReadFull(rnd, buf)
OK(b, err) OK(b, err)
id, err := repo.SaveBlob(restic.DataBlob, buf, restic.ID{}) id, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{})
OK(b, err) OK(b, err)
OK(b, repo.Flush()) OK(b, repo.Flush())
@ -209,7 +210,7 @@ func BenchmarkLoadBlob(b *testing.B) {
b.SetBytes(int64(length)) b.SetBytes(int64(length))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
n, err := repo.LoadBlob(restic.DataBlob, id, buf) n, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, buf)
OK(b, err) OK(b, err)
if n != length { if n != length {
b.Errorf("wanted %d bytes, got %d", length, n) b.Errorf("wanted %d bytes, got %d", length, n)
@ -233,7 +234,7 @@ func BenchmarkLoadAndDecrypt(b *testing.B) {
dataID := restic.Hash(buf) dataID := restic.Hash(buf)
storageID, err := repo.SaveUnpacked(restic.DataFile, buf) storageID, err := repo.SaveUnpacked(context.TODO(), restic.DataFile, buf)
OK(b, err) OK(b, err)
// OK(b, repo.Flush()) // OK(b, repo.Flush())
@ -241,7 +242,7 @@ func BenchmarkLoadAndDecrypt(b *testing.B) {
b.SetBytes(int64(length)) b.SetBytes(int64(length))
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
data, err := repo.LoadAndDecrypt(restic.DataFile, storageID) data, err := repo.LoadAndDecrypt(context.TODO(), restic.DataFile, storageID)
OK(b, err) OK(b, err)
if len(data) != length { if len(data) != length {
b.Errorf("wanted %d bytes, got %d", length, len(data)) b.Errorf("wanted %d bytes, got %d", length, len(data))
@ -267,13 +268,13 @@ func TestLoadJSONUnpacked(t *testing.T) {
sn.Hostname = "foobar" sn.Hostname = "foobar"
sn.Username = "test!" sn.Username = "test!"
id, err := repo.SaveJSONUnpacked(restic.SnapshotFile, &sn) id, err := repo.SaveJSONUnpacked(context.TODO(), restic.SnapshotFile, &sn)
OK(t, err) OK(t, err)
var sn2 restic.Snapshot var sn2 restic.Snapshot
// restore // restore
err = repo.LoadJSONUnpacked(restic.SnapshotFile, id, &sn2) err = repo.LoadJSONUnpacked(context.TODO(), restic.SnapshotFile, id, &sn2)
OK(t, err) OK(t, err)
Equals(t, sn.Hostname, sn2.Hostname) Equals(t, sn.Hostname, sn2.Hostname)
@ -287,7 +288,7 @@ func TestRepositoryLoadIndex(t *testing.T) {
defer cleanup() defer cleanup()
repo := repository.TestOpenLocal(t, repodir) repo := repository.TestOpenLocal(t, repodir)
OK(t, repo.LoadIndex()) OK(t, repo.LoadIndex(context.TODO()))
} }
func BenchmarkLoadIndex(b *testing.B) { func BenchmarkLoadIndex(b *testing.B) {
@ -310,18 +311,18 @@ func BenchmarkLoadIndex(b *testing.B) {
}) })
} }
id, err := repository.SaveIndex(repo, idx) id, err := repository.SaveIndex(context.TODO(), repo, idx)
OK(b, err) OK(b, err)
b.Logf("index saved as %v (%v entries)", id.Str(), idx.Count(restic.DataBlob)) b.Logf("index saved as %v (%v entries)", id.Str(), idx.Count(restic.DataBlob))
fi, err := repo.Backend().Stat(restic.Handle{Type: restic.IndexFile, Name: id.String()}) fi, err := repo.Backend().Stat(context.TODO(), restic.Handle{Type: restic.IndexFile, Name: id.String()})
OK(b, err) OK(b, err)
b.Logf("filesize is %v", fi.Size) b.Logf("filesize is %v", fi.Size)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := repository.LoadIndex(repo, id) _, err := repository.LoadIndex(context.TODO(), repo, id)
OK(b, err) OK(b, err)
} }
} }
@ -335,7 +336,7 @@ func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax
_, err := io.ReadFull(rnd, buf) _, err := io.ReadFull(rnd, buf)
OK(t, err) OK(t, err)
_, err = repo.SaveBlob(restic.DataBlob, buf, restic.ID{}) _, err = repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{})
OK(t, err) OK(t, err)
} }
} }
@ -354,7 +355,7 @@ func TestRepositoryIncrementalIndex(t *testing.T) {
OK(t, repo.Flush()) OK(t, repo.Flush())
} }
OK(t, repo.SaveFullIndex()) OK(t, repo.SaveFullIndex(context.TODO()))
} }
// add another 5 packs // add another 5 packs
@ -364,12 +365,12 @@ func TestRepositoryIncrementalIndex(t *testing.T) {
} }
// save final index // save final index
OK(t, repo.SaveIndex()) OK(t, repo.SaveIndex(context.TODO()))
packEntries := make(map[restic.ID]map[restic.ID]struct{}) packEntries := make(map[restic.ID]map[restic.ID]struct{})
for id := range repo.List(restic.IndexFile, nil) { for id := range repo.List(context.TODO(), restic.IndexFile) {
idx, err := repository.LoadIndex(repo, id) idx, err := repository.LoadIndex(context.TODO(), repo, id)
OK(t, err) OK(t, err)
for pb := range idx.Each(nil) { for pb := range idx.Each(nil) {

View file

@ -1,6 +1,7 @@
package repository package repository
import ( import (
"context"
"os" "os"
"restic" "restic"
"restic/backend/local" "restic/backend/local"
@ -50,7 +51,7 @@ func TestRepositoryWithBackend(t testing.TB, be restic.Backend) (r restic.Reposi
repo := New(be) repo := New(be)
cfg := restic.TestCreateConfig(t, testChunkerPol) cfg := restic.TestCreateConfig(t, testChunkerPol)
err := repo.init(test.TestPassword, cfg) err := repo.init(context.TODO(), test.TestPassword, cfg)
if err != nil { if err != nil {
t.Fatalf("TestRepository(): initialize repo failed: %v", err) t.Fatalf("TestRepository(): initialize repo failed: %v", err)
} }
@ -94,7 +95,7 @@ func TestOpenLocal(t testing.TB, dir string) (r restic.Repository) {
} }
repo := New(be) repo := New(be)
err = repo.SearchKey(test.TestPassword, 10) err = repo.SearchKey(context.TODO(), test.TestPassword, 10)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
@ -30,7 +31,7 @@ func NewRestorer(repo Repository, id ID) (*Restorer, error) {
var err error var err error
r.sn, err = LoadSnapshot(repo, id) r.sn, err = LoadSnapshot(context.TODO(), repo, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -38,8 +39,8 @@ func NewRestorer(repo Repository, id ID) (*Restorer, error) {
return r, nil return r, nil
} }
func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkIndex) error { func (res *Restorer) restoreTo(ctx context.Context, dst string, dir string, treeID ID, idx *HardlinkIndex) error {
tree, err := res.repo.LoadTree(treeID) tree, err := res.repo.LoadTree(ctx, treeID)
if err != nil { if err != nil {
return res.Error(dir, nil, err) return res.Error(dir, nil, err)
} }
@ -50,7 +51,7 @@ func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkI
debug.Log("SelectForRestore returned %v", selectedForRestore) debug.Log("SelectForRestore returned %v", selectedForRestore)
if selectedForRestore { if selectedForRestore {
err := res.restoreNodeTo(node, dir, dst, idx) err := res.restoreNodeTo(ctx, node, dir, dst, idx)
if err != nil { if err != nil {
return err return err
} }
@ -62,7 +63,7 @@ func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkI
} }
subp := filepath.Join(dir, node.Name) subp := filepath.Join(dir, node.Name)
err = res.restoreTo(dst, subp, *node.Subtree, idx) err = res.restoreTo(ctx, dst, subp, *node.Subtree, idx)
if err != nil { if err != nil {
err = res.Error(subp, node, err) err = res.Error(subp, node, err)
if err != nil { if err != nil {
@ -83,11 +84,11 @@ func (res *Restorer) restoreTo(dst string, dir string, treeID ID, idx *HardlinkI
return nil return nil
} }
func (res *Restorer) restoreNodeTo(node *Node, dir string, dst string, idx *HardlinkIndex) error { func (res *Restorer) restoreNodeTo(ctx context.Context, node *Node, dir string, dst string, idx *HardlinkIndex) error {
debug.Log("node %v, dir %v, dst %v", node.Name, dir, dst) debug.Log("node %v, dir %v, dst %v", node.Name, dir, dst)
dstPath := filepath.Join(dst, dir, node.Name) dstPath := filepath.Join(dst, dir, node.Name)
err := node.CreateAt(dstPath, res.repo, idx) err := node.CreateAt(ctx, dstPath, res.repo, idx)
if err != nil { if err != nil {
debug.Log("node.CreateAt(%s) error %v", dstPath, err) debug.Log("node.CreateAt(%s) error %v", dstPath, err)
} }
@ -99,7 +100,7 @@ func (res *Restorer) restoreNodeTo(node *Node, dir string, dst string, idx *Hard
// Create parent directories and retry // Create parent directories and retry
err = fs.MkdirAll(filepath.Dir(dstPath), 0700) err = fs.MkdirAll(filepath.Dir(dstPath), 0700)
if err == nil || os.IsExist(errors.Cause(err)) { if err == nil || os.IsExist(errors.Cause(err)) {
err = node.CreateAt(dstPath, res.repo, idx) err = node.CreateAt(ctx, dstPath, res.repo, idx)
} }
} }
@ -118,9 +119,9 @@ func (res *Restorer) restoreNodeTo(node *Node, dir string, dst string, idx *Hard
// RestoreTo creates the directories and files in the snapshot below dst. // RestoreTo creates the directories and files in the snapshot below dst.
// Before an item is created, res.Filter is called. // Before an item is created, res.Filter is called.
func (res *Restorer) RestoreTo(dst string) error { func (res *Restorer) RestoreTo(ctx context.Context, dst string) error {
idx := NewHardlinkIndex() idx := NewHardlinkIndex()
return res.restoreTo(dst, string(filepath.Separator), *res.sn.Tree, idx) return res.restoreTo(ctx, dst, string(filepath.Separator), *res.sn.Tree, idx)
} }
// Snapshot returns the snapshot this restorer is configured to use. // Snapshot returns the snapshot this restorer is configured to use.

View file

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"fmt" "fmt"
"os/user" "os/user"
"path/filepath" "path/filepath"
@ -51,9 +52,9 @@ func NewSnapshot(paths []string, tags []string, hostname string) (*Snapshot, err
} }
// LoadSnapshot loads the snapshot with the id and returns it. // LoadSnapshot loads the snapshot with the id and returns it.
func LoadSnapshot(repo Repository, id ID) (*Snapshot, error) { func LoadSnapshot(ctx context.Context, repo Repository, id ID) (*Snapshot, error) {
sn := &Snapshot{id: &id} sn := &Snapshot{id: &id}
err := repo.LoadJSONUnpacked(SnapshotFile, id, sn) err := repo.LoadJSONUnpacked(ctx, SnapshotFile, id, sn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -62,12 +63,9 @@ func LoadSnapshot(repo Repository, id ID) (*Snapshot, error) {
} }
// LoadAllSnapshots returns a list of all snapshots in the repo. // LoadAllSnapshots returns a list of all snapshots in the repo.
func LoadAllSnapshots(repo Repository) (snapshots []*Snapshot, err error) { func LoadAllSnapshots(ctx context.Context, repo Repository) (snapshots []*Snapshot, err error) {
done := make(chan struct{}) for id := range repo.List(ctx, SnapshotFile) {
defer close(done) sn, err := LoadSnapshot(ctx, repo, id)
for id := range repo.List(SnapshotFile, done) {
sn, err := LoadSnapshot(repo, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -178,15 +176,15 @@ func (sn *Snapshot) SamePaths(paths []string) bool {
var ErrNoSnapshotFound = errors.New("no snapshot found") var ErrNoSnapshotFound = errors.New("no snapshot found")
// FindLatestSnapshot finds latest snapshot with optional target/directory, tags and hostname filters. // FindLatestSnapshot finds latest snapshot with optional target/directory, tags and hostname filters.
func FindLatestSnapshot(repo Repository, targets []string, tags []string, hostname string) (ID, error) { func FindLatestSnapshot(ctx context.Context, repo Repository, targets []string, tags []string, hostname string) (ID, error) {
var ( var (
latest time.Time latest time.Time
latestID ID latestID ID
found bool found bool
) )
for snapshotID := range repo.List(SnapshotFile, make(chan struct{})) { for snapshotID := range repo.List(ctx, SnapshotFile) {
snapshot, err := LoadSnapshot(repo, snapshotID) snapshot, err := LoadSnapshot(ctx, repo, snapshotID)
if err != nil { if err != nil {
return ID{}, errors.Errorf("Error listing snapshot: %v", err) return ID{}, errors.Errorf("Error listing snapshot: %v", err)
} }

View file

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -29,7 +30,7 @@ type fakeFileSystem struct {
// saveFile reads from rd and saves the blobs in the repository. The list of // saveFile reads from rd and saves the blobs in the repository. The list of
// IDs is returned. // IDs is returned.
func (fs *fakeFileSystem) saveFile(rd io.Reader) (blobs IDs) { func (fs *fakeFileSystem) saveFile(ctx context.Context, rd io.Reader) (blobs IDs) {
if fs.buf == nil { if fs.buf == nil {
fs.buf = make([]byte, chunker.MaxSize) fs.buf = make([]byte, chunker.MaxSize)
} }
@ -53,7 +54,7 @@ func (fs *fakeFileSystem) saveFile(rd io.Reader) (blobs IDs) {
id := Hash(chunk.Data) id := Hash(chunk.Data)
if !fs.blobIsKnown(id, DataBlob) { if !fs.blobIsKnown(id, DataBlob) {
_, err := fs.repo.SaveBlob(DataBlob, chunk.Data, id) _, err := fs.repo.SaveBlob(ctx, DataBlob, chunk.Data, id)
if err != nil { if err != nil {
fs.t.Fatalf("error saving chunk: %v", err) fs.t.Fatalf("error saving chunk: %v", err)
} }
@ -103,7 +104,7 @@ func (fs *fakeFileSystem) blobIsKnown(id ID, t BlobType) bool {
} }
// saveTree saves a tree of fake files in the repo and returns the ID. // saveTree saves a tree of fake files in the repo and returns the ID.
func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID { func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) ID {
rnd := rand.NewSource(seed) rnd := rand.NewSource(seed)
numNodes := int(rnd.Int63() % maxNodes) numNodes := int(rnd.Int63() % maxNodes)
@ -113,7 +114,7 @@ func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID {
// randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4). // randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4).
if depth > 1 && rnd.Int63()%4 == 0 { if depth > 1 && rnd.Int63()%4 == 0 {
treeSeed := rnd.Int63() % maxSeed treeSeed := rnd.Int63() % maxSeed
id := fs.saveTree(treeSeed, depth-1) id := fs.saveTree(ctx, treeSeed, depth-1)
node := &Node{ node := &Node{
Name: fmt.Sprintf("dir-%v", treeSeed), Name: fmt.Sprintf("dir-%v", treeSeed),
@ -136,7 +137,7 @@ func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID {
Size: uint64(fileSize), Size: uint64(fileSize),
} }
node.Content = fs.saveFile(fakeFile(fs.t, fileSeed, fileSize)) node.Content = fs.saveFile(ctx, fakeFile(fs.t, fileSeed, fileSize))
tree.Nodes = append(tree.Nodes, node) tree.Nodes = append(tree.Nodes, node)
} }
@ -145,7 +146,7 @@ func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID {
return id return id
} }
_, err := fs.repo.SaveBlob(TreeBlob, buf, id) _, err := fs.repo.SaveBlob(ctx, TreeBlob, buf, id)
if err != nil { if err != nil {
fs.t.Fatal(err) fs.t.Fatal(err)
} }
@ -176,10 +177,10 @@ func TestCreateSnapshot(t testing.TB, repo Repository, at time.Time, depth int,
duplication: duplication, duplication: duplication,
} }
treeID := fs.saveTree(seed, depth) treeID := fs.saveTree(context.TODO(), seed, depth)
snapshot.Tree = &treeID snapshot.Tree = &treeID
id, err := repo.SaveJSONUnpacked(SnapshotFile, snapshot) id, err := repo.SaveJSONUnpacked(context.TODO(), SnapshotFile, snapshot)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -193,7 +194,7 @@ func TestCreateSnapshot(t testing.TB, repo Repository, at time.Time, depth int,
t.Fatal(err) t.Fatal(err)
} }
err = repo.SaveIndex() err = repo.SaveIndex(context.TODO())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -1,6 +1,7 @@
package restic_test package restic_test
import ( import (
"context"
"restic" "restic"
"restic/checker" "restic/checker"
"restic/repository" "restic/repository"
@ -23,7 +24,7 @@ func TestCreateSnapshot(t *testing.T) {
restic.TestCreateSnapshot(t, repo, testSnapshotTime.Add(time.Duration(i)*time.Second), testDepth, 0) restic.TestCreateSnapshot(t, repo, testSnapshotTime.Add(time.Duration(i)*time.Second), testDepth, 0)
} }
snapshots, err := restic.LoadAllSnapshots(repo) snapshots, err := restic.LoadAllSnapshots(context.TODO(), repo)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -1,6 +1,7 @@
package restic_test package restic_test
import ( import (
"context"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"os" "os"
@ -98,14 +99,14 @@ func TestLoadTree(t *testing.T) {
// save tree // save tree
tree := restic.NewTree() tree := restic.NewTree()
id, err := repo.SaveTree(tree) id, err := repo.SaveTree(context.TODO(), tree)
OK(t, err) OK(t, err)
// save packs // save packs
OK(t, repo.Flush()) OK(t, repo.Flush())
// load tree again // load tree again
tree2, err := repo.LoadTree(id) tree2, err := repo.LoadTree(context.TODO(), id)
OK(t, err) OK(t, err)
Assert(t, tree.Equals(tree2), Assert(t, tree.Equals(tree2),

View file

@ -1,6 +1,7 @@
package walk package walk
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -34,7 +35,7 @@ func NewTreeWalker(ch chan<- loadTreeJob, out chan<- TreeJob) *TreeWalker {
// Walk starts walking the tree given by id. When the channel done is closed, // Walk starts walking the tree given by id. When the channel done is closed,
// processing stops. // processing stops.
func (tw *TreeWalker) Walk(path string, id restic.ID, done chan struct{}) { func (tw *TreeWalker) Walk(ctx context.Context, path string, id restic.ID) {
debug.Log("starting on tree %v for %v", id.Str(), path) debug.Log("starting on tree %v for %v", id.Str(), path)
defer debug.Log("done walking tree %v for %v", id.Str(), path) defer debug.Log("done walking tree %v for %v", id.Str(), path)
@ -48,22 +49,22 @@ func (tw *TreeWalker) Walk(path string, id restic.ID, done chan struct{}) {
if res.err != nil { if res.err != nil {
select { select {
case tw.out <- TreeJob{Path: path, Error: res.err}: case tw.out <- TreeJob{Path: path, Error: res.err}:
case <-done: case <-ctx.Done():
return return
} }
return return
} }
tw.walk(path, res.tree, done) tw.walk(ctx, path, res.tree)
select { select {
case tw.out <- TreeJob{Path: path, Tree: res.tree}: case tw.out <- TreeJob{Path: path, Tree: res.tree}:
case <-done: case <-ctx.Done():
return return
} }
} }
func (tw *TreeWalker) walk(path string, tree *restic.Tree, done chan struct{}) { func (tw *TreeWalker) walk(ctx context.Context, path string, tree *restic.Tree) {
debug.Log("start on %q", path) debug.Log("start on %q", path)
defer debug.Log("done for %q", path) defer debug.Log("done for %q", path)
@ -94,7 +95,7 @@ func (tw *TreeWalker) walk(path string, tree *restic.Tree, done chan struct{}) {
res := <-results[i] res := <-results[i]
if res.err == nil { if res.err == nil {
tw.walk(p, res.tree, done) tw.walk(ctx, p, res.tree)
} else { } else {
fmt.Fprintf(os.Stderr, "error loading tree: %v\n", res.err) fmt.Fprintf(os.Stderr, "error loading tree: %v\n", res.err)
} }
@ -106,7 +107,7 @@ func (tw *TreeWalker) walk(path string, tree *restic.Tree, done chan struct{}) {
select { select {
case tw.out <- job: case tw.out <- job:
case <-done: case <-ctx.Done():
return return
} }
} }
@ -124,14 +125,14 @@ type loadTreeJob struct {
type treeLoader func(restic.ID) (*restic.Tree, error) type treeLoader func(restic.ID) (*restic.Tree, error)
func loadTreeWorker(wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader, done <-chan struct{}) { func loadTreeWorker(ctx context.Context, wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader) {
debug.Log("start") debug.Log("start")
defer debug.Log("exit") defer debug.Log("exit")
defer wg.Done() defer wg.Done()
for { for {
select { select {
case <-done: case <-ctx.Done():
debug.Log("done channel closed") debug.Log("done channel closed")
return return
case job, ok := <-in: case job, ok := <-in:
@ -148,7 +149,7 @@ func loadTreeWorker(wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader,
select { select {
case job.res <- loadTreeResult{tree, err}: case job.res <- loadTreeResult{tree, err}:
debug.Log("job result sent") debug.Log("job result sent")
case <-done: case <-ctx.Done():
debug.Log("done channel closed before result could be sent") debug.Log("done channel closed before result could be sent")
return return
} }
@ -158,7 +159,7 @@ func loadTreeWorker(wg *sync.WaitGroup, in <-chan loadTreeJob, load treeLoader,
// TreeLoader loads tree objects. // TreeLoader loads tree objects.
type TreeLoader interface { type TreeLoader interface {
LoadTree(restic.ID) (*restic.Tree, error) LoadTree(context.Context, restic.ID) (*restic.Tree, error)
} }
const loadTreeWorkers = 10 const loadTreeWorkers = 10
@ -166,11 +167,11 @@ const loadTreeWorkers = 10
// Tree walks the tree specified by id recursively and sends a job for each // Tree walks the tree specified by id recursively and sends a job for each
// file and directory it finds. When the channel done is closed, processing // file and directory it finds. When the channel done is closed, processing
// stops. // stops.
func Tree(repo TreeLoader, id restic.ID, done chan struct{}, jobCh chan<- TreeJob) { func Tree(ctx context.Context, repo TreeLoader, id restic.ID, jobCh chan<- TreeJob) {
debug.Log("start on %v, start workers", id.Str()) debug.Log("start on %v, start workers", id.Str())
load := func(id restic.ID) (*restic.Tree, error) { load := func(id restic.ID) (*restic.Tree, error) {
tree, err := repo.LoadTree(id) tree, err := repo.LoadTree(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -182,11 +183,11 @@ func Tree(repo TreeLoader, id restic.ID, done chan struct{}, jobCh chan<- TreeJo
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < loadTreeWorkers; i++ { for i := 0; i < loadTreeWorkers; i++ {
wg.Add(1) wg.Add(1)
go loadTreeWorker(&wg, ch, load, done) go loadTreeWorker(ctx, &wg, ch, load)
} }
tw := NewTreeWalker(ch, jobCh) tw := NewTreeWalker(ch, jobCh)
tw.Walk("", id, done) tw.Walk(ctx, "", id)
close(jobCh) close(jobCh)
close(ch) close(ch)

View file

@ -1,6 +1,7 @@
package walk_test package walk_test
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -24,17 +25,15 @@ func TestWalkTree(t *testing.T) {
// archive a few files // archive a few files
arch := archiver.New(repo) arch := archiver.New(repo)
sn, _, err := arch.Snapshot(nil, dirs, nil, "localhost", nil) sn, _, err := arch.Snapshot(context.TODO(), nil, dirs, nil, "localhost", nil)
OK(t, err) OK(t, err)
// flush repo, write all packs // flush repo, write all packs
OK(t, repo.Flush()) OK(t, repo.Flush())
done := make(chan struct{})
// start tree walker // start tree walker
treeJobs := make(chan walk.TreeJob) treeJobs := make(chan walk.TreeJob)
go walk.Tree(repo, *sn.Tree, done, treeJobs) go walk.Tree(context.TODO(), repo, *sn.Tree, treeJobs)
// start filesystem walker // start filesystem walker
fsJobs := make(chan pipe.Job) fsJobs := make(chan pipe.Job)
@ -43,7 +42,7 @@ func TestWalkTree(t *testing.T) {
f := func(string, os.FileInfo) bool { f := func(string, os.FileInfo) bool {
return true return true
} }
go pipe.Walk(dirs, f, done, fsJobs, resCh) go pipe.Walk(context.TODO(), dirs, f, fsJobs, resCh)
for { for {
// receive fs job // receive fs job
@ -95,9 +94,9 @@ type delayRepo struct {
delay time.Duration delay time.Duration
} }
func (d delayRepo) LoadTree(id restic.ID) (*restic.Tree, error) { func (d delayRepo) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) {
time.Sleep(d.delay) time.Sleep(d.delay)
return d.repo.LoadTree(id) return d.repo.LoadTree(ctx, id)
} }
var repoFixture = filepath.Join("testdata", "walktree-test-repo.tar.gz") var repoFixture = filepath.Join("testdata", "walktree-test-repo.tar.gz")
@ -1345,7 +1344,7 @@ func TestDelayedWalkTree(t *testing.T) {
defer cleanup() defer cleanup()
repo := repository.TestOpenLocal(t, repodir) repo := repository.TestOpenLocal(t, repodir)
OK(t, repo.LoadIndex()) OK(t, repo.LoadIndex(context.TODO()))
root, err := restic.ParseID("937a2f64f736c64ee700c6ab06f840c68c94799c288146a0e81e07f4c94254da") root, err := restic.ParseID("937a2f64f736c64ee700c6ab06f840c68c94799c288146a0e81e07f4c94254da")
OK(t, err) OK(t, err)
@ -1354,7 +1353,7 @@ func TestDelayedWalkTree(t *testing.T) {
// start tree walker // start tree walker
treeJobs := make(chan walk.TreeJob) treeJobs := make(chan walk.TreeJob)
go walk.Tree(dr, root, nil, treeJobs) go walk.Tree(context.TODO(), dr, root, treeJobs)
i := 0 i := 0
for job := range treeJobs { for job := range treeJobs {
@ -1375,7 +1374,7 @@ func BenchmarkDelayedWalkTree(t *testing.B) {
defer cleanup() defer cleanup()
repo := repository.TestOpenLocal(t, repodir) repo := repository.TestOpenLocal(t, repodir)
OK(t, repo.LoadIndex()) OK(t, repo.LoadIndex(context.TODO()))
root, err := restic.ParseID("937a2f64f736c64ee700c6ab06f840c68c94799c288146a0e81e07f4c94254da") root, err := restic.ParseID("937a2f64f736c64ee700c6ab06f840c68c94799c288146a0e81e07f4c94254da")
OK(t, err) OK(t, err)
@ -1387,7 +1386,7 @@ func BenchmarkDelayedWalkTree(t *testing.B) {
for i := 0; i < t.N; i++ { for i := 0; i < t.N; i++ {
// start tree walker // start tree walker
treeJobs := make(chan walk.TreeJob) treeJobs := make(chan walk.TreeJob)
go walk.Tree(dr, root, nil, treeJobs) go walk.Tree(context.TODO(), dr, root, treeJobs)
for range treeJobs { for range treeJobs {
} }

View file

@ -1,5 +1,7 @@
package worker package worker
import "context"
// Job is one unit of work. It is given to a Func, and the returned result and // Job is one unit of work. It is given to a Func, and the returned result and
// error are stored in Result and Error. // error are stored in Result and Error.
type Job struct { type Job struct {
@ -9,12 +11,11 @@ type Job struct {
} }
// Func does the actual work within a Pool. // Func does the actual work within a Pool.
type Func func(job Job, done <-chan struct{}) (result interface{}, err error) type Func func(ctx context.Context, job Job) (result interface{}, err error)
// Pool implements a worker pool. // Pool implements a worker pool.
type Pool struct { type Pool struct {
f Func f Func
done chan struct{}
jobCh <-chan Job jobCh <-chan Job
resCh chan<- Job resCh chan<- Job
@ -25,10 +26,9 @@ type Pool struct {
// New returns a new worker pool with n goroutines, each running the function // New returns a new worker pool with n goroutines, each running the function
// f. The workers are started immediately. // f. The workers are started immediately.
func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool { func New(ctx context.Context, n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool {
p := &Pool{ p := &Pool{
f: f, f: f,
done: make(chan struct{}),
workersExit: make(chan struct{}), workersExit: make(chan struct{}),
allWorkersDone: make(chan struct{}), allWorkersDone: make(chan struct{}),
numWorkers: n, numWorkers: n,
@ -37,7 +37,7 @@ func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool {
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
go p.runWorker(i) go p.runWorker(ctx, i)
} }
go p.waitForExit() go p.waitForExit()
@ -58,7 +58,7 @@ func (p *Pool) waitForExit() {
} }
// runWorker runs a worker function. // runWorker runs a worker function.
func (p *Pool) runWorker(numWorker int) { func (p *Pool) runWorker(ctx context.Context, numWorker int) {
defer func() { defer func() {
p.workersExit <- struct{}{} p.workersExit <- struct{}{}
}() }()
@ -75,7 +75,7 @@ func (p *Pool) runWorker(numWorker int) {
for { for {
select { select {
case <-p.done: case <-ctx.Done():
return return
case job, ok = <-inCh: case job, ok = <-inCh:
@ -83,7 +83,7 @@ func (p *Pool) runWorker(numWorker int) {
return return
} }
job.Result, job.Error = p.f(job, p.done) job.Result, job.Error = p.f(ctx, job)
inCh = nil inCh = nil
outCh = p.resCh outCh = p.resCh

View file

@ -1,6 +1,7 @@
package worker_test package worker_test
import ( import (
"context"
"testing" "testing"
"restic/errors" "restic/errors"
@ -12,7 +13,7 @@ const concurrency = 10
var errTooLarge = errors.New("too large") var errTooLarge = errors.New("too large")
func square(job worker.Job, done <-chan struct{}) (interface{}, error) { func square(ctx context.Context, job worker.Job) (interface{}, error) {
n := job.Data.(int) n := job.Data.(int)
if n > 2000 { if n > 2000 {
return nil, errTooLarge return nil, errTooLarge
@ -20,15 +21,15 @@ func square(job worker.Job, done <-chan struct{}) (interface{}, error) {
return n * n, nil return n * n, nil
} }
func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) { func newBufferedPool(ctx context.Context, bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) {
inCh := make(chan worker.Job, bufsize) inCh := make(chan worker.Job, bufsize)
outCh := make(chan worker.Job, bufsize) outCh := make(chan worker.Job, bufsize)
return inCh, outCh, worker.New(n, f, inCh, outCh) return inCh, outCh, worker.New(ctx, n, f, inCh, outCh)
} }
func TestPool(t *testing.T) { func TestPool(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square) inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square)
for i := 0; i < 150; i++ { for i := 0; i < 150; i++ {
inCh <- worker.Job{Data: i} inCh <- worker.Job{Data: i}
@ -53,7 +54,7 @@ func TestPool(t *testing.T) {
} }
func TestPoolErrors(t *testing.T) { func TestPoolErrors(t *testing.T) {
inCh, outCh, p := newBufferedPool(200, concurrency, square) inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square)
for i := 0; i < 150; i++ { for i := 0; i < 150; i++ {
inCh <- worker.Job{Data: i + 1900} inCh <- worker.Job{Data: i + 1900}