forked from TrueCloudLab/restic
progress: extract progress updating into Updater struct
This allows reusing the code to create periodic progress updates.
This commit is contained in:
parent
52682b1c7b
commit
e499bbe3ae
6 changed files with 152 additions and 83 deletions
|
@ -37,7 +37,7 @@ func newProgressMax(show bool, max uint64, description string) *progress.Counter
|
|||
interval := calculateProgressInterval(show, false)
|
||||
canUpdateStatus := stdoutCanUpdateStatus()
|
||||
|
||||
return progress.New(interval, max, func(v uint64, max uint64, d time.Duration, final bool) {
|
||||
return progress.NewCounter(interval, max, func(v uint64, max uint64, d time.Duration, final bool) {
|
||||
var status string
|
||||
if max == 0 {
|
||||
status = fmt.Sprintf("[%s] %d %s",
|
||||
|
|
|
@ -93,7 +93,7 @@ func TestFindUsedBlobs(t *testing.T) {
|
|||
snapshots = append(snapshots, sn)
|
||||
}
|
||||
|
||||
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
|
||||
p := progress.NewCounter(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
|
||||
defer p.Done()
|
||||
|
||||
for i, sn := range snapshots {
|
||||
|
@ -142,7 +142,7 @@ func TestMultiFindUsedBlobs(t *testing.T) {
|
|||
want.Merge(loadIDSet(t, goldenFilename))
|
||||
}
|
||||
|
||||
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
|
||||
p := progress.NewCounter(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
|
||||
defer p.Done()
|
||||
|
||||
// run twice to check progress bar handling of duplicate tree roots
|
||||
|
|
|
@ -3,9 +3,6 @@ package progress
|
|||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/restic/restic/internal/debug"
|
||||
"github.com/restic/restic/internal/ui/signals"
|
||||
)
|
||||
|
||||
// A Func is a callback for a Counter.
|
||||
|
@ -19,32 +16,22 @@ type Func func(value uint64, total uint64, runtime time.Duration, final bool)
|
|||
//
|
||||
// The Func is also called when SIGUSR1 (or SIGINFO, on BSD) is received.
|
||||
type Counter struct {
|
||||
report Func
|
||||
start time.Time
|
||||
stopped chan struct{} // Closed by run.
|
||||
stop chan struct{} // Close to stop run.
|
||||
tick *time.Ticker
|
||||
Updater
|
||||
|
||||
valueMutex sync.Mutex
|
||||
value uint64
|
||||
max uint64
|
||||
}
|
||||
|
||||
// New starts a new Counter.
|
||||
func New(interval time.Duration, total uint64, report Func) *Counter {
|
||||
// NewCounter starts a new Counter.
|
||||
func NewCounter(interval time.Duration, total uint64, report Func) *Counter {
|
||||
c := &Counter{
|
||||
report: report,
|
||||
start: time.Now(),
|
||||
stopped: make(chan struct{}),
|
||||
stop: make(chan struct{}),
|
||||
max: total,
|
||||
max: total,
|
||||
}
|
||||
|
||||
if interval > 0 {
|
||||
c.tick = time.NewTicker(interval)
|
||||
}
|
||||
|
||||
go c.run()
|
||||
c.Updater = *NewUpdater(interval, func(runtime time.Duration, final bool) {
|
||||
v, max := c.Get()
|
||||
report(v, max, runtime, final)
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
|
@ -69,18 +56,6 @@ func (c *Counter) SetMax(max uint64) {
|
|||
c.valueMutex.Unlock()
|
||||
}
|
||||
|
||||
// Done tells a Counter to stop and waits for it to report its final value.
|
||||
func (c *Counter) Done() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if c.tick != nil {
|
||||
c.tick.Stop()
|
||||
}
|
||||
close(c.stop)
|
||||
<-c.stopped // Wait for last progress report.
|
||||
}
|
||||
|
||||
// Get returns the current value and the maximum of c.
|
||||
// This method is concurrency-safe.
|
||||
func (c *Counter) Get() (v, max uint64) {
|
||||
|
@ -91,32 +66,8 @@ func (c *Counter) Get() (v, max uint64) {
|
|||
return v, max
|
||||
}
|
||||
|
||||
func (c *Counter) run() {
|
||||
defer close(c.stopped)
|
||||
defer func() {
|
||||
// Must be a func so that time.Since isn't called at defer time.
|
||||
v, max := c.Get()
|
||||
c.report(v, max, time.Since(c.start), true)
|
||||
}()
|
||||
|
||||
var tick <-chan time.Time
|
||||
if c.tick != nil {
|
||||
tick = c.tick.C
|
||||
}
|
||||
signalsCh := signals.GetProgressChannel()
|
||||
for {
|
||||
var now time.Time
|
||||
|
||||
select {
|
||||
case now = <-tick:
|
||||
case sig := <-signalsCh:
|
||||
debug.Log("Signal received: %v\n", sig)
|
||||
now = time.Now()
|
||||
case <-c.stop:
|
||||
return
|
||||
}
|
||||
|
||||
v, max := c.Get()
|
||||
c.report(v, max, now.Sub(c.start), false)
|
||||
func (c *Counter) Done() {
|
||||
if c != nil {
|
||||
c.Updater.Done()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ func TestCounter(t *testing.T) {
|
|||
lastTotal = total
|
||||
ncalls++
|
||||
}
|
||||
c := progress.New(10*time.Millisecond, startTotal, report)
|
||||
c := progress.NewCounter(10*time.Millisecond, startTotal, report)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -63,24 +63,6 @@ func TestCounterNil(t *testing.T) {
|
|||
// Shouldn't panic.
|
||||
var c *progress.Counter
|
||||
c.Add(1)
|
||||
c.SetMax(42)
|
||||
c.Done()
|
||||
}
|
||||
|
||||
func TestCounterNoTick(t *testing.T) {
|
||||
finalSeen := false
|
||||
otherSeen := false
|
||||
|
||||
report := func(value, total uint64, d time.Duration, final bool) {
|
||||
if final {
|
||||
finalSeen = true
|
||||
} else {
|
||||
otherSeen = true
|
||||
}
|
||||
}
|
||||
c := progress.New(0, 1, report)
|
||||
time.Sleep(time.Millisecond)
|
||||
c.Done()
|
||||
|
||||
test.Assert(t, finalSeen, "final call did not happen")
|
||||
test.Assert(t, !otherSeen, "unexpected status update")
|
||||
}
|
||||
|
|
84
internal/ui/progress/updater.go
Normal file
84
internal/ui/progress/updater.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package progress
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/restic/restic/internal/debug"
|
||||
"github.com/restic/restic/internal/ui/signals"
|
||||
)
|
||||
|
||||
// An UpdateFunc is a callback for a (progress) Updater.
|
||||
//
|
||||
// The final argument is true if Updater.Done has been called,
|
||||
// which means that the current call will be the last.
|
||||
type UpdateFunc func(runtime time.Duration, final bool)
|
||||
|
||||
// An Updater controls a goroutine that periodically calls an UpdateFunc.
|
||||
//
|
||||
// The UpdateFunc is also called when SIGUSR1 (or SIGINFO, on BSD) is received.
|
||||
type Updater struct {
|
||||
report UpdateFunc
|
||||
start time.Time
|
||||
stopped chan struct{} // Closed by run.
|
||||
stop chan struct{} // Close to stop run.
|
||||
tick *time.Ticker
|
||||
}
|
||||
|
||||
// NewUpdater starts a new Updater.
|
||||
func NewUpdater(interval time.Duration, report UpdateFunc) *Updater {
|
||||
c := &Updater{
|
||||
report: report,
|
||||
start: time.Now(),
|
||||
stopped: make(chan struct{}),
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
|
||||
if interval > 0 {
|
||||
c.tick = time.NewTicker(interval)
|
||||
}
|
||||
|
||||
go c.run()
|
||||
return c
|
||||
}
|
||||
|
||||
// Done tells an Updater to stop and waits for it to report its final value.
|
||||
// Later calls do nothing.
|
||||
func (c *Updater) Done() {
|
||||
if c == nil || c.stop == nil {
|
||||
return
|
||||
}
|
||||
if c.tick != nil {
|
||||
c.tick.Stop()
|
||||
}
|
||||
close(c.stop)
|
||||
<-c.stopped // Wait for last progress report.
|
||||
c.stop = nil
|
||||
}
|
||||
|
||||
func (c *Updater) run() {
|
||||
defer close(c.stopped)
|
||||
defer func() {
|
||||
// Must be a func so that time.Since isn't called at defer time.
|
||||
c.report(time.Since(c.start), true)
|
||||
}()
|
||||
|
||||
var tick <-chan time.Time
|
||||
if c.tick != nil {
|
||||
tick = c.tick.C
|
||||
}
|
||||
signalsCh := signals.GetProgressChannel()
|
||||
for {
|
||||
var now time.Time
|
||||
|
||||
select {
|
||||
case now = <-tick:
|
||||
case sig := <-signalsCh:
|
||||
debug.Log("Signal received: %v\n", sig)
|
||||
now = time.Now()
|
||||
case <-c.stop:
|
||||
return
|
||||
}
|
||||
|
||||
c.report(now.Sub(c.start), false)
|
||||
}
|
||||
}
|
52
internal/ui/progress/updater_test.go
Normal file
52
internal/ui/progress/updater_test.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package progress_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/restic/restic/internal/test"
|
||||
"github.com/restic/restic/internal/ui/progress"
|
||||
)
|
||||
|
||||
func TestUpdater(t *testing.T) {
|
||||
finalSeen := false
|
||||
var ncalls int
|
||||
|
||||
report := func(d time.Duration, final bool) {
|
||||
if final {
|
||||
finalSeen = true
|
||||
}
|
||||
ncalls++
|
||||
}
|
||||
c := progress.NewUpdater(10*time.Millisecond, report)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
c.Done()
|
||||
|
||||
test.Assert(t, finalSeen, "final call did not happen")
|
||||
test.Assert(t, ncalls > 0, "no progress was reported")
|
||||
}
|
||||
|
||||
func TestUpdaterStopTwice(t *testing.T) {
|
||||
c := progress.NewUpdater(0, func(runtime time.Duration, final bool) {})
|
||||
c.Done()
|
||||
c.Done()
|
||||
}
|
||||
|
||||
func TestUpdaterNoTick(t *testing.T) {
|
||||
finalSeen := false
|
||||
otherSeen := false
|
||||
|
||||
report := func(d time.Duration, final bool) {
|
||||
if final {
|
||||
finalSeen = true
|
||||
} else {
|
||||
otherSeen = true
|
||||
}
|
||||
}
|
||||
c := progress.NewUpdater(0, report)
|
||||
time.Sleep(time.Millisecond)
|
||||
c.Done()
|
||||
|
||||
test.Assert(t, finalSeen, "final call did not happen")
|
||||
test.Assert(t, !otherSeen, "unexpected status update")
|
||||
}
|
Loading…
Reference in a new issue