From ae1cb889dd566a1c3417234f18e3b6b7c908df6f Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Wed, 31 Jul 2024 19:30:47 +0200 Subject: [PATCH] Add more checks for canceled contexts --- cmd/restic/cmd_diff.go | 12 ++++++++++++ cmd/restic/cmd_dump.go | 4 ++++ cmd/restic/cmd_find.go | 4 ++++ cmd/restic/cmd_forget.go | 4 ++++ cmd/restic/cmd_recover.go | 4 ++++ cmd/restic/cmd_snapshots.go | 8 ++++++++ internal/backend/sftp/sftp.go | 4 ++++ internal/fuse/dir.go | 8 ++++++++ internal/fuse/file.go | 6 +++++- internal/fuse/snapshots_dir.go | 4 ++++ internal/repository/check.go | 4 ++++ internal/repository/repository.go | 4 ++++ internal/restic/snapshot_find.go | 4 ++++ internal/restorer/filerestorer.go | 4 ++++ internal/restorer/restorer.go | 8 ++++++-- internal/walker/rewriter.go | 4 ++++ internal/walker/walker.go | 4 ++++ 17 files changed, 87 insertions(+), 3 deletions(-) diff --git a/cmd/restic/cmd_diff.go b/cmd/restic/cmd_diff.go index 6488a7c35..b15882b09 100644 --- a/cmd/restic/cmd_diff.go +++ b/cmd/restic/cmd_diff.go @@ -177,6 +177,10 @@ func (c *Comparer) printDir(ctx context.Context, mode string, stats *DiffStat, b } for _, node := range tree.Nodes { + if ctx.Err() != nil { + return ctx.Err() + } + name := path.Join(prefix, node.Name) if node.Type == "dir" { name += "/" @@ -204,6 +208,10 @@ func (c *Comparer) collectDir(ctx context.Context, blobs restic.BlobSet, id rest } for _, node := range tree.Nodes { + if ctx.Err() != nil { + return ctx.Err() + } + addBlobs(blobs, node) if node.Type == "dir" { @@ -255,6 +263,10 @@ func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, pref tree1Nodes, tree2Nodes, names := uniqueNodeNames(tree1, tree2) for _, name := range names { + if ctx.Err() != nil { + return ctx.Err() + } + node1, t1 := tree1Nodes[name] node2, t2 := tree2Nodes[name] diff --git a/cmd/restic/cmd_dump.go b/cmd/restic/cmd_dump.go index 7e1efa3ae..9c0fe535e 100644 --- a/cmd/restic/cmd_dump.go +++ b/cmd/restic/cmd_dump.go @@ -85,6 +85,10 @@ func printFromTree(ctx context.Context, tree *restic.Tree, repo restic.BlobLoade item := filepath.Join(prefix, pathComponents[0]) l := len(pathComponents) for _, node := range tree.Nodes { + if ctx.Err() != nil { + return ctx.Err() + } + // If dumping something in the highest level it will just take the // first item it finds and dump that according to the switch case below. if node.Name == pathComponents[0] { diff --git a/cmd/restic/cmd_find.go b/cmd/restic/cmd_find.go index 4f9549ca4..aebca594e 100644 --- a/cmd/restic/cmd_find.go +++ b/cmd/restic/cmd_find.go @@ -377,6 +377,10 @@ func (f *Finder) findIDs(ctx context.Context, sn *restic.Snapshot) error { if node.Type == "file" && f.blobIDs != nil { for _, id := range node.Content { + if ctx.Err() != nil { + return ctx.Err() + } + idStr := id.String() if _, ok := f.blobIDs[idStr]; !ok { // Look for short ID form diff --git a/cmd/restic/cmd_forget.go b/cmd/restic/cmd_forget.go index 87738b518..27b8f4f74 100644 --- a/cmd/restic/cmd_forget.go +++ b/cmd/restic/cmd_forget.go @@ -246,6 +246,10 @@ func runForget(ctx context.Context, opts ForgetOptions, pruneOptions PruneOption printer.P("Applying Policy: %v\n", policy) for k, snapshotGroup := range snapshotGroups { + if ctx.Err() != nil { + return ctx.Err() + } + if gopts.Verbose >= 1 && !gopts.JSON { err = PrintSnapshotGroupHeader(globalOptions.stdout, k) if err != nil { diff --git a/cmd/restic/cmd_recover.go b/cmd/restic/cmd_recover.go index 5e4744bb6..4e8b8c077 100644 --- a/cmd/restic/cmd_recover.go +++ b/cmd/restic/cmd_recover.go @@ -118,6 +118,10 @@ func runRecover(ctx context.Context, gopts GlobalOptions) error { return nil } + if ctx.Err() != nil { + return ctx.Err() + } + tree := restic.NewTree(len(roots)) for id := range roots { var subtreeID = id diff --git a/cmd/restic/cmd_snapshots.go b/cmd/restic/cmd_snapshots.go index 9112e1b95..826ab55ec 100644 --- a/cmd/restic/cmd_snapshots.go +++ b/cmd/restic/cmd_snapshots.go @@ -81,6 +81,10 @@ func runSnapshots(ctx context.Context, opts SnapshotOptions, gopts GlobalOptions } for k, list := range snapshotGroups { + if ctx.Err() != nil { + return ctx.Err() + } + if opts.Last { // This branch should be removed in the same time // that --last. @@ -101,6 +105,10 @@ func runSnapshots(ctx context.Context, opts SnapshotOptions, gopts GlobalOptions } for k, list := range snapshotGroups { + if ctx.Err() != nil { + return ctx.Err() + } + if grouped { err := PrintSnapshotGroupHeader(globalOptions.stdout, k) if err != nil { diff --git a/internal/backend/sftp/sftp.go b/internal/backend/sftp/sftp.go index 70fc30a62..efbd0c8d5 100644 --- a/internal/backend/sftp/sftp.go +++ b/internal/backend/sftp/sftp.go @@ -578,6 +578,10 @@ func (r *SFTP) deleteRecursive(ctx context.Context, name string) error { } for _, fi := range entries { + if ctx.Err() != nil { + return ctx.Err() + } + itemName := r.Join(name, fi.Name()) if fi.IsDir() { err := r.deleteRecursive(ctx, itemName) diff --git a/internal/fuse/dir.go b/internal/fuse/dir.go index 763a9640c..fd030295b 100644 --- a/internal/fuse/dir.go +++ b/internal/fuse/dir.go @@ -107,6 +107,10 @@ func (d *dir) open(ctx context.Context) error { } items := make(map[string]*restic.Node) for _, n := range tree.Nodes { + if ctx.Err() != nil { + return ctx.Err() + } + nodes, err := replaceSpecialNodes(ctx, d.root.repo, n) if err != nil { debug.Log(" replaceSpecialNodes(%v) failed: %v", n, err) @@ -171,6 +175,10 @@ func (d *dir) ReadDirAll(ctx context.Context) ([]fuse.Dirent, error) { }) for _, node := range d.items { + if ctx.Err() != nil { + return nil, ctx.Err() + } + name := cleanupNodeName(node.Name) var typ fuse.DirentType switch node.Type { diff --git a/internal/fuse/file.go b/internal/fuse/file.go index e2e0cf9a0..494fca283 100644 --- a/internal/fuse/file.go +++ b/internal/fuse/file.go @@ -66,12 +66,16 @@ func (f *file) Attr(_ context.Context, a *fuse.Attr) error { } -func (f *file) Open(_ context.Context, _ *fuse.OpenRequest, _ *fuse.OpenResponse) (fs.Handle, error) { +func (f *file) Open(ctx context.Context, _ *fuse.OpenRequest, _ *fuse.OpenResponse) (fs.Handle, error) { debug.Log("open file %v with %d blobs", f.node.Name, len(f.node.Content)) var bytes uint64 cumsize := make([]uint64, 1+len(f.node.Content)) for i, id := range f.node.Content { + if ctx.Err() != nil { + return nil, ctx.Err() + } + size, found := f.root.repo.LookupBlobSize(restic.DataBlob, id) if !found { return nil, errors.Errorf("id %v not found in repository", id) diff --git a/internal/fuse/snapshots_dir.go b/internal/fuse/snapshots_dir.go index 7369ea17a..4cae7106c 100644 --- a/internal/fuse/snapshots_dir.go +++ b/internal/fuse/snapshots_dir.go @@ -78,6 +78,10 @@ func (d *SnapshotsDir) ReadDirAll(ctx context.Context) ([]fuse.Dirent, error) { } for name, entry := range meta.names { + if ctx.Err() != nil { + return nil, ctx.Err() + } + d := fuse.Dirent{ Inode: inodeFromName(d.inode, name), Name: name, diff --git a/internal/repository/check.go b/internal/repository/check.go index 27eb11d71..1eeea58dc 100644 --- a/internal/repository/check.go +++ b/internal/repository/check.go @@ -95,6 +95,10 @@ func checkPackInner(ctx context.Context, r *Repository, id restic.ID, blobs []re it := newPackBlobIterator(id, newBufReader(bufRd), 0, blobs, r.Key(), dec) for { + if ctx.Err() != nil { + return ctx.Err() + } + val, err := it.Next() if err == errPackEOF { break diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 838858c38..f7fd65c71 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -1000,6 +1000,10 @@ func streamPackPart(ctx context.Context, beLoad backendLoadFn, loadBlobFn loadBl it := newPackBlobIterator(packID, newByteReader(data), dataStart, blobs, key, dec) for { + if ctx.Err() != nil { + return ctx.Err() + } + val, err := it.Next() if err == errPackEOF { break diff --git a/internal/restic/snapshot_find.go b/internal/restic/snapshot_find.go index 6d1ab9a7a..6eb51b237 100644 --- a/internal/restic/snapshot_find.go +++ b/internal/restic/snapshot_find.go @@ -134,6 +134,10 @@ func (f *SnapshotFilter) FindAll(ctx context.Context, be Lister, loader LoaderUn ids := NewIDSet() // Process all snapshot IDs given as arguments. for _, s := range snapshotIDs { + if ctx.Err() != nil { + return ctx.Err() + } + var sn *Snapshot if s == "latest" { if usedFilter { diff --git a/internal/restorer/filerestorer.go b/internal/restorer/filerestorer.go index e517e6284..31234b960 100644 --- a/internal/restorer/filerestorer.go +++ b/internal/restorer/filerestorer.go @@ -122,6 +122,10 @@ func (r *fileRestorer) restoreFiles(ctx context.Context) error { // create packInfo from fileInfo for _, file := range r.files { + if ctx.Err() != nil { + return ctx.Err() + } + fileBlobs := file.blobs.(restic.IDs) largeFile := len(fileBlobs) > largeFileBlobCount var packsMap map[restic.ID][]fileBlobInfo diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index cd3fd076d..00da4e18e 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -450,7 +450,7 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { }, leaveDir: func(node *restic.Node, target, location string, expectedFilenames []string) error { if res.opts.Delete { - if err := res.removeUnexpectedFiles(target, location, expectedFilenames); err != nil { + if err := res.removeUnexpectedFiles(ctx, target, location, expectedFilenames); err != nil { return err } } @@ -469,7 +469,7 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { return err } -func (res *Restorer) removeUnexpectedFiles(target, location string, expectedFilenames []string) error { +func (res *Restorer) removeUnexpectedFiles(ctx context.Context, target, location string, expectedFilenames []string) error { if !res.opts.Delete { panic("internal error") } @@ -487,6 +487,10 @@ func (res *Restorer) removeUnexpectedFiles(target, location string, expectedFile } for _, entry := range entries { + if ctx.Err() != nil { + return ctx.Err() + } + if _, ok := keep[toComparableFilename(entry)]; ok { continue } diff --git a/internal/walker/rewriter.go b/internal/walker/rewriter.go index 6c27b26ac..7e984ae25 100644 --- a/internal/walker/rewriter.go +++ b/internal/walker/rewriter.go @@ -116,6 +116,10 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node tb := restic.NewTreeJSONBuilder() for _, node := range curTree.Nodes { + if ctx.Err() != nil { + return restic.ID{}, ctx.Err() + } + path := path.Join(nodepath, node.Name) node = t.opts.RewriteNode(node, path) if node == nil { diff --git a/internal/walker/walker.go b/internal/walker/walker.go index 091b05489..788ece1cf 100644 --- a/internal/walker/walker.go +++ b/internal/walker/walker.go @@ -57,6 +57,10 @@ func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTree }) for _, node := range tree.Nodes { + if ctx.Err() != nil { + return ctx.Err() + } + p := path.Join(prefix, node.Name) if node.Type == "" {