From d886cb5c27980d552e952230afc97b3d7008041d Mon Sep 17 00:00:00 2001 From: George Armhold Date: Wed, 22 Nov 2017 06:27:29 -0500 Subject: [PATCH] replace ad-hoc context.TODO() with gopts.ctx, so that cancellation can properly trickle down from cmd_*. gh-1434 --- cmd/restic/cmd_backup.go | 11 +++++------ cmd/restic/cmd_tag.go | 4 ++-- cmd/restic/global.go | 6 +++--- internal/archiver/archive_reader.go | 2 +- internal/archiver/archiver.go | 2 +- internal/archiver/archiver_duplication_test.go | 2 +- internal/archiver/archiver_test.go | 2 +- internal/backend/b2/b2.go | 8 ++++---- internal/backend/b2/b2_test.go | 6 +++--- internal/repository/packer_manager.go | 4 ++-- internal/repository/repack.go | 2 +- internal/repository/repack_test.go | 4 ++-- internal/repository/repository.go | 6 +++--- internal/repository/repository_test.go | 16 ++++++++-------- internal/restic/repository.go | 2 +- internal/restic/testing.go | 2 +- internal/restic/tree_test.go | 2 +- internal/walk/walk_test.go | 2 +- 18 files changed, 41 insertions(+), 42 deletions(-) diff --git a/cmd/restic/cmd_backup.go b/cmd/restic/cmd_backup.go index c866e788a..48a607ecb 100644 --- a/cmd/restic/cmd_backup.go +++ b/cmd/restic/cmd_backup.go @@ -2,7 +2,6 @@ package main import ( "bufio" - "context" "fmt" "io" "os" @@ -256,7 +255,7 @@ func readBackupFromStdin(opts BackupOptions, gopts GlobalOptions, args []string) return err } - err = repo.LoadIndex(context.TODO()) + err = repo.LoadIndex(gopts.ctx) if err != nil { return err } @@ -267,7 +266,7 @@ func readBackupFromStdin(opts BackupOptions, gopts GlobalOptions, args []string) Hostname: opts.Hostname, } - _, id, err := r.Archive(context.TODO(), opts.StdinFilename, os.Stdin, newArchiveStdinProgress(gopts)) + _, id, err := r.Archive(gopts.ctx, opts.StdinFilename, os.Stdin, newArchiveStdinProgress(gopts)) if err != nil { return err } @@ -404,7 +403,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error { rejectFuncs = append(rejectFuncs, f) } - err = repo.LoadIndex(context.TODO()) + err = repo.LoadIndex(gopts.ctx) if err != nil { return err } @@ -423,7 +422,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error { // Find last snapshot to set it as parent, if not already set if !opts.Force && parentSnapshotID == nil { - id, err := restic.FindLatestSnapshot(context.TODO(), repo, target, []restic.TagList{}, opts.Hostname) + id, err := restic.FindLatestSnapshot(gopts.ctx, repo, target, []restic.TagList{}, opts.Hostname) if err == nil { parentSnapshotID = &id } else if err != restic.ErrNoSnapshotFound { @@ -469,7 +468,7 @@ func runBackup(opts BackupOptions, gopts GlobalOptions, args []string) error { } } - _, id, err := arch.Snapshot(context.TODO(), newArchiveProgress(gopts, stat), target, opts.Tags, opts.Hostname, parentSnapshotID, timeStamp) + _, id, err := arch.Snapshot(gopts.ctx, newArchiveProgress(gopts, stat), target, opts.Tags, opts.Hostname, parentSnapshotID, timeStamp) if err != nil { return err } diff --git a/cmd/restic/cmd_tag.go b/cmd/restic/cmd_tag.go index a07d627c5..cb098d61e 100644 --- a/cmd/restic/cmd_tag.go +++ b/cmd/restic/cmd_tag.go @@ -77,14 +77,14 @@ func changeTags(repo *repository.Repository, sn *restic.Snapshot, setTags, addTa } // Save the new snapshot. - id, err := repo.SaveJSONUnpacked(context.TODO(), restic.SnapshotFile, sn) + id, err := repo.SaveJSONUnpacked(globalOptions.ctx, restic.SnapshotFile, sn) if err != nil { return false, err } debug.Log("new snapshot saved as %v", id.Str()) - if err = repo.Flush(); err != nil { + if err = repo.Flush(globalOptions.ctx); err != nil { return false, err } diff --git a/cmd/restic/global.go b/cmd/restic/global.go index df36df9b8..b2d7fcb1f 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -553,7 +553,7 @@ func open(s string, opts options.Options) (restic.Backend, error) { case "swift": be, err = swift.Open(cfg.(swift.Config), rt) case "b2": - be, err = b2.Open(cfg.(b2.Config), rt) + be, err = b2.Open(globalOptions.ctx, cfg.(b2.Config), rt) case "rest": be, err = rest.Open(cfg.(rest.Config), rt) @@ -566,7 +566,7 @@ func open(s string, opts options.Options) (restic.Backend, error) { } // check if config is there - fi, err := be.Stat(context.TODO(), restic.Handle{Type: restic.ConfigFile}) + fi, err := be.Stat(globalOptions.ctx, restic.Handle{Type: restic.ConfigFile}) if err != nil { return nil, errors.Fatalf("unable to open config file: %v\nIs there a repository at the following location?\n%v", err, s) } @@ -610,7 +610,7 @@ func create(s string, opts options.Options) (restic.Backend, error) { case "swift": return swift.Open(cfg.(swift.Config), rt) case "b2": - return b2.Create(cfg.(b2.Config), rt) + return b2.Create(globalOptions.ctx, cfg.(b2.Config), rt) case "rest": return rest.Create(cfg.(rest.Config), rt) } diff --git a/internal/archiver/archive_reader.go b/internal/archiver/archive_reader.go index b6dab993b..07b224ad8 100644 --- a/internal/archiver/archive_reader.go +++ b/internal/archiver/archive_reader.go @@ -103,7 +103,7 @@ func (r *Reader) Archive(ctx context.Context, name string, rd io.Reader, p *rest debug.Log("snapshot saved as %v", id.Str()) - err = repo.Flush() + err = repo.Flush(ctx) if err != nil { return nil, restic.ID{}, err } diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index 967972700..20ba863c3 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -764,7 +764,7 @@ func (arch *Archiver) Snapshot(ctx context.Context, p *restic.Progress, paths, t debug.Log("workers terminated") // flush repository - err = arch.repo.Flush() + err = arch.repo.Flush(ctx) if err != nil { return nil, restic.ID{}, err } diff --git a/internal/archiver/archiver_duplication_test.go b/internal/archiver/archiver_duplication_test.go index 2ac2c1308..783dce11c 100644 --- a/internal/archiver/archiver_duplication_test.go +++ b/internal/archiver/archiver_duplication_test.go @@ -144,7 +144,7 @@ func testArchiverDuplication(t *testing.T) { wg.Wait() - err = repo.Flush() + err = repo.Flush(context.Background()) if err != nil { t.Fatal(err) } diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index 293de9152..e578ab3de 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -248,7 +248,7 @@ func testParallelSaveWithDuplication(t *testing.T, seed int) { rtest.OK(t, <-errChan) } - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) rtest.OK(t, repo.SaveIndex(context.TODO())) chkr := createAndInitChecker(t, repo) diff --git a/internal/backend/b2/b2.go b/internal/backend/b2/b2.go index 7e570c6eb..3e0c28c71 100644 --- a/internal/backend/b2/b2.go +++ b/internal/backend/b2/b2.go @@ -41,10 +41,10 @@ func newClient(ctx context.Context, cfg Config, rt http.RoundTripper) (*b2.Clien } // Open opens a connection to the B2 service. -func Open(cfg Config, rt http.RoundTripper) (restic.Backend, error) { +func Open(ctx context.Context, cfg Config, rt http.RoundTripper) (restic.Backend, error) { debug.Log("cfg %#v", cfg) - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(ctx) defer cancel() client, err := newClient(ctx, cfg, rt) @@ -79,10 +79,10 @@ func Open(cfg Config, rt http.RoundTripper) (restic.Backend, error) { // Create opens a connection to the B2 service. If the bucket does not exist yet, // it is created. -func Create(cfg Config, rt http.RoundTripper) (restic.Backend, error) { +func Create(ctx context.Context, cfg Config, rt http.RoundTripper) (restic.Backend, error) { debug.Log("cfg %#v", cfg) - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(ctx) defer cancel() client, err := newClient(ctx, cfg, rt) diff --git a/internal/backend/b2/b2_test.go b/internal/backend/b2/b2_test.go index 5784bd9a2..7f22a7986 100644 --- a/internal/backend/b2/b2_test.go +++ b/internal/backend/b2/b2_test.go @@ -45,19 +45,19 @@ func newB2TestSuite(t testing.TB) *test.Suite { // CreateFn is a function that creates a temporary repository for the tests. Create: func(config interface{}) (restic.Backend, error) { cfg := config.(b2.Config) - return b2.Create(cfg, tr) + return b2.Create(context.Background(), cfg, tr) }, // OpenFn is a function that opens a previously created temporary repository. Open: func(config interface{}) (restic.Backend, error) { cfg := config.(b2.Config) - return b2.Open(cfg, tr) + return b2.Open(context.Background(), cfg, tr) }, // CleanupFn removes data created during the tests. Cleanup: func(config interface{}) error { cfg := config.(b2.Config) - be, err := b2.Open(cfg, tr) + be, err := b2.Open(context.Background(), cfg, tr) if err != nil { return err } diff --git a/internal/repository/packer_manager.go b/internal/repository/packer_manager.go index 3b905903c..cfee2e365 100644 --- a/internal/repository/packer_manager.go +++ b/internal/repository/packer_manager.go @@ -89,7 +89,7 @@ func (r *packerManager) insertPacker(p *Packer) { } // savePacker stores p in the backend. -func (r *Repository) savePacker(t restic.BlobType, p *Packer) error { +func (r *Repository) savePacker(ctx context.Context, t restic.BlobType, p *Packer) error { debug.Log("save packer for %v with %d blobs (%d bytes)\n", t, p.Packer.Count(), p.Packer.Size()) _, err := p.Packer.Finalize() if err != nil { @@ -104,7 +104,7 @@ func (r *Repository) savePacker(t restic.BlobType, p *Packer) error { id := restic.IDFromHash(p.hw.Sum(nil)) h := restic.Handle{Type: restic.DataFile, Name: id.String()} - err = r.be.Save(context.TODO(), h, p.tmpfile) + err = r.be.Save(ctx, h, p.tmpfile) if err != nil { debug.Log("Save(%v) error: %v", h, err) return err diff --git a/internal/repository/repack.go b/internal/repository/repack.go index 3eacea182..b8e80c3a2 100644 --- a/internal/repository/repack.go +++ b/internal/repository/repack.go @@ -126,7 +126,7 @@ func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, kee } } - if err := repo.Flush(); err != nil { + if err := repo.Flush(ctx); err != nil { return nil, err } diff --git a/internal/repository/repack_test.go b/internal/repository/repack_test.go index 2f265150d..2d29a589a 100644 --- a/internal/repository/repack_test.go +++ b/internal/repository/repack_test.go @@ -55,13 +55,13 @@ func createRandomBlobs(t testing.TB, repo restic.Repository, blobs int, pData fl } if rand.Float32() < 0.2 { - if err = repo.Flush(); err != nil { + if err = repo.Flush(context.Background()); err != nil { t.Fatalf("repo.Flush() returned error %v", err) } } } - if err := repo.Flush(); err != nil { + if err := repo.Flush(context.Background()); err != nil { t.Fatalf("repo.Flush() returned error %v", err) } } diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 9bf686548..193ec1ca7 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -250,7 +250,7 @@ func (r *Repository) SaveAndEncrypt(ctx context.Context, t restic.BlobType, data } // else write the pack to the backend - return *id, r.savePacker(t, packer) + return *id, r.savePacker(ctx, t, packer) } // SaveJSONUnpacked serialises item as JSON and encrypts and saves it in the @@ -289,7 +289,7 @@ func (r *Repository) SaveUnpacked(ctx context.Context, t restic.FileType, p []by } // Flush saves all remaining packs. -func (r *Repository) Flush() error { +func (r *Repository) Flush(ctx context.Context) error { pms := []struct { t restic.BlobType pm *packerManager @@ -303,7 +303,7 @@ func (r *Repository) Flush() error { debug.Log("manually flushing %d packs", len(p.pm.packers)) for _, packer := range p.pm.packers { - err := r.savePacker(p.t, packer) + err := r.savePacker(ctx, p.t, packer) if err != nil { p.pm.pm.Unlock() return err diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index d3b9dec93..a90f0959b 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -37,7 +37,7 @@ func TestSave(t *testing.T) { rtest.Equals(t, id, sid) - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) // rtest.OK(t, repo.SaveIndex()) // read back @@ -72,7 +72,7 @@ func TestSaveFrom(t *testing.T) { rtest.OK(t, err) rtest.Equals(t, id, id2) - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) // read back buf := restic.NewBlobBuffer(size) @@ -122,7 +122,7 @@ func TestLoadTree(t *testing.T) { // archive a few files sn := archiver.TestSnapshot(t, repo, rtest.BenchArchiveDirectory, nil) - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) _, err := repo.LoadTree(context.TODO(), *sn.Tree) rtest.OK(t, err) @@ -138,7 +138,7 @@ func BenchmarkLoadTree(t *testing.B) { // archive a few files sn := archiver.TestSnapshot(t, repo, rtest.BenchArchiveDirectory, nil) - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) t.ResetTimer() @@ -159,7 +159,7 @@ func TestLoadBlob(t *testing.T) { id, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{}) rtest.OK(t, err) - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) // first, test with buffers that are too small for _, testlength := range []int{length - 20, length, restic.CiphertextLength(length) - 1} { @@ -204,7 +204,7 @@ func BenchmarkLoadBlob(b *testing.B) { id, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{}) rtest.OK(b, err) - rtest.OK(b, repo.Flush()) + rtest.OK(b, repo.Flush(context.Background())) b.ResetTimer() b.SetBytes(int64(length)) @@ -352,7 +352,7 @@ func TestRepositoryIncrementalIndex(t *testing.T) { // add 3 packs, write intermediate index for i := 0; i < 3; i++ { saveRandomDataBlobs(t, repo, 5, 1<<15) - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) } rtest.OK(t, repo.SaveFullIndex(context.TODO())) @@ -361,7 +361,7 @@ func TestRepositoryIncrementalIndex(t *testing.T) { // add another 5 packs for i := 0; i < 5; i++ { saveRandomDataBlobs(t, repo, 5, 1<<15) - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) } // save final index diff --git a/internal/restic/repository.go b/internal/restic/repository.go index 6c8cad863..2daae41b2 100644 --- a/internal/restic/repository.go +++ b/internal/restic/repository.go @@ -29,7 +29,7 @@ type Repository interface { List(context.Context, FileType) <-chan ID ListPack(context.Context, ID) ([]Blob, int64, error) - Flush() error + Flush(context.Context) error SaveUnpacked(context.Context, FileType, []byte) (ID, error) SaveJSONUnpacked(context.Context, FileType, interface{}) (ID, error) diff --git a/internal/restic/testing.go b/internal/restic/testing.go index 5e1f3372b..ad7604a6c 100644 --- a/internal/restic/testing.go +++ b/internal/restic/testing.go @@ -189,7 +189,7 @@ func TestCreateSnapshot(t testing.TB, repo Repository, at time.Time, depth int, t.Logf("saved snapshot %v", id.Str()) - err = repo.Flush() + err = repo.Flush(context.Background()) if err != nil { t.Fatal(err) } diff --git a/internal/restic/tree_test.go b/internal/restic/tree_test.go index d1cc8df91..2bcda6760 100644 --- a/internal/restic/tree_test.go +++ b/internal/restic/tree_test.go @@ -103,7 +103,7 @@ func TestLoadTree(t *testing.T) { rtest.OK(t, err) // save packs - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) // load tree again tree2, err := repo.LoadTree(context.TODO(), id) diff --git a/internal/walk/walk_test.go b/internal/walk/walk_test.go index 2e6d4f7cc..b67ae9151 100644 --- a/internal/walk/walk_test.go +++ b/internal/walk/walk_test.go @@ -29,7 +29,7 @@ func TestWalkTree(t *testing.T) { rtest.OK(t, err) // flush repo, write all packs - rtest.OK(t, repo.Flush()) + rtest.OK(t, repo.Flush(context.Background())) // start tree walker treeJobs := make(chan walk.TreeJob)