forked from TrueCloudLab/restic
Replace restic.Progress with new progress.Counter
This fixes two race conditions while cleaning up the code.
This commit is contained in:
parent
5731e391f8
commit
ddca699cd2
15 changed files with 229 additions and 323 deletions
|
@ -296,7 +296,6 @@ func prune(opts PruneOptions, gopts GlobalOptions, repo restic.Repository, usedB
|
|||
|
||||
// loop over all packs and decide what to do
|
||||
bar := newProgressMax(!gopts.Quiet, uint64(len(indexPack)), "packs processed")
|
||||
bar.Start()
|
||||
err := repo.List(ctx, restic.PackFile, func(id restic.ID, packSize int64) error {
|
||||
p, ok := indexPack[id]
|
||||
if !ok {
|
||||
|
@ -345,7 +344,7 @@ func prune(opts PruneOptions, gopts GlobalOptions, repo restic.Repository, usedB
|
|||
}
|
||||
|
||||
delete(indexPack, id)
|
||||
bar.Report(restic.Stat{Blobs: 1})
|
||||
bar.Add(1)
|
||||
return nil
|
||||
})
|
||||
bar.Done()
|
||||
|
@ -517,6 +516,7 @@ func rebuildIndexFiles(gopts GlobalOptions, repo restic.Repository, removePacks
|
|||
bar := newProgressMax(!gopts.Quiet, packcount, "packs processed")
|
||||
obsoleteIndexes, err := (repo.Index()).(*repository.MasterIndex).
|
||||
Save(gopts.ctx, repo, removePacks, bar)
|
||||
bar.Done()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -533,7 +533,6 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, snapshots []*rest
|
|||
usedBlobs = restic.NewBlobSet()
|
||||
|
||||
bar := newProgressMax(!gopts.Quiet, uint64(len(snapshots)), "snapshots")
|
||||
bar.Start()
|
||||
defer bar.Done()
|
||||
for _, sn := range snapshots {
|
||||
debug.Log("process snapshot %v", sn.ID())
|
||||
|
@ -548,7 +547,7 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, snapshots []*rest
|
|||
}
|
||||
|
||||
debug.Log("processed snapshot %v", sn.ID())
|
||||
bar.Report(restic.Stat{Blobs: 1})
|
||||
bar.Add(1)
|
||||
}
|
||||
return usedBlobs, nil
|
||||
}
|
||||
|
|
|
@ -33,8 +33,8 @@ func deleteFiles(gopts GlobalOptions, ignoreError bool, repo restic.Repository,
|
|||
}()
|
||||
|
||||
bar := newProgressMax(!gopts.JSON && !gopts.Quiet, uint64(totalCount), "files deleted")
|
||||
defer bar.Done()
|
||||
wg, ctx := errgroup.WithContext(gopts.ctx)
|
||||
bar.Start()
|
||||
for i := 0; i < numDeleteWorkers; i++ {
|
||||
wg.Go(func() error {
|
||||
for id := range fileChan {
|
||||
|
@ -51,12 +51,11 @@ func deleteFiles(gopts GlobalOptions, ignoreError bool, repo restic.Repository,
|
|||
if !gopts.JSON && gopts.verbosity > 2 {
|
||||
Verbosef("removed %v\n", h)
|
||||
}
|
||||
bar.Report(restic.Stat{Blobs: 1})
|
||||
bar.Add(1)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
err := wg.Wait()
|
||||
bar.Done()
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,35 +2,45 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/restic/restic/internal/restic"
|
||||
"github.com/restic/restic/internal/ui/progress"
|
||||
)
|
||||
|
||||
// newProgressMax returns a progress that counts blobs.
|
||||
func newProgressMax(show bool, max uint64, description string) *restic.Progress {
|
||||
// newProgressMax returns a progress.Counter that prints to stdout.
|
||||
func newProgressMax(show bool, max uint64, description string) *progress.Counter {
|
||||
if !show {
|
||||
return nil
|
||||
}
|
||||
|
||||
p := restic.NewProgress()
|
||||
interval := time.Second / 60
|
||||
if !stdinIsTerminal() {
|
||||
interval = time.Second
|
||||
} else {
|
||||
fps, err := strconv.ParseInt(os.Getenv("RESTIC_PROGRESS_FPS"), 10, 64)
|
||||
if err == nil && fps >= 1 {
|
||||
if fps > 60 {
|
||||
fps = 60
|
||||
}
|
||||
interval = time.Second / time.Duration(fps)
|
||||
}
|
||||
}
|
||||
|
||||
p.OnUpdate = func(s restic.Stat, d time.Duration, ticker bool) {
|
||||
return progress.New(interval, func(v uint64, d time.Duration, final bool) {
|
||||
status := fmt.Sprintf("[%s] %s %d / %d %s",
|
||||
formatDuration(d),
|
||||
formatPercent(s.Blobs, max),
|
||||
s.Blobs, max, description)
|
||||
formatPercent(v, max),
|
||||
v, max, description)
|
||||
|
||||
if w := stdoutTerminalWidth(); w > 0 {
|
||||
status = shortenStatus(w, status)
|
||||
}
|
||||
|
||||
PrintProgress("%s", status)
|
||||
}
|
||||
|
||||
p.OnDone = func(s restic.Stat, d time.Duration, ticker bool) {
|
||||
fmt.Printf("\n")
|
||||
}
|
||||
|
||||
return p
|
||||
if final {
|
||||
fmt.Print("\n")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/restic/restic/internal/pack"
|
||||
"github.com/restic/restic/internal/repository"
|
||||
"github.com/restic/restic/internal/restic"
|
||||
"github.com/restic/restic/internal/ui/progress"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
|
@ -772,15 +773,13 @@ func checkPack(ctx context.Context, r restic.Repository, id restic.ID) error {
|
|||
}
|
||||
|
||||
// ReadData loads all data from the repository and checks the integrity.
|
||||
func (c *Checker) ReadData(ctx context.Context, p *restic.Progress, errChan chan<- error) {
|
||||
func (c *Checker) ReadData(ctx context.Context, p *progress.Counter, errChan chan<- error) {
|
||||
c.ReadPacks(ctx, c.packs, p, errChan)
|
||||
}
|
||||
|
||||
// ReadPacks loads data from specified packs and checks the integrity.
|
||||
func (c *Checker) ReadPacks(ctx context.Context, packs restic.IDSet, p *restic.Progress, errChan chan<- error) {
|
||||
func (c *Checker) ReadPacks(ctx context.Context, packs restic.IDSet, p *progress.Counter, errChan chan<- error) {
|
||||
defer close(errChan)
|
||||
|
||||
p.Start()
|
||||
defer p.Done()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
@ -803,7 +802,7 @@ func (c *Checker) ReadPacks(ctx context.Context, packs restic.IDSet, p *restic.P
|
|||
}
|
||||
|
||||
err := checkPack(ctx, c.repo, id)
|
||||
p.Report(restic.Stat{Blobs: 1})
|
||||
p.Add(1)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
"github.com/restic/restic/internal/errors"
|
||||
"github.com/restic/restic/internal/pack"
|
||||
"github.com/restic/restic/internal/restic"
|
||||
"github.com/restic/restic/internal/ui/progress"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
|
@ -48,8 +50,7 @@ type Lister interface {
|
|||
|
||||
// New creates a new index for repo from scratch. InvalidFiles contains all IDs
|
||||
// of files that cannot be listed successfully.
|
||||
func New(ctx context.Context, repo Lister, ignorePacks restic.IDSet, p *restic.Progress) (idx *Index, invalidFiles restic.IDs, err error) {
|
||||
p.Start()
|
||||
func New(ctx context.Context, repo Lister, ignorePacks restic.IDSet, p *progress.Counter) (idx *Index, invalidFiles restic.IDs, err error) {
|
||||
defer p.Done()
|
||||
|
||||
type Job struct {
|
||||
|
@ -118,7 +119,7 @@ func New(ctx context.Context, repo Lister, ignorePacks restic.IDSet, p *restic.P
|
|||
idx = newIndex()
|
||||
|
||||
for res := range outputCh {
|
||||
p.Report(restic.Stat{Blobs: 1})
|
||||
p.Add(1)
|
||||
if res.Error != nil {
|
||||
cause := errors.Cause(res.Error)
|
||||
if _, ok := cause.(pack.InvalidFileError); ok {
|
||||
|
@ -187,10 +188,9 @@ func loadIndexJSON(ctx context.Context, repo ListLoader, id restic.ID) (*indexJS
|
|||
}
|
||||
|
||||
// Load creates an index by loading all index files from the repo.
|
||||
func Load(ctx context.Context, repo ListLoader, p *restic.Progress) (*Index, error) {
|
||||
func Load(ctx context.Context, repo ListLoader, p *progress.Counter) (*Index, error) {
|
||||
debug.Log("loading indexes")
|
||||
|
||||
p.Start()
|
||||
defer p.Done()
|
||||
|
||||
supersedes := make(map[restic.ID]restic.IDSet)
|
||||
|
@ -199,7 +199,7 @@ func Load(ctx context.Context, repo ListLoader, p *restic.Progress) (*Index, err
|
|||
index := newIndex()
|
||||
|
||||
err := repo.List(ctx, restic.IndexFile, func(id restic.ID, size int64) error {
|
||||
p.Report(restic.Stat{Blobs: 1})
|
||||
p.Add(1)
|
||||
|
||||
debug.Log("Load index %v", id)
|
||||
idx, err := loadIndexJSON(ctx, repo, id)
|
||||
|
|
|
@ -4,9 +4,9 @@ import (
|
|||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/restic/restic/internal/restic"
|
||||
|
||||
"github.com/restic/restic/internal/debug"
|
||||
"github.com/restic/restic/internal/restic"
|
||||
"github.com/restic/restic/internal/ui/progress"
|
||||
)
|
||||
|
||||
// MasterIndex is a collection of indexes and IDs of chunks that are in the process of being saved.
|
||||
|
@ -266,10 +266,7 @@ func (mi *MasterIndex) MergeFinalIndexes() {
|
|||
// of all known indexes in the "supersedes" field. The IDs are also returned in
|
||||
// the IDSet obsolete
|
||||
// After calling this function, you should remove the obsolete index files.
|
||||
func (mi *MasterIndex) Save(ctx context.Context, repo restic.Repository, packBlacklist restic.IDSet, p *restic.Progress) (obsolete restic.IDSet, err error) {
|
||||
p.Start()
|
||||
defer p.Done()
|
||||
|
||||
func (mi *MasterIndex) Save(ctx context.Context, repo restic.Repository, packBlacklist restic.IDSet, p *progress.Counter) (obsolete restic.IDSet, err error) {
|
||||
mi.idxMutex.Lock()
|
||||
defer mi.idxMutex.Unlock()
|
||||
|
||||
|
@ -310,7 +307,7 @@ func (mi *MasterIndex) Save(ctx context.Context, repo restic.Repository, packBla
|
|||
|
||||
for pbs := range idx.EachByPack(ctx, packBlacklist) {
|
||||
newIndex.StorePack(pbs.packID, pbs.blobs)
|
||||
p.Report(restic.Stat{Blobs: 1})
|
||||
p.Add(1)
|
||||
if IndexFull(newIndex) {
|
||||
if err := finalize(); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
"github.com/restic/restic/internal/fs"
|
||||
"github.com/restic/restic/internal/pack"
|
||||
"github.com/restic/restic/internal/restic"
|
||||
"github.com/restic/restic/internal/ui/progress"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
|
@ -22,11 +24,8 @@ const numRepackWorkers = 8
|
|||
//
|
||||
// The map keepBlobs is modified by Repack, it is used to keep track of which
|
||||
// blobs have been processed.
|
||||
func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *restic.Progress) (obsoletePacks restic.IDSet, err error) {
|
||||
if p != nil {
|
||||
p.Start()
|
||||
defer p.Done()
|
||||
}
|
||||
func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet, p *progress.Counter) (obsoletePacks restic.IDSet, err error) {
|
||||
defer p.Done()
|
||||
|
||||
debug.Log("repacking %d packs while keeping %d blobs", len(packs), len(keepBlobs))
|
||||
|
||||
|
@ -172,9 +171,7 @@ func Repack(ctx context.Context, repo restic.Repository, packs restic.IDSet, kee
|
|||
if err = fs.RemoveIfExists(tempfile.Name()); err != nil {
|
||||
return errors.Wrap(err, "Remove")
|
||||
}
|
||||
if p != nil {
|
||||
p.Report(restic.Stat{Blobs: 1})
|
||||
}
|
||||
p.Add(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,232 +0,0 @@
|
|||
package restic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
// minTickerTime limits how often the progress ticker is updated. It can be
|
||||
// overridden using the RESTIC_PROGRESS_FPS (frames per second) environment
|
||||
// variable.
|
||||
var minTickerTime = time.Second / 60
|
||||
|
||||
var isTerminal = terminal.IsTerminal(int(os.Stdout.Fd()))
|
||||
var forceUpdateProgress = make(chan bool)
|
||||
|
||||
func init() {
|
||||
fps, err := strconv.ParseInt(os.Getenv("RESTIC_PROGRESS_FPS"), 10, 64)
|
||||
if err == nil && fps >= 1 {
|
||||
if fps > 60 {
|
||||
fps = 60
|
||||
}
|
||||
minTickerTime = time.Second / time.Duration(fps)
|
||||
}
|
||||
}
|
||||
|
||||
// Progress reports progress on an operation.
|
||||
type Progress struct {
|
||||
OnStart func()
|
||||
OnUpdate ProgressFunc
|
||||
OnDone ProgressFunc
|
||||
fnM sync.Mutex
|
||||
|
||||
cur Stat
|
||||
curM sync.Mutex
|
||||
start time.Time
|
||||
c *time.Ticker
|
||||
cancel chan struct{}
|
||||
once sync.Once
|
||||
d time.Duration
|
||||
lastUpdate time.Time
|
||||
|
||||
running bool
|
||||
}
|
||||
|
||||
// Stat captures newly done parts of the operation.
|
||||
type Stat struct {
|
||||
Files uint64
|
||||
Dirs uint64
|
||||
Bytes uint64
|
||||
Trees uint64
|
||||
Blobs uint64
|
||||
Errors uint64
|
||||
}
|
||||
|
||||
// ProgressFunc is used to report progress back to the user.
|
||||
type ProgressFunc func(s Stat, runtime time.Duration, ticker bool)
|
||||
|
||||
// NewProgress returns a new progress reporter. When Start() is called, the
|
||||
// function OnStart is executed once. Afterwards the function OnUpdate is
|
||||
// called when new data arrives or at least every d interval. The function
|
||||
// OnDone is called when Done() is called. Both functions are called
|
||||
// synchronously and can use shared state.
|
||||
func NewProgress() *Progress {
|
||||
var d time.Duration
|
||||
if isTerminal {
|
||||
d = time.Second
|
||||
}
|
||||
return &Progress{d: d}
|
||||
}
|
||||
|
||||
// Start resets and runs the progress reporter.
|
||||
func (p *Progress) Start() {
|
||||
if p == nil || p.running {
|
||||
return
|
||||
}
|
||||
|
||||
p.cancel = make(chan struct{})
|
||||
p.running = true
|
||||
p.Reset()
|
||||
p.start = time.Now()
|
||||
p.c = nil
|
||||
if p.d != 0 {
|
||||
p.c = time.NewTicker(p.d)
|
||||
}
|
||||
|
||||
if p.OnStart != nil {
|
||||
p.OnStart()
|
||||
}
|
||||
|
||||
go p.reporter()
|
||||
}
|
||||
|
||||
// Reset resets all statistic counters to zero.
|
||||
func (p *Progress) Reset() {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !p.running {
|
||||
panic("resetting a non-running Progress")
|
||||
}
|
||||
|
||||
p.curM.Lock()
|
||||
p.cur = Stat{}
|
||||
p.curM.Unlock()
|
||||
}
|
||||
|
||||
// Report adds the statistics from s to the current state and tries to report
|
||||
// the accumulated statistics via the feedback channel.
|
||||
func (p *Progress) Report(s Stat) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !p.running {
|
||||
panic("reporting in a non-running Progress")
|
||||
}
|
||||
|
||||
p.curM.Lock()
|
||||
p.cur.Add(s)
|
||||
cur := p.cur
|
||||
needUpdate := false
|
||||
if isTerminal && time.Since(p.lastUpdate) > minTickerTime {
|
||||
p.lastUpdate = time.Now()
|
||||
needUpdate = true
|
||||
}
|
||||
p.curM.Unlock()
|
||||
|
||||
if needUpdate {
|
||||
p.updateProgress(cur, false)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *Progress) updateProgress(cur Stat, ticker bool) {
|
||||
if p.OnUpdate == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.fnM.Lock()
|
||||
p.OnUpdate(cur, time.Since(p.start), ticker)
|
||||
p.fnM.Unlock()
|
||||
}
|
||||
|
||||
func (p *Progress) reporter() {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
updateProgress := func() {
|
||||
p.curM.Lock()
|
||||
cur := p.cur
|
||||
p.curM.Unlock()
|
||||
p.updateProgress(cur, true)
|
||||
}
|
||||
|
||||
var ticker <-chan time.Time
|
||||
if p.c != nil {
|
||||
ticker = p.c.C
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker:
|
||||
updateProgress()
|
||||
case <-forceUpdateProgress:
|
||||
updateProgress()
|
||||
case <-p.cancel:
|
||||
if p.c != nil {
|
||||
p.c.Stop()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Done closes the progress report.
|
||||
func (p *Progress) Done() {
|
||||
if p == nil || !p.running {
|
||||
return
|
||||
}
|
||||
|
||||
p.running = false
|
||||
p.once.Do(func() {
|
||||
close(p.cancel)
|
||||
})
|
||||
|
||||
cur := p.cur
|
||||
|
||||
if p.OnDone != nil {
|
||||
p.fnM.Lock()
|
||||
p.OnUpdate(cur, time.Since(p.start), false)
|
||||
p.OnDone(cur, time.Since(p.start), false)
|
||||
p.fnM.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Add accumulates other into s.
|
||||
func (s *Stat) Add(other Stat) {
|
||||
s.Bytes += other.Bytes
|
||||
s.Dirs += other.Dirs
|
||||
s.Files += other.Files
|
||||
s.Trees += other.Trees
|
||||
s.Blobs += other.Blobs
|
||||
s.Errors += other.Errors
|
||||
}
|
||||
|
||||
func (s Stat) String() string {
|
||||
b := float64(s.Bytes)
|
||||
var str string
|
||||
|
||||
switch {
|
||||
case s.Bytes > 1<<40:
|
||||
str = fmt.Sprintf("%.3f TiB", b/(1<<40))
|
||||
case s.Bytes > 1<<30:
|
||||
str = fmt.Sprintf("%.3f GiB", b/(1<<30))
|
||||
case s.Bytes > 1<<20:
|
||||
str = fmt.Sprintf("%.3f MiB", b/(1<<20))
|
||||
case s.Bytes > 1<<10:
|
||||
str = fmt.Sprintf("%.3f KiB", b/(1<<10))
|
||||
default:
|
||||
str = fmt.Sprintf("%dB", s.Bytes)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Stat(%d files, %d dirs, %v trees, %v blobs, %d errors, %v)",
|
||||
s.Files, s.Dirs, s.Trees, s.Blobs, s.Errors, str)
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
// +build !windows,!darwin,!freebsd,!netbsd,!openbsd,!dragonfly,!solaris
|
||||
|
||||
package restic
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/restic/restic/internal/debug"
|
||||
)
|
||||
|
||||
func init() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGUSR1)
|
||||
go func() {
|
||||
for s := range c {
|
||||
debug.Log("Signal received: %v\n", s)
|
||||
forceUpdateProgress <- true
|
||||
}
|
||||
}()
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
// +build darwin freebsd netbsd openbsd dragonfly
|
||||
|
||||
package restic
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/restic/restic/internal/debug"
|
||||
)
|
||||
|
||||
func init() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINFO, syscall.SIGUSR1)
|
||||
go func() {
|
||||
for s := range c {
|
||||
debug.Log("Signal received: %v\n", s)
|
||||
forceUpdateProgress <- true
|
||||
}
|
||||
}()
|
||||
}
|
99
internal/ui/progress/counter.go
Normal file
99
internal/ui/progress/counter.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package progress
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/restic/restic/internal/debug"
|
||||
)
|
||||
|
||||
// A Func is a callback for a Counter.
|
||||
//
|
||||
// The final argument is true if Counter.Done has been called,
|
||||
// which means that the current call will be the last.
|
||||
type Func func(value uint64, runtime time.Duration, final bool)
|
||||
|
||||
// A Counter tracks a running count and controls a goroutine that passes its
|
||||
// value periodically to a Func.
|
||||
//
|
||||
// The Func is also called when SIGUSR1 (or SIGINFO, on BSD) is received.
|
||||
type Counter struct {
|
||||
report Func
|
||||
start time.Time
|
||||
stopped chan struct{} // Closed by run.
|
||||
stop chan struct{} // Close to stop run.
|
||||
tick *time.Ticker
|
||||
value uint64
|
||||
}
|
||||
|
||||
// New starts a new Counter.
|
||||
func New(interval time.Duration, report Func) *Counter {
|
||||
signals.Once.Do(func() {
|
||||
signals.ch = make(chan os.Signal, 1)
|
||||
setupSignals()
|
||||
})
|
||||
|
||||
c := &Counter{
|
||||
report: report,
|
||||
start: time.Now(),
|
||||
stopped: make(chan struct{}),
|
||||
stop: make(chan struct{}),
|
||||
tick: time.NewTicker(interval),
|
||||
}
|
||||
|
||||
go c.run()
|
||||
return c
|
||||
}
|
||||
|
||||
// Add v to the Counter. This method is concurrency-safe.
|
||||
func (c *Counter) Add(v uint64) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
atomic.AddUint64(&c.value, v)
|
||||
}
|
||||
|
||||
// Done tells a Counter to stop and waits for it to report its final value.
|
||||
func (c *Counter) Done() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.tick.Stop()
|
||||
close(c.stop)
|
||||
<-c.stopped // Wait for last progress report.
|
||||
*c = Counter{} // Prevent reuse.
|
||||
}
|
||||
|
||||
func (c *Counter) get() uint64 { return atomic.LoadUint64(&c.value) }
|
||||
|
||||
func (c *Counter) run() {
|
||||
defer close(c.stopped)
|
||||
defer func() {
|
||||
// Must be a func so that time.Since isn't called at defer time.
|
||||
c.report(c.get(), time.Since(c.start), true)
|
||||
}()
|
||||
|
||||
for {
|
||||
var now time.Time
|
||||
|
||||
select {
|
||||
case now = <-c.tick.C:
|
||||
case sig := <-signals.ch:
|
||||
debug.Log("Signal received: %v\n", sig)
|
||||
now = time.Now()
|
||||
case <-c.stop:
|
||||
return
|
||||
}
|
||||
|
||||
c.report(c.get(), now.Sub(c.start), false)
|
||||
}
|
||||
}
|
||||
|
||||
// XXX The fact that signals is a single global variable means that only one
|
||||
// Counter receives each incoming signal.
|
||||
var signals struct {
|
||||
ch chan os.Signal
|
||||
sync.Once
|
||||
}
|
55
internal/ui/progress/counter_test.go
Normal file
55
internal/ui/progress/counter_test.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package progress_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/restic/restic/internal/test"
|
||||
"github.com/restic/restic/internal/ui/progress"
|
||||
)
|
||||
|
||||
func TestCounter(t *testing.T) {
|
||||
const N = 100
|
||||
|
||||
var (
|
||||
finalSeen = false
|
||||
increasing = true
|
||||
last uint64
|
||||
ncalls int
|
||||
)
|
||||
|
||||
report := func(value uint64, d time.Duration, final bool) {
|
||||
finalSeen = true
|
||||
if value < last {
|
||||
increasing = false
|
||||
}
|
||||
last = value
|
||||
ncalls++
|
||||
}
|
||||
c := progress.New(10*time.Millisecond, report)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for i := 0; i < N; i++ {
|
||||
time.Sleep(time.Millisecond)
|
||||
c.Add(1)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
c.Done()
|
||||
|
||||
test.Assert(t, finalSeen, "final call did not happen")
|
||||
test.Assert(t, increasing, "values not increasing")
|
||||
test.Equals(t, uint64(N), last)
|
||||
|
||||
t.Log("number of calls:", ncalls)
|
||||
}
|
||||
|
||||
func TestCounterNil(t *testing.T) {
|
||||
// Shouldn't panic.
|
||||
var c *progress.Counter = nil
|
||||
c.Add(1)
|
||||
c.Done()
|
||||
}
|
12
internal/ui/progress/signals_bsd.go
Normal file
12
internal/ui/progress/signals_bsd.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
// +build darwin dragonfly freebsd netbsd openbsd
|
||||
|
||||
package progress
|
||||
|
||||
import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setupSignals() {
|
||||
signal.Notify(signals.ch, syscall.SIGINFO, syscall.SIGUSR1)
|
||||
}
|
12
internal/ui/progress/signals_sysv.go
Normal file
12
internal/ui/progress/signals_sysv.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
// +build linux solaris
|
||||
|
||||
package progress
|
||||
|
||||
import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setupSignals() {
|
||||
signal.Notify(signals.ch, syscall.SIGUSR1)
|
||||
}
|
3
internal/ui/progress/signals_windows.go
Normal file
3
internal/ui/progress/signals_windows.go
Normal file
|
@ -0,0 +1,3 @@
|
|||
package progress
|
||||
|
||||
func setupSignals() {}
|
Loading…
Reference in a new issue