progress/counter: Support updating the progress bar maximum

This commit is contained in:
Michael Eischer 2020-12-05 23:57:06 +01:00 committed by Alexander Neumann
parent eda8c67616
commit 505f8a2229
3 changed files with 47 additions and 13 deletions

View file

@ -33,11 +33,14 @@ func newProgressMax(show bool, max uint64, description string) *progress.Counter
} }
interval := calculateProgressInterval() interval := calculateProgressInterval()
return progress.New(interval, func(v uint64, d time.Duration, final bool) { return progress.New(interval, max, func(v uint64, max uint64, d time.Duration, final bool) {
status := fmt.Sprintf("[%s] %s %d / %d %s", var status string
formatDuration(d), if max == 0 {
formatPercent(v, max), status = fmt.Sprintf("[%s] %d %s", formatDuration(d), v, description)
v, max, description) } else {
status = fmt.Sprintf("[%s] %s %d / %d %s",
formatDuration(d), formatPercent(v, max), v, max, description)
}
if w := stdoutTerminalWidth(); w > 0 { if w := stdoutTerminalWidth(); w > 0 {
status = shortenStatus(w, status) status = shortenStatus(w, status)

View file

@ -12,7 +12,7 @@ import (
// //
// The final argument is true if Counter.Done has been called, // The final argument is true if Counter.Done has been called,
// which means that the current call will be the last. // which means that the current call will be the last.
type Func func(value uint64, runtime time.Duration, final bool) type Func func(value uint64, total uint64, runtime time.Duration, final bool)
// A Counter tracks a running count and controls a goroutine that passes its // A Counter tracks a running count and controls a goroutine that passes its
// value periodically to a Func. // value periodically to a Func.
@ -27,16 +27,19 @@ type Counter struct {
valueMutex sync.Mutex valueMutex sync.Mutex
value uint64 value uint64
max uint64
} }
// New starts a new Counter. // New starts a new Counter.
func New(interval time.Duration, report Func) *Counter { func New(interval time.Duration, total uint64, report Func) *Counter {
c := &Counter{ c := &Counter{
report: report, report: report,
start: time.Now(), start: time.Now(),
stopped: make(chan struct{}), stopped: make(chan struct{}),
stop: make(chan struct{}), stop: make(chan struct{}),
max: total,
} }
if interval > 0 { if interval > 0 {
c.tick = time.NewTicker(interval) c.tick = time.NewTicker(interval)
} }
@ -56,6 +59,16 @@ func (c *Counter) Add(v uint64) {
c.valueMutex.Unlock() c.valueMutex.Unlock()
} }
// SetMax sets the maximum expected counter value. This method is concurrency-safe.
func (c *Counter) SetMax(max uint64) {
if c == nil {
return
}
c.valueMutex.Lock()
c.max = max
c.valueMutex.Unlock()
}
// Done tells a Counter to stop and waits for it to report its final value. // Done tells a Counter to stop and waits for it to report its final value.
func (c *Counter) Done() { func (c *Counter) Done() {
if c == nil { if c == nil {
@ -77,11 +90,19 @@ func (c *Counter) get() uint64 {
return v return v
} }
func (c *Counter) getMax() uint64 {
c.valueMutex.Lock()
max := c.max
c.valueMutex.Unlock()
return max
}
func (c *Counter) run() { func (c *Counter) run() {
defer close(c.stopped) defer close(c.stopped)
defer func() { defer func() {
// Must be a func so that time.Since isn't called at defer time. // Must be a func so that time.Since isn't called at defer time.
c.report(c.get(), time.Since(c.start), true) c.report(c.get(), c.getMax(), time.Since(c.start), true)
}() }()
var tick <-chan time.Time var tick <-chan time.Time
@ -101,6 +122,6 @@ func (c *Counter) run() {
return return
} }
c.report(c.get(), now.Sub(c.start), false) c.report(c.get(), c.getMax(), now.Sub(c.start), false)
} }
} }

View file

@ -10,23 +10,30 @@ import (
func TestCounter(t *testing.T) { func TestCounter(t *testing.T) {
const N = 100 const N = 100
const startTotal = uint64(12345)
var ( var (
finalSeen = false finalSeen = false
increasing = true increasing = true
last uint64 last uint64
lastTotal = startTotal
ncalls int ncalls int
nmaxChange int
) )
report := func(value uint64, d time.Duration, final bool) { report := func(value uint64, total uint64, d time.Duration, final bool) {
finalSeen = true finalSeen = true
if value < last { if value < last {
increasing = false increasing = false
} }
last = value last = value
if total != lastTotal {
nmaxChange++
}
lastTotal = total
ncalls++ ncalls++
} }
c := progress.New(10*time.Millisecond, report) c := progress.New(10*time.Millisecond, startTotal, report)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -35,6 +42,7 @@ func TestCounter(t *testing.T) {
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
c.Add(1) c.Add(1)
} }
c.SetMax(42)
}() }()
<-done <-done
@ -43,6 +51,8 @@ func TestCounter(t *testing.T) {
test.Assert(t, finalSeen, "final call did not happen") test.Assert(t, finalSeen, "final call did not happen")
test.Assert(t, increasing, "values not increasing") test.Assert(t, increasing, "values not increasing")
test.Equals(t, uint64(N), last) test.Equals(t, uint64(N), last)
test.Equals(t, uint64(42), lastTotal)
test.Equals(t, int(1), nmaxChange)
t.Log("number of calls:", ncalls) t.Log("number of calls:", ncalls)
} }
@ -58,14 +68,14 @@ func TestCounterNoTick(t *testing.T) {
finalSeen := false finalSeen := false
otherSeen := false otherSeen := false
report := func(value uint64, d time.Duration, final bool) { report := func(value, total uint64, d time.Duration, final bool) {
if final { if final {
finalSeen = true finalSeen = true
} else { } else {
otherSeen = true otherSeen = true
} }
} }
c := progress.New(0, report) c := progress.New(0, 1, report)
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
c.Done() c.Done()