From b2f5381737d73f509c40437ae42c5dd20ec9d23f Mon Sep 17 00:00:00 2001
From: Alexander Weiss <alex@weissfam.de>
Date: Sun, 19 Jul 2020 07:13:41 +0200
Subject: [PATCH] Make realistic forget --prune --dryrun

---
 cmd/restic/cmd_forget.go        | 3 +--
 cmd/restic/cmd_prune.go         | 6 +++---
 internal/restic/snapshot.go     | 6 +++++-
 internal/restic/testing_test.go | 2 +-
 4 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/cmd/restic/cmd_forget.go b/cmd/restic/cmd_forget.go
index fa9739c0b..596c7c550 100644
--- a/cmd/restic/cmd_forget.go
+++ b/cmd/restic/cmd_forget.go
@@ -214,9 +214,8 @@ func runForget(opts ForgetOptions, gopts GlobalOptions, args []string) error {
 		if !gopts.JSON {
 			Verbosef("%d snapshots have been removed, running prune\n", len(removeSnIDs))
 		}
-
 		pruneOptions.DryRun = opts.DryRun
-		return runPruneWithRepo(pruneOptions, gopts, repo)
+		return runPruneWithRepo(pruneOptions, gopts, repo, removeSnIDs)
 	}
 
 	return nil
diff --git a/cmd/restic/cmd_prune.go b/cmd/restic/cmd_prune.go
index 9a93c600c..605f6258e 100644
--- a/cmd/restic/cmd_prune.go
+++ b/cmd/restic/cmd_prune.go
@@ -128,15 +128,15 @@ func runPrune(opts PruneOptions, gopts GlobalOptions) error {
 		return err
 	}
 
-	return runPruneWithRepo(opts, gopts, repo)
+	return runPruneWithRepo(opts, gopts, repo, restic.NewIDSet())
 }
 
-func runPruneWithRepo(opts PruneOptions, gopts GlobalOptions, repo *repository.Repository) error {
+func runPruneWithRepo(opts PruneOptions, gopts GlobalOptions, repo *repository.Repository, ignoreSnapshots restic.IDSet) error {
 	// we do not need index updates while pruning!
 	repo.DisableAutoIndexUpdate()
 
 	Verbosef("loading all snapshots...\n")
-	snapshots, err := restic.LoadAllSnapshots(gopts.ctx, repo)
+	snapshots, err := restic.LoadAllSnapshots(gopts.ctx, repo, ignoreSnapshots)
 	if err != nil {
 		return err
 	}
diff --git a/internal/restic/snapshot.go b/internal/restic/snapshot.go
index dc0dd5949..86e98e234 100644
--- a/internal/restic/snapshot.go
+++ b/internal/restic/snapshot.go
@@ -67,8 +67,12 @@ func LoadSnapshot(ctx context.Context, repo Repository, id ID) (*Snapshot, error
 }
 
 // LoadAllSnapshots returns a list of all snapshots in the repo.
-func LoadAllSnapshots(ctx context.Context, repo Repository) (snapshots []*Snapshot, err error) {
+// If a snapshot ID is in excludeIDs, it will not be included in the result.
+func LoadAllSnapshots(ctx context.Context, repo Repository, excludeIDs IDSet) (snapshots []*Snapshot, err error) {
 	err = repo.List(ctx, SnapshotFile, func(id ID, size int64) error {
+		if excludeIDs.Has(id) {
+			return nil
+		}
 		sn, err := LoadSnapshot(ctx, repo, id)
 		if err != nil {
 			return err
diff --git a/internal/restic/testing_test.go b/internal/restic/testing_test.go
index 0386fb76a..c3989f55f 100644
--- a/internal/restic/testing_test.go
+++ b/internal/restic/testing_test.go
@@ -25,7 +25,7 @@ func TestCreateSnapshot(t *testing.T) {
 		restic.TestCreateSnapshot(t, repo, testSnapshotTime.Add(time.Duration(i)*time.Second), testDepth, 0)
 	}
 
-	snapshots, err := restic.LoadAllSnapshots(context.TODO(), repo)
+	snapshots, err := restic.LoadAllSnapshots(context.TODO(), repo, restic.NewIDSet())
 	if err != nil {
 		t.Fatal(err)
 	}