Merge pull request from MichaelEischer/extract-progress-updater

ui/progress: Extract progress updater
This commit is contained in:
Michael Eischer 2023-01-14 12:07:26 +01:00 committed by GitHub
commit 41d31b1e27
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 182 additions and 148 deletions

View file

@ -483,16 +483,12 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter
}
progressReporter := backup.NewProgress(progressPrinter,
calculateProgressInterval(!gopts.Quiet, gopts.JSON))
defer progressReporter.Done()
if opts.DryRun {
repo.SetDryRun()
}
wg, wgCtx := errgroup.WithContext(ctx)
cancelCtx, cancel := context.WithCancel(wgCtx)
defer cancel()
wg.Go(func() error { progressReporter.Run(cancelCtx); return nil })
if !gopts.JSON {
progressPrinter.V("lock repository")
}
@ -590,6 +586,10 @@ func runBackup(ctx context.Context, opts BackupOptions, gopts GlobalOptions, ter
targets = []string{filename}
}
wg, wgCtx := errgroup.WithContext(ctx)
cancelCtx, cancel := context.WithCancel(wgCtx)
defer cancel()
if !opts.NoScan {
sc := archiver.NewScanner(targetFS)
sc.SelectByName = selectByNameFilter

View file

@ -37,7 +37,7 @@ func newProgressMax(show bool, max uint64, description string) *progress.Counter
interval := calculateProgressInterval(show, false)
canUpdateStatus := stdoutCanUpdateStatus()
return progress.New(interval, max, func(v uint64, max uint64, d time.Duration, final bool) {
return progress.NewCounter(interval, max, func(v uint64, max uint64, d time.Duration, final bool) {
var status string
if max == 0 {
status = fmt.Sprintf("[%s] %d %s",

View file

@ -93,7 +93,7 @@ func TestFindUsedBlobs(t *testing.T) {
snapshots = append(snapshots, sn)
}
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
p := progress.NewCounter(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
defer p.Done()
for i, sn := range snapshots {
@ -142,7 +142,7 @@ func TestMultiFindUsedBlobs(t *testing.T) {
want.Merge(loadIDSet(t, goldenFilename))
}
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
p := progress.NewCounter(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
defer p.Done()
// run twice to check progress bar handling of duplicate tree roots

View file

@ -1,13 +1,12 @@
package backup
import (
"context"
"sync"
"time"
"github.com/restic/restic/internal/archiver"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/ui/signals"
"github.com/restic/restic/internal/ui/progress"
)
// A ProgressPrinter can print various progress messages.
@ -41,10 +40,10 @@ type Summary struct {
// Progress reports progress for the `backup` command.
type Progress struct {
progress.Updater
mu sync.Mutex
interval time.Duration
start time.Time
start time.Time
scanStarted, scanFinished bool
@ -52,66 +51,37 @@ type Progress struct {
processed, total Counter
errors uint
closed chan struct{}
summary Summary
printer ProgressPrinter
}
func NewProgress(printer ProgressPrinter, interval time.Duration) *Progress {
return &Progress{
interval: interval,
start: time.Now(),
p := &Progress{
start: time.Now(),
currentFiles: make(map[string]struct{}),
closed: make(chan struct{}),
printer: printer,
printer: printer,
}
}
p.Updater = *progress.NewUpdater(interval, func(runtime time.Duration, final bool) {
if final {
p.printer.Reset()
} else {
p.mu.Lock()
defer p.mu.Unlock()
if !p.scanStarted {
return
}
// Run regularly updates the status lines. It should be called in a separate
// goroutine.
func (p *Progress) Run(ctx context.Context) {
defer close(p.closed)
// Reset status when finished
defer p.printer.Reset()
var secondsRemaining uint64
if p.scanFinished {
secs := float64(runtime / time.Second)
todo := float64(p.total.Bytes - p.processed.Bytes)
secondsRemaining = uint64(secs / float64(p.processed.Bytes) * todo)
}
var tick <-chan time.Time
if p.interval != 0 {
t := time.NewTicker(p.interval)
defer t.Stop()
tick = t.C
}
signalsCh := signals.GetProgressChannel()
for {
var now time.Time
select {
case <-ctx.Done():
return
case now = <-tick:
case <-signalsCh:
now = time.Now()
p.printer.Update(p.total, p.processed, p.errors, p.currentFiles, p.start, secondsRemaining)
}
p.mu.Lock()
if !p.scanStarted {
p.mu.Unlock()
continue
}
var secondsRemaining uint64
if p.scanFinished {
secs := float64(now.Sub(p.start) / time.Second)
todo := float64(p.total.Bytes - p.processed.Bytes)
secondsRemaining = uint64(secs / float64(p.processed.Bytes) * todo)
}
p.printer.Update(p.total, p.processed, p.errors, p.currentFiles, p.start, secondsRemaining)
p.mu.Unlock()
}
})
return p
}
// Error is the error callback function for the archiver, it prints the error and returns nil.
@ -236,6 +206,6 @@ func (p *Progress) ReportTotal(item string, s archiver.ScanStats) {
// Finish prints the finishing messages.
func (p *Progress) Finish(snapshotID restic.ID, dryrun bool) {
// wait for the status update goroutine to shut down
<-p.closed
p.Updater.Done()
p.printer.Finish(snapshotID, p.start, &p.summary, dryrun)
}

View file

@ -1,7 +1,6 @@
package backup
import (
"context"
"sync"
"testing"
"time"
@ -53,9 +52,6 @@ func TestProgress(t *testing.T) {
prnt := &mockPrinter{}
prog := NewProgress(prnt, time.Millisecond)
ctx, cancel := context.WithCancel(context.Background())
go prog.Run(ctx)
prog.StartFile("foo")
prog.CompleteBlob(1024)
@ -67,7 +63,6 @@ func TestProgress(t *testing.T) {
prog.CompleteItem("foo", nil, &node, archiver.ItemStats{}, 0)
time.Sleep(10 * time.Millisecond)
cancel()
id := restic.NewRandomID()
prog.Finish(id, false)

View file

@ -3,9 +3,6 @@ package progress
import (
"sync"
"time"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/ui/signals"
)
// A Func is a callback for a Counter.
@ -19,32 +16,22 @@ type Func func(value uint64, total uint64, runtime time.Duration, final bool)
//
// The Func is also called when SIGUSR1 (or SIGINFO, on BSD) is received.
type Counter struct {
report Func
start time.Time
stopped chan struct{} // Closed by run.
stop chan struct{} // Close to stop run.
tick *time.Ticker
Updater
valueMutex sync.Mutex
value uint64
max uint64
}
// New starts a new Counter.
func New(interval time.Duration, total uint64, report Func) *Counter {
// NewCounter starts a new Counter.
func NewCounter(interval time.Duration, total uint64, report Func) *Counter {
c := &Counter{
report: report,
start: time.Now(),
stopped: make(chan struct{}),
stop: make(chan struct{}),
max: total,
max: total,
}
if interval > 0 {
c.tick = time.NewTicker(interval)
}
go c.run()
c.Updater = *NewUpdater(interval, func(runtime time.Duration, final bool) {
v, max := c.Get()
report(v, max, runtime, final)
})
return c
}
@ -69,18 +56,6 @@ func (c *Counter) SetMax(max uint64) {
c.valueMutex.Unlock()
}
// Done tells a Counter to stop and waits for it to report its final value.
func (c *Counter) Done() {
if c == nil {
return
}
if c.tick != nil {
c.tick.Stop()
}
close(c.stop)
<-c.stopped // Wait for last progress report.
}
// Get returns the current value and the maximum of c.
// This method is concurrency-safe.
func (c *Counter) Get() (v, max uint64) {
@ -91,32 +66,8 @@ func (c *Counter) Get() (v, max uint64) {
return v, max
}
func (c *Counter) run() {
defer close(c.stopped)
defer func() {
// Must be a func so that time.Since isn't called at defer time.
v, max := c.Get()
c.report(v, max, time.Since(c.start), true)
}()
var tick <-chan time.Time
if c.tick != nil {
tick = c.tick.C
}
signalsCh := signals.GetProgressChannel()
for {
var now time.Time
select {
case now = <-tick:
case sig := <-signalsCh:
debug.Log("Signal received: %v\n", sig)
now = time.Now()
case <-c.stop:
return
}
v, max := c.Get()
c.report(v, max, now.Sub(c.start), false)
func (c *Counter) Done() {
if c != nil {
c.Updater.Done()
}
}

View file

@ -35,7 +35,7 @@ func TestCounter(t *testing.T) {
lastTotal = total
ncalls++
}
c := progress.New(10*time.Millisecond, startTotal, report)
c := progress.NewCounter(10*time.Millisecond, startTotal, report)
done := make(chan struct{})
go func() {
@ -63,24 +63,6 @@ func TestCounterNil(t *testing.T) {
// Shouldn't panic.
var c *progress.Counter
c.Add(1)
c.SetMax(42)
c.Done()
}
func TestCounterNoTick(t *testing.T) {
finalSeen := false
otherSeen := false
report := func(value, total uint64, d time.Duration, final bool) {
if final {
finalSeen = true
} else {
otherSeen = true
}
}
c := progress.New(0, 1, report)
time.Sleep(time.Millisecond)
c.Done()
test.Assert(t, finalSeen, "final call did not happen")
test.Assert(t, !otherSeen, "unexpected status update")
}

View file

@ -0,0 +1,84 @@
package progress
import (
"time"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/ui/signals"
)
// An UpdateFunc is a callback for a (progress) Updater.
//
// The final argument is true if Updater.Done has been called,
// which means that the current call will be the last.
type UpdateFunc func(runtime time.Duration, final bool)
// An Updater controls a goroutine that periodically calls an UpdateFunc.
//
// The UpdateFunc is also called when SIGUSR1 (or SIGINFO, on BSD) is received.
type Updater struct {
report UpdateFunc
start time.Time
stopped chan struct{} // Closed by run.
stop chan struct{} // Close to stop run.
tick *time.Ticker
}
// NewUpdater starts a new Updater.
func NewUpdater(interval time.Duration, report UpdateFunc) *Updater {
c := &Updater{
report: report,
start: time.Now(),
stopped: make(chan struct{}),
stop: make(chan struct{}),
}
if interval > 0 {
c.tick = time.NewTicker(interval)
}
go c.run()
return c
}
// Done tells an Updater to stop and waits for it to report its final value.
// Later calls do nothing.
func (c *Updater) Done() {
if c == nil || c.stop == nil {
return
}
if c.tick != nil {
c.tick.Stop()
}
close(c.stop)
<-c.stopped // Wait for last progress report.
c.stop = nil
}
func (c *Updater) run() {
defer close(c.stopped)
defer func() {
// Must be a func so that time.Since isn't called at defer time.
c.report(time.Since(c.start), true)
}()
var tick <-chan time.Time
if c.tick != nil {
tick = c.tick.C
}
signalsCh := signals.GetProgressChannel()
for {
var now time.Time
select {
case now = <-tick:
case sig := <-signalsCh:
debug.Log("Signal received: %v\n", sig)
now = time.Now()
case <-c.stop:
return
}
c.report(now.Sub(c.start), false)
}
}

View file

@ -0,0 +1,52 @@
package progress_test
import (
"testing"
"time"
"github.com/restic/restic/internal/test"
"github.com/restic/restic/internal/ui/progress"
)
func TestUpdater(t *testing.T) {
finalSeen := false
var ncalls int
report := func(d time.Duration, final bool) {
if final {
finalSeen = true
}
ncalls++
}
c := progress.NewUpdater(10*time.Millisecond, report)
time.Sleep(100 * time.Millisecond)
c.Done()
test.Assert(t, finalSeen, "final call did not happen")
test.Assert(t, ncalls > 0, "no progress was reported")
}
func TestUpdaterStopTwice(t *testing.T) {
c := progress.NewUpdater(0, func(runtime time.Duration, final bool) {})
c.Done()
c.Done()
}
func TestUpdaterNoTick(t *testing.T) {
finalSeen := false
otherSeen := false
report := func(d time.Duration, final bool) {
if final {
finalSeen = true
} else {
otherSeen = true
}
}
c := progress.NewUpdater(0, report)
time.Sleep(time.Millisecond)
c.Done()
test.Assert(t, finalSeen, "final call did not happen")
test.Assert(t, !otherSeen, "unexpected status update")
}