From d92957dd78ca01dc16a7998ca6a9e339fe23044b Mon Sep 17 00:00:00 2001
From: Michael Eischer <michael.eischer@fau.de>
Date: Sun, 31 Oct 2021 23:25:36 +0100
Subject: [PATCH] lock: Implement strict lock expiry monitoring

Restic continued e.g. a backup task even when it failed to renew the
lock or failed to do so in time. For example if a backup client enters
standby during the backup this can allow other operations like `prune`
to run in the meantime (after calling `unlock`). After leaving standby
the backup client will continue its backup and upload indexes which
refer pack files that were removed in the meantime.

This commit introduces a goroutine explicitly monitoring for locks that
are not refreshed in time. To simplify the implementation there's now a
separate goroutine to refresh the lock and monitor for timeouts for each
lock. The monitoring goroutine would now cause the backup to fail as the
client has lost it's lock in the meantime.

The lock refresh goroutines are bound to the context used to lock the
repository initially. The context returned by `lockRepo` is also
cancelled when any of the goroutines exits. This ensures that the
context is cancelled whenever for any reason the lock is no longer
refreshed.
---
 cmd/restic/lock.go      | 149 ++++++++++++++++++++++++++--------------
 internal/restic/lock.go |   4 +-
 2 files changed, 99 insertions(+), 54 deletions(-)

diff --git a/cmd/restic/lock.go b/cmd/restic/lock.go
index 0cea02cfd..188bb1d59 100644
--- a/cmd/restic/lock.go
+++ b/cmd/restic/lock.go
@@ -11,10 +11,13 @@ import (
 	"github.com/restic/restic/internal/restic"
 )
 
+type lockContext struct {
+	cancel    context.CancelFunc
+	refreshWG sync.WaitGroup
+}
+
 var globalLocks struct {
-	locks         []*restic.Lock
-	cancelRefresh chan struct{}
-	refreshWG     sync.WaitGroup
+	locks map[*restic.Lock]*lockContext
 	sync.Mutex
 	sync.Once
 }
@@ -27,6 +30,8 @@ func lockRepoExclusive(ctx context.Context, repo *repository.Repository) (*resti
 	return lockRepository(ctx, repo, true)
 }
 
+// lockRepository wraps the ctx such that it is cancelled when the repository is unlocked
+// cancelling the original context also stops the lock refresh
 func lockRepository(ctx context.Context, repo *repository.Repository, exclusive bool) (*restic.Lock, context.Context, error) {
 	// make sure that a repository is unlocked properly and after cancel() was
 	// called by the cleanup handler in global.go
@@ -45,16 +50,17 @@ func lockRepository(ctx context.Context, repo *repository.Repository, exclusive
 	}
 	debug.Log("create lock %p (exclusive %v)", lock, exclusive)
 
-	globalLocks.Lock()
-	if globalLocks.cancelRefresh == nil {
-		debug.Log("start goroutine for lock refresh")
-		globalLocks.cancelRefresh = make(chan struct{})
-		globalLocks.refreshWG = sync.WaitGroup{}
-		globalLocks.refreshWG.Add(1)
-		go refreshLocks(&globalLocks.refreshWG, globalLocks.cancelRefresh)
+	ctx, cancel := context.WithCancel(ctx)
+	lockInfo := &lockContext{
+		cancel: cancel,
 	}
+	lockInfo.refreshWG.Add(2)
+	refreshChan := make(chan struct{})
 
-	globalLocks.locks = append(globalLocks.locks, lock)
+	globalLocks.Lock()
+	globalLocks.locks[lock] = lockInfo
+	go refreshLocks(ctx, lock, lockInfo, refreshChan)
+	go monitorLockRefresh(ctx, lock, lockInfo, refreshChan)
 	globalLocks.Unlock()
 
 	return lock, ctx, err
@@ -62,32 +68,76 @@ func lockRepository(ctx context.Context, repo *repository.Repository, exclusive
 
 var refreshInterval = 5 * time.Minute
 
-func refreshLocks(wg *sync.WaitGroup, done <-chan struct{}) {
-	debug.Log("start")
-	defer func() {
-		wg.Done()
-		globalLocks.Lock()
-		globalLocks.cancelRefresh = nil
-		globalLocks.Unlock()
-	}()
+// consider a lock refresh failed a bit before the lock actually becomes stale
+// the difference allows to compensate for a small time drift between clients.
+var refreshabilityTimeout = restic.StaleLockTimeout - refreshInterval*3/2
 
+func refreshLocks(ctx context.Context, lock *restic.Lock, lockInfo *lockContext, refreshed chan<- struct{}) {
+	debug.Log("start")
 	ticker := time.NewTicker(refreshInterval)
+	lastRefresh := lock.Time
+
+	defer func() {
+		ticker.Stop()
+		// ensure that the context was cancelled before removing the lock
+		lockInfo.cancel()
+
+		// remove the lock from the repo
+		debug.Log("unlocking repository with lock %v", lock)
+		if err := lock.Unlock(); err != nil {
+			debug.Log("error while unlocking: %v", err)
+			Warnf("error while unlocking: %v", err)
+		}
+
+		lockInfo.refreshWG.Done()
+	}()
 
 	for {
 		select {
-		case <-done:
+		case <-ctx.Done():
 			debug.Log("terminate")
 			return
 		case <-ticker.C:
+			if time.Since(lastRefresh) > refreshabilityTimeout {
+				// the lock is too old, wait until the expiry monitor cancels the context
+				continue
+			}
+
 			debug.Log("refreshing locks")
-			globalLocks.Lock()
-			for _, lock := range globalLocks.locks {
-				err := lock.Refresh(context.TODO())
-				if err != nil {
-					Warnf("unable to refresh lock: %v\n", err)
+			err := lock.Refresh(context.TODO())
+			if err != nil {
+				Warnf("unable to refresh lock: %v\n", err)
+			} else {
+				lastRefresh = lock.Time
+				// inform monitor gorountine about successful refresh
+				select {
+				case <-ctx.Done():
+				case refreshed <- struct{}{}:
 				}
 			}
-			globalLocks.Unlock()
+		}
+	}
+}
+
+func monitorLockRefresh(ctx context.Context, lock *restic.Lock, lockInfo *lockContext, refreshed <-chan struct{}) {
+	timer := time.NewTimer(refreshabilityTimeout)
+	defer func() {
+		timer.Stop()
+		lockInfo.cancel()
+		lockInfo.refreshWG.Done()
+	}()
+
+	for {
+		select {
+		case <-ctx.Done():
+			debug.Log("terminate expiry monitoring")
+			return
+		case <-refreshed:
+			// reset timer once the lock was refreshed successfully
+			timer.Reset(refreshabilityTimeout)
+		case <-timer.C:
+			Warnf("Fatal: failed to refresh lock in time\n")
+			return
 		}
 	}
 }
@@ -98,40 +148,35 @@ func unlockRepo(lock *restic.Lock) {
 	}
 
 	globalLocks.Lock()
-	defer globalLocks.Unlock()
+	lockInfo, exists := globalLocks.locks[lock]
+	delete(globalLocks.locks, lock)
+	globalLocks.Unlock()
 
-	for i := 0; i < len(globalLocks.locks); i++ {
-		if lock == globalLocks.locks[i] {
-			// remove the lock from the repo
-			debug.Log("unlocking repository with lock %v", lock)
-			if err := lock.Unlock(); err != nil {
-				debug.Log("error while unlocking: %v", err)
-				Warnf("error while unlocking: %v", err)
-				return
-			}
-
-			// remove the lock from the list of locks
-			globalLocks.locks = append(globalLocks.locks[:i], globalLocks.locks[i+1:]...)
-			return
-		}
+	if !exists {
+		debug.Log("unable to find lock %v in the global list of locks, ignoring", lock)
+		return
 	}
-
-	debug.Log("unable to find lock %v in the global list of locks, ignoring", lock)
+	lockInfo.cancel()
+	lockInfo.refreshWG.Wait()
 }
 
 func unlockAll(code int) (int, error) {
 	globalLocks.Lock()
-	defer globalLocks.Unlock()
-
+	locks := globalLocks.locks
 	debug.Log("unlocking %d locks", len(globalLocks.locks))
-	for _, lock := range globalLocks.locks {
-		if err := lock.Unlock(); err != nil {
-			debug.Log("error while unlocking: %v", err)
-			return code, err
-		}
-		debug.Log("successfully removed lock")
+	for _, lockInfo := range globalLocks.locks {
+		lockInfo.cancel()
+	}
+	globalLocks.locks = make(map[*restic.Lock]*lockContext)
+	globalLocks.Unlock()
+
+	for _, lockInfo := range locks {
+		lockInfo.refreshWG.Wait()
 	}
-	globalLocks.locks = globalLocks.locks[:0]
 
 	return code, nil
 }
+
+func init() {
+	globalLocks.locks = make(map[*restic.Lock]*lockContext)
+}
diff --git a/internal/restic/lock.go b/internal/restic/lock.go
index 031e8755c..c8079f58d 100644
--- a/internal/restic/lock.go
+++ b/internal/restic/lock.go
@@ -175,14 +175,14 @@ func (l *Lock) Unlock() error {
 	return l.repo.Backend().Remove(context.TODO(), Handle{Type: LockFile, Name: l.lockID.String()})
 }
 
-var staleTimeout = 30 * time.Minute
+var StaleLockTimeout = 30 * time.Minute
 
 // Stale returns true if the lock is stale. A lock is stale if the timestamp is
 // older than 30 minutes or if it was created on the current machine and the
 // process isn't alive any more.
 func (l *Lock) Stale() bool {
 	debug.Log("testing if lock %v for process %d is stale", l, l.PID)
-	if time.Since(l.Time) > staleTimeout {
+	if time.Since(l.Time) > StaleLockTimeout {
 		debug.Log("lock is stale, timestamp is too old: %v\n", l.Time)
 		return true
 	}