diff --git a/cmd/restic/cmd_find.go b/cmd/restic/cmd_find.go index 33fff864f..04e6ae3dd 100644 --- a/cmd/restic/cmd_find.go +++ b/cmd/restic/cmd_find.go @@ -260,7 +260,7 @@ func (f *Finder) findInSnapshot(ctx context.Context, sn *restic.Snapshot) error } f.out.newsn = sn - return walker.Walk(ctx, f.repo, *sn.Tree, func(parentTreeID restic.ID, nodepath string, node *restic.Node, err error) error { + return walker.Walk(ctx, f.repo, *sn.Tree, walker.WalkVisitor{ProcessNode: func(parentTreeID restic.ID, nodepath string, node *restic.Node, err error) error { if err != nil { debug.Log("Error loading tree %v: %v", parentTreeID, err) @@ -327,7 +327,7 @@ func (f *Finder) findInSnapshot(ctx context.Context, sn *restic.Snapshot) error debug.Log(" found match\n") f.out.PrintPattern(nodepath, node) return nil - }) + }}) } func (f *Finder) findIDs(ctx context.Context, sn *restic.Snapshot) error { @@ -338,7 +338,7 @@ func (f *Finder) findIDs(ctx context.Context, sn *restic.Snapshot) error { } f.out.newsn = sn - return walker.Walk(ctx, f.repo, *sn.Tree, func(parentTreeID restic.ID, nodepath string, node *restic.Node, err error) error { + return walker.Walk(ctx, f.repo, *sn.Tree, walker.WalkVisitor{ProcessNode: func(parentTreeID restic.ID, nodepath string, node *restic.Node, err error) error { if err != nil { debug.Log("Error loading tree %v: %v", parentTreeID, err) @@ -388,7 +388,7 @@ func (f *Finder) findIDs(ctx context.Context, sn *restic.Snapshot) error { } return nil - }) + }}) } var errAllPacksFound = errors.New("all packs found") diff --git a/cmd/restic/cmd_ls.go b/cmd/restic/cmd_ls.go index 755addfe1..e38985a26 100644 --- a/cmd/restic/cmd_ls.go +++ b/cmd/restic/cmd_ls.go @@ -318,7 +318,7 @@ func runLs(ctx context.Context, opts LsOptions, gopts GlobalOptions, args []stri printSnapshot(sn) - err = walker.Walk(ctx, repo, *sn.Tree, func(_ restic.ID, nodepath string, node *restic.Node, err error) error { + processNode := func(_ restic.ID, nodepath string, node *restic.Node, err error) error { if err != nil { return err } @@ -349,6 +349,10 @@ func runLs(ctx context.Context, opts LsOptions, gopts GlobalOptions, args []stri return walker.ErrSkipNode } return nil + } + + err = walker.Walk(ctx, repo, *sn.Tree, walker.WalkVisitor{ + ProcessNode: processNode, }) if err != nil { diff --git a/cmd/restic/cmd_stats.go b/cmd/restic/cmd_stats.go index b0837510d..f7febf4d0 100644 --- a/cmd/restic/cmd_stats.go +++ b/cmd/restic/cmd_stats.go @@ -203,7 +203,9 @@ func statsWalkSnapshot(ctx context.Context, snapshot *restic.Snapshot, repo rest } hardLinkIndex := restorer.NewHardlinkIndex[struct{}]() - err := walker.Walk(ctx, repo, *snapshot.Tree, statsWalkTree(repo, opts, stats, hardLinkIndex)) + err := walker.Walk(ctx, repo, *snapshot.Tree, walker.WalkVisitor{ + ProcessNode: statsWalkTree(repo, opts, stats, hardLinkIndex), + }) if err != nil { return fmt.Errorf("walking tree %s: %v", *snapshot.Tree, err) } diff --git a/internal/dump/common.go b/internal/dump/common.go index 3ca1ced82..88b59e689 100644 --- a/internal/dump/common.go +++ b/internal/dump/common.go @@ -70,7 +70,7 @@ func sendNodes(ctx context.Context, repo restic.Repository, root *restic.Node, c return nil } - err := walker.Walk(ctx, repo, *root.Subtree, func(_ restic.ID, nodepath string, node *restic.Node, err error) error { + err := walker.Walk(ctx, repo, *root.Subtree, walker.WalkVisitor{ProcessNode: func(_ restic.ID, nodepath string, node *restic.Node, err error) error { if err != nil { return err } @@ -91,7 +91,7 @@ func sendNodes(ctx context.Context, repo restic.Repository, root *restic.Node, c } return nil - }) + }}) return err } diff --git a/internal/walker/walker.go b/internal/walker/walker.go index aba2e39e5..1bcdda16e 100644 --- a/internal/walker/walker.go +++ b/internal/walker/walker.go @@ -23,12 +23,20 @@ var ErrSkipNode = errors.New("skip this node") // tree are skipped. type WalkFunc func(parentTreeID restic.ID, path string, node *restic.Node, nodeErr error) (err error) +type WalkVisitor struct { + // If the node is a `dir`, it will be entered afterwards unless `ErrSkipNode` + // was returned. This function is mandatory + ProcessNode WalkFunc + // Optional callback + LeaveDir func(path string) +} + // Walk calls walkFn recursively for each node in root. If walkFn returns an // error, it is passed up the call stack. The trees in ignoreTrees are not // walked. If walkFn ignores trees, these are added to the set. -func Walk(ctx context.Context, repo restic.BlobLoader, root restic.ID, walkFn WalkFunc) error { +func Walk(ctx context.Context, repo restic.BlobLoader, root restic.ID, visitor WalkVisitor) error { tree, err := restic.LoadTree(ctx, repo, root) - err = walkFn(root, "/", nil, err) + err = visitor.ProcessNode(root, "/", nil, err) if err != nil { if err == ErrSkipNode { @@ -37,13 +45,13 @@ func Walk(ctx context.Context, repo restic.BlobLoader, root restic.ID, walkFn Wa return err } - return walk(ctx, repo, "/", root, tree, walkFn) + return walk(ctx, repo, "/", root, tree, visitor) } // walk recursively traverses the tree, ignoring subtrees when the ID of the // subtree is in ignoreTrees. If err is nil and ignore is true, the subtree ID // will be added to ignoreTrees by walk. -func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTreeID restic.ID, tree *restic.Tree, walkFn WalkFunc) (err error) { +func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTreeID restic.ID, tree *restic.Tree, visitor WalkVisitor) (err error) { sort.Slice(tree.Nodes, func(i, j int) bool { return tree.Nodes[i].Name < tree.Nodes[j].Name }) @@ -56,7 +64,7 @@ func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTree } if node.Type != "dir" { - err := walkFn(parentTreeID, p, node, nil) + err := visitor.ProcessNode(parentTreeID, p, node, nil) if err != nil { if err == ErrSkipNode { // skip the remaining entries in this tree @@ -74,18 +82,22 @@ func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTree } subtree, err := restic.LoadTree(ctx, repo, *node.Subtree) - err = walkFn(parentTreeID, p, node, err) + err = visitor.ProcessNode(parentTreeID, p, node, err) if err != nil { if err == ErrSkipNode { continue } } - err = walk(ctx, repo, p, *node.Subtree, subtree, walkFn) + err = walk(ctx, repo, p, *node.Subtree, subtree, visitor) if err != nil { return err } } + if visitor.LeaveDir != nil { + visitor.LeaveDir(prefix) + } + return nil } diff --git a/internal/walker/walker_test.go b/internal/walker/walker_test.go index 786570e02..e2d1f866f 100644 --- a/internal/walker/walker_test.go +++ b/internal/walker/walker_test.go @@ -406,7 +406,7 @@ func TestWalker(t *testing.T) { defer cancel() fn, last := check(t) - err := Walk(ctx, repo, root, fn) + err := Walk(ctx, repo, root, WalkVisitor{ProcessNode: fn}) if err != nil { t.Error(err) }