Merge pull request #3106 from MichaelEischer/parallel-tree-walk

Parallelize tree walk in prune and copy and add progress bar to check
This commit is contained in:
Alexander Neumann 2021-01-28 12:06:42 +01:00 committed by GitHub
commit 72eec8c0c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 437 additions and 322 deletions

View file

@ -0,0 +1,10 @@
Enhancement: Parallelize scan of snapshot content in copy and prune
The copy and the prune commands used to traverse the directories of
snapshots one by one to find used data. This snapshot traversal is
now parallized which can speed up this step several times.
In addition the check command now reports how many snapshots have
already been processed.
https://github.com/restic/restic/pull/3106

View file

@ -240,7 +240,11 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error {
Verbosef("check snapshots, trees and blobs\n")
errChan = make(chan error)
go chkr.Structure(gopts.ctx, errChan)
go func() {
bar := newProgressMax(!gopts.Quiet, 0, "snapshots")
defer bar.Done()
chkr.Structure(gopts.ctx, bar, errChan)
}()
for err := range errChan {
errorsFound = true

View file

@ -6,6 +6,7 @@ import (
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/restic"
"golang.org/x/sync/errgroup"
"github.com/spf13/cobra"
)
@ -103,12 +104,8 @@ func runCopy(opts CopyOptions, gopts GlobalOptions, args []string) error {
dstSnapshotByOriginal[*sn.ID()] = append(dstSnapshotByOriginal[*sn.ID()], sn)
}
cloner := &treeCloner{
srcRepo: srcRepo,
dstRepo: dstRepo,
visitedTrees: restic.NewIDSet(),
buf: nil,
}
// remember already processed trees across all snapshots
visitedTrees := restic.NewIDSet()
for sn := range FindFilteredSnapshots(ctx, srcRepo, opts.Hosts, opts.Tags, opts.Paths, args) {
Verbosef("\nsnapshot %s of %v at %s)\n", sn.ID().Str(), sn.Paths, sn.Time)
@ -133,7 +130,7 @@ func runCopy(opts CopyOptions, gopts GlobalOptions, args []string) error {
}
Verbosef(" copy started, this may take a while...\n")
if err := cloner.copyTree(ctx, *sn.Tree); err != nil {
if err := copyTree(ctx, srcRepo, dstRepo, visitedTrees, *sn.Tree); err != nil {
return err
}
debug.Log("tree copied")
@ -177,64 +174,64 @@ func similarSnapshots(sna *restic.Snapshot, snb *restic.Snapshot) bool {
return true
}
type treeCloner struct {
srcRepo restic.Repository
dstRepo restic.Repository
visitedTrees restic.IDSet
buf []byte
}
func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Repository,
visitedTrees restic.IDSet, rootTreeID restic.ID) error {
func (t *treeCloner) copyTree(ctx context.Context, treeID restic.ID) error {
// We have already processed this tree
if t.visitedTrees.Has(treeID) {
wg, ctx := errgroup.WithContext(ctx)
treeStream := restic.StreamTrees(ctx, wg, srcRepo, restic.IDs{rootTreeID}, func(treeID restic.ID) bool {
visited := visitedTrees.Has(treeID)
visitedTrees.Insert(treeID)
return visited
}, nil)
wg.Go(func() error {
// reused buffer
var buf []byte
for tree := range treeStream {
if tree.Error != nil {
return fmt.Errorf("LoadTree(%v) returned error %v", tree.ID.Str(), tree.Error)
}
// Do we already have this tree blob?
if !dstRepo.Index().Has(restic.BlobHandle{ID: tree.ID, Type: restic.TreeBlob}) {
newTreeID, err := dstRepo.SaveTree(ctx, tree.Tree)
if err != nil {
return fmt.Errorf("SaveTree(%v) returned error %v", tree.ID.Str(), err)
}
// Assurance only.
if newTreeID != tree.ID {
return fmt.Errorf("SaveTree(%v) returned unexpected id %s", tree.ID.Str(), newTreeID.Str())
}
}
// TODO: parallelize blob down/upload
for _, entry := range tree.Nodes {
// Recursion into directories is handled by StreamTrees
// Copy the blobs for this file.
for _, blobID := range entry.Content {
// Do we already have this data blob?
if dstRepo.Index().Has(restic.BlobHandle{ID: blobID, Type: restic.DataBlob}) {
continue
}
debug.Log("Copying blob %s\n", blobID.Str())
var err error
buf, err = srcRepo.LoadBlob(ctx, restic.DataBlob, blobID, buf)
if err != nil {
return fmt.Errorf("LoadBlob(%v) returned error %v", blobID, err)
}
_, _, err = dstRepo.SaveBlob(ctx, restic.DataBlob, buf, blobID, false)
if err != nil {
return fmt.Errorf("SaveBlob(%v) returned error %v", blobID, err)
}
}
}
}
return nil
}
tree, err := t.srcRepo.LoadTree(ctx, treeID)
if err != nil {
return fmt.Errorf("LoadTree(%v) returned error %v", treeID.Str(), err)
}
t.visitedTrees.Insert(treeID)
// Do we already have this tree blob?
if !t.dstRepo.Index().Has(restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}) {
newTreeID, err := t.dstRepo.SaveTree(ctx, tree)
if err != nil {
return fmt.Errorf("SaveTree(%v) returned error %v", treeID.Str(), err)
}
// Assurance only.
if newTreeID != treeID {
return fmt.Errorf("SaveTree(%v) returned unexpected id %s", treeID.Str(), newTreeID.Str())
}
}
// TODO: parellize this stuff, likely only needed inside a tree.
for _, entry := range tree.Nodes {
// If it is a directory, recurse
if entry.Type == "dir" && entry.Subtree != nil {
if err := t.copyTree(ctx, *entry.Subtree); err != nil {
return err
}
}
// Copy the blobs for this file.
for _, blobID := range entry.Content {
// Do we already have this data blob?
if t.dstRepo.Index().Has(restic.BlobHandle{ID: blobID, Type: restic.DataBlob}) {
continue
}
debug.Log("Copying blob %s\n", blobID.Str())
t.buf, err = t.srcRepo.LoadBlob(ctx, restic.DataBlob, blobID, t.buf)
if err != nil {
return fmt.Errorf("LoadBlob(%v) returned error %v", blobID, err)
}
_, _, err = t.dstRepo.SaveBlob(ctx, restic.DataBlob, t.buf, blobID, false)
if err != nil {
return fmt.Errorf("SaveBlob(%v) returned error %v", blobID, err)
}
}
}
return nil
})
return wg.Wait()
}

View file

@ -574,20 +574,14 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, ignoreSnapshots r
bar := newProgressMax(!gopts.Quiet, uint64(len(snapshotTrees)), "snapshots")
defer bar.Done()
for _, tree := range snapshotTrees {
debug.Log("process tree %v", tree)
err = restic.FindUsedBlobs(ctx, repo, tree, usedBlobs)
if err != nil {
if repo.Backend().IsNotExist(err) {
return nil, errors.Fatal("unable to load a tree from the repo: " + err.Error())
}
return nil, err
err = restic.FindUsedBlobs(ctx, repo, snapshotTrees, usedBlobs, bar)
if err != nil {
if repo.Backend().IsNotExist(err) {
return nil, errors.Fatal("unable to load a tree from the repo: " + err.Error())
}
debug.Log("processed tree %v", tree)
bar.Add(1)
return nil, err
}
return usedBlobs, nil
}

View file

@ -166,7 +166,7 @@ func statsWalkSnapshot(ctx context.Context, snapshot *restic.Snapshot, repo rest
if statsOptions.countMode == countModeRawData {
// count just the sizes of unique blobs; we don't need to walk the tree
// ourselves in this case, since a nifty function does it for us
return restic.FindUsedBlobs(ctx, repo, *snapshot.Tree, stats.blobs)
return restic.FindUsedBlobs(ctx, repo, restic.IDs{*snapshot.Tree}, stats.blobs, nil)
}
err := walker.Walk(ctx, repo, *snapshot.Tree, restic.NewIDSet(), statsWalkTree(repo, stats))

View file

@ -33,11 +33,14 @@ func newProgressMax(show bool, max uint64, description string) *progress.Counter
}
interval := calculateProgressInterval()
return progress.New(interval, func(v uint64, d time.Duration, final bool) {
status := fmt.Sprintf("[%s] %s %d / %d %s",
formatDuration(d),
formatPercent(v, max),
v, max, description)
return progress.New(interval, max, func(v uint64, max uint64, d time.Duration, final bool) {
var status string
if max == 0 {
status = fmt.Sprintf("[%s] %d %s", formatDuration(d), v, description)
} else {
status = fmt.Sprintf("[%s] %s %d / %d %s",
formatDuration(d), formatPercent(v, max), v, max, description)
}
if w := stdoutTerminalWidth(); w > 0 {
status = shortenStatus(w, status)

View file

@ -308,200 +308,27 @@ func (e TreeError) Error() string {
return fmt.Sprintf("tree %v: %v", e.ID.Str(), e.Errors)
}
type treeJob struct {
restic.ID
error
*restic.Tree
}
// loadTreeWorker loads trees from repo and sends them to out.
func loadTreeWorker(ctx context.Context, repo restic.Repository,
in <-chan restic.ID, out chan<- treeJob,
wg *sync.WaitGroup) {
defer func() {
debug.Log("exiting")
wg.Done()
}()
var (
inCh = in
outCh = out
job treeJob
)
outCh = nil
for {
select {
case <-ctx.Done():
return
case treeID, ok := <-inCh:
if !ok {
return
}
debug.Log("load tree %v", treeID)
tree, err := repo.LoadTree(ctx, treeID)
debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err)
job = treeJob{ID: treeID, error: err, Tree: tree}
outCh = out
inCh = nil
case outCh <- job:
debug.Log("sent tree %v", job.ID)
outCh = nil
inCh = in
}
}
}
// checkTreeWorker checks the trees received and sends out errors to errChan.
func (c *Checker) checkTreeWorker(ctx context.Context, in <-chan treeJob, out chan<- error, wg *sync.WaitGroup) {
defer func() {
debug.Log("exiting")
wg.Done()
}()
func (c *Checker) checkTreeWorker(ctx context.Context, trees <-chan restic.TreeItem, out chan<- error) {
for job := range trees {
debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.Error)
var (
inCh = in
outCh = out
treeError TreeError
)
var errs []error
if job.Error != nil {
errs = append(errs, job.Error)
} else {
errs = c.checkTree(job.ID, job.Tree)
}
outCh = nil
for {
if len(errs) == 0 {
continue
}
treeError := TreeError{ID: job.ID, Errors: errs}
select {
case <-ctx.Done():
debug.Log("done channel closed, exiting")
return
case job, ok := <-inCh:
if !ok {
debug.Log("input channel closed, exiting")
return
}
debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.error)
var errs []error
if job.error != nil {
errs = append(errs, job.error)
} else {
errs = c.checkTree(job.ID, job.Tree)
}
if len(errs) > 0 {
debug.Log("checked tree %v: %v errors", job.ID, len(errs))
treeError = TreeError{ID: job.ID, Errors: errs}
outCh = out
inCh = nil
}
case outCh <- treeError:
case out <- treeError:
debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors))
outCh = nil
inCh = in
}
}
}
func (c *Checker) filterTrees(ctx context.Context, backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob) {
defer func() {
debug.Log("closing output channels")
close(loaderChan)
close(out)
}()
var (
inCh = in
outCh = out
loadCh = loaderChan
job treeJob
nextTreeID restic.ID
outstandingLoadTreeJobs = 0
)
outCh = nil
loadCh = nil
for {
if loadCh == nil && len(backlog) > 0 {
// process last added ids first, that is traverse the tree in depth-first order
ln := len(backlog) - 1
nextTreeID, backlog = backlog[ln], backlog[:ln]
// use a separate flag for processed trees to ensure that check still processes trees
// even when a file references a tree blob
c.blobRefs.Lock()
h := restic.BlobHandle{ID: nextTreeID, Type: restic.TreeBlob}
blobReferenced := c.blobRefs.M.Has(h)
// noop if already referenced
c.blobRefs.M.Insert(h)
c.blobRefs.Unlock()
if blobReferenced {
continue
}
loadCh = loaderChan
}
if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 {
debug.Log("backlog is empty, all channels nil, exiting")
return
}
select {
case <-ctx.Done():
return
case loadCh <- nextTreeID:
outstandingLoadTreeJobs++
loadCh = nil
case j, ok := <-inCh:
if !ok {
debug.Log("input channel closed")
inCh = nil
in = nil
continue
}
outstandingLoadTreeJobs--
debug.Log("input job tree %v", j.ID)
if j.error != nil {
debug.Log("received job with error: %v (tree %v, ID %v)", j.error, j.Tree, j.ID)
} else if j.Tree == nil {
debug.Log("received job with nil tree pointer: %v (ID %v)", j.error, j.ID)
// send a new job with the new error instead of the old one
j = treeJob{ID: j.ID, error: errors.New("tree is nil and error is nil")}
} else {
subtrees := j.Tree.Subtrees()
debug.Log("subtrees for tree %v: %v", j.ID, subtrees)
// iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection
for i := len(subtrees) - 1; i >= 0; i-- {
id := subtrees[i]
if id.IsNull() {
// We do not need to raise this error here, it is
// checked when the tree is checked. Just make sure
// that we do not add any null IDs to the backlog.
debug.Log("tree %v has nil subtree", j.ID)
continue
}
backlog = append(backlog, id)
}
}
job = j
outCh = out
inCh = nil
case outCh <- job:
debug.Log("tree sent to check: %v", job.ID)
outCh = nil
inCh = in
}
}
}
@ -527,10 +354,9 @@ func loadSnapshotTreeIDs(ctx context.Context, repo restic.Repository) (ids resti
// Structure checks that for all snapshots all referenced data blobs and
// subtrees are available in the index. errChan is closed after all trees have
// been traversed.
func (c *Checker) Structure(ctx context.Context, errChan chan<- error) {
defer close(errChan)
func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan chan<- error) {
trees, errs := loadSnapshotTreeIDs(ctx, c.repo)
p.SetMax(uint64(len(trees)))
debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs))
for _, err := range errs {
@ -541,19 +367,26 @@ func (c *Checker) Structure(ctx context.Context, errChan chan<- error) {
}
}
treeIDChan := make(chan restic.ID)
treeJobChan1 := make(chan treeJob)
treeJobChan2 := make(chan treeJob)
wg, ctx := errgroup.WithContext(ctx)
treeStream := restic.StreamTrees(ctx, wg, c.repo, trees, func(treeID restic.ID) bool {
// blobRefs may be accessed in parallel by checkTree
c.blobRefs.Lock()
h := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}
blobReferenced := c.blobRefs.M.Has(h)
// noop if already referenced
c.blobRefs.M.Insert(h)
c.blobRefs.Unlock()
return blobReferenced
}, p)
var wg sync.WaitGroup
defer close(errChan)
for i := 0; i < defaultParallelism; i++ {
wg.Add(2)
go loadTreeWorker(ctx, c.repo, treeIDChan, treeJobChan1, &wg)
go c.checkTreeWorker(ctx, treeJobChan2, errChan, &wg)
wg.Go(func() error {
c.checkTreeWorker(ctx, treeStream, errChan)
return nil
})
}
c.filterTrees(ctx, trees, treeIDChan, treeJobChan1, treeJobChan2)
wg.Wait()
}

View file

@ -43,7 +43,9 @@ func checkPacks(chkr *checker.Checker) []error {
}
func checkStruct(chkr *checker.Checker) []error {
return collectErrors(context.TODO(), chkr.Structure)
return collectErrors(context.TODO(), func(ctx context.Context, errChan chan<- error) {
chkr.Structure(ctx, nil, errChan)
})
}
func checkData(chkr *checker.Checker) []error {

View file

@ -30,7 +30,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) {
// structure
errChan = make(chan error)
go chkr.Structure(context.TODO(), errChan)
go chkr.Structure(context.TODO(), nil, errChan)
for err := range errChan {
t.Error(err)

View file

@ -368,7 +368,7 @@ func TestIndexSave(t *testing.T) {
defer cancel()
errCh := make(chan error)
go checker.Structure(ctx, errCh)
go checker.Structure(ctx, nil, errCh)
i := 0
for err := range errCh {
t.Errorf("checker returned error: %v", err)

View file

@ -1,6 +1,12 @@
package restic
import "context"
import (
"context"
"sync"
"github.com/restic/restic/internal/ui/progress"
"golang.org/x/sync/errgroup"
)
// TreeLoader loads a tree from a repository.
type TreeLoader interface {
@ -9,31 +15,39 @@ type TreeLoader interface {
// FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data
// blobs) to the set blobs. Already seen tree blobs will not be visited again.
func FindUsedBlobs(ctx context.Context, repo TreeLoader, treeID ID, blobs BlobSet) error {
h := BlobHandle{ID: treeID, Type: TreeBlob}
if blobs.Has(h) {
return nil
}
blobs.Insert(h)
func FindUsedBlobs(ctx context.Context, repo TreeLoader, treeIDs IDs, blobs BlobSet, p *progress.Counter) error {
var lock sync.Mutex
tree, err := repo.LoadTree(ctx, treeID)
if err != nil {
return err
}
wg, ctx := errgroup.WithContext(ctx)
treeStream := StreamTrees(ctx, wg, repo, treeIDs, func(treeID ID) bool {
// locking is necessary the goroutine below concurrently adds data blobs
lock.Lock()
h := BlobHandle{ID: treeID, Type: TreeBlob}
blobReferenced := blobs.Has(h)
// noop if already referenced
blobs.Insert(h)
lock.Unlock()
return blobReferenced
}, p)
for _, node := range tree.Nodes {
switch node.Type {
case "file":
for _, blob := range node.Content {
blobs.Insert(BlobHandle{ID: blob, Type: DataBlob})
wg.Go(func() error {
for tree := range treeStream {
if tree.Error != nil {
return tree.Error
}
case "dir":
err := FindUsedBlobs(ctx, repo, *node.Subtree, blobs)
if err != nil {
return err
lock.Lock()
for _, node := range tree.Nodes {
switch node.Type {
case "file":
for _, blob := range node.Content {
blobs.Insert(BlobHandle{ID: blob, Type: DataBlob})
}
}
}
lock.Unlock()
}
}
return nil
return nil
})
return wg.Wait()
}

View file

@ -15,6 +15,8 @@ import (
"github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/repository"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/test"
"github.com/restic/restic/internal/ui/progress"
)
func loadIDSet(t testing.TB, filename string) restic.BlobSet {
@ -92,9 +94,12 @@ func TestFindUsedBlobs(t *testing.T) {
snapshots = append(snapshots, sn)
}
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
defer p.Done()
for i, sn := range snapshots {
usedBlobs := restic.NewBlobSet()
err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, usedBlobs)
err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*sn.Tree}, usedBlobs, p)
if err != nil {
t.Errorf("FindUsedBlobs returned error: %v", err)
continue
@ -105,6 +110,8 @@ func TestFindUsedBlobs(t *testing.T) {
continue
}
test.Equals(t, p.Get(), uint64(i+1))
goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i))
want := loadIDSet(t, goldenFilename)
@ -119,6 +126,40 @@ func TestFindUsedBlobs(t *testing.T) {
}
}
func TestMultiFindUsedBlobs(t *testing.T) {
repo, cleanup := repository.TestRepository(t)
defer cleanup()
var snapshotTrees restic.IDs
for i := 0; i < findTestSnapshots; i++ {
sn := restic.TestCreateSnapshot(t, repo, findTestTime.Add(time.Duration(i)*time.Second), findTestDepth, 0)
t.Logf("snapshot %v saved, tree %v", sn.ID().Str(), sn.Tree.Str())
snapshotTrees = append(snapshotTrees, *sn.Tree)
}
want := restic.NewBlobSet()
for i := range snapshotTrees {
goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i))
want.Merge(loadIDSet(t, goldenFilename))
}
p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {})
defer p.Done()
// run twice to check progress bar handling of duplicate tree roots
usedBlobs := restic.NewBlobSet()
for i := 1; i < 3; i++ {
err := restic.FindUsedBlobs(context.TODO(), repo, snapshotTrees, usedBlobs, p)
test.OK(t, err)
test.Equals(t, p.Get(), uint64(i*len(snapshotTrees)))
if !want.Equals(usedBlobs) {
t.Errorf("wrong list of blobs returned:\n missing blobs: %v\n extra blobs: %v",
want.Sub(usedBlobs), usedBlobs.Sub(want))
}
}
}
type ForbiddenRepo struct{}
func (r ForbiddenRepo) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) {
@ -133,12 +174,12 @@ func TestFindUsedBlobsSkipsSeenBlobs(t *testing.T) {
t.Logf("snapshot %v saved, tree %v", snapshot.ID().Str(), snapshot.Tree.Str())
usedBlobs := restic.NewBlobSet()
err := restic.FindUsedBlobs(context.TODO(), repo, *snapshot.Tree, usedBlobs)
err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*snapshot.Tree}, usedBlobs, nil)
if err != nil {
t.Fatalf("FindUsedBlobs returned error: %v", err)
}
err = restic.FindUsedBlobs(context.TODO(), ForbiddenRepo{}, *snapshot.Tree, usedBlobs)
err = restic.FindUsedBlobs(context.TODO(), ForbiddenRepo{}, restic.IDs{*snapshot.Tree}, usedBlobs, nil)
if err != nil {
t.Fatalf("FindUsedBlobs returned error: %v", err)
}
@ -154,7 +195,7 @@ func BenchmarkFindUsedBlobs(b *testing.B) {
for i := 0; i < b.N; i++ {
blobs := restic.NewBlobSet()
err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, blobs)
err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*sn.Tree}, blobs, nil)
if err != nil {
b.Error(err)
}

View file

@ -0,0 +1,183 @@
package restic
import (
"context"
"errors"
"sync"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/ui/progress"
"golang.org/x/sync/errgroup"
)
const streamTreeParallelism = 5
// TreeItem is used to return either an error or the tree for a tree id
type TreeItem struct {
ID
Error error
*Tree
}
type trackedTreeItem struct {
TreeItem
rootIdx int
}
type trackedID struct {
ID
rootIdx int
}
// loadTreeWorker loads trees from repo and sends them to out.
func loadTreeWorker(ctx context.Context, repo TreeLoader,
in <-chan trackedID, out chan<- trackedTreeItem) {
for treeID := range in {
tree, err := repo.LoadTree(ctx, treeID.ID)
debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err)
job := trackedTreeItem{TreeItem: TreeItem{ID: treeID.ID, Error: err, Tree: tree}, rootIdx: treeID.rootIdx}
select {
case <-ctx.Done():
return
case out <- job:
}
}
}
func filterTrees(ctx context.Context, trees IDs, loaderChan chan<- trackedID,
in <-chan trackedTreeItem, out chan<- TreeItem, skip func(tree ID) bool, p *progress.Counter) {
var (
inCh = in
outCh chan<- TreeItem
loadCh chan<- trackedID
job TreeItem
nextTreeID trackedID
outstandingLoadTreeJobs = 0
)
rootCounter := make([]int, len(trees))
backlog := make([]trackedID, 0, len(trees))
for idx, id := range trees {
backlog = append(backlog, trackedID{ID: id, rootIdx: idx})
rootCounter[idx] = 1
}
for {
if loadCh == nil && len(backlog) > 0 {
// process last added ids first, that is traverse the tree in depth-first order
ln := len(backlog) - 1
nextTreeID, backlog = backlog[ln], backlog[:ln]
if skip(nextTreeID.ID) {
rootCounter[nextTreeID.rootIdx]--
if p != nil && rootCounter[nextTreeID.rootIdx] == 0 {
p.Add(1)
}
continue
}
loadCh = loaderChan
}
if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 {
debug.Log("backlog is empty, all channels nil, exiting")
return
}
select {
case <-ctx.Done():
return
case loadCh <- nextTreeID:
outstandingLoadTreeJobs++
loadCh = nil
case j, ok := <-inCh:
if !ok {
debug.Log("input channel closed")
inCh = nil
in = nil
continue
}
outstandingLoadTreeJobs--
rootCounter[j.rootIdx]--
debug.Log("input job tree %v", j.ID)
if j.Error != nil {
debug.Log("received job with error: %v (tree %v, ID %v)", j.Error, j.Tree, j.ID)
} else if j.Tree == nil {
debug.Log("received job with nil tree pointer: %v (ID %v)", j.Error, j.ID)
// send a new job with the new error instead of the old one
j = trackedTreeItem{TreeItem: TreeItem{ID: j.ID, Error: errors.New("tree is nil and error is nil")}, rootIdx: j.rootIdx}
} else {
subtrees := j.Tree.Subtrees()
debug.Log("subtrees for tree %v: %v", j.ID, subtrees)
// iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection
for i := len(subtrees) - 1; i >= 0; i-- {
id := subtrees[i]
if id.IsNull() {
// We do not need to raise this error here, it is
// checked when the tree is checked. Just make sure
// that we do not add any null IDs to the backlog.
debug.Log("tree %v has nil subtree", j.ID)
continue
}
backlog = append(backlog, trackedID{ID: id, rootIdx: j.rootIdx})
rootCounter[j.rootIdx]++
}
}
if p != nil && rootCounter[j.rootIdx] == 0 {
p.Add(1)
}
job = j.TreeItem
outCh = out
inCh = nil
case outCh <- job:
debug.Log("tree sent to process: %v", job.ID)
outCh = nil
inCh = in
}
}
}
// StreamTrees iteratively loads the given trees and their subtrees. The skip method
// is guaranteed to always be called from the same goroutine. To shutdown the started
// goroutines, either read all items from the channel or cancel the context. Then `Wait()`
// on the errgroup until all goroutines were stopped.
func StreamTrees(ctx context.Context, wg *errgroup.Group, repo TreeLoader, trees IDs, skip func(tree ID) bool, p *progress.Counter) <-chan TreeItem {
loaderChan := make(chan trackedID)
loadedTreeChan := make(chan trackedTreeItem)
treeStream := make(chan TreeItem)
var loadTreeWg sync.WaitGroup
for i := 0; i < streamTreeParallelism; i++ {
loadTreeWg.Add(1)
wg.Go(func() error {
defer loadTreeWg.Done()
loadTreeWorker(ctx, repo, loaderChan, loadedTreeChan)
return nil
})
}
// close once all loadTreeWorkers have completed
wg.Go(func() error {
loadTreeWg.Wait()
close(loadedTreeChan)
return nil
})
wg.Go(func() error {
defer close(loaderChan)
defer close(treeStream)
filterTrees(ctx, trees, loaderChan, loadedTreeChan, treeStream, skip, p)
return nil
})
return treeStream
}

View file

@ -12,7 +12,7 @@ import (
//
// The final argument is true if Counter.Done has been called,
// which means that the current call will be the last.
type Func func(value uint64, runtime time.Duration, final bool)
type Func func(value uint64, total uint64, runtime time.Duration, final bool)
// A Counter tracks a running count and controls a goroutine that passes its
// value periodically to a Func.
@ -27,16 +27,19 @@ type Counter struct {
valueMutex sync.Mutex
value uint64
max uint64
}
// New starts a new Counter.
func New(interval time.Duration, report Func) *Counter {
func New(interval time.Duration, total uint64, report Func) *Counter {
c := &Counter{
report: report,
start: time.Now(),
stopped: make(chan struct{}),
stop: make(chan struct{}),
max: total,
}
if interval > 0 {
c.tick = time.NewTicker(interval)
}
@ -56,6 +59,16 @@ func (c *Counter) Add(v uint64) {
c.valueMutex.Unlock()
}
// SetMax sets the maximum expected counter value. This method is concurrency-safe.
func (c *Counter) SetMax(max uint64) {
if c == nil {
return
}
c.valueMutex.Lock()
c.max = max
c.valueMutex.Unlock()
}
// Done tells a Counter to stop and waits for it to report its final value.
func (c *Counter) Done() {
if c == nil {
@ -69,7 +82,8 @@ func (c *Counter) Done() {
*c = Counter{} // Prevent reuse.
}
func (c *Counter) get() uint64 {
// Get the current Counter value. This method is concurrency-safe.
func (c *Counter) Get() uint64 {
c.valueMutex.Lock()
v := c.value
c.valueMutex.Unlock()
@ -77,11 +91,19 @@ func (c *Counter) get() uint64 {
return v
}
func (c *Counter) getMax() uint64 {
c.valueMutex.Lock()
max := c.max
c.valueMutex.Unlock()
return max
}
func (c *Counter) run() {
defer close(c.stopped)
defer func() {
// Must be a func so that time.Since isn't called at defer time.
c.report(c.get(), time.Since(c.start), true)
c.report(c.Get(), c.getMax(), time.Since(c.start), true)
}()
var tick <-chan time.Time
@ -101,6 +123,6 @@ func (c *Counter) run() {
return
}
c.report(c.get(), now.Sub(c.start), false)
c.report(c.Get(), c.getMax(), now.Sub(c.start), false)
}
}

View file

@ -10,23 +10,32 @@ import (
func TestCounter(t *testing.T) {
const N = 100
const startTotal = uint64(12345)
var (
finalSeen = false
increasing = true
last uint64
lastTotal = startTotal
ncalls int
nmaxChange int
)
report := func(value uint64, d time.Duration, final bool) {
finalSeen = true
report := func(value uint64, total uint64, d time.Duration, final bool) {
if final {
finalSeen = true
}
if value < last {
increasing = false
}
last = value
if total != lastTotal {
nmaxChange++
}
lastTotal = total
ncalls++
}
c := progress.New(10*time.Millisecond, report)
c := progress.New(10*time.Millisecond, startTotal, report)
done := make(chan struct{})
go func() {
@ -35,6 +44,7 @@ func TestCounter(t *testing.T) {
time.Sleep(time.Millisecond)
c.Add(1)
}
c.SetMax(42)
}()
<-done
@ -43,6 +53,8 @@ func TestCounter(t *testing.T) {
test.Assert(t, finalSeen, "final call did not happen")
test.Assert(t, increasing, "values not increasing")
test.Equals(t, uint64(N), last)
test.Equals(t, uint64(42), lastTotal)
test.Equals(t, int(1), nmaxChange)
t.Log("number of calls:", ncalls)
}
@ -58,14 +70,14 @@ func TestCounterNoTick(t *testing.T) {
finalSeen := false
otherSeen := false
report := func(value uint64, d time.Duration, final bool) {
report := func(value, total uint64, d time.Duration, final bool) {
if final {
finalSeen = true
} else {
otherSeen = true
}
}
c := progress.New(0, report)
c := progress.New(0, 1, report)
time.Sleep(time.Millisecond)
c.Done()