restore: configurable overwrite behavior

This commit is contained in:
Michael Eischer 2024-05-31 11:43:42 +02:00
parent a23cb3a428
commit 6a4ae9d6b1
2 changed files with 129 additions and 17 deletions

View file

@ -51,8 +51,9 @@ type RestoreOptions struct {
InsensitiveInclude []string InsensitiveInclude []string
Target string Target string
restic.SnapshotFilter restic.SnapshotFilter
Sparse bool Sparse bool
Verify bool Verify bool
Overwrite restorer.OverwriteBehavior
} }
var restoreOptions RestoreOptions var restoreOptions RestoreOptions
@ -70,6 +71,7 @@ func init() {
initSingleSnapshotFilter(flags, &restoreOptions.SnapshotFilter) initSingleSnapshotFilter(flags, &restoreOptions.SnapshotFilter)
flags.BoolVar(&restoreOptions.Sparse, "sparse", false, "restore files as sparse") flags.BoolVar(&restoreOptions.Sparse, "sparse", false, "restore files as sparse")
flags.BoolVar(&restoreOptions.Verify, "verify", false, "verify restored files content") flags.BoolVar(&restoreOptions.Verify, "verify", false, "verify restored files content")
flags.Var(&restoreOptions.Overwrite, "overwrite", "overwrite behavior, one of (always|if-newer|never) (default: always)")
} }
func runRestore(ctx context.Context, opts RestoreOptions, gopts GlobalOptions, func runRestore(ctx context.Context, opts RestoreOptions, gopts GlobalOptions,
@ -165,6 +167,7 @@ func runRestore(ctx context.Context, opts RestoreOptions, gopts GlobalOptions,
res := restorer.NewRestorer(repo, sn, restorer.Options{ res := restorer.NewRestorer(repo, sn, restorer.Options{
Sparse: opts.Sparse, Sparse: opts.Sparse,
Progress: progress, Progress: progress,
Overwrite: opts.Overwrite,
}) })
totalErrors := 0 totalErrors := 0

View file

@ -2,6 +2,7 @@ package restorer
import ( import (
"context" "context"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"sync/atomic" "sync/atomic"
@ -17,10 +18,13 @@ import (
// Restorer is used to restore a snapshot to a directory. // Restorer is used to restore a snapshot to a directory.
type Restorer struct { type Restorer struct {
repo restic.Repository repo restic.Repository
sn *restic.Snapshot sn *restic.Snapshot
sparse bool sparse bool
progress *restoreui.Progress progress *restoreui.Progress
overwrite OverwriteBehavior
fileList map[string]struct{}
Error func(location string, err error) error Error func(location string, err error) error
Warn func(message string) Warn func(message string)
@ -30,8 +34,53 @@ type Restorer struct {
var restorerAbortOnAllErrors = func(_ string, err error) error { return err } var restorerAbortOnAllErrors = func(_ string, err error) error { return err }
type Options struct { type Options struct {
Sparse bool Sparse bool
Progress *restoreui.Progress Progress *restoreui.Progress
Overwrite OverwriteBehavior
}
type OverwriteBehavior int
// Constants for different overwrite behavior
const (
OverwriteAlways OverwriteBehavior = 0
OverwriteIfNewer OverwriteBehavior = 1
OverwriteNever OverwriteBehavior = 2
OverwriteInvalid OverwriteBehavior = 3
)
// Set implements the method needed for pflag command flag parsing.
func (c *OverwriteBehavior) Set(s string) error {
switch s {
case "always":
*c = OverwriteAlways
case "if-newer":
*c = OverwriteIfNewer
case "never":
*c = OverwriteNever
default:
*c = OverwriteInvalid
return fmt.Errorf("invalid overwrite behavior %q, must be one of (always|if-newer|never)", s)
}
return nil
}
func (c *OverwriteBehavior) String() string {
switch *c {
case OverwriteAlways:
return "always"
case OverwriteIfNewer:
return "if-newer"
case OverwriteNever:
return "never"
default:
return "invalid"
}
}
func (c *OverwriteBehavior) Type() string {
return "behavior"
} }
// NewRestorer creates a restorer preloaded with the content from the snapshot id. // NewRestorer creates a restorer preloaded with the content from the snapshot id.
@ -40,6 +89,8 @@ func NewRestorer(repo restic.Repository, sn *restic.Snapshot, opts Options) *Res
repo: repo, repo: repo,
sparse: opts.Sparse, sparse: opts.Sparse,
progress: opts.Progress, progress: opts.Progress,
overwrite: opts.Overwrite,
fileList: make(map[string]struct{}),
Error: restorerAbortOnAllErrors, Error: restorerAbortOnAllErrors,
SelectFilter: func(string, string, *restic.Node) (bool, bool) { return true, true }, SelectFilter: func(string, string, *restic.Node) (bool, bool) { return true, true },
sn: sn, sn: sn,
@ -252,10 +303,12 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error {
idx.Add(node.Inode, node.DeviceID, location) idx.Add(node.Inode, node.DeviceID, location)
} }
res.progress.AddFile(node.Size) return res.withOverwriteCheck(node, target, location, false, func() error {
filerestorer.addFile(location, node.Content, int64(node.Size)) res.progress.AddFile(node.Size)
filerestorer.addFile(location, node.Content, int64(node.Size))
return nil res.trackFile(location)
return nil
})
}, },
}) })
if err != nil { if err != nil {
@ -274,14 +327,22 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error {
visitNode: func(node *restic.Node, target, location string) error { visitNode: func(node *restic.Node, target, location string) error {
debug.Log("second pass, visitNode: restore node %q", location) debug.Log("second pass, visitNode: restore node %q", location)
if node.Type != "file" { if node.Type != "file" {
return res.restoreNodeTo(ctx, node, target, location) return res.withOverwriteCheck(node, target, location, false, func() error {
return res.restoreNodeTo(ctx, node, target, location)
})
} }
if idx.Has(node.Inode, node.DeviceID) && idx.Value(node.Inode, node.DeviceID) != location { if idx.Has(node.Inode, node.DeviceID) && idx.Value(node.Inode, node.DeviceID) != location {
return res.restoreHardlinkAt(node, filerestorer.targetPath(idx.Value(node.Inode, node.DeviceID)), target, location) return res.withOverwriteCheck(node, target, location, true, func() error {
return res.restoreHardlinkAt(node, filerestorer.targetPath(idx.Value(node.Inode, node.DeviceID)), target, location)
})
} }
return res.restoreNodeMetadataTo(node, target, location) if res.hasRestoredFile(location) {
return res.restoreNodeMetadataTo(node, target, location)
}
// don't touch skipped files
return nil
}, },
leaveDir: func(node *restic.Node, target, location string) error { leaveDir: func(node *restic.Node, target, location string) error {
err := res.restoreNodeMetadataTo(node, target, location) err := res.restoreNodeMetadataTo(node, target, location)
@ -294,6 +355,54 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error {
return err return err
} }
func (res *Restorer) trackFile(location string) {
res.fileList[location] = struct{}{}
}
func (res *Restorer) hasRestoredFile(location string) bool {
_, ok := res.fileList[location]
return ok
}
func (res *Restorer) withOverwriteCheck(node *restic.Node, target, location string, isHardlink bool, cb func() error) error {
overwrite, err := shouldOverwrite(res.overwrite, node, target)
if err != nil {
return err
} else if !overwrite {
size := node.Size
if isHardlink {
size = 0
}
res.progress.AddFile(size)
res.progress.AddProgress(location, size, size)
return nil
}
return cb()
}
func shouldOverwrite(overwrite OverwriteBehavior, node *restic.Node, destination string) (bool, error) {
if overwrite == OverwriteAlways {
return true, nil
}
fi, err := fs.Lstat(destination)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return true, nil
}
return false, err
}
if overwrite == OverwriteIfNewer {
// return if node is newer
return node.ModTime.After(fi.ModTime()), nil
} else if overwrite == OverwriteNever {
// file exists
return false, nil
}
panic("unknown overwrite behavior")
}
// Snapshot returns the snapshot this restorer is configured to use. // Snapshot returns the snapshot this restorer is configured to use.
func (res *Restorer) Snapshot() *restic.Snapshot { func (res *Restorer) Snapshot() *restic.Snapshot {
return res.sn return res.sn
@ -324,8 +433,8 @@ func (res *Restorer) VerifyFiles(ctx context.Context, dst string) (int, error) {
defer close(work) defer close(work)
_, err := res.traverseTree(ctx, dst, string(filepath.Separator), *res.sn.Tree, treeVisitor{ _, err := res.traverseTree(ctx, dst, string(filepath.Separator), *res.sn.Tree, treeVisitor{
visitNode: func(node *restic.Node, target, _ string) error { visitNode: func(node *restic.Node, target, location string) error {
if node.Type != "file" { if node.Type != "file" || !res.hasRestoredFile(location) {
return nil return nil
} }
select { select {