walker: restructure FilterTree into TreeRewriter

The more generic RewriteNode callback replaces the SelectByName and
PrintExclude functions. The main part of this change is a preparation to
allow using the TreeRewriter for the `repair snapshots` command.
This commit is contained in:
Michael Eischer 2022-12-28 11:04:28 +01:00
parent bc2399fbd9
commit 38dac78180
4 changed files with 103 additions and 48 deletions

View file

@ -87,12 +87,19 @@ func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *resti
return true return true
} }
rewriter := walker.NewTreeRewriter(walker.RewriteOpts{
RewriteNode: func(node *restic.Node, path string) *restic.Node {
if selectByName(path) {
return node
}
Verbosef(fmt.Sprintf("excluding %s\n", path))
return nil
},
})
return filterAndReplaceSnapshot(ctx, repo, sn, return filterAndReplaceSnapshot(ctx, repo, sn,
func(ctx context.Context, sn *restic.Snapshot) (restic.ID, error) { func(ctx context.Context, sn *restic.Snapshot) (restic.ID, error) {
return walker.FilterTree(ctx, repo, "/", *sn.Tree, &walker.TreeFilterVisitor{ return rewriter.RewriteTree(ctx, repo, "/", *sn.Tree)
SelectByName: selectByName,
PrintExclude: func(path string) { Verbosef(fmt.Sprintf("excluding %s\n", path)) },
})
}, opts.DryRun, opts.Forget, "rewrite") }, opts.DryRun, opts.Forget, "rewrite")
} }

View file

@ -9,13 +9,28 @@ import (
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
) )
// SelectByNameFunc returns true for all items that should be included (files and type NodeRewriteFunc func(node *restic.Node, path string) *restic.Node
// dirs). If false is returned, files are ignored and dirs are not even walked.
type SelectByNameFunc func(item string) bool
type TreeFilterVisitor struct { type RewriteOpts struct {
SelectByName SelectByNameFunc // return nil to remove the node
PrintExclude func(string) RewriteNode NodeRewriteFunc
}
type TreeRewriter struct {
opts RewriteOpts
}
func NewTreeRewriter(opts RewriteOpts) *TreeRewriter {
rw := &TreeRewriter{
opts: opts,
}
// setup default implementations
if rw.opts.RewriteNode == nil {
rw.opts.RewriteNode = func(node *restic.Node, path string) *restic.Node {
return node
}
}
return rw
} }
type BlobLoadSaver interface { type BlobLoadSaver interface {
@ -23,7 +38,7 @@ type BlobLoadSaver interface {
restic.BlobLoader restic.BlobLoader
} }
func FilterTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID restic.ID, visitor *TreeFilterVisitor) (newNodeID restic.ID, err error) { func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID restic.ID) (newNodeID restic.ID, err error) {
curTree, err := restic.LoadTree(ctx, repo, nodeID) curTree, err := restic.LoadTree(ctx, repo, nodeID)
if err != nil { if err != nil {
return restic.ID{}, err return restic.ID{}, err
@ -45,10 +60,8 @@ func FilterTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID
tb := restic.NewTreeJSONBuilder() tb := restic.NewTreeJSONBuilder()
for _, node := range curTree.Nodes { for _, node := range curTree.Nodes {
path := path.Join(nodepath, node.Name) path := path.Join(nodepath, node.Name)
if !visitor.SelectByName(path) { node = t.opts.RewriteNode(node, path)
if visitor.PrintExclude != nil { if node == nil {
visitor.PrintExclude(path)
}
continue continue
} }
@ -59,7 +72,7 @@ func FilterTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID
} }
continue continue
} }
newID, err := FilterTree(ctx, repo, path, *node.Subtree, visitor) newID, err := t.RewriteTree(ctx, repo, path, *node.Subtree)
if err != nil { if err != nil {
return restic.ID{}, err return restic.ID{}, err
} }

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
) )
@ -38,26 +37,26 @@ func (t WritableTreeMap) Dump() {
} }
} }
type checkRewriteFunc func(t testing.TB) (visitor TreeFilterVisitor, final func(testing.TB)) type checkRewriteFunc func(t testing.TB) (rewriter *TreeRewriter, final func(testing.TB))
// checkRewriteItemOrder ensures that the order of the 'path' arguments is the one passed in as 'want'. // checkRewriteItemOrder ensures that the order of the 'path' arguments is the one passed in as 'want'.
func checkRewriteItemOrder(want []string) checkRewriteFunc { func checkRewriteItemOrder(want []string) checkRewriteFunc {
pos := 0 pos := 0
return func(t testing.TB) (visitor TreeFilterVisitor, final func(testing.TB)) { return func(t testing.TB) (rewriter *TreeRewriter, final func(testing.TB)) {
vis := TreeFilterVisitor{ rewriter = NewTreeRewriter(RewriteOpts{
SelectByName: func(path string) bool { RewriteNode: func(node *restic.Node, path string) *restic.Node {
if pos >= len(want) { if pos >= len(want) {
t.Errorf("additional unexpected path found: %v", path) t.Errorf("additional unexpected path found: %v", path)
return false return nil
} }
if path != want[pos] { if path != want[pos] {
t.Errorf("wrong path found, want %q, got %q", want[pos], path) t.Errorf("wrong path found, want %q, got %q", want[pos], path)
} }
pos++ pos++
return true return node
}, },
} })
final = func(t testing.TB) { final = func(t testing.TB) {
if pos != len(want) { if pos != len(want) {
@ -65,21 +64,20 @@ func checkRewriteItemOrder(want []string) checkRewriteFunc {
} }
} }
return vis, final return rewriter, final
} }
} }
// checkRewriteSkips excludes nodes if path is in skipFor, it checks that all excluded entries are printed. // checkRewriteSkips excludes nodes if path is in skipFor, it checks that rewriting proceedes in the correct order.
func checkRewriteSkips(skipFor map[string]struct{}, want []string) checkRewriteFunc { func checkRewriteSkips(skipFor map[string]struct{}, want []string) checkRewriteFunc {
var pos int var pos int
printed := make(map[string]struct{})
return func(t testing.TB) (visitor TreeFilterVisitor, final func(testing.TB)) { return func(t testing.TB) (rewriter *TreeRewriter, final func(testing.TB)) {
vis := TreeFilterVisitor{ rewriter = NewTreeRewriter(RewriteOpts{
SelectByName: func(path string) bool { RewriteNode: func(node *restic.Node, path string) *restic.Node {
if pos >= len(want) { if pos >= len(want) {
t.Errorf("additional unexpected path found: %v", path) t.Errorf("additional unexpected path found: %v", path)
return false return nil
} }
if path != want[pos] { if path != want[pos] {
@ -87,27 +85,39 @@ func checkRewriteSkips(skipFor map[string]struct{}, want []string) checkRewriteF
} }
pos++ pos++
_, ok := skipFor[path] _, skip := skipFor[path]
return !ok if skip {
}, return nil
PrintExclude: func(s string) {
if _, ok := printed[s]; ok {
t.Errorf("path was already printed %v", s)
} }
printed[s] = struct{}{} return node
}, },
} })
final = func(t testing.TB) { final = func(t testing.TB) {
if !cmp.Equal(skipFor, printed) {
t.Errorf("unexpected paths skipped: %s", cmp.Diff(skipFor, printed))
}
if pos != len(want) { if pos != len(want) {
t.Errorf("not enough items returned, want %d, got %d", len(want), pos) t.Errorf("not enough items returned, want %d, got %d", len(want), pos)
} }
} }
return vis, final return rewriter, final
}
}
// checkIncreaseNodeSize modifies each node by changing its size.
func checkIncreaseNodeSize(increase uint64) checkRewriteFunc {
return func(t testing.TB) (rewriter *TreeRewriter, final func(testing.TB)) {
rewriter = NewTreeRewriter(RewriteOpts{
RewriteNode: func(node *restic.Node, path string) *restic.Node {
if node.Type == "file" {
node.Size += increase
}
return node
},
})
final = func(t testing.TB) {}
return rewriter, final
} }
} }
@ -172,6 +182,21 @@ func TestRewriter(t *testing.T) {
}, },
), ),
}, },
{ // modify node
tree: TestTree{
"foo": TestFile{Size: 21},
"subdir": TestTree{
"subfile": TestFile{Size: 21},
},
},
newTree: TestTree{
"foo": TestFile{Size: 42},
"subdir": TestTree{
"subfile": TestFile{Size: 42},
},
},
check: checkIncreaseNodeSize(21),
},
} }
for _, test := range tests { for _, test := range tests {
@ -186,8 +211,8 @@ func TestRewriter(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
defer cancel() defer cancel()
vis, last := test.check(t) rewriter, last := test.check(t)
newRoot, err := FilterTree(ctx, modrepo, "/", root, &vis) newRoot, err := rewriter.RewriteTree(ctx, modrepo, "/", root)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -213,8 +238,15 @@ func TestRewriterFailOnUnknownFields(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
defer cancel() defer cancel()
// use nil visitor to crash if the tree loading works unexpectedly
_, err := FilterTree(ctx, tm, "/", id, nil) rewriter := NewTreeRewriter(RewriteOpts{
RewriteNode: func(node *restic.Node, path string) *restic.Node {
// tree loading must not succeed
t.Fail()
return node
},
})
_, err := rewriter.RewriteTree(ctx, tm, "/", id)
if err == nil { if err == nil {
t.Error("missing error on unknown field") t.Error("missing error on unknown field")

View file

@ -14,7 +14,9 @@ import (
type TestTree map[string]interface{} type TestTree map[string]interface{}
// TestNode is used to test the walker. // TestNode is used to test the walker.
type TestFile struct{} type TestFile struct {
Size uint64
}
func BuildTreeMap(tree TestTree) (m TreeMap, root restic.ID) { func BuildTreeMap(tree TestTree) (m TreeMap, root restic.ID) {
m = TreeMap{} m = TreeMap{}
@ -37,6 +39,7 @@ func buildTreeMap(tree TestTree, m TreeMap) restic.ID {
err := tb.AddNode(&restic.Node{ err := tb.AddNode(&restic.Node{
Name: name, Name: name,
Type: "file", Type: "file",
Size: elem.Size,
}) })
if err != nil { if err != nil {
panic(err) panic(err)