Replace restic.Progress with new progress.Counter

This fixes two race conditions while cleaning up the code.
This commit is contained in:
greatroar 2020-11-04 14:11:29 +01:00
parent 5731e391f8
commit ddca699cd2
15 changed files with 229 additions and 323 deletions

View file

@ -296,7 +296,6 @@ func prune(opts PruneOptions, gopts GlobalOptions, repo restic.Repository, usedB
// loop over all packs and decide what to do
bar := newProgressMax(!gopts.Quiet, uint64(len(indexPack)), "packs processed")
bar.Start()
err := repo.List(ctx, restic.PackFile, func(id restic.ID, packSize int64) error {
p, ok := indexPack[id]
if !ok {
@ -345,7 +344,7 @@ func prune(opts PruneOptions, gopts GlobalOptions, repo restic.Repository, usedB
}
delete(indexPack, id)
bar.Report(restic.Stat{Blobs: 1})
bar.Add(1)
return nil
})
bar.Done()
@ -517,6 +516,7 @@ func rebuildIndexFiles(gopts GlobalOptions, repo restic.Repository, removePacks
bar := newProgressMax(!gopts.Quiet, packcount, "packs processed")
obsoleteIndexes, err := (repo.Index()).(*repository.MasterIndex).
Save(gopts.ctx, repo, removePacks, bar)
bar.Done()
if err != nil {
return err
}
@ -533,7 +533,6 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, snapshots []*rest
usedBlobs = restic.NewBlobSet()
bar := newProgressMax(!gopts.Quiet, uint64(len(snapshots)), "snapshots")
bar.Start()
defer bar.Done()
for _, sn := range snapshots {
debug.Log("process snapshot %v", sn.ID())
@ -548,7 +547,7 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, snapshots []*rest
}
debug.Log("processed snapshot %v", sn.ID())
bar.Report(restic.Stat{Blobs: 1})
bar.Add(1)
}
return usedBlobs, nil
}

View file

@ -33,8 +33,8 @@ func deleteFiles(gopts GlobalOptions, ignoreError bool, repo restic.Repository,
}()
bar := newProgressMax(!gopts.JSON && !gopts.Quiet, uint64(totalCount), "files deleted")
defer bar.Done()
wg, ctx := errgroup.WithContext(gopts.ctx)
bar.Start()
for i := 0; i < numDeleteWorkers; i++ {
wg.Go(func() error {
for id := range fileChan {
@ -51,12 +51,11 @@ func deleteFiles(gopts GlobalOptions, ignoreError bool, repo restic.Repository,
if !gopts.JSON && gopts.verbosity > 2 {
Verbosef("removed %v\n", h)
}
bar.Report(restic.Stat{Blobs: 1})
bar.Add(1)
}
return nil
})
}
err := wg.Wait()
bar.Done()
return err
}

View file

@ -2,35 +2,45 @@ package main
import (
"fmt"
"os"
"strconv"
"time"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/ui/progress"
)
// newProgressMax returns a progress that counts blobs.
func newProgressMax(show bool, max uint64, description string) *restic.Progress {
// newProgressMax returns a progress.Counter that prints to stdout.
func newProgressMax(show bool, max uint64, description string) *progress.Counter {
if !show {
return nil
}
p := restic.NewProgress()
interval := time.Second / 60
if !stdinIsTerminal() {
interval = time.Second
} else {
fps, err := strconv.ParseInt(os.Getenv("RESTIC_PROGRESS_FPS"), 10, 64)
if err == nil && fps >= 1 {
if fps > 60 {
fps = 60
}
interval = time.Second / time.Duration(fps)
}
}
p.OnUpdate = func(s restic.Stat, d time.Duration, ticker bool) {
return progress.New(interval, func(v uint64, d time.Duration, final bool) {
status := fmt.Sprintf("[%s] %s %d / %d %s",
formatDuration(d),
formatPercent(s.Blobs, max),
s.Blobs, max, description)
formatPercent(v, max),
v, max, description)
if w := stdoutTerminalWidth(); w > 0 {
status = shortenStatus(w, status)
}
PrintProgress("%s", status)
if final {
fmt.Print("\n")
}
p.OnDone = func(s restic.Stat, d time.Duration, ticker bool) {
fmt.Printf("\n")
}
return p
})
}

View file

@ -12,6 +12,7 @@ import (
"github.com/restic/restic/internal/pack"
"github.com/restic/restic/internal/repository"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/ui/progress"
"golang.org/x/sync/errgroup"
)
@ -772,15 +773,13 @@ func checkPack(ctx context.Context, r restic.Repository, id restic.ID) error {
}
// ReadData loads all data from the repository and checks the integrity.
func (c *Checker) ReadData(ctx context.Context, p *restic.Progress, errChan chan<- error) {
func (c *Checker) ReadData(ctx context.Context, p *progress.Counter, errChan chan<- error) {
c.ReadPacks(ctx, c.packs, p, errChan)
}
// ReadPacks loads data from specified packs and checks the integrity.
func (c *Checker) ReadPacks(ctx context.Context, packs restic.IDSet, p *restic.Progress, errChan chan<- error) {
func (c *Checker) ReadPacks(ctx context.Context, packs restic.IDSet, p *progress.Counter, errChan chan<- error) {
defer close(errChan)
p.Start()
defer p.Done()
g, ctx := errgroup.WithContext(ctx)
@ -803,7 +802,7 @@ func (c *Checker) ReadPacks(ctx context.Context, packs restic.IDSet, p *restic.P
}
err := checkPack(ctx, c.repo, id)
p.Report(restic.Stat{Blobs: 1})
p.Add(1)
if err == nil {
continue
}

View file

@ -11,6 +11,8 @@ import (
"github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/pack"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/ui/progress"
"golang.org/x/sync/errgroup"
)
@ -48,8 +50,7 @@ type Lister interface {
// New creates a new index for repo from scratch. InvalidFiles contains all IDs
// of files that cannot be listed successfully.
func New(ctx context.Context, repo Lister, ignorePacks restic.IDSet, p *restic.Progress) (idx *Index, invalidFiles restic.IDs, err error) {
p.Start()
func New(ctx context.Context, repo Lister, ignorePacks restic.IDSet, p *progress.Counter) (idx *Index, invalidFiles restic.IDs, err error) {
defer p.Done()
type Job struct {
@ -118,7 +119,7 @@ func New(ctx context.Context, repo Lister, ignorePacks restic.IDSet, p *restic.P
idx = newIndex()
for res := range outputCh {
p.Report(restic.Stat{Blobs: 1})
p.Add(1)
if res.Error != nil {
cause := errors.Cause(res.Error)
if _, ok := cause.(pack.InvalidFileError); ok {
@ -187,10 +188,9 @@ func loadIndexJSON(ctx context.Context, repo ListLoader, id restic.ID) (*indexJS
}
// Load creates an index by loading all index files from the repo.
func Load(ctx context.Context, repo ListLoader, p *restic.Progress) (*Index, error) {
func Load(ctx context.Context, repo ListLoader, p *progress.Counter) (*Index, error) {
debug.Log("loading indexes")
p.Start()
defer p.Done()
supersedes := make(map[restic.ID]restic.IDSet)
@ -199,7 +199,7 @@ func Load(ctx context.Context, repo ListLoader, p *restic.Progress) (*Index, err
index := newIndex()
err := repo.List(ctx, restic.IndexFile, func(id restic.ID, size int64) error {
p.Report(restic.Stat{Blobs: 1})
p.Add(1)
debug.Log("Load index %v", id)
idx, err := loadIndexJSON(ctx, repo, id)

View file

@ -4,9 +4,9 @@ import (
"context"
"sync"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/ui/progress"
)
// MasterIndex is a collection of indexes and IDs of chunks that are in the process of being saved.
@ -266,10 +266,7 @@ func (mi *MasterIndex) MergeFinalIndexes() {
// of all known indexes in the "supersedes" field. The IDs are also returned in
// the IDSet obsolete
// After calling this function, you should remove the obsolete index files.
func (mi *MasterIndex) Save(ctx context.Context, repo restic.Repository, packBlacklist restic.IDSet, p *restic.Progress) (obsolete restic.IDSet, err error) {
p.Start()
defer p.Done()
func (mi *MasterIndex) Save(ctx context.Context, repo restic.Repository, packBlacklist restic.IDSet, p *progress.Counter) (obsolete restic.IDSet, err error) {
mi.idxMutex.Lock()
defer mi.idxMutex.Unlock()
@ -310,7 +307,7 @@ func (mi *MasterIndex) Save(ctx context.Context, repo restic.Repository, packBla
for pbs := range idx.EachByPack(ctx, packBlacklist) {
newIndex.StorePack(pbs.packID, pbs.blobs)
p.Report(restic.Stat{Blobs: 1})
p.Add(1)
if IndexFull(newIndex) {
if err := finalize(); err != nil {
return nil, err

View file

@ -10,6 +10,8 @@ import (
"github.com/restic/restic/internal/fs"
"github.com/restic/restic/internal/pack"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/ui/progress"
"golang.org/x/sync/errgroup"
)
@ -22,11 +24,8 @@ const numRepackWorkers = 8
//
// The map keepBlobs is modified by Repack, it is used to keep track of which
// blobs have been processed.
func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *restic.Progress) (obsoletePacks restic.IDSet, err error) {
if p != nil {
p.Start()
func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *progress.Counter) (obsoletePacks restic.IDSet, err error) {
defer p.Done()
}
debug.Log("repacking %d packs while keeping %d blobs", len(packs), len(keepBlobs))
@ -172,9 +171,7 @@ func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, kee
if err = fs.RemoveIfExists(tempfile.Name()); err != nil {
return errors.Wrap(err, "Remove")
}
if p != nil {
p.Report(restic.Stat{Blobs: 1})
}
p.Add(1)
}
return nil
}

View file

@ -1,232 +0,0 @@
package restic
import (
"fmt"
"os"
"strconv"
"sync"
"time"
"golang.org/x/crypto/ssh/terminal"
)
// minTickerTime limits how often the progress ticker is updated. It can be
// overridden using the RESTIC_PROGRESS_FPS (frames per second) environment
// variable.
var minTickerTime = time.Second / 60
var isTerminal = terminal.IsTerminal(int(os.Stdout.Fd()))
var forceUpdateProgress = make(chan bool)
func init() {
fps, err := strconv.ParseInt(os.Getenv("RESTIC_PROGRESS_FPS"), 10, 64)
if err == nil && fps >= 1 {
if fps > 60 {
fps = 60
}
minTickerTime = time.Second / time.Duration(fps)
}
}
// Progress reports progress on an operation.
type Progress struct {
OnStart func()
OnUpdate ProgressFunc
OnDone ProgressFunc
fnM sync.Mutex
cur Stat
curM sync.Mutex
start time.Time
c *time.Ticker
cancel chan struct{}
once sync.Once
d time.Duration
lastUpdate time.Time
running bool
}
// Stat captures newly done parts of the operation.
type Stat struct {
Files uint64
Dirs uint64
Bytes uint64
Trees uint64
Blobs uint64
Errors uint64
}
// ProgressFunc is used to report progress back to the user.
type ProgressFunc func(s Stat, runtime time.Duration, ticker bool)
// NewProgress returns a new progress reporter. When Start() is called, the
// function OnStart is executed once. Afterwards the function OnUpdate is
// called when new data arrives or at least every d interval. The function
// OnDone is called when Done() is called. Both functions are called
// synchronously and can use shared state.
func NewProgress() *Progress {
var d time.Duration
if isTerminal {
d = time.Second
}
return &Progress{d: d}
}
// Start resets and runs the progress reporter.
func (p *Progress) Start() {
if p == nil || p.running {
return
}
p.cancel = make(chan struct{})
p.running = true
p.Reset()
p.start = time.Now()
p.c = nil
if p.d != 0 {
p.c = time.NewTicker(p.d)
}
if p.OnStart != nil {
p.OnStart()
}
go p.reporter()
}
// Reset resets all statistic counters to zero.
func (p *Progress) Reset() {
if p == nil {
return
}
if !p.running {
panic("resetting a non-running Progress")
}
p.curM.Lock()
p.cur = Stat{}
p.curM.Unlock()
}
// Report adds the statistics from s to the current state and tries to report
// the accumulated statistics via the feedback channel.
func (p *Progress) Report(s Stat) {
if p == nil {
return
}
if !p.running {
panic("reporting in a non-running Progress")
}
p.curM.Lock()
p.cur.Add(s)
cur := p.cur
needUpdate := false
if isTerminal && time.Since(p.lastUpdate) > minTickerTime {
p.lastUpdate = time.Now()
needUpdate = true
}
p.curM.Unlock()
if needUpdate {
p.updateProgress(cur, false)
}
}
func (p *Progress) updateProgress(cur Stat, ticker bool) {
if p.OnUpdate == nil {
return
}
p.fnM.Lock()
p.OnUpdate(cur, time.Since(p.start), ticker)
p.fnM.Unlock()
}
func (p *Progress) reporter() {
if p == nil {
return
}
updateProgress := func() {
p.curM.Lock()
cur := p.cur
p.curM.Unlock()
p.updateProgress(cur, true)
}
var ticker <-chan time.Time
if p.c != nil {
ticker = p.c.C
}
for {
select {
case <-ticker:
updateProgress()
case <-forceUpdateProgress:
updateProgress()
case <-p.cancel:
if p.c != nil {
p.c.Stop()
}
return
}
}
}
// Done closes the progress report.
func (p *Progress) Done() {
if p == nil || !p.running {
return
}
p.running = false
p.once.Do(func() {
close(p.cancel)
})
cur := p.cur
if p.OnDone != nil {
p.fnM.Lock()
p.OnUpdate(cur, time.Since(p.start), false)
p.OnDone(cur, time.Since(p.start), false)
p.fnM.Unlock()
}
}
// Add accumulates other into s.
func (s *Stat) Add(other Stat) {
s.Bytes += other.Bytes
s.Dirs += other.Dirs
s.Files += other.Files
s.Trees += other.Trees
s.Blobs += other.Blobs
s.Errors += other.Errors
}
func (s Stat) String() string {
b := float64(s.Bytes)
var str string
switch {
case s.Bytes > 1<<40:
str = fmt.Sprintf("%.3f TiB", b/(1<<40))
case s.Bytes > 1<<30:
str = fmt.Sprintf("%.3f GiB", b/(1<<30))
case s.Bytes > 1<<20:
str = fmt.Sprintf("%.3f MiB", b/(1<<20))
case s.Bytes > 1<<10:
str = fmt.Sprintf("%.3f KiB", b/(1<<10))
default:
str = fmt.Sprintf("%dB", s.Bytes)
}
return fmt.Sprintf("Stat(%d files, %d dirs, %v trees, %v blobs, %d errors, %v)",
s.Files, s.Dirs, s.Trees, s.Blobs, s.Errors, str)
}

View file

@ -1,22 +0,0 @@
// +build !windows,!darwin,!freebsd,!netbsd,!openbsd,!dragonfly,!solaris
package restic
import (
"os"
"os/signal"
"syscall"
"github.com/restic/restic/internal/debug"
)
func init() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGUSR1)
go func() {
for s := range c {
debug.Log("Signal received: %v\n", s)
forceUpdateProgress <- true
}
}()
}

View file

@ -1,22 +0,0 @@
// +build darwin freebsd netbsd openbsd dragonfly
package restic
import (
"os"
"os/signal"
"syscall"
"github.com/restic/restic/internal/debug"
)
func init() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINFO, syscall.SIGUSR1)
go func() {
for s := range c {
debug.Log("Signal received: %v\n", s)
forceUpdateProgress <- true
}
}()
}

View file

@ -0,0 +1,99 @@
package progress
import (
"os"
"sync"
"sync/atomic"
"time"
"github.com/restic/restic/internal/debug"
)
// A Func is a callback for a Counter.
//
// The final argument is true if Counter.Done has been called,
// which means that the current call will be the last.
type Func func(value uint64, runtime time.Duration, final bool)
// A Counter tracks a running count and controls a goroutine that passes its
// value periodically to a Func.
//
// The Func is also called when SIGUSR1 (or SIGINFO, on BSD) is received.
type Counter struct {
report Func
start time.Time
stopped chan struct{} // Closed by run.
stop chan struct{} // Close to stop run.
tick *time.Ticker
value uint64
}
// New starts a new Counter.
func New(interval time.Duration, report Func) *Counter {
signals.Once.Do(func() {
signals.ch = make(chan os.Signal, 1)
setupSignals()
})
c := &Counter{
report: report,
start: time.Now(),
stopped: make(chan struct{}),
stop: make(chan struct{}),
tick: time.NewTicker(interval),
}
go c.run()
return c
}
// Add v to the Counter. This method is concurrency-safe.
func (c *Counter) Add(v uint64) {
if c == nil {
return
}
atomic.AddUint64(&c.value, v)
}
// Done tells a Counter to stop and waits for it to report its final value.
func (c *Counter) Done() {
if c == nil {
return
}
c.tick.Stop()
close(c.stop)
<-c.stopped // Wait for last progress report.
*c = Counter{} // Prevent reuse.
}
func (c *Counter) get() uint64 { return atomic.LoadUint64(&c.value) }
func (c *Counter) run() {
defer close(c.stopped)
defer func() {
// Must be a func so that time.Since isn't called at defer time.
c.report(c.get(), time.Since(c.start), true)
}()
for {
var now time.Time
select {
case now = <-c.tick.C:
case sig := <-signals.ch:
debug.Log("Signal received: %v\n", sig)
now = time.Now()
case <-c.stop:
return
}
c.report(c.get(), now.Sub(c.start), false)
}
}
// XXX The fact that signals is a single global variable means that only one
// Counter receives each incoming signal.
var signals struct {
ch chan os.Signal
sync.Once
}

View file

@ -0,0 +1,55 @@
package progress_test
import (
"testing"
"time"
"github.com/restic/restic/internal/test"
"github.com/restic/restic/internal/ui/progress"
)
func TestCounter(t *testing.T) {
const N = 100
var (
finalSeen = false
increasing = true
last uint64
ncalls int
)
report := func(value uint64, d time.Duration, final bool) {
finalSeen = true
if value < last {
increasing = false
}
last = value
ncalls++
}
c := progress.New(10*time.Millisecond, report)
done := make(chan struct{})
go func() {
defer close(done)
for i := 0; i < N; i++ {
time.Sleep(time.Millisecond)
c.Add(1)
}
}()
<-done
c.Done()
test.Assert(t, finalSeen, "final call did not happen")
test.Assert(t, increasing, "values not increasing")
test.Equals(t, uint64(N), last)
t.Log("number of calls:", ncalls)
}
func TestCounterNil(t *testing.T) {
// Shouldn't panic.
var c *progress.Counter = nil
c.Add(1)
c.Done()
}

View file

@ -0,0 +1,12 @@
// +build darwin dragonfly freebsd netbsd openbsd
package progress
import (
"os/signal"
"syscall"
)
func setupSignals() {
signal.Notify(signals.ch, syscall.SIGINFO, syscall.SIGUSR1)
}

View file

@ -0,0 +1,12 @@
// +build linux solaris
package progress
import (
"os/signal"
"syscall"
)
func setupSignals() {
signal.Notify(signals.ch, syscall.SIGUSR1)
}

View file

@ -0,0 +1,3 @@
package progress
func setupSignals() {}