diff --git a/Gopkg.lock b/Gopkg.lock index 02c706dfe..42fc2791e 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -187,6 +187,12 @@ packages = [".","google","internal","jws","jwt"] revision = "f95fa95eaa936d9d87489b15d1d18b97c1ba9c28" +[[projects]] + branch = "master" + name = "golang.org/x/sync" + packages = ["errgroup"] + revision = "fd80eb99c8f653c847d294a001bdf2a3a6f768f5" + [[projects]] branch = "master" name = "golang.org/x/sys" @@ -214,6 +220,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "f0a207197cb502238ac87ca8e07b2640c02ec380a50b036e09ef87e40e31ca2d" + inputs-digest = "a7d099b3ce195ffc37adedb05a4386be38e6158925a1c0fe579efdc20fa11f6a" solver-name = "gps-cdcl" solver-version = 1 diff --git a/cmd/restic/cmd_debug.go b/cmd/restic/cmd_debug.go index 9ac7abad3..8f25933f9 100644 --- a/cmd/restic/cmd_debug.go +++ b/cmd/restic/cmd_debug.go @@ -15,8 +15,6 @@ import ( "github.com/restic/restic/internal/pack" "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" - - "github.com/restic/restic/internal/worker" ) var cmdDebug = &cobra.Command{ @@ -52,26 +50,18 @@ func prettyPrintJSON(wr io.Writer, item interface{}) error { } func debugPrintSnapshots(repo *repository.Repository, wr io.Writer) error { - for id := range repo.List(context.TODO(), restic.SnapshotFile) { + return repo.List(context.TODO(), restic.SnapshotFile, func(id restic.ID, size int64) error { snapshot, err := restic.LoadSnapshot(context.TODO(), repo, id) if err != nil { - fmt.Fprintf(os.Stderr, "LoadSnapshot(%v): %v", id.Str(), err) - continue + return err } fmt.Fprintf(wr, "snapshot_id: %v\n", id) - err = prettyPrintJSON(wr, snapshot) - if err != nil { - return err - } - } - - return nil + return prettyPrintJSON(wr, snapshot) + }) } -const dumpPackWorkers = 10 - // Pack is the struct used in printPacks. type Pack struct { Name string `json:"name"` @@ -88,49 +78,21 @@ type Blob struct { } func printPacks(repo *repository.Repository, wr io.Writer) error { - f := func(ctx context.Context, job worker.Job) (interface{}, error) { - name := job.Data.(string) - h := restic.Handle{Type: restic.DataFile, Name: name} + return repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { + h := restic.Handle{Type: restic.DataFile, Name: id.String()} - blobInfo, err := repo.Backend().Stat(ctx, h) + blobs, err := pack.List(repo.Key(), restic.ReaderAt(repo.Backend(), h), size) if err != nil { - return nil, err + fmt.Fprintf(os.Stderr, "error for pack %v: %v\n", id.Str(), err) + return nil } - blobs, err := pack.List(repo.Key(), restic.ReaderAt(repo.Backend(), h), blobInfo.Size) - if err != nil { - return nil, err - } - - return blobs, nil - } - - jobCh := make(chan worker.Job) - resCh := make(chan worker.Job) - wp := worker.New(context.TODO(), dumpPackWorkers, f, jobCh, resCh) - - go func() { - for name := range repo.Backend().List(context.TODO(), restic.DataFile) { - jobCh <- worker.Job{Data: name} - } - close(jobCh) - }() - - for job := range resCh { - name := job.Data.(string) - - if job.Error != nil { - fmt.Fprintf(os.Stderr, "error for pack %v: %v\n", name, job.Error) - continue - } - - entries := job.Result.([]restic.Blob) p := Pack{ - Name: name, - Blobs: make([]Blob, len(entries)), + Name: id.String(), + Blobs: make([]Blob, len(blobs)), } - for i, blob := range entries { + for i, blob := range blobs { p.Blobs[i] = Blob{ Type: blob.Type, Length: blob.Length, @@ -139,16 +101,14 @@ func printPacks(repo *repository.Repository, wr io.Writer) error { } } - prettyPrintJSON(os.Stdout, p) - } - - wp.Wait() + return prettyPrintJSON(os.Stdout, p) + }) return nil } func dumpIndexes(repo restic.Repository) error { - for id := range repo.List(context.TODO(), restic.IndexFile) { + return repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error { fmt.Printf("index_id: %v\n", id) idx, err := repository.LoadIndex(context.TODO(), repo, id) @@ -156,13 +116,8 @@ func dumpIndexes(repo restic.Repository) error { return err } - err = idx.Dump(os.Stdout) - if err != nil { - return err - } - } - - return nil + return idx.Dump(os.Stdout) + }) } func runDebugDump(gopts GlobalOptions, args []string) error { diff --git a/cmd/restic/cmd_key.go b/cmd/restic/cmd_key.go index e89a69b63..7552c778d 100644 --- a/cmd/restic/cmd_key.go +++ b/cmd/restic/cmd_key.go @@ -32,11 +32,11 @@ func listKeys(ctx context.Context, s *repository.Repository) error { tab.Header = fmt.Sprintf(" %-10s %-10s %-10s %s", "ID", "User", "Host", "Created") tab.RowFormat = "%s%-10s %-10s %-10s %s" - for id := range s.List(ctx, restic.KeyFile) { + err := s.List(ctx, restic.KeyFile, func(id restic.ID, size int64) error { k, err := repository.LoadKey(ctx, s, id.String()) if err != nil { Warnf("LoadKey() failed: %v\n", err) - continue + return nil } var current string @@ -47,6 +47,10 @@ func listKeys(ctx context.Context, s *repository.Repository) error { } tab.Rows = append(tab.Rows, []interface{}{current, id.Str(), k.Username, k.Hostname, k.Created.Format(TimeFormat)}) + return nil + }) + if err != nil { + return err } return tab.Write(globalOptions.stdout) diff --git a/cmd/restic/cmd_list.go b/cmd/restic/cmd_list.go index 0a7e9ca01..431085ff5 100644 --- a/cmd/restic/cmd_list.go +++ b/cmd/restic/cmd_list.go @@ -73,9 +73,8 @@ func runList(opts GlobalOptions, args []string) error { return errors.Fatal("invalid type") } - for id := range repo.List(opts.ctx, t) { + return repo.List(opts.ctx, t, func(id restic.ID, size int64) error { Printf("%s\n", id) - } - - return nil + return nil + }) } diff --git a/cmd/restic/cmd_prune.go b/cmd/restic/cmd_prune.go index 1383d15a4..6baf7ead3 100644 --- a/cmd/restic/cmd_prune.go +++ b/cmd/restic/cmd_prune.go @@ -120,8 +120,12 @@ func pruneRepository(gopts GlobalOptions, repo restic.Repository) error { } Verbosef("counting files in repo\n") - for range repo.List(ctx, restic.DataFile) { + err = repo.List(ctx, restic.DataFile, func(restic.ID, int64) error { stats.packs++ + return nil + }) + if err != nil { + return err } Verbosef("building new index for repo\n") diff --git a/cmd/restic/cmd_rebuild_index.go b/cmd/restic/cmd_rebuild_index.go index 9480374ea..55bcfa047 100644 --- a/cmd/restic/cmd_rebuild_index.go +++ b/cmd/restic/cmd_rebuild_index.go @@ -48,8 +48,12 @@ func rebuildIndex(ctx context.Context, repo restic.Repository, ignorePacks resti Verbosef("counting files in repo\n") var packs uint64 - for range repo.List(ctx, restic.DataFile) { + err := repo.List(ctx, restic.DataFile, func(restic.ID, int64) error { packs++ + return nil + }) + if err != nil { + return err } bar := newProgressMax(!globalOptions.Quiet, packs-uint64(len(ignorePacks)), "packs") @@ -61,8 +65,12 @@ func rebuildIndex(ctx context.Context, repo restic.Repository, ignorePacks resti Verbosef("finding old index files\n") var supersedes restic.IDs - for id := range repo.List(ctx, restic.IndexFile) { + err = repo.List(ctx, restic.IndexFile, func(id restic.ID, size int64) error { supersedes = append(supersedes, id) + return nil + }) + if err != nil { + return err } id, err := idx.Save(ctx, repo, supersedes) diff --git a/cmd/restic/find.go b/cmd/restic/find.go index 8b227fa55..e48a6ab55 100644 --- a/cmd/restic/find.go +++ b/cmd/restic/find.go @@ -58,7 +58,13 @@ func FindFilteredSnapshots(ctx context.Context, repo *repository.Repository, hos return } - for _, sn := range restic.FindFilteredSnapshots(ctx, repo, host, tags, paths) { + snapshots, err := restic.FindFilteredSnapshots(ctx, repo, host, tags, paths) + if err != nil { + Warnf("could not load snapshots: %v\n", err) + return + } + + for _, sn := range snapshots { select { case <-ctx.Done(): return diff --git a/doc/100_references.rst b/doc/100_references.rst index 53a1d2eac..dc2b40ee7 100644 --- a/doc/100_references.rst +++ b/doc/100_references.rst @@ -658,10 +658,30 @@ REST Backend ************ Restic can interact with HTTP Backend that respects the following REST -API. The following values are valid for ``{type}``: ``data``, ``keys``, -``locks``, ``snapshots``, ``index``, ``config``. ``{path}`` is a path to -the repository, so that multiple different repositories can be accessed. -The default path is ``/``. +API. + +The following values are valid for ``{type}``: + + * ``data`` + * ``keys`` + * ``locks`` + * ``snapshots`` + * ``index`` + * ``config`` + +The API version is selected via the ``Accept`` HTTP header in the request. The +following values are defined: + + * ``application/vnd.x.restic.rest.v1+json`` or empty: Select API version 1 + * ``application/vnd.x.restic.rest.v2+json``: Select API version 2 + +The server will respond with the value of the highest version it supports in +the ``Content-Type`` HTTP response header for the HTTP requests which should +return JSON. Any different value for this header means API version 1. + +The placeholder ``{path}`` in this document is a path to the repository, so +that multiple different repositories can be accessed. The default path is +``/``. POST {path}?create=true ======================= @@ -701,10 +721,48 @@ saved, an HTTP error otherwise. GET {path}/{type}/ ================== -Returns a JSON array containing the names of all the blobs stored for a -given type. +API version 1 +------------- -Response format: JSON +Returns a JSON array containing the names of all the blobs stored for a given +type, example: + +.. code:: json + + [ + "245bc4c430d393f74fbe7b13325e30dbde9fb0745e50caad57c446c93d20096b", + "85b420239efa1132c41cea0065452a40ebc20c6f8e0b132a5b2f5848360973ec", + "8e2006bb5931a520f3c7009fe278d1ebb87eb72c3ff92a50c30e90f1b8cf3e60", + "e75c8c407ea31ba399ab4109f28dd18c4c68303d8d86cc275432820c42ce3649" + ] + +API version 2 +------------- + +Returns a JSON array containing an object for each file of the given type. The +objects have two keys: ``name`` for the file name, and ``size`` for the size in +bytes. + +.. code:: json + + [ + { + "name": "245bc4c430d393f74fbe7b13325e30dbde9fb0745e50caad57c446c93d20096b", + "size": 2341058 + }, + { + "name": "85b420239efa1132c41cea0065452a40ebc20c6f8e0b132a5b2f5848360973ec", + "size": 2908900 + }, + { + "name": "8e2006bb5931a520f3c7009fe278d1ebb87eb72c3ff92a50c30e90f1b8cf3e60", + "size": 3030712 + }, + { + "name": "e75c8c407ea31ba399ab4109f28dd18c4c68303d8d86cc275432820c42ce3649", + "size": 2804 + } + ] HEAD {path}/{type}/{name} ========================= diff --git a/internal/archiver/archive_reader_test.go b/internal/archiver/archive_reader_test.go index 56e5fec5f..fafc0ed1a 100644 --- a/internal/archiver/archive_reader_test.go +++ b/internal/archiver/archive_reader_test.go @@ -135,8 +135,12 @@ func (e errReader) Read([]byte) (int, error) { func countSnapshots(t testing.TB, repo restic.Repository) int { snapshots := 0 - for range repo.List(context.TODO(), restic.SnapshotFile) { + err := repo.List(context.TODO(), restic.SnapshotFile, func(id restic.ID, size int64) error { snapshots++ + return nil + }) + if err != nil { + t.Fatal(err) } return snapshots } diff --git a/internal/archiver/archiver_duplication_test.go b/internal/archiver/archiver_duplication_test.go index 783dce11c..bdcecf0c6 100644 --- a/internal/archiver/archiver_duplication_test.go +++ b/internal/archiver/archiver_duplication_test.go @@ -60,10 +60,8 @@ func forgetfulBackend() restic.Backend { return nil } - be.ListFn = func(ctx context.Context, t restic.FileType) <-chan string { - ch := make(chan string) - close(ch) - return ch + be.ListFn = func(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { + return nil } be.DeleteFn = func(ctx context.Context) error { diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index e578ab3de..6a16a36fc 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -131,9 +131,13 @@ func BenchmarkArchiveDirectory(b *testing.B) { } } -func countPacks(repo restic.Repository, t restic.FileType) (n uint) { - for range repo.Backend().List(context.TODO(), t) { +func countPacks(t testing.TB, repo restic.Repository, tpe restic.FileType) (n uint) { + err := repo.Backend().List(context.TODO(), tpe, func(restic.FileInfo) error { n++ + return nil + }) + if err != nil { + t.Fatal(err) } return n @@ -158,7 +162,7 @@ func archiveWithDedup(t testing.TB) { t.Logf("archived snapshot %v", sn.ID().Str()) // get archive stats - cnt.before.packs = countPacks(repo, restic.DataFile) + cnt.before.packs = countPacks(t, repo, restic.DataFile) cnt.before.dataBlobs = repo.Index().Count(restic.DataBlob) cnt.before.treeBlobs = repo.Index().Count(restic.TreeBlob) t.Logf("packs %v, data blobs %v, tree blobs %v", @@ -169,7 +173,7 @@ func archiveWithDedup(t testing.TB) { t.Logf("archived snapshot %v", sn2.ID().Str()) // get archive stats again - cnt.after.packs = countPacks(repo, restic.DataFile) + cnt.after.packs = countPacks(t, repo, restic.DataFile) cnt.after.dataBlobs = repo.Index().Count(restic.DataBlob) cnt.after.treeBlobs = repo.Index().Count(restic.TreeBlob) t.Logf("packs %v, data blobs %v, tree blobs %v", @@ -186,7 +190,7 @@ func archiveWithDedup(t testing.TB) { t.Logf("archived snapshot %v, parent %v", sn3.ID().Str(), sn2.ID().Str()) // get archive stats again - cnt.after2.packs = countPacks(repo, restic.DataFile) + cnt.after2.packs = countPacks(t, repo, restic.DataFile) cnt.after2.dataBlobs = repo.Index().Count(restic.DataBlob) cnt.after2.treeBlobs = repo.Index().Count(restic.TreeBlob) t.Logf("packs %v, data blobs %v, tree blobs %v", diff --git a/internal/backend/azure/azure.go b/internal/backend/azure/azure.go index b163ec992..ed401c868 100644 --- a/internal/backend/azure/azure.go +++ b/internal/backend/azure/azure.go @@ -242,7 +242,11 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, return restic.FileInfo{}, errors.Wrap(err, "blob.GetProperties") } - return restic.FileInfo{Size: int64(blob.Properties.ContentLength)}, nil + fi := restic.FileInfo{ + Size: int64(blob.Properties.ContentLength), + Name: h.Name, + } + return fi, nil } // Test returns true if a blob of the given type and name exists in the backend. @@ -271,17 +275,15 @@ func (be *Backend) Remove(ctx context.Context, h restic.Handle) error { return errors.Wrap(err, "client.RemoveObject") } -// List returns a channel that yields all names of blobs of type t. A -// goroutine is started for this. If the channel done is closed, sending -// stops. -func (be *Backend) List(ctx context.Context, t restic.FileType) <-chan string { +// List runs fn for each file in the backend which has the type t. When an +// error occurs (or fn returns an error), List stops and returns it. +func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { debug.Log("listing %v", t) - ch := make(chan string) prefix, _ := be.Basedir(t) // make sure prefix ends with a slash - if prefix[len(prefix)-1] != '/' { + if !strings.HasSuffix(prefix, "/") { prefix += "/" } @@ -290,53 +292,57 @@ func (be *Backend) List(ctx context.Context, t restic.FileType) <-chan string { Prefix: prefix, } - go func() { - defer close(ch) + for { + be.sem.GetToken() + obj, err := be.container.ListBlobs(params) + be.sem.ReleaseToken() - for { - be.sem.GetToken() - obj, err := be.container.ListBlobs(params) - be.sem.ReleaseToken() - - if err != nil { - return - } - - debug.Log("got %v objects", len(obj.Blobs)) - - for _, item := range obj.Blobs { - m := strings.TrimPrefix(item.Name, prefix) - if m == "" { - continue - } - - select { - case ch <- path.Base(m): - case <-ctx.Done(): - return - } - } - - if obj.NextMarker == "" { - break - } - params.Marker = obj.NextMarker + if err != nil { + return err } - }() - return ch + debug.Log("got %v objects", len(obj.Blobs)) + + for _, item := range obj.Blobs { + m := strings.TrimPrefix(item.Name, prefix) + if m == "" { + continue + } + + fi := restic.FileInfo{ + Name: path.Base(m), + Size: item.Properties.ContentLength, + } + + if ctx.Err() != nil { + return ctx.Err() + } + + err := fn(fi) + if err != nil { + return err + } + + if ctx.Err() != nil { + return ctx.Err() + } + + } + + if obj.NextMarker == "" { + break + } + params.Marker = obj.NextMarker + } + + return ctx.Err() } // Remove keys for a specified backend type. func (be *Backend) removeKeys(ctx context.Context, t restic.FileType) error { - for key := range be.List(ctx, restic.DataFile) { - err := be.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key}) - if err != nil { - return err - } - } - - return nil + return be.List(ctx, t, func(fi restic.FileInfo) error { + return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name}) + }) } // Delete removes all restic keys in the bucket. It will not remove the bucket itself. diff --git a/internal/backend/b2/b2.go b/internal/backend/b2/b2.go index edf9a14f3..64a333017 100644 --- a/internal/backend/b2/b2.go +++ b/internal/backend/b2/b2.go @@ -228,7 +228,7 @@ func (be *b2Backend) Stat(ctx context.Context, h restic.Handle) (bi restic.FileI debug.Log("Attrs() err %v", err) return restic.FileInfo{}, errors.Wrap(err, "Stat") } - return restic.FileInfo{Size: info.Size}, nil + return restic.FileInfo{Size: info.Size, Name: h.Name}, nil } // Test returns true if a blob of the given type and name exists in the backend. @@ -262,66 +262,76 @@ func (be *b2Backend) Remove(ctx context.Context, h restic.Handle) error { // List returns a channel that yields all names of blobs of type t. A // goroutine is started for this. If the channel done is closed, sending // stops. -func (be *b2Backend) List(ctx context.Context, t restic.FileType) <-chan string { +func (be *b2Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { debug.Log("List %v", t) - ch := make(chan string) + + prefix, _ := be.Basedir(t) + cur := &b2.Cursor{Prefix: prefix} ctx, cancel := context.WithCancel(ctx) + defer cancel() - go func() { - defer close(ch) - defer cancel() + for { + be.sem.GetToken() + objs, c, err := be.bucket.ListCurrentObjects(ctx, be.listMaxItems, cur) + be.sem.ReleaseToken() - prefix, _ := be.Basedir(t) - cur := &b2.Cursor{Prefix: prefix} - - for { - be.sem.GetToken() - objs, c, err := be.bucket.ListCurrentObjects(ctx, be.listMaxItems, cur) - be.sem.ReleaseToken() - if err != nil && err != io.EOF { - // TODO: return err to caller once err handling in List() is improved - debug.Log("List: %v", err) - return - } - debug.Log("returned %v items", len(objs)) - for _, obj := range objs { - // Skip objects returned that do not have the specified prefix. - if !strings.HasPrefix(obj.Name(), prefix) { - continue - } - - m := path.Base(obj.Name()) - if m == "" { - continue - } - - select { - case ch <- m: - case <-ctx.Done(): - return - } - } - if err == io.EOF { - return - } - cur = c + if err != nil && err != io.EOF { + debug.Log("List: %v", err) + return err } - }() - return ch + debug.Log("returned %v items", len(objs)) + for _, obj := range objs { + // Skip objects returned that do not have the specified prefix. + if !strings.HasPrefix(obj.Name(), prefix) { + continue + } + + m := path.Base(obj.Name()) + if m == "" { + continue + } + + if ctx.Err() != nil { + return ctx.Err() + } + + attrs, err := obj.Attrs(ctx) + if err != nil { + return err + } + + fi := restic.FileInfo{ + Name: m, + Size: attrs.Size, + } + + err = fn(fi) + if err != nil { + return err + } + + if ctx.Err() != nil { + return ctx.Err() + } + } + + if err == io.EOF { + return ctx.Err() + } + cur = c + } + + return ctx.Err() } // Remove keys for a specified backend type. func (be *b2Backend) removeKeys(ctx context.Context, t restic.FileType) error { debug.Log("removeKeys %v", t) - for key := range be.List(ctx, t) { - err := be.Remove(ctx, restic.Handle{Type: t, Name: key}) - if err != nil { - return err - } - } - return nil + return be.List(ctx, t, func(fi restic.FileInfo) error { + return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name}) + }) } // Delete removes all restic keys in the bucket. It will not remove the bucket itself. diff --git a/internal/backend/gs/gs.go b/internal/backend/gs/gs.go index 8d0e66d23..e88d49f45 100644 --- a/internal/backend/gs/gs.go +++ b/internal/backend/gs/gs.go @@ -333,7 +333,7 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInf return restic.FileInfo{}, errors.Wrap(err, "service.Objects.Get") } - return restic.FileInfo{Size: int64(obj.Size)}, nil + return restic.FileInfo{Size: int64(obj.Size), Name: h.Name}, nil } // Test returns true if a blob of the given type and name exists in the backend. @@ -370,69 +370,72 @@ func (be *Backend) Remove(ctx context.Context, h restic.Handle) error { return errors.Wrap(err, "client.RemoveObject") } -// List returns a channel that yields all names of blobs of type t. A -// goroutine is started for this. If the channel done is closed, sending -// stops. -func (be *Backend) List(ctx context.Context, t restic.FileType) <-chan string { +// List runs fn for each file in the backend which has the type t. When an +// error occurs (or fn returns an error), List stops and returns it. +func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { debug.Log("listing %v", t) - ch := make(chan string) prefix, _ := be.Basedir(t) // make sure prefix ends with a slash - if prefix[len(prefix)-1] != '/' { + if !strings.HasSuffix(prefix, "/") { prefix += "/" } - go func() { - defer close(ch) + ctx, cancel := context.WithCancel(ctx) + defer cancel() - listReq := be.service.Objects.List(be.bucketName).Prefix(prefix).MaxResults(int64(be.listMaxItems)) - for { - be.sem.GetToken() - obj, err := listReq.Do() - be.sem.ReleaseToken() + listReq := be.service.Objects.List(be.bucketName).Context(ctx).Prefix(prefix).MaxResults(int64(be.listMaxItems)) + for { + be.sem.GetToken() + obj, err := listReq.Do() + be.sem.ReleaseToken() - if err != nil { - fmt.Fprintf(os.Stderr, "error listing %v: %v\n", prefix, err) - return - } - - debug.Log("returned %v items", len(obj.Items)) - - for _, item := range obj.Items { - m := strings.TrimPrefix(item.Name, prefix) - if m == "" { - continue - } - - select { - case ch <- path.Base(m): - case <-ctx.Done(): - return - } - } - - if obj.NextPageToken == "" { - break - } - listReq.PageToken(obj.NextPageToken) + if err != nil { + return err } - }() - return ch + debug.Log("returned %v items", len(obj.Items)) + + for _, item := range obj.Items { + m := strings.TrimPrefix(item.Name, prefix) + if m == "" { + continue + } + + if ctx.Err() != nil { + return ctx.Err() + } + + fi := restic.FileInfo{ + Name: path.Base(m), + Size: int64(item.Size), + } + + err := fn(fi) + if err != nil { + return err + } + + if ctx.Err() != nil { + return ctx.Err() + } + } + + if obj.NextPageToken == "" { + break + } + listReq.PageToken(obj.NextPageToken) + } + + return ctx.Err() } // Remove keys for a specified backend type. func (be *Backend) removeKeys(ctx context.Context, t restic.FileType) error { - for key := range be.List(ctx, restic.DataFile) { - err := be.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key}) - if err != nil { - return err - } - } - - return nil + return be.List(ctx, t, func(fi restic.FileInfo) error { + return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name}) + }) } // Delete removes all restic keys in the bucket. It will not remove the bucket itself. diff --git a/internal/backend/local/layout_test.go b/internal/backend/local/layout_test.go index 05eba96f0..f2531a332 100644 --- a/internal/backend/local/layout_test.go +++ b/internal/backend/local/layout_test.go @@ -49,8 +49,13 @@ func TestLayout(t *testing.T) { } datafiles := make(map[string]bool) - for id := range be.List(context.TODO(), restic.DataFile) { - datafiles[id] = false + err = be.List(context.TODO(), restic.DataFile, func(fi restic.FileInfo) error { + datafiles[fi.Name] = false + return nil + }) + + if err != nil { + t.Fatalf("List() returned error %v", err) } if len(datafiles) == 0 { diff --git a/internal/backend/local/local.go b/internal/backend/local/local.go index b7afd96c1..8bf949f37 100644 --- a/internal/backend/local/local.go +++ b/internal/backend/local/local.go @@ -191,7 +191,7 @@ func (b *Local) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, err return restic.FileInfo{}, errors.Wrap(err, "Stat") } - return restic.FileInfo{Size: fi.Size()}, nil + return restic.FileInfo{Size: fi.Size(), Name: h.Name}, nil } // Test returns true if a blob of the given type and name exists in the backend. @@ -226,52 +226,48 @@ func isFile(fi os.FileInfo) bool { return fi.Mode()&(os.ModeType|os.ModeCharDevice) == 0 } -// List returns a channel that yields all names of blobs of type t. A -// goroutine is started for this. -func (b *Local) List(ctx context.Context, t restic.FileType) <-chan string { +// List runs fn for each file in the backend which has the type t. When an +// error occurs (or fn returns an error), List stops and returns it. +func (b *Local) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { debug.Log("List %v", t) - ch := make(chan string) - - go func() { - defer close(ch) - - basedir, subdirs := b.Basedir(t) - err := fs.Walk(basedir, func(path string, fi os.FileInfo, err error) error { - debug.Log("walk on %v\n", path) - if err != nil { - return err - } - - if path == basedir { - return nil - } - - if !isFile(fi) { - return nil - } - - if fi.IsDir() && !subdirs { - return filepath.SkipDir - } - - debug.Log("send %v\n", filepath.Base(path)) - - select { - case ch <- filepath.Base(path): - case <-ctx.Done(): - return nil - } - - return nil - }) - + basedir, subdirs := b.Basedir(t) + return fs.Walk(basedir, func(path string, fi os.FileInfo, err error) error { + debug.Log("walk on %v\n", path) if err != nil { - debug.Log("Walk %v", err) + return err } - }() - return ch + if path == basedir { + return nil + } + + if !isFile(fi) { + return nil + } + + if fi.IsDir() && !subdirs { + return filepath.SkipDir + } + + debug.Log("send %v\n", filepath.Base(path)) + + rfi := restic.FileInfo{ + Name: filepath.Base(path), + Size: fi.Size(), + } + + if ctx.Err() != nil { + return ctx.Err() + } + + err = fn(rfi) + if err != nil { + return err + } + + return ctx.Err() + }) } // Delete removes the repository and all files. diff --git a/internal/backend/mem/mem_backend.go b/internal/backend/mem/mem_backend.go index 9e7de51d5..576ff8140 100644 --- a/internal/backend/mem/mem_backend.go +++ b/internal/backend/mem/mem_backend.go @@ -143,7 +143,7 @@ func (be *MemoryBackend) Stat(ctx context.Context, h restic.Handle) (restic.File return restic.FileInfo{}, errNotFound } - return restic.FileInfo{Size: int64(len(e))}, nil + return restic.FileInfo{Size: int64(len(e)), Name: h.Name}, nil } // Remove deletes a file from the backend. @@ -163,34 +163,40 @@ func (be *MemoryBackend) Remove(ctx context.Context, h restic.Handle) error { } // List returns a channel which yields entries from the backend. -func (be *MemoryBackend) List(ctx context.Context, t restic.FileType) <-chan string { +func (be *MemoryBackend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { + entries := make(map[string]int64) + be.m.Lock() - defer be.m.Unlock() - - ch := make(chan string) - - var ids []string - for entry := range be.data { + for entry, buf := range be.data { if entry.Type != t { continue } - ids = append(ids, entry.Name) + + entries[entry.Name] = int64(len(buf)) + } + be.m.Unlock() + + for name, size := range entries { + fi := restic.FileInfo{ + Name: name, + Size: size, + } + + if ctx.Err() != nil { + return ctx.Err() + } + + err := fn(fi) + if err != nil { + return err + } + + if ctx.Err() != nil { + return ctx.Err() + } } - debug.Log("list %v: %v", t, ids) - - go func() { - defer close(ch) - for _, id := range ids { - select { - case ch <- id: - case <-ctx.Done(): - return - } - } - }() - - return ch + return ctx.Err() } // Location returns the location of the backend (RAM). diff --git a/internal/backend/rest/rest.go b/internal/backend/rest/rest.go index 825202f20..1f60b1f2f 100644 --- a/internal/backend/rest/rest.go +++ b/internal/backend/rest/rest.go @@ -30,6 +30,11 @@ type restBackend struct { backend.Layout } +const ( + contentTypeV1 = "application/vnd.x.restic.rest.v1" + contentTypeV2 = "application/vnd.x.restic.rest.v2" +) + // Open opens the REST backend with the given config. func Open(cfg Config, rt http.RoundTripper) (*restBackend, error) { client := &http.Client{Transport: rt} @@ -111,8 +116,15 @@ func (b *restBackend) Save(ctx context.Context, h restic.Handle, rd io.Reader) ( // make sure that client.Post() cannot close the reader by wrapping it rd = ioutil.NopCloser(rd) + req, err := http.NewRequest(http.MethodPost, b.Filename(h), rd) + if err != nil { + return errors.Wrap(err, "NewRequest") + } + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Accept", contentTypeV2) + b.sem.GetToken() - resp, err := ctxhttp.Post(ctx, b.client, b.Filename(h), "binary/octet-stream", rd) + resp, err := ctxhttp.Do(ctx, b.client, req) b.sem.ReleaseToken() if resp != nil { @@ -180,7 +192,8 @@ func (b *restBackend) Load(ctx context.Context, h restic.Handle, length int, off if length > 0 { byteRange = fmt.Sprintf("bytes=%d-%d", offset, offset+int64(length)-1) } - req.Header.Add("Range", byteRange) + req.Header.Set("Range", byteRange) + req.Header.Set("Accept", contentTypeV2) debug.Log("Load(%v) send range %v", h, byteRange) b.sem.GetToken() @@ -214,8 +227,14 @@ func (b *restBackend) Stat(ctx context.Context, h restic.Handle) (restic.FileInf return restic.FileInfo{}, err } + req, err := http.NewRequest(http.MethodHead, b.Filename(h), nil) + if err != nil { + return restic.FileInfo{}, errors.Wrap(err, "NewRequest") + } + req.Header.Set("Accept", contentTypeV2) + b.sem.GetToken() - resp, err := ctxhttp.Head(ctx, b.client, b.Filename(h)) + resp, err := ctxhttp.Do(ctx, b.client, req) b.sem.ReleaseToken() if err != nil { return restic.FileInfo{}, errors.Wrap(err, "client.Head") @@ -241,6 +260,7 @@ func (b *restBackend) Stat(ctx context.Context, h restic.Handle) (restic.FileInf bi := restic.FileInfo{ Size: resp.ContentLength, + Name: h.Name, } return bi, nil @@ -266,6 +286,8 @@ func (b *restBackend) Remove(ctx context.Context, h restic.Handle) error { if err != nil { return errors.Wrap(err, "http.NewRequest") } + req.Header.Set("Accept", contentTypeV2) + b.sem.GetToken() resp, err := ctxhttp.Do(ctx, b.client, req) b.sem.ReleaseToken() @@ -291,56 +313,105 @@ func (b *restBackend) Remove(ctx context.Context, h restic.Handle) error { return errors.Wrap(resp.Body.Close(), "Close") } -// List returns a channel that yields all names of blobs of type t. A -// goroutine is started for this. If the channel done is closed, sending -// stops. -func (b *restBackend) List(ctx context.Context, t restic.FileType) <-chan string { - ch := make(chan string) - +// List runs fn for each file in the backend which has the type t. When an +// error occurs (or fn returns an error), List stops and returns it. +func (b *restBackend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { url := b.Dirname(restic.Handle{Type: t}) if !strings.HasSuffix(url, "/") { url += "/" } + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return errors.Wrap(err, "NewRequest") + } + req.Header.Set("Accept", contentTypeV2) + b.sem.GetToken() - resp, err := ctxhttp.Get(ctx, b.client, url) + resp, err := ctxhttp.Do(ctx, b.client, req) b.sem.ReleaseToken() - if resp != nil { - defer func() { - _, _ = io.Copy(ioutil.Discard, resp.Body) - e := resp.Body.Close() - - if err == nil { - err = errors.Wrap(e, "Close") - } - }() - } - if err != nil { - close(ch) - return ch + return errors.Wrap(err, "Get") } + if resp.Header.Get("Content-Type") == contentTypeV2 { + return b.listv2(ctx, t, resp, fn) + } + + return b.listv1(ctx, t, resp, fn) +} + +// listv1 uses the REST protocol v1, where a list HTTP request (e.g. `GET +// /data/`) only returns the names of the files, so we need to issue an HTTP +// HEAD request for each file. +func (b *restBackend) listv1(ctx context.Context, t restic.FileType, resp *http.Response, fn func(restic.FileInfo) error) error { + debug.Log("parsing API v1 response") dec := json.NewDecoder(resp.Body) var list []string - if err = dec.Decode(&list); err != nil { - close(ch) - return ch + if err := dec.Decode(&list); err != nil { + return errors.Wrap(err, "Decode") } - go func() { - defer close(ch) - for _, m := range list { - select { - case ch <- m: - case <-ctx.Done(): - return - } + for _, m := range list { + fi, err := b.Stat(ctx, restic.Handle{Name: m, Type: t}) + if err != nil { + return err } - }() - return ch + if ctx.Err() != nil { + return ctx.Err() + } + + fi.Name = m + err = fn(fi) + if err != nil { + return err + } + + if ctx.Err() != nil { + return ctx.Err() + } + } + + return ctx.Err() +} + +// listv2 uses the REST protocol v2, where a list HTTP request (e.g. `GET +// /data/`) returns the names and sizes of all files. +func (b *restBackend) listv2(ctx context.Context, t restic.FileType, resp *http.Response, fn func(restic.FileInfo) error) error { + debug.Log("parsing API v2 response") + dec := json.NewDecoder(resp.Body) + + var list []struct { + Name string `json:"name"` + Size int64 `json:"size"` + } + if err := dec.Decode(&list); err != nil { + return errors.Wrap(err, "Decode") + } + + for _, item := range list { + if ctx.Err() != nil { + return ctx.Err() + } + + fi := restic.FileInfo{ + Name: item.Name, + Size: item.Size, + } + + err := fn(fi) + if err != nil { + return err + } + + if ctx.Err() != nil { + return ctx.Err() + } + } + + return ctx.Err() } // Close closes all open files. @@ -352,14 +423,9 @@ func (b *restBackend) Close() error { // Remove keys for a specified backend type. func (b *restBackend) removeKeys(ctx context.Context, t restic.FileType) error { - for key := range b.List(ctx, restic.DataFile) { - err := b.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key}) - if err != nil { - return err - } - } - - return nil + return b.List(ctx, t, func(fi restic.FileInfo) error { + return b.Remove(ctx, restic.Handle{Type: t, Name: fi.Name}) + }) } // Delete removes all data in the backend. diff --git a/internal/backend/rest/rest_int_test.go b/internal/backend/rest/rest_int_test.go new file mode 100644 index 000000000..ea4e265fd --- /dev/null +++ b/internal/backend/rest/rest_int_test.go @@ -0,0 +1,150 @@ +package rest_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strconv" + "testing" + + "github.com/restic/restic/internal/backend/rest" + "github.com/restic/restic/internal/restic" +) + +func TestListAPI(t *testing.T) { + var tests = []struct { + Name string + + ContentType string // response header + Data string // response data + Requests int + + Result []restic.FileInfo + }{ + { + Name: "content-type-unknown", + ContentType: "application/octet-stream", + Data: `[ + "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", + "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", + "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b" + ]`, + Result: []restic.FileInfo{ + {Name: "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", Size: 4386}, + {Name: "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", Size: 15214}, + {Name: "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b", Size: 33393}, + }, + Requests: 4, + }, + { + Name: "content-type-v1", + ContentType: "application/vnd.x.restic.rest.v1", + Data: `[ + "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", + "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", + "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b" + ]`, + Result: []restic.FileInfo{ + {Name: "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", Size: 4386}, + {Name: "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", Size: 15214}, + {Name: "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b", Size: 33393}, + }, + Requests: 4, + }, + { + Name: "content-type-v2", + ContentType: "application/vnd.x.restic.rest.v2", + Data: `[ + {"name": "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", "size": 1001}, + {"name": "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", "size": 1002}, + {"name": "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b", "size": 1003} + ]`, + Result: []restic.FileInfo{ + {Name: "1122e6749358b057fa1ac6b580a0fbe7a9a5fbc92e82743ee21aaf829624a985", Size: 1001}, + {Name: "3b6ec1af8d4f7099d0445b12fdb75b166ba19f789e5c48350c423dc3b3e68352", Size: 1002}, + {Name: "8271d221a60e0058e6c624f248d0080fc04f4fac07a28584a9b89d0eb69e189b", Size: 1003}, + }, + Requests: 1, + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + numRequests := 0 + srv := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + numRequests++ + t.Logf("req %v %v, accept: %v", req.Method, req.URL.Path, req.Header["Accept"]) + + var err error + switch { + case req.Method == "GET": + // list files in data/ + res.Header().Set("Content-Type", test.ContentType) + _, err = res.Write([]byte(test.Data)) + + if err != nil { + t.Fatal(err) + } + return + case req.Method == "HEAD": + // stat file in data/, use the first two bytes in the name + // of the file as the size :) + filename := req.URL.Path[6:] + len, err := strconv.ParseInt(filename[:4], 16, 64) + if err != nil { + t.Fatal(err) + } + + res.Header().Set("Content-Length", fmt.Sprintf("%d", len)) + res.WriteHeader(http.StatusOK) + return + } + + t.Errorf("unhandled request %v %v", req.Method, req.URL.Path) + })) + defer srv.Close() + + srvURL, err := url.Parse(srv.URL) + if err != nil { + t.Fatal(err) + } + + cfg := rest.Config{ + Connections: 5, + URL: srvURL, + } + + be, err := rest.Open(cfg, http.DefaultTransport) + if err != nil { + t.Fatal(err) + } + + var list []restic.FileInfo + err = be.List(context.TODO(), restic.DataFile, func(fi restic.FileInfo) error { + list = append(list, fi) + return nil + }) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(list, test.Result) { + t.Fatalf("wrong response returned, want:\n %v\ngot: %v", test.Result, list) + } + + if numRequests != test.Requests { + t.Fatalf("wrong number of HTTP requests executed, want %d, got %d", test.Requests, numRequests) + } + + defer func() { + err = be.Close() + if err != nil { + t.Fatal(err) + } + }() + }) + } +} diff --git a/internal/backend/s3/s3.go b/internal/backend/s3/s3.go index b0818e9b2..b33e76f64 100644 --- a/internal/backend/s3/s3.go +++ b/internal/backend/s3/s3.go @@ -365,7 +365,7 @@ func (be *Backend) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInf return restic.FileInfo{}, errors.Wrap(err, "Stat") } - return restic.FileInfo{Size: fi.Size}, nil + return restic.FileInfo{Size: fi.Size, Name: h.Name}, nil } // Test returns true if a blob of the given type and name exists in the backend. @@ -402,54 +402,59 @@ func (be *Backend) Remove(ctx context.Context, h restic.Handle) error { return errors.Wrap(err, "client.RemoveObject") } -// List returns a channel that yields all names of blobs of type t. A -// goroutine is started for this. If the channel done is closed, sending -// stops. -func (be *Backend) List(ctx context.Context, t restic.FileType) <-chan string { +// List runs fn for each file in the backend which has the type t. When an +// error occurs (or fn returns an error), List stops and returns it. +func (be *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { debug.Log("listing %v", t) - ch := make(chan string) prefix, recursive := be.Basedir(t) // make sure prefix ends with a slash - if prefix[len(prefix)-1] != '/' { + if !strings.HasSuffix(prefix, "/") { prefix += "/" } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // NB: unfortunately we can't protect this with be.sem.GetToken() here. // Doing so would enable a deadlock situation (gh-1399), as ListObjects() // starts its own goroutine and returns results via a channel. listresp := be.client.ListObjects(be.cfg.Bucket, prefix, recursive, ctx.Done()) - go func() { - defer close(ch) - for obj := range listresp { - m := strings.TrimPrefix(obj.Key, prefix) - if m == "" { - continue - } - - select { - case ch <- path.Base(m): - case <-ctx.Done(): - return - } + for obj := range listresp { + m := strings.TrimPrefix(obj.Key, prefix) + if m == "" { + continue } - }() - return ch + fi := restic.FileInfo{ + Name: path.Base(m), + Size: obj.Size, + } + + if ctx.Err() != nil { + return ctx.Err() + } + + err := fn(fi) + if err != nil { + return err + } + + if ctx.Err() != nil { + return ctx.Err() + } + } + + return ctx.Err() } // Remove keys for a specified backend type. func (be *Backend) removeKeys(ctx context.Context, t restic.FileType) error { - for key := range be.List(ctx, restic.DataFile) { - err := be.Remove(ctx, restic.Handle{Type: restic.DataFile, Name: key}) - if err != nil { - return err - } - } - - return nil + return be.List(ctx, restic.DataFile, func(fi restic.FileInfo) error { + return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name}) + }) } // Delete removes all restic keys in the bucket. It will not remove the bucket itself. diff --git a/internal/backend/sftp/layout_test.go b/internal/backend/sftp/layout_test.go index db1f1a870..81e5f3240 100644 --- a/internal/backend/sftp/layout_test.go +++ b/internal/backend/sftp/layout_test.go @@ -56,9 +56,10 @@ func TestLayout(t *testing.T) { } datafiles := make(map[string]bool) - for id := range be.List(context.TODO(), restic.DataFile) { - datafiles[id] = false - } + err = be.List(context.TODO(), restic.DataFile, func(fi restic.FileInfo) error { + datafiles[fi.Name] = false + return nil + }) if len(datafiles) == 0 { t.Errorf("List() returned zero data files") diff --git a/internal/backend/sftp/sftp.go b/internal/backend/sftp/sftp.go index 7dfa2951e..a0e20101a 100644 --- a/internal/backend/sftp/sftp.go +++ b/internal/backend/sftp/sftp.go @@ -376,7 +376,7 @@ func (r *SFTP) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, erro return restic.FileInfo{}, errors.Wrap(err, "Lstat") } - return restic.FileInfo{Size: fi.Size()}, nil + return restic.FileInfo{Size: fi.Size(), Name: h.Name}, nil } // Test returns true if a blob of the given type and name exists in the backend. @@ -408,47 +408,54 @@ func (r *SFTP) Remove(ctx context.Context, h restic.Handle) error { return r.c.Remove(r.Filename(h)) } -// List returns a channel that yields all names of blobs of type t. A -// goroutine is started for this. If the channel done is closed, sending -// stops. -func (r *SFTP) List(ctx context.Context, t restic.FileType) <-chan string { +// List runs fn for each file in the backend which has the type t. When an +// error occurs (or fn returns an error), List stops and returns it. +func (r *SFTP) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { debug.Log("List %v", t) - ch := make(chan string) - - go func() { - defer close(ch) - - basedir, subdirs := r.Basedir(t) - walker := r.c.Walk(basedir) - for walker.Step() { - if walker.Err() != nil { - continue - } - - if walker.Path() == basedir { - continue - } - - if walker.Stat().IsDir() && !subdirs { - walker.SkipDir() - continue - } - - if !walker.Stat().Mode().IsRegular() { - continue - } - - select { - case ch <- path.Base(walker.Path()): - case <-ctx.Done(): - return - } + basedir, subdirs := r.Basedir(t) + walker := r.c.Walk(basedir) + for walker.Step() { + if walker.Err() != nil { + return walker.Err() } - }() - return ch + if walker.Path() == basedir { + continue + } + if walker.Stat().IsDir() && !subdirs { + walker.SkipDir() + continue + } + + fi := walker.Stat() + if !fi.Mode().IsRegular() { + continue + } + + debug.Log("send %v\n", path.Base(walker.Path())) + + rfi := restic.FileInfo{ + Name: path.Base(walker.Path()), + Size: fi.Size(), + } + + if ctx.Err() != nil { + return ctx.Err() + } + + err := fn(rfi) + if err != nil { + return err + } + + if ctx.Err() != nil { + return ctx.Err() + } + } + + return ctx.Err() } var closeTimeout = 2 * time.Second diff --git a/internal/backend/swift/swift.go b/internal/backend/swift/swift.go index 48aeba600..27df0d55a 100644 --- a/internal/backend/swift/swift.go +++ b/internal/backend/swift/swift.go @@ -6,7 +6,6 @@ import ( "io" "net/http" "path" - "path/filepath" "strings" "time" @@ -203,7 +202,7 @@ func (be *beSwift) Stat(ctx context.Context, h restic.Handle) (bi restic.FileInf return restic.FileInfo{}, errors.Wrap(err, "conn.Object") } - return restic.FileInfo{Size: obj.Bytes}, nil + return restic.FileInfo{Size: obj.Bytes, Name: h.Name}, nil } // Test returns true if a blob of the given type and name exists in the backend. @@ -237,61 +236,62 @@ func (be *beSwift) Remove(ctx context.Context, h restic.Handle) error { return errors.Wrap(err, "conn.ObjectDelete") } -// List returns a channel that yields all names of blobs of type t. A -// goroutine is started for this. If the channel done is closed, sending -// stops. -func (be *beSwift) List(ctx context.Context, t restic.FileType) <-chan string { +// List runs fn for each file in the backend which has the type t. When an +// error occurs (or fn returns an error), List stops and returns it. +func (be *beSwift) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { debug.Log("listing %v", t) - ch := make(chan string) prefix, _ := be.Basedir(t) prefix += "/" - go func() { - defer close(ch) + err := be.conn.ObjectsWalk(be.container, &swift.ObjectsOpts{Prefix: prefix}, + func(opts *swift.ObjectsOpts) (interface{}, error) { + be.sem.GetToken() + newObjects, err := be.conn.Objects(be.container, opts) + be.sem.ReleaseToken() - err := be.conn.ObjectsWalk(be.container, &swift.ObjectsOpts{Prefix: prefix}, - func(opts *swift.ObjectsOpts) (interface{}, error) { - be.sem.GetToken() - newObjects, err := be.conn.ObjectNames(be.container, opts) - be.sem.ReleaseToken() + if err != nil { + return nil, errors.Wrap(err, "conn.ObjectNames") + } + for _, obj := range newObjects { + m := path.Base(strings.TrimPrefix(obj.Name, prefix)) + if m == "" { + continue + } + fi := restic.FileInfo{ + Name: m, + Size: obj.Bytes, + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + err := fn(fi) if err != nil { - return nil, errors.Wrap(err, "conn.ObjectNames") + return nil, err } - for _, obj := range newObjects { - m := filepath.Base(strings.TrimPrefix(obj, prefix)) - if m == "" { - continue - } - select { - case ch <- m: - case <-ctx.Done(): - return nil, io.EOF - } + if ctx.Err() != nil { + return nil, ctx.Err() } - return newObjects, nil - }) + } + return newObjects, nil + }) - if err != nil { - debug.Log("ObjectsWalk returned error: %v", err) - } - }() + if err != nil { + return err + } - return ch + return ctx.Err() } // Remove keys for a specified backend type. func (be *beSwift) removeKeys(ctx context.Context, t restic.FileType) error { - for key := range be.List(ctx, t) { - err := be.Remove(ctx, restic.Handle{Type: t, Name: key}) - if err != nil { - return err - } - } - - return nil + return be.List(ctx, t, func(fi restic.FileInfo) error { + return be.Remove(ctx, restic.Handle{Type: t, Name: fi.Name}) + }) } // IsNotExist returns true if the error is caused by a not existing file. diff --git a/internal/backend/test/tests.go b/internal/backend/test/tests.go index 9e88a12ba..f6fe0644d 100644 --- a/internal/backend/test/tests.go +++ b/internal/backend/test/tests.go @@ -249,17 +249,17 @@ func (s *Suite) TestList(t *testing.T) { b := s.open(t) defer s.close(t, b) - list1 := restic.NewIDSet() + list1 := make(map[restic.ID]int64) for i := 0; i < numTestFiles; i++ { - data := []byte(fmt.Sprintf("random test blob %v", i)) + data := test.Random(rand.Int(), rand.Intn(100)+55) id := restic.Hash(data) h := restic.Handle{Type: restic.DataFile, Name: id.String()} err := b.Save(context.TODO(), h, bytes.NewReader(data)) if err != nil { t.Fatal(err) } - list1.Insert(id) + list1[id] = int64(len(data)) } t.Logf("wrote %v files", len(list1)) @@ -272,7 +272,7 @@ func (s *Suite) TestList(t *testing.T) { for _, test := range tests { t.Run(fmt.Sprintf("max-%v", test.maxItems), func(t *testing.T) { - list2 := restic.NewIDSet() + list2 := make(map[restic.ID]int64) type setter interface { SetListMaxItems(int) @@ -283,19 +283,37 @@ func (s *Suite) TestList(t *testing.T) { s.SetListMaxItems(test.maxItems) } - for name := range b.List(context.TODO(), restic.DataFile) { - id, err := restic.ParseID(name) + err := b.List(context.TODO(), restic.DataFile, func(fi restic.FileInfo) error { + id, err := restic.ParseID(fi.Name) if err != nil { t.Fatal(err) } - list2.Insert(id) + list2[id] = fi.Size + return nil + }) + + if err != nil { + t.Fatalf("List returned error %v", err) } t.Logf("loaded %v IDs from backend", len(list2)) - if !list1.Equals(list2) { - t.Errorf("lists are not equal, list1 %d entries, list2 %d entries", - len(list1), len(list2)) + for id, size := range list1 { + size2, ok := list2[id] + if !ok { + t.Errorf("id %v not returned by List()", id.Str()) + } + + if size != size2 { + t.Errorf("wrong size for id %v returned: want %v, got %v", id.Str(), size, size2) + } + } + + for id := range list2 { + _, ok := list1[id] + if !ok { + t.Errorf("extra id %v returned by List()", id.Str()) + } } }) } @@ -312,6 +330,123 @@ func (s *Suite) TestList(t *testing.T) { } } +// TestListCancel tests that the context is respected and the error is returned by List. +func (s *Suite) TestListCancel(t *testing.T) { + seedRand(t) + + numTestFiles := 5 + + b := s.open(t) + defer s.close(t, b) + + testFiles := make([]restic.Handle, 0, numTestFiles) + + for i := 0; i < numTestFiles; i++ { + data := []byte(fmt.Sprintf("random test blob %v", i)) + id := restic.Hash(data) + h := restic.Handle{Type: restic.DataFile, Name: id.String()} + err := b.Save(context.TODO(), h, bytes.NewReader(data)) + if err != nil { + t.Fatal(err) + } + testFiles = append(testFiles, h) + } + + t.Run("Cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + + // pass in a cancelled context + err := b.List(ctx, restic.DataFile, func(fi restic.FileInfo) error { + t.Errorf("got FileInfo %v for cancelled context", fi) + return nil + }) + + if errors.Cause(err) != context.Canceled { + t.Fatalf("expected error not found, want %v, got %v", context.Canceled, errors.Cause(err)) + } + }) + + t.Run("First", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + i := 0 + err := b.List(ctx, restic.DataFile, func(fi restic.FileInfo) error { + i++ + // cancel the context on the first file + if i == 1 { + cancel() + } + return nil + }) + + if err != context.Canceled { + t.Fatalf("expected error not found, want %v, got %v", context.Canceled, err) + } + + if i != 1 { + t.Fatalf("wrong number of files returned by List, want %v, got %v", 1, i) + } + }) + + t.Run("Last", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + i := 0 + err := b.List(ctx, restic.DataFile, func(fi restic.FileInfo) error { + // cancel the context at the last file + i++ + if i == numTestFiles { + cancel() + } + return nil + }) + + if err != context.Canceled { + t.Fatalf("expected error not found, want %v, got %v", context.Canceled, err) + } + + if i != numTestFiles { + t.Fatalf("wrong number of files returned by List, want %v, got %v", numTestFiles, i) + } + }) + + t.Run("Timeout", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + // rather large timeout, let's try to get at least one item + timeout := time.Second + + ctxTimeout, _ := context.WithTimeout(ctx, timeout) + + i := 0 + // pass in a context with a timeout + err := b.List(ctxTimeout, restic.DataFile, func(fi restic.FileInfo) error { + i++ + + // wait until the context is cancelled + <-ctxTimeout.Done() + return nil + }) + + if err != context.DeadlineExceeded { + t.Fatalf("expected error not found, want %#v, got %#v", context.DeadlineExceeded, err) + } + + if i > 1 { + t.Fatalf("wrong number of files returned by List, want <= 1, got %v", i) + } + }) + + err := s.delayedRemove(t, b, testFiles...) + if err != nil { + t.Fatal(err) + } +} + type errorCloser struct { io.Reader l int @@ -366,8 +501,12 @@ func (s *Suite) TestSave(t *testing.T) { fi, err := b.Stat(context.TODO(), h) test.OK(t, err) + if fi.Name != h.Name { + t.Errorf("Stat() returned wrong name, want %q, got %q", h.Name, fi.Name) + } + if fi.Size != int64(len(data)) { - t.Fatalf("Stat() returned different size, want %q, got %d", len(data), fi.Size) + t.Errorf("Stat() returned different size, want %q, got %d", len(data), fi.Size) } err = b.Remove(context.TODO(), h) @@ -556,10 +695,16 @@ func delayedList(t testing.TB, b restic.Backend, tpe restic.FileType, max int, m list := restic.NewIDSet() start := time.Now() for i := 0; i < max; i++ { - for s := range b.List(context.TODO(), tpe) { - id := restic.TestParseID(s) + err := b.List(context.TODO(), tpe, func(fi restic.FileInfo) error { + id := restic.TestParseID(fi.Name) list.Insert(id) + return nil + }) + + if err != nil { + t.Fatal(err) } + if len(list) < max && time.Since(start) < maxwait { time.Sleep(500 * time.Millisecond) } diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 6f134dbb7..aec691857 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -12,6 +12,7 @@ import ( "github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/hashing" "github.com/restic/restic/internal/restic" + "golang.org/x/sync/errgroup" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/pack" @@ -192,13 +193,14 @@ func (c *Checker) Packs(ctx context.Context, errChan chan<- error) { debug.Log("listing repository packs") repoPacks := restic.NewIDSet() - for id := range c.repo.List(ctx, restic.DataFile) { - select { - case <-ctx.Done(): - return - default: - } + + err := c.repo.List(ctx, restic.DataFile, func(id restic.ID, size int64) error { repoPacks.Insert(id) + return nil + }) + + if err != nil { + errChan <- err } // orphaned: present in the repo but not in c.packs @@ -719,42 +721,58 @@ func (c *Checker) ReadData(ctx context.Context, p *restic.Progress, errChan chan p.Start() defer p.Done() - worker := func(wg *sync.WaitGroup, in <-chan restic.ID) { - defer wg.Done() - for { - var id restic.ID - var ok bool + g, ctx := errgroup.WithContext(ctx) + ch := make(chan restic.ID) + // start producer for channel ch + g.Go(func() error { + defer close(ch) + return c.repo.List(ctx, restic.DataFile, func(id restic.ID, size int64) error { select { case <-ctx.Done(): - return - case id, ok = <-in: - if !ok { - return + case ch <- id: + } + return nil + }) + }) + + // run workers + for i := 0; i < defaultParallelism; i++ { + g.Go(func() error { + for { + var id restic.ID + var ok bool + + select { + case <-ctx.Done(): + return nil + case id, ok = <-ch: + if !ok { + return nil + } + } + + err := checkPack(ctx, c.repo, id) + p.Report(restic.Stat{Blobs: 1}) + if err == nil { + continue + } + + select { + case <-ctx.Done(): + return nil + case errChan <- err: } } + }) + } - err := checkPack(ctx, c.repo, id) - p.Report(restic.Stat{Blobs: 1}) - if err == nil { - continue - } - - select { - case <-ctx.Done(): - return - case errChan <- err: - } + err := g.Wait() + if err != nil { + select { + case <-ctx.Done(): + return + case errChan <- err: } } - - ch := c.repo.List(ctx, restic.DataFile) - - var wg sync.WaitGroup - for i := 0; i < defaultParallelism; i++ { - wg.Add(1) - go worker(&wg, ch) - } - - wg.Wait() } diff --git a/internal/errors/wrap.go b/internal/errors/errors.go similarity index 56% rename from internal/errors/wrap.go rename to internal/errors/errors.go index 99d6e88ba..aa2ff6f2b 100644 --- a/internal/errors/wrap.go +++ b/internal/errors/errors.go @@ -1,11 +1,10 @@ package errors -import "github.com/pkg/errors" +import ( + "net/url" -// Cause returns the cause of an error. -func Cause(err error) error { - return errors.Cause(err) -} + "github.com/pkg/errors" +) // New creates a new error based on message. Wrapped so that this package does // not appear in the stack trace. @@ -22,3 +21,29 @@ var Wrap = errors.Wrap // Wrapf returns an error annotating err with the format specifier. If err is // nil, Wrapf returns nil. var Wrapf = errors.Wrapf + +// Cause returns the cause of an error. It will also unwrap certain errors, +// e.g. *url.Error returned by the net/http client. +func Cause(err error) error { + type Causer interface { + Cause() error + } + + for { + // unwrap *url.Error + if urlErr, ok := err.(*url.Error); ok { + err = urlErr.Err + continue + } + + // if err is a Causer, return the cause for this error. + if c, ok := err.(Causer); ok { + err = c.Cause() + continue + } + + break + } + + return err +} diff --git a/internal/fuse/file_test.go b/internal/fuse/file_test.go index 121c218f5..622b5dd80 100644 --- a/internal/fuse/file_test.go +++ b/internal/fuse/file_test.go @@ -35,11 +35,17 @@ func testRead(t testing.TB, f *file, offset, length int, data []byte) { } func firstSnapshotID(t testing.TB, repo restic.Repository) (first restic.ID) { - for id := range repo.List(context.TODO(), restic.SnapshotFile) { + err := repo.List(context.TODO(), restic.SnapshotFile, func(id restic.ID, size int64) error { if first.IsNull() { first = id } + return nil + }) + + if err != nil { + t.Fatal(err) } + return first } diff --git a/internal/fuse/snapshots_dir.go b/internal/fuse/snapshots_dir.go index 762056611..671a51ba2 100644 --- a/internal/fuse/snapshots_dir.go +++ b/internal/fuse/snapshots_dir.go @@ -227,18 +227,24 @@ func isElem(e string, list []string) bool { const minSnapshotsReloadTime = 60 * time.Second // update snapshots if repository has changed -func updateSnapshots(ctx context.Context, root *Root) { +func updateSnapshots(ctx context.Context, root *Root) error { if time.Since(root.lastCheck) < minSnapshotsReloadTime { - return + return nil + } + + snapshots, err := restic.FindFilteredSnapshots(ctx, root.repo, root.cfg.Host, root.cfg.Tags, root.cfg.Paths) + if err != nil { + return err } - snapshots := restic.FindFilteredSnapshots(ctx, root.repo, root.cfg.Host, root.cfg.Tags, root.cfg.Paths) if root.snCount != len(snapshots) { root.snCount = len(snapshots) root.repo.LoadIndex(ctx) root.snapshots = snapshots } root.lastCheck = time.Now() + + return nil } // read snapshot timestamps from the current repository-state. diff --git a/internal/index/index.go b/internal/index/index.go index 4c9ebeac3..732dc4b9c 100644 --- a/internal/index/index.go +++ b/internal/index/index.go @@ -115,13 +115,13 @@ func Load(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Ind index := newIndex() - for id := range repo.List(ctx, restic.IndexFile) { + err := repo.List(ctx, restic.IndexFile, func(id restic.ID, size int64) error { p.Report(restic.Stat{Blobs: 1}) debug.Log("Load index %v", id.Str()) idx, err := loadIndexJSON(ctx, repo, id) if err != nil { - return nil, err + return err } res := make(map[restic.ID]Pack) @@ -144,12 +144,18 @@ func Load(ctx context.Context, repo restic.Repository, p *restic.Progress) (*Ind } if err = index.AddPack(jpack.ID, 0, entries); err != nil { - return nil, err + return err } } results[id] = res index.IndexIDs.Insert(id) + + return nil + }) + + if err != nil { + return nil, err } for superID, list := range supersedes { diff --git a/internal/index/index_test.go b/internal/index/index_test.go index 28829afe9..00e9a523e 100644 --- a/internal/index/index_test.go +++ b/internal/index/index_test.go @@ -28,7 +28,7 @@ func createFilledRepo(t testing.TB, snapshots int, dup float32) (restic.Reposito } func validateIndex(t testing.TB, repo restic.Repository, idx *Index) { - for id := range repo.List(context.TODO(), restic.DataFile) { + err := repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { p, ok := idx.Packs[id] if !ok { t.Errorf("pack %v missing from index", id.Str()) @@ -37,6 +37,11 @@ func validateIndex(t testing.TB, repo restic.Repository, idx *Index) { if !p.ID.Equal(id) { t.Errorf("pack %v has invalid ID: want %v, got %v", id.Str(), id, p.ID) } + return nil + }) + + if err != nil { + t.Fatal(err) } } @@ -308,7 +313,14 @@ func TestIndexAddRemovePack(t *testing.T) { t.Fatalf("Load() returned error %v", err) } - packID := <-repo.List(context.TODO(), restic.DataFile) + var packID restic.ID + err = repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { + packID = id + return nil + }) + if err != nil { + t.Fatal(err) + } t.Logf("selected pack %v", packID.Str()) diff --git a/internal/list/list.go b/internal/list/list.go index 04916b906..ffcc08729 100644 --- a/internal/list/list.go +++ b/internal/list/list.go @@ -11,7 +11,7 @@ const listPackWorkers = 10 // Lister combines lists packs in a repo and blobs in a pack. type Lister interface { - List(context.Context, restic.FileType) <-chan restic.ID + List(context.Context, restic.FileType, func(restic.ID, int64) error) error ListPack(context.Context, restic.ID) ([]restic.Blob, int64, error) } @@ -55,17 +55,19 @@ func AllPacks(ctx context.Context, repo Lister, ignorePacks restic.IDSet, ch cha go func() { defer close(jobCh) - for id := range repo.List(ctx, restic.DataFile) { + + _ = repo.List(ctx, restic.DataFile, func(id restic.ID, size int64) error { if ignorePacks.Has(id) { - continue + return nil } select { case jobCh <- worker.Job{Data: id}: case <-ctx.Done(): - return + return ctx.Err() } - } + return nil + }) }() wp.Wait() diff --git a/internal/migrations/s3_layout.go b/internal/migrations/s3_layout.go index 3d27f0d83..12ffef0ff 100644 --- a/internal/migrations/s3_layout.go +++ b/internal/migrations/s3_layout.go @@ -59,14 +59,14 @@ func (m *S3Layout) moveFiles(ctx context.Context, be *s3.Backend, l backend.Layo fmt.Fprintf(os.Stderr, "renaming file returned error: %v\n", err) } - for name := range be.List(ctx, t) { - h := restic.Handle{Type: t, Name: name} + return be.List(ctx, t, func(fi restic.FileInfo) error { + h := restic.Handle{Type: t, Name: fi.Name} debug.Log("move %v", h) - retry(maxErrors, printErr, func() error { + return retry(maxErrors, printErr, func() error { return be.Rename(h, l) }) - } + }) return nil } diff --git a/internal/mock/backend.go b/internal/mock/backend.go index b011131c4..29543c5fe 100644 --- a/internal/mock/backend.go +++ b/internal/mock/backend.go @@ -15,7 +15,7 @@ type Backend struct { SaveFn func(ctx context.Context, h restic.Handle, rd io.Reader) error LoadFn func(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) StatFn func(ctx context.Context, h restic.Handle) (restic.FileInfo, error) - ListFn func(ctx context.Context, t restic.FileType) <-chan string + ListFn func(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error RemoveFn func(ctx context.Context, h restic.Handle) error TestFn func(ctx context.Context, h restic.Handle) (bool, error) DeleteFn func(ctx context.Context) error @@ -77,14 +77,12 @@ func (m *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, e } // List items of type t. -func (m *Backend) List(ctx context.Context, t restic.FileType) <-chan string { +func (m *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { if m.ListFn == nil { - ch := make(chan string) - close(ch) - return ch + return nil } - return m.ListFn(ctx, t) + return m.ListFn(ctx, t, fn) } // Remove data from the backend. diff --git a/internal/repository/key.go b/internal/repository/key.go index fd3ef1a1c..63c35b8e8 100644 --- a/internal/repository/key.go +++ b/internal/repository/key.go @@ -113,42 +113,48 @@ func OpenKey(ctx context.Context, s *Repository, name string, password string) ( // given password. If none could be found, ErrNoKeyFound is returned. When // maxKeys is reached, ErrMaxKeysReached is returned. When setting maxKeys to // zero, all keys in the repo are checked. -func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int) (*Key, error) { +func SearchKey(ctx context.Context, s *Repository, password string, maxKeys int) (k *Key, err error) { checked := 0 // try at most maxKeysForSearch keys in repo - for name := range s.Backend().List(ctx, restic.KeyFile) { + err = s.Backend().List(ctx, restic.KeyFile, func(fi restic.FileInfo) error { if maxKeys > 0 && checked > maxKeys { - return nil, ErrMaxKeysReached + return ErrMaxKeysReached } - _, err := restic.ParseID(name) + _, err := restic.ParseID(fi.Name) if err != nil { - debug.Log("rejecting key with invalid name: %v", name) - continue + debug.Log("rejecting key with invalid name: %v", fi.Name) + return nil } - debug.Log("trying key %q", name) - key, err := OpenKey(ctx, s, name, password) + debug.Log("trying key %q", fi.Name) + key, err := OpenKey(ctx, s, fi.Name, password) if err != nil { - debug.Log("key %v returned error %v", name, err) + debug.Log("key %v returned error %v", fi.Name, err) // ErrUnauthenticated means the password is wrong, try the next key if errors.Cause(err) == crypto.ErrUnauthenticated { - continue + return nil } - if err != nil { - debug.Log("unable to open key %v: %v\n", err) - continue - } + return err } - debug.Log("successfully opened key %v", name) - return key, nil + debug.Log("successfully opened key %v", fi.Name) + k = key + return nil + }) + + if err != nil { + return nil, err } - return nil, ErrNoKeyFound + if k == nil { + return nil, ErrNoKeyFound + } + + return k, nil } // LoadKey loads a key from the backend. diff --git a/internal/repository/parallel.go b/internal/repository/parallel.go index 5f87f94a5..154b58bfa 100644 --- a/internal/repository/parallel.go +++ b/internal/repository/parallel.go @@ -2,10 +2,10 @@ package repository import ( "context" - "sync" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/restic" + "golang.org/x/sync/errgroup" ) // ParallelWorkFunc gets one file ID to work on. If an error is returned, @@ -17,47 +17,36 @@ type ParallelWorkFunc func(ctx context.Context, id string) error type ParallelIDWorkFunc func(ctx context.Context, id restic.ID) error // FilesInParallel runs n workers of f in parallel, on the IDs that -// repo.List(t) yield. If f returns an error, the process is aborted and the +// repo.List(t) yields. If f returns an error, the process is aborted and the // first error is returned. -func FilesInParallel(ctx context.Context, repo restic.Lister, t restic.FileType, n uint, f ParallelWorkFunc) error { - wg := &sync.WaitGroup{} - ch := repo.List(ctx, t) - errors := make(chan error, n) +func FilesInParallel(ctx context.Context, repo restic.Lister, t restic.FileType, n int, f ParallelWorkFunc) error { + g, ctx := errgroup.WithContext(ctx) - for i := 0; uint(i) < n; i++ { - wg.Add(1) - go func() { - defer wg.Done() + ch := make(chan string, n) + g.Go(func() error { + defer close(ch) + return repo.List(ctx, t, func(fi restic.FileInfo) error { + select { + case <-ctx.Done(): + case ch <- fi.Name: + } + return nil + }) + }) - for { - select { - case id, ok := <-ch: - if !ok { - return - } - - err := f(ctx, id) - if err != nil { - errors <- err - return - } - case <-ctx.Done(): - return + for i := 0; i < n; i++ { + g.Go(func() error { + for name := range ch { + err := f(ctx, name) + if err != nil { + return err } } - }() + return nil + }) } - wg.Wait() - - select { - case err := <-errors: - return err - default: - break - } - - return nil + return g.Wait() } // ParallelWorkFuncParseID converts a function that takes a restic.ID to a diff --git a/internal/repository/parallel_test.go b/internal/repository/parallel_test.go index 9fa3687bb..7b4c4a583 100644 --- a/internal/repository/parallel_test.go +++ b/internal/repository/parallel_test.go @@ -74,24 +74,25 @@ var lister = testIDs{ "34dd044c228727f2226a0c9c06a3e5ceb5e30e31cb7854f8fa1cde846b395a58", } -func (tests testIDs) List(ctx context.Context, t restic.FileType) <-chan string { - ch := make(chan string) +func (tests testIDs) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) error { + for i := 0; i < 500; i++ { + for _, id := range tests { + if ctx.Err() != nil { + return ctx.Err() + } - go func() { - defer close(ch) + fi := restic.FileInfo{ + Name: id, + } - for i := 0; i < 500; i++ { - for _, id := range tests { - select { - case ch <- id: - case <-ctx.Done(): - return - } + err := fn(fi) + if err != nil { + return err } } - }() + } - return ch + return nil } func TestFilesInParallel(t *testing.T) { @@ -100,7 +101,7 @@ func TestFilesInParallel(t *testing.T) { return nil } - for n := uint(1); n < 5; n++ { + for n := 1; n < 5; n++ { err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f) rtest.OK(t, err) } @@ -109,7 +110,6 @@ func TestFilesInParallel(t *testing.T) { var errTest = errors.New("test error") func TestFilesInParallelWithError(t *testing.T) { - f := func(ctx context.Context, id string) error { time.Sleep(1 * time.Millisecond) @@ -120,8 +120,10 @@ func TestFilesInParallelWithError(t *testing.T) { return nil } - for n := uint(1); n < 5; n++ { + for n := 1; n < 5; n++ { err := repository.FilesInParallel(context.TODO(), lister, restic.DataFile, n*100, f) - rtest.Equals(t, errTest, err) + if err != errTest { + t.Fatalf("wrong error returned, want %q, got %v", errTest, err) + } } } diff --git a/internal/repository/repack_test.go b/internal/repository/repack_test.go index 2d29a589a..458362c59 100644 --- a/internal/repository/repack_test.go +++ b/internal/repository/repack_test.go @@ -16,7 +16,7 @@ func randomSize(min, max int) int { } func random(t testing.TB, length int) []byte { - rd := restic.NewRandReader(rand.New(rand.NewSource(int64(length)))) + rd := restic.NewRandReader(rand.New(rand.NewSource(rand.Int63()))) buf := make([]byte, length) _, err := io.ReadFull(rd, buf) if err != nil { @@ -74,7 +74,7 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 blobs := restic.NewBlobSet() - for id := range repo.List(context.TODO(), restic.DataFile) { + err := repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { entries, _, err := repo.ListPack(context.TODO(), id) if err != nil { t.Fatalf("error listing pack %v: %v", id, err) @@ -84,7 +84,7 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 h := restic.BlobHandle{ID: entry.ID, Type: entry.Type} if blobs.Has(h) { t.Errorf("ignoring duplicate blob %v", h) - continue + return nil } blobs.Insert(h) @@ -93,8 +93,11 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 } else { list2.Insert(restic.BlobHandle{ID: entry.ID, Type: entry.Type}) } - } + return nil + }) + if err != nil { + t.Fatal(err) } return list1, list2 @@ -102,8 +105,13 @@ func selectBlobs(t *testing.T, repo restic.Repository, p float32) (list1, list2 func listPacks(t *testing.T, repo restic.Repository) restic.IDSet { list := restic.NewIDSet() - for id := range repo.List(context.TODO(), restic.DataFile) { + err := repo.List(context.TODO(), restic.DataFile, func(id restic.ID, size int64) error { list.Insert(id) + return nil + }) + + if err != nil { + t.Fatal(err) } return list @@ -153,15 +161,15 @@ func rebuildIndex(t *testing.T, repo restic.Repository) { t.Fatal(err) } - for id := range repo.List(context.TODO(), restic.IndexFile) { + err = repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error { h := restic.Handle{ Type: restic.IndexFile, Name: id.String(), } - err = repo.Backend().Remove(context.TODO(), h) - if err != nil { - t.Fatal(err) - } + return repo.Backend().Remove(context.TODO(), h) + }) + if err != nil { + t.Fatal(err) } _, err = idx.Save(context.TODO(), repo, nil) @@ -181,6 +189,10 @@ func TestRepack(t *testing.T) { repo, cleanup := repository.TestRepository(t) defer cleanup() + seed := rand.Int63() + rand.Seed(seed) + t.Logf("rand seed is %v", seed) + createRandomBlobs(t, repo, 100, 0.7) packsBefore := listPacks(t, repo) diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 193ec1ca7..e772cd8ef 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -536,22 +536,15 @@ func (r *Repository) KeyName() string { return r.keyName } -// List returns a channel that yields all IDs of type t in the backend. -func (r *Repository) List(ctx context.Context, t restic.FileType) <-chan restic.ID { - out := make(chan restic.ID) - go func() { - defer close(out) - for strID := range r.be.List(ctx, t) { - if id, err := restic.ParseID(strID); err == nil { - select { - case out <- id: - case <-ctx.Done(): - return - } - } +// List runs fn for all files of type t in the repo. +func (r *Repository) List(ctx context.Context, t restic.FileType, fn func(restic.ID, int64) error) error { + return r.be.List(ctx, t, func(fi restic.FileInfo) error { + id, err := restic.ParseID(fi.Name) + if err != nil { + debug.Log("unable to parse %v as an ID", fi.Name) } - }() - return out + return fn(id, fi.Size) + }) } // ListPack returns the list of blobs saved in the pack id and the length of diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index a90f0959b..60c1190ce 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -369,7 +369,7 @@ func TestRepositoryIncrementalIndex(t *testing.T) { packEntries := make(map[restic.ID]map[restic.ID]struct{}) - for id := range repo.List(context.TODO(), restic.IndexFile) { + err := repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error { idx, err := repository.LoadIndex(context.TODO(), repo, id) rtest.OK(t, err) @@ -380,6 +380,10 @@ func TestRepositoryIncrementalIndex(t *testing.T) { packEntries[pb.PackID][id] = struct{}{} } + return nil + }) + if err != nil { + t.Fatal(err) } for packID, ids := range packEntries { diff --git a/internal/restic/backend.go b/internal/restic/backend.go index 9174b5522..51198215f 100644 --- a/internal/restic/backend.go +++ b/internal/restic/backend.go @@ -32,10 +32,12 @@ type Backend interface { // Stat returns information about the File identified by h. Stat(ctx context.Context, h Handle) (FileInfo, error) - // List returns a channel that yields all names of files of type t in an - // arbitrary order. A goroutine is started for this, which is stopped when - // ctx is cancelled. - List(ctx context.Context, t FileType) <-chan string + // List runs fn for each file in the backend which has the type t. When an + // error occurs (or fn returns an error), List stops and returns it. + // + // The function fn is called in the same Goroutine that List() is called + // from. + List(ctx context.Context, t FileType, fn func(FileInfo) error) error // IsNotExist returns true if the error was caused by a non-existing file // in the backend. @@ -45,6 +47,8 @@ type Backend interface { Delete(ctx context.Context) error } -// FileInfo is returned by Stat() and contains information about a file in the -// backend. -type FileInfo struct{ Size int64 } +// FileInfo is contains information about a file in the backend. +type FileInfo struct { + Size int64 + Name string +} diff --git a/internal/restic/backend_find.go b/internal/restic/backend_find.go index 02521a25d..722a42dd2 100644 --- a/internal/restic/backend_find.go +++ b/internal/restic/backend_find.go @@ -20,15 +20,23 @@ var ErrMultipleIDMatches = errors.New("multiple IDs with prefix found") func Find(be Lister, t FileType, prefix string) (string, error) { match := "" - // TODO: optimize by sorting list etc. - for name := range be.List(context.TODO(), t) { - if prefix == name[:len(prefix)] { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + err := be.List(ctx, t, func(fi FileInfo) error { + if prefix == fi.Name[:len(prefix)] { if match == "" { - match = name + match = fi.Name } else { - return "", ErrMultipleIDMatches + return ErrMultipleIDMatches } } + + return nil + }) + + if err != nil { + return "", err } if match != "" { @@ -45,8 +53,17 @@ const minPrefixLength = 8 func PrefixLength(be Lister, t FileType) (int, error) { // load all IDs of the given type list := make([]string, 0, 100) - for name := range be.List(context.TODO(), t) { - list = append(list, name) + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + err := be.List(ctx, t, func(fi FileInfo) error { + list = append(list, fi.Name) + return nil + }) + + if err != nil { + return 0, err } // select prefixes of length l, test if the last one is the same as the current one diff --git a/internal/restic/backend_find_test.go b/internal/restic/backend_find_test.go index 032c8a9d9..2cec35b1f 100644 --- a/internal/restic/backend_find_test.go +++ b/internal/restic/backend_find_test.go @@ -6,11 +6,11 @@ import ( ) type mockBackend struct { - list func(context.Context, FileType) <-chan string + list func(context.Context, FileType, func(FileInfo) error) error } -func (m mockBackend) List(ctx context.Context, t FileType) <-chan string { - return m.list(ctx, t) +func (m mockBackend) List(ctx context.Context, t FileType, fn func(FileInfo) error) error { + return m.list(ctx, t, fn) } var samples = IDs{ @@ -28,19 +28,14 @@ func TestPrefixLength(t *testing.T) { list := samples m := mockBackend{} - m.list = func(ctx context.Context, t FileType) <-chan string { - ch := make(chan string) - go func() { - defer close(ch) - for _, id := range list { - select { - case ch <- id.String(): - case <-ctx.Done(): - return - } + m.list = func(ctx context.Context, t FileType, fn func(FileInfo) error) error { + for _, id := range list { + err := fn(FileInfo{Name: id.String()}) + if err != nil { + return err } - }() - return ch + } + return nil } l, err := PrefixLength(m, SnapshotFile) diff --git a/internal/restic/lock.go b/internal/restic/lock.go index 177b0d707..882f970f3 100644 --- a/internal/restic/lock.go +++ b/internal/restic/lock.go @@ -157,15 +157,14 @@ func (l *Lock) checkForOtherLocks(ctx context.Context) error { } func eachLock(ctx context.Context, repo Repository, f func(ID, *Lock, error) error) error { - for id := range repo.List(ctx, LockFile) { + return repo.List(ctx, LockFile, func(id ID, size int64) error { lock, err := LoadLock(ctx, repo, id) - err = f(id, lock, err) if err != nil { return err } - } - return nil + return f(id, lock, err) + }) } // createLock acquires the lock by creating a file in the repository. diff --git a/internal/restic/lock_test.go b/internal/restic/lock_test.go index a3b4936c9..daadd479f 100644 --- a/internal/restic/lock_test.go +++ b/internal/restic/lock_test.go @@ -227,21 +227,29 @@ func TestLockRefresh(t *testing.T) { rtest.OK(t, err) var lockID *restic.ID - for id := range repo.List(context.TODO(), restic.LockFile) { + err = repo.List(context.TODO(), restic.LockFile, func(id restic.ID, size int64) error { if lockID != nil { t.Error("more than one lock found") } lockID = &id + return nil + }) + if err != nil { + t.Fatal(err) } rtest.OK(t, lock.Refresh(context.TODO())) var lockID2 *restic.ID - for id := range repo.List(context.TODO(), restic.LockFile) { + err = repo.List(context.TODO(), restic.LockFile, func(id restic.ID, size int64) error { if lockID2 != nil { t.Error("more than one lock found") } lockID2 = &id + return nil + }) + if err != nil { + t.Fatal(err) } rtest.Assert(t, !lockID.Equal(*lockID2), diff --git a/internal/restic/repository.go b/internal/restic/repository.go index 2daae41b2..1ab4a2766 100644 --- a/internal/restic/repository.go +++ b/internal/restic/repository.go @@ -26,7 +26,12 @@ type Repository interface { LookupBlobSize(ID, BlobType) (uint, error) - List(context.Context, FileType) <-chan ID + // List calls the function fn for each file of type t in the repository. + // When an error is returned by fn, processing stops and List() returns the + // error. + // + // The function fn is called in the same Goroutine List() was called from. + List(ctx context.Context, t FileType, fn func(ID, int64) error) error ListPack(context.Context, ID) ([]Blob, int64, error) Flush(context.Context) error @@ -46,7 +51,7 @@ type Repository interface { // Lister allows listing files in a backend. type Lister interface { - List(context.Context, FileType) <-chan string + List(context.Context, FileType, func(FileInfo) error) error } // Index keeps track of the blobs are stored within files. diff --git a/internal/restic/snapshot.go b/internal/restic/snapshot.go index 47b123240..4622bb530 100644 --- a/internal/restic/snapshot.go +++ b/internal/restic/snapshot.go @@ -64,15 +64,21 @@ func LoadSnapshot(ctx context.Context, repo Repository, id ID) (*Snapshot, error // LoadAllSnapshots returns a list of all snapshots in the repo. func LoadAllSnapshots(ctx context.Context, repo Repository) (snapshots []*Snapshot, err error) { - for id := range repo.List(ctx, SnapshotFile) { + err = repo.List(ctx, SnapshotFile, func(id ID, size int64) error { sn, err := LoadSnapshot(ctx, repo, id) if err != nil { - return nil, err + return err } snapshots = append(snapshots, sn) + return nil + }) + + if err != nil { + return nil, err } - return + + return snapshots, nil } func (sn Snapshot) String() string { diff --git a/internal/restic/snapshot_find.go b/internal/restic/snapshot_find.go index 4c239fb1e..b5d0a8276 100644 --- a/internal/restic/snapshot_find.go +++ b/internal/restic/snapshot_find.go @@ -20,26 +20,31 @@ func FindLatestSnapshot(ctx context.Context, repo Repository, targets []string, found bool ) - for snapshotID := range repo.List(ctx, SnapshotFile) { + err := repo.List(ctx, SnapshotFile, func(snapshotID ID, size int64) error { snapshot, err := LoadSnapshot(ctx, repo, snapshotID) if err != nil { - return ID{}, errors.Errorf("Error listing snapshot: %v", err) + return errors.Errorf("Error loading snapshot %v: %v", snapshotID.Str(), err) } if snapshot.Time.Before(latest) || (hostname != "" && hostname != snapshot.Hostname) { - continue + return nil } if !snapshot.HasTagList(tagLists) { - continue + return nil } if !snapshot.HasPaths(targets) { - continue + return nil } latest = snapshot.Time latestID = snapshotID found = true + return nil + }) + + if err != nil { + return ID{}, err } if !found { @@ -64,20 +69,27 @@ func FindSnapshot(repo Repository, s string) (ID, error) { // FindFilteredSnapshots yields Snapshots filtered from the list of all // snapshots. -func FindFilteredSnapshots(ctx context.Context, repo Repository, host string, tags []TagList, paths []string) Snapshots { +func FindFilteredSnapshots(ctx context.Context, repo Repository, host string, tags []TagList, paths []string) (Snapshots, error) { results := make(Snapshots, 0, 20) - for id := range repo.List(ctx, SnapshotFile) { + err := repo.List(ctx, SnapshotFile, func(id ID, size int64) error { sn, err := LoadSnapshot(ctx, repo, id) if err != nil { fmt.Fprintf(os.Stderr, "could not load snapshot %v: %v\n", id.Str(), err) - continue + return nil } + if (host != "" && host != sn.Hostname) || !sn.HasTagList(tags) || !sn.HasPaths(paths) { - continue + return nil } results = append(results, sn) + return nil + }) + + if err != nil { + return nil, err } - return results + + return results, nil } diff --git a/vendor/golang.org/x/sync/AUTHORS b/vendor/golang.org/x/sync/AUTHORS new file mode 100644 index 000000000..15167cd74 --- /dev/null +++ b/vendor/golang.org/x/sync/AUTHORS @@ -0,0 +1,3 @@ +# This source code refers to The Go Authors for copyright purposes. +# The master list of authors is in the main Go distribution, +# visible at http://tip.golang.org/AUTHORS. diff --git a/vendor/golang.org/x/sync/CONTRIBUTING.md b/vendor/golang.org/x/sync/CONTRIBUTING.md new file mode 100644 index 000000000..88dff59bc --- /dev/null +++ b/vendor/golang.org/x/sync/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to Go + +Go is an open source project. + +It is the work of hundreds of contributors. We appreciate your help! + + +## Filing issues + +When [filing an issue](https://golang.org/issue/new), make sure to answer these five questions: + +1. What version of Go are you using (`go version`)? +2. What operating system and processor architecture are you using? +3. What did you do? +4. What did you expect to see? +5. What did you see instead? + +General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker. +The gophers there will answer or ask you to file an issue if you've tripped over a bug. + +## Contributing code + +Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html) +before sending patches. + +**We do not accept GitHub pull requests** +(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review). + +Unless otherwise noted, the Go source files are distributed under +the BSD-style license found in the LICENSE file. + diff --git a/vendor/golang.org/x/sync/CONTRIBUTORS b/vendor/golang.org/x/sync/CONTRIBUTORS new file mode 100644 index 000000000..1c4577e96 --- /dev/null +++ b/vendor/golang.org/x/sync/CONTRIBUTORS @@ -0,0 +1,3 @@ +# This source code was written by the Go contributors. +# The master list of contributors is in the main Go distribution, +# visible at http://tip.golang.org/CONTRIBUTORS. diff --git a/vendor/golang.org/x/sync/LICENSE b/vendor/golang.org/x/sync/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/golang.org/x/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/sync/PATENTS b/vendor/golang.org/x/sync/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/golang.org/x/sync/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/sync/README.md b/vendor/golang.org/x/sync/README.md new file mode 100644 index 000000000..1f8436cc9 --- /dev/null +++ b/vendor/golang.org/x/sync/README.md @@ -0,0 +1,18 @@ +# Go Sync + +This repository provides Go concurrency primitives in addition to the +ones provided by the language and "sync" and "sync/atomic" packages. + +## Download/Install + +The easiest way to install is to run `go get -u golang.org/x/sync`. You can +also manually git clone the repository to `$GOPATH/src/golang.org/x/sync`. + +## Report Issues / Send Patches + +This repository uses Gerrit for code changes. To learn how to submit changes to +this repository, see https://golang.org/doc/contribute.html. + +The main issue tracker for the sync repository is located at +https://github.com/golang/go/issues. Prefix your issue with "x/sync:" in the +subject line, so it is easy to find. diff --git a/vendor/golang.org/x/sync/codereview.cfg b/vendor/golang.org/x/sync/codereview.cfg new file mode 100644 index 000000000..3f8b14b64 --- /dev/null +++ b/vendor/golang.org/x/sync/codereview.cfg @@ -0,0 +1 @@ +issuerepo: golang/go diff --git a/vendor/golang.org/x/sync/errgroup/errgroup.go b/vendor/golang.org/x/sync/errgroup/errgroup.go new file mode 100644 index 000000000..533438d91 --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/errgroup.go @@ -0,0 +1,67 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package errgroup provides synchronization, error propagation, and Context +// cancelation for groups of goroutines working on subtasks of a common task. +package errgroup + +import ( + "sync" + + "golang.org/x/net/context" +) + +// A Group is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero Group is valid and does not cancel on error. +type Group struct { + cancel func() + + wg sync.WaitGroup + + errOnce sync.Once + err error +} + +// WithContext returns a new Group and an associated Context derived from ctx. +// +// The derived Context is canceled the first time a function passed to Go +// returns a non-nil error or the first time Wait returns, whichever occurs +// first. +func WithContext(ctx context.Context) (*Group, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &Group{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the first non-nil error (if any) from them. +func (g *Group) Wait() error { + g.wg.Wait() + if g.cancel != nil { + g.cancel() + } + return g.err +} + +// Go calls the given function in a new goroutine. +// +// The first call to return a non-nil error cancels the group; its error will be +// returned by Wait. +func (g *Group) Go(f func() error) { + g.wg.Add(1) + + go func() { + defer g.wg.Done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() +} diff --git a/vendor/golang.org/x/sync/errgroup/errgroup_example_md5all_test.go b/vendor/golang.org/x/sync/errgroup/errgroup_example_md5all_test.go new file mode 100644 index 000000000..714b5aea7 --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/errgroup_example_md5all_test.go @@ -0,0 +1,101 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package errgroup_test + +import ( + "crypto/md5" + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + + "golang.org/x/net/context" + "golang.org/x/sync/errgroup" +) + +// Pipeline demonstrates the use of a Group to implement a multi-stage +// pipeline: a version of the MD5All function with bounded parallelism from +// https://blog.golang.org/pipelines. +func ExampleGroup_pipeline() { + m, err := MD5All(context.Background(), ".") + if err != nil { + log.Fatal(err) + } + + for k, sum := range m { + fmt.Printf("%s:\t%x\n", k, sum) + } +} + +type result struct { + path string + sum [md5.Size]byte +} + +// MD5All reads all the files in the file tree rooted at root and returns a map +// from file path to the MD5 sum of the file's contents. If the directory walk +// fails or any read operation fails, MD5All returns an error. +func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) { + // ctx is canceled when g.Wait() returns. When this version of MD5All returns + // - even in case of error! - we know that all of the goroutines have finished + // and the memory they were using can be garbage-collected. + g, ctx := errgroup.WithContext(ctx) + paths := make(chan string) + + g.Go(func() error { + defer close(paths) + return filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.Mode().IsRegular() { + return nil + } + select { + case paths <- path: + case <-ctx.Done(): + return ctx.Err() + } + return nil + }) + }) + + // Start a fixed number of goroutines to read and digest files. + c := make(chan result) + const numDigesters = 20 + for i := 0; i < numDigesters; i++ { + g.Go(func() error { + for path := range paths { + data, err := ioutil.ReadFile(path) + if err != nil { + return err + } + select { + case c <- result{path, md5.Sum(data)}: + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + }) + } + go func() { + g.Wait() + close(c) + }() + + m := make(map[string][md5.Size]byte) + for r := range c { + m[r.path] = r.sum + } + // Check whether any of the goroutines failed. Since g is accumulating the + // errors, we don't need to send them (or check for them) in the individual + // results sent on the channel. + if err := g.Wait(); err != nil { + return nil, err + } + return m, nil +} diff --git a/vendor/golang.org/x/sync/errgroup/errgroup_test.go b/vendor/golang.org/x/sync/errgroup/errgroup_test.go new file mode 100644 index 000000000..6a9696efc --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/errgroup_test.go @@ -0,0 +1,176 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package errgroup_test + +import ( + "errors" + "fmt" + "net/http" + "os" + "testing" + + "golang.org/x/net/context" + "golang.org/x/sync/errgroup" +) + +var ( + Web = fakeSearch("web") + Image = fakeSearch("image") + Video = fakeSearch("video") +) + +type Result string +type Search func(ctx context.Context, query string) (Result, error) + +func fakeSearch(kind string) Search { + return func(_ context.Context, query string) (Result, error) { + return Result(fmt.Sprintf("%s result for %q", kind, query)), nil + } +} + +// JustErrors illustrates the use of a Group in place of a sync.WaitGroup to +// simplify goroutine counting and error handling. This example is derived from +// the sync.WaitGroup example at https://golang.org/pkg/sync/#example_WaitGroup. +func ExampleGroup_justErrors() { + var g errgroup.Group + var urls = []string{ + "http://www.golang.org/", + "http://www.google.com/", + "http://www.somestupidname.com/", + } + for _, url := range urls { + // Launch a goroutine to fetch the URL. + url := url // https://golang.org/doc/faq#closures_and_goroutines + g.Go(func() error { + // Fetch the URL. + resp, err := http.Get(url) + if err == nil { + resp.Body.Close() + } + return err + }) + } + // Wait for all HTTP fetches to complete. + if err := g.Wait(); err == nil { + fmt.Println("Successfully fetched all URLs.") + } +} + +// Parallel illustrates the use of a Group for synchronizing a simple parallel +// task: the "Google Search 2.0" function from +// https://talks.golang.org/2012/concurrency.slide#46, augmented with a Context +// and error-handling. +func ExampleGroup_parallel() { + Google := func(ctx context.Context, query string) ([]Result, error) { + g, ctx := errgroup.WithContext(ctx) + + searches := []Search{Web, Image, Video} + results := make([]Result, len(searches)) + for i, search := range searches { + i, search := i, search // https://golang.org/doc/faq#closures_and_goroutines + g.Go(func() error { + result, err := search(ctx, query) + if err == nil { + results[i] = result + } + return err + }) + } + if err := g.Wait(); err != nil { + return nil, err + } + return results, nil + } + + results, err := Google(context.Background(), "golang") + if err != nil { + fmt.Fprintln(os.Stderr, err) + return + } + for _, result := range results { + fmt.Println(result) + } + + // Output: + // web result for "golang" + // image result for "golang" + // video result for "golang" +} + +func TestZeroGroup(t *testing.T) { + err1 := errors.New("errgroup_test: 1") + err2 := errors.New("errgroup_test: 2") + + cases := []struct { + errs []error + }{ + {errs: []error{}}, + {errs: []error{nil}}, + {errs: []error{err1}}, + {errs: []error{err1, nil}}, + {errs: []error{err1, nil, err2}}, + } + + for _, tc := range cases { + var g errgroup.Group + + var firstErr error + for i, err := range tc.errs { + err := err + g.Go(func() error { return err }) + + if firstErr == nil && err != nil { + firstErr = err + } + + if gErr := g.Wait(); gErr != firstErr { + t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ + "g.Wait() = %v; want %v", + g, tc.errs[:i+1], err, firstErr) + } + } + } +} + +func TestWithContext(t *testing.T) { + errDoom := errors.New("group_test: doomed") + + cases := []struct { + errs []error + want error + }{ + {want: nil}, + {errs: []error{nil}, want: nil}, + {errs: []error{errDoom}, want: errDoom}, + {errs: []error{errDoom, nil}, want: errDoom}, + } + + for _, tc := range cases { + g, ctx := errgroup.WithContext(context.Background()) + + for _, err := range tc.errs { + err := err + g.Go(func() error { return err }) + } + + if err := g.Wait(); err != tc.want { + t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ + "g.Wait() = %v; want %v", + g, tc.errs, err, tc.want) + } + + canceled := false + select { + case <-ctx.Done(): + canceled = true + default: + } + if !canceled { + t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ + "ctx.Done() was not closed", + g, tc.errs) + } + } +} diff --git a/vendor/golang.org/x/sync/semaphore/semaphore.go b/vendor/golang.org/x/sync/semaphore/semaphore.go new file mode 100644 index 000000000..e9d2d79a9 --- /dev/null +++ b/vendor/golang.org/x/sync/semaphore/semaphore.go @@ -0,0 +1,131 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package semaphore provides a weighted semaphore implementation. +package semaphore // import "golang.org/x/sync/semaphore" + +import ( + "container/list" + "sync" + + // Use the old context because packages that depend on this one + // (e.g. cloud.google.com/go/...) must run on Go 1.6. + // TODO(jba): update to "context" when possible. + "golang.org/x/net/context" +) + +type waiter struct { + n int64 + ready chan<- struct{} // Closed when semaphore acquired. +} + +// NewWeighted creates a new weighted semaphore with the given +// maximum combined weight for concurrent access. +func NewWeighted(n int64) *Weighted { + w := &Weighted{size: n} + return w +} + +// Weighted provides a way to bound concurrent access to a resource. +// The callers can request access with a given weight. +type Weighted struct { + size int64 + cur int64 + mu sync.Mutex + waiters list.List +} + +// Acquire acquires the semaphore with a weight of n, blocking only until ctx +// is done. On success, returns nil. On failure, returns ctx.Err() and leaves +// the semaphore unchanged. +// +// If ctx is already done, Acquire may still succeed without blocking. +func (s *Weighted) Acquire(ctx context.Context, n int64) error { + s.mu.Lock() + if s.size-s.cur >= n && s.waiters.Len() == 0 { + s.cur += n + s.mu.Unlock() + return nil + } + + if n > s.size { + // Don't make other Acquire calls block on one that's doomed to fail. + s.mu.Unlock() + <-ctx.Done() + return ctx.Err() + } + + ready := make(chan struct{}) + w := waiter{n: n, ready: ready} + elem := s.waiters.PushBack(w) + s.mu.Unlock() + + select { + case <-ctx.Done(): + err := ctx.Err() + s.mu.Lock() + select { + case <-ready: + // Acquired the semaphore after we were canceled. Rather than trying to + // fix up the queue, just pretend we didn't notice the cancelation. + err = nil + default: + s.waiters.Remove(elem) + } + s.mu.Unlock() + return err + + case <-ready: + return nil + } +} + +// TryAcquire acquires the semaphore with a weight of n without blocking. +// On success, returns true. On failure, returns false and leaves the semaphore unchanged. +func (s *Weighted) TryAcquire(n int64) bool { + s.mu.Lock() + success := s.size-s.cur >= n && s.waiters.Len() == 0 + if success { + s.cur += n + } + s.mu.Unlock() + return success +} + +// Release releases the semaphore with a weight of n. +func (s *Weighted) Release(n int64) { + s.mu.Lock() + s.cur -= n + if s.cur < 0 { + s.mu.Unlock() + panic("semaphore: bad release") + } + for { + next := s.waiters.Front() + if next == nil { + break // No more waiters blocked. + } + + w := next.Value.(waiter) + if s.size-s.cur < w.n { + // Not enough tokens for the next waiter. We could keep going (to try to + // find a waiter with a smaller request), but under load that could cause + // starvation for large requests; instead, we leave all remaining waiters + // blocked. + // + // Consider a semaphore used as a read-write lock, with N tokens, N + // readers, and one writer. Each reader can Acquire(1) to obtain a read + // lock. The writer can Acquire(N) to obtain a write lock, excluding all + // of the readers. If we allow the readers to jump ahead in the queue, + // the writer will starve — there is always one token available for every + // reader. + break + } + + s.cur += w.n + s.waiters.Remove(next) + close(w.ready) + } + s.mu.Unlock() +} diff --git a/vendor/golang.org/x/sync/semaphore/semaphore_bench_test.go b/vendor/golang.org/x/sync/semaphore/semaphore_bench_test.go new file mode 100644 index 000000000..1e3ab75f5 --- /dev/null +++ b/vendor/golang.org/x/sync/semaphore/semaphore_bench_test.go @@ -0,0 +1,131 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7 + +package semaphore_test + +import ( + "fmt" + "testing" + + "golang.org/x/net/context" + "golang.org/x/sync/semaphore" +) + +// weighted is an interface matching a subset of *Weighted. It allows +// alternate implementations for testing and benchmarking. +type weighted interface { + Acquire(context.Context, int64) error + TryAcquire(int64) bool + Release(int64) +} + +// semChan implements Weighted using a channel for +// comparing against the condition variable-based implementation. +type semChan chan struct{} + +func newSemChan(n int64) semChan { + return semChan(make(chan struct{}, n)) +} + +func (s semChan) Acquire(_ context.Context, n int64) error { + for i := int64(0); i < n; i++ { + s <- struct{}{} + } + return nil +} + +func (s semChan) TryAcquire(n int64) bool { + if int64(len(s))+n > int64(cap(s)) { + return false + } + + for i := int64(0); i < n; i++ { + s <- struct{}{} + } + return true +} + +func (s semChan) Release(n int64) { + for i := int64(0); i < n; i++ { + <-s + } +} + +// acquireN calls Acquire(size) on sem N times and then calls Release(size) N times. +func acquireN(b *testing.B, sem weighted, size int64, N int) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < N; j++ { + sem.Acquire(context.Background(), size) + } + for j := 0; j < N; j++ { + sem.Release(size) + } + } +} + +// tryAcquireN calls TryAcquire(size) on sem N times and then calls Release(size) N times. +func tryAcquireN(b *testing.B, sem weighted, size int64, N int) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < N; j++ { + if !sem.TryAcquire(size) { + b.Fatalf("TryAcquire(%v) = false, want true", size) + } + } + for j := 0; j < N; j++ { + sem.Release(size) + } + } +} + +func BenchmarkNewSeq(b *testing.B) { + for _, cap := range []int64{1, 128} { + b.Run(fmt.Sprintf("Weighted-%d", cap), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = semaphore.NewWeighted(cap) + } + }) + b.Run(fmt.Sprintf("semChan-%d", cap), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = newSemChan(cap) + } + }) + } +} + +func BenchmarkAcquireSeq(b *testing.B) { + for _, c := range []struct { + cap, size int64 + N int + }{ + {1, 1, 1}, + {2, 1, 1}, + {16, 1, 1}, + {128, 1, 1}, + {2, 2, 1}, + {16, 2, 8}, + {128, 2, 64}, + {2, 1, 2}, + {16, 8, 2}, + {128, 64, 2}, + } { + for _, w := range []struct { + name string + w weighted + }{ + {"Weighted", semaphore.NewWeighted(c.cap)}, + {"semChan", newSemChan(c.cap)}, + } { + b.Run(fmt.Sprintf("%s-acquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) { + acquireN(b, w.w, c.size, c.N) + }) + b.Run(fmt.Sprintf("%s-tryAcquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) { + tryAcquireN(b, w.w, c.size, c.N) + }) + } + } +} diff --git a/vendor/golang.org/x/sync/semaphore/semaphore_example_test.go b/vendor/golang.org/x/sync/semaphore/semaphore_example_test.go new file mode 100644 index 000000000..e75cd79f5 --- /dev/null +++ b/vendor/golang.org/x/sync/semaphore/semaphore_example_test.go @@ -0,0 +1,84 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package semaphore_test + +import ( + "context" + "fmt" + "log" + "runtime" + + "golang.org/x/sync/semaphore" +) + +// Example_workerPool demonstrates how to use a semaphore to limit the number of +// goroutines working on parallel tasks. +// +// This use of a semaphore mimics a typical “worker pool” pattern, but without +// the need to explicitly shut down idle workers when the work is done. +func Example_workerPool() { + ctx := context.TODO() + + var ( + maxWorkers = runtime.GOMAXPROCS(0) + sem = semaphore.NewWeighted(int64(maxWorkers)) + out = make([]int, 32) + ) + + // Compute the output using up to maxWorkers goroutines at a time. + for i := range out { + // When maxWorkers goroutines are in flight, Acquire blocks until one of the + // workers finishes. + if err := sem.Acquire(ctx, 1); err != nil { + log.Printf("Failed to acquire semaphore: %v", err) + break + } + + go func(i int) { + defer sem.Release(1) + out[i] = collatzSteps(i + 1) + }(i) + } + + // Acquire all of the tokens to wait for any remaining workers to finish. + // + // If you are already waiting for the workers by some other means (such as an + // errgroup.Group), you can omit this final Acquire call. + if err := sem.Acquire(ctx, int64(maxWorkers)); err != nil { + log.Printf("Failed to acquire semaphore: %v", err) + } + + fmt.Println(out) + + // Output: + // [0 1 7 2 5 8 16 3 19 6 14 9 9 17 17 4 12 20 20 7 7 15 15 10 23 10 111 18 18 18 106 5] +} + +// collatzSteps computes the number of steps to reach 1 under the Collatz +// conjecture. (See https://en.wikipedia.org/wiki/Collatz_conjecture.) +func collatzSteps(n int) (steps int) { + if n <= 0 { + panic("nonpositive input") + } + + for ; n > 1; steps++ { + if steps < 0 { + panic("too many steps") + } + + if n%2 == 0 { + n /= 2 + continue + } + + const maxInt = int(^uint(0) >> 1) + if n > (maxInt-1)/3 { + panic("overflow") + } + n = 3*n + 1 + } + + return steps +} diff --git a/vendor/golang.org/x/sync/semaphore/semaphore_test.go b/vendor/golang.org/x/sync/semaphore/semaphore_test.go new file mode 100644 index 000000000..2541b9068 --- /dev/null +++ b/vendor/golang.org/x/sync/semaphore/semaphore_test.go @@ -0,0 +1,171 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package semaphore_test + +import ( + "math/rand" + "runtime" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" +) + +const maxSleep = 1 * time.Millisecond + +func HammerWeighted(sem *semaphore.Weighted, n int64, loops int) { + for i := 0; i < loops; i++ { + sem.Acquire(context.Background(), n) + time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond) + sem.Release(n) + } +} + +func TestWeighted(t *testing.T) { + t.Parallel() + + n := runtime.GOMAXPROCS(0) + loops := 10000 / n + sem := semaphore.NewWeighted(int64(n)) + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + i := i + go func() { + defer wg.Done() + HammerWeighted(sem, int64(i), loops) + }() + } + wg.Wait() +} + +func TestWeightedPanic(t *testing.T) { + t.Parallel() + + defer func() { + if recover() == nil { + t.Fatal("release of an unacquired weighted semaphore did not panic") + } + }() + w := semaphore.NewWeighted(1) + w.Release(1) +} + +func TestWeightedTryAcquire(t *testing.T) { + t.Parallel() + + ctx := context.Background() + sem := semaphore.NewWeighted(2) + tries := []bool{} + sem.Acquire(ctx, 1) + tries = append(tries, sem.TryAcquire(1)) + tries = append(tries, sem.TryAcquire(1)) + + sem.Release(2) + + tries = append(tries, sem.TryAcquire(1)) + sem.Acquire(ctx, 1) + tries = append(tries, sem.TryAcquire(1)) + + want := []bool{true, false, true, false} + for i := range tries { + if tries[i] != want[i] { + t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i]) + } + } +} + +func TestWeightedAcquire(t *testing.T) { + t.Parallel() + + ctx := context.Background() + sem := semaphore.NewWeighted(2) + tryAcquire := func(n int64) bool { + ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + return sem.Acquire(ctx, n) == nil + } + + tries := []bool{} + sem.Acquire(ctx, 1) + tries = append(tries, tryAcquire(1)) + tries = append(tries, tryAcquire(1)) + + sem.Release(2) + + tries = append(tries, tryAcquire(1)) + sem.Acquire(ctx, 1) + tries = append(tries, tryAcquire(1)) + + want := []bool{true, false, true, false} + for i := range tries { + if tries[i] != want[i] { + t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i]) + } + } +} + +func TestWeightedDoesntBlockIfTooBig(t *testing.T) { + t.Parallel() + + const n = 2 + sem := semaphore.NewWeighted(n) + { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go sem.Acquire(ctx, n+1) + } + + g, ctx := errgroup.WithContext(context.Background()) + for i := n * 3; i > 0; i-- { + g.Go(func() error { + err := sem.Acquire(ctx, 1) + if err == nil { + time.Sleep(1 * time.Millisecond) + sem.Release(1) + } + return err + }) + } + if err := g.Wait(); err != nil { + t.Errorf("semaphore.NewWeighted(%v) failed to AcquireCtx(_, 1) with AcquireCtx(_, %v) pending", n, n+1) + } +} + +// TestLargeAcquireDoesntStarve times out if a large call to Acquire starves. +// Merely returning from the test function indicates success. +func TestLargeAcquireDoesntStarve(t *testing.T) { + t.Parallel() + + ctx := context.Background() + n := int64(runtime.GOMAXPROCS(0)) + sem := semaphore.NewWeighted(n) + running := true + + var wg sync.WaitGroup + wg.Add(int(n)) + for i := n; i > 0; i-- { + sem.Acquire(ctx, 1) + go func() { + defer func() { + sem.Release(1) + wg.Done() + }() + for running { + time.Sleep(1 * time.Millisecond) + sem.Release(1) + sem.Acquire(ctx, 1) + } + }() + } + + sem.Acquire(ctx, n) + running = false + sem.Release(n) + wg.Wait() +} diff --git a/vendor/golang.org/x/sync/singleflight/singleflight.go b/vendor/golang.org/x/sync/singleflight/singleflight.go new file mode 100644 index 000000000..9a4f8d59e --- /dev/null +++ b/vendor/golang.org/x/sync/singleflight/singleflight.go @@ -0,0 +1,111 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package singleflight provides a duplicate function call suppression +// mechanism. +package singleflight // import "golang.org/x/sync/singleflight" + +import "sync" + +// call is an in-flight or completed singleflight.Do call +type call struct { + wg sync.WaitGroup + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val interface{} + err error + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups int + chans []chan<- Result +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group struct { + mu sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result struct { + Val interface{} + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + return c.val, c.err, true + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, c.dups > 0 +} + +// DoChan is like Do but returns a channel that will receive the +// results when they are ready. +func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result { + ch := make(chan Result, 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.chans = append(c.chans, ch) + g.mu.Unlock() + return ch + } + c := &call{chans: []chan<- Result{ch}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + go g.doCall(c, key, fn) + + return ch +} + +// doCall handles the single call for a key. +func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { + c.val, c.err = fn() + c.wg.Done() + + g.mu.Lock() + delete(g.m, key) + for _, ch := range c.chans { + ch <- Result{c.val, c.err, c.dups > 0} + } + g.mu.Unlock() +} + +// Forget tells the singleflight to forget about a key. Future calls +// to Do for this key will call the function rather than waiting for +// an earlier call to complete. +func (g *Group) Forget(key string) { + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() +} diff --git a/vendor/golang.org/x/sync/singleflight/singleflight_test.go b/vendor/golang.org/x/sync/singleflight/singleflight_test.go new file mode 100644 index 000000000..5e6f1b328 --- /dev/null +++ b/vendor/golang.org/x/sync/singleflight/singleflight_test.go @@ -0,0 +1,87 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package singleflight + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo(t *testing.T) { + var g Group + v, err, _ := g.Do("key", func() (interface{}, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoErr(t *testing.T) { + var g Group + someErr := errors.New("Some error") + v, err, _ := g.Do("key", func() (interface{}, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr %v", err, someErr) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress(t *testing.T) { + var g Group + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var calls int32 + fn := func() (interface{}, error) { + if atomic.AddInt32(&calls, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err, _ := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if s, _ := v.(string); s != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) + } +} diff --git a/vendor/golang.org/x/sync/syncmap/map.go b/vendor/golang.org/x/sync/syncmap/map.go new file mode 100644 index 000000000..80e15847e --- /dev/null +++ b/vendor/golang.org/x/sync/syncmap/map.go @@ -0,0 +1,372 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package syncmap provides a concurrent map implementation. +// It is a prototype for a proposed addition to the sync package +// in the standard library. +// (https://golang.org/issue/18177) +package syncmap + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +// Map is a concurrent map with amortized-constant-time loads, stores, and deletes. +// It is safe for multiple goroutines to call a Map's methods concurrently. +// +// The zero Map is valid and empty. +// +// A Map must not be copied after first use. +type Map struct { + mu sync.Mutex + + // read contains the portion of the map's contents that are safe for + // concurrent access (with or without mu held). + // + // The read field itself is always safe to load, but must only be stored with + // mu held. + // + // Entries stored in read may be updated concurrently without mu, but updating + // a previously-expunged entry requires that the entry be copied to the dirty + // map and unexpunged with mu held. + read atomic.Value // readOnly + + // dirty contains the portion of the map's contents that require mu to be + // held. To ensure that the dirty map can be promoted to the read map quickly, + // it also includes all of the non-expunged entries in the read map. + // + // Expunged entries are not stored in the dirty map. An expunged entry in the + // clean map must be unexpunged and added to the dirty map before a new value + // can be stored to it. + // + // If the dirty map is nil, the next write to the map will initialize it by + // making a shallow copy of the clean map, omitting stale entries. + dirty map[interface{}]*entry + + // misses counts the number of loads since the read map was last updated that + // needed to lock mu to determine whether the key was present. + // + // Once enough misses have occurred to cover the cost of copying the dirty + // map, the dirty map will be promoted to the read map (in the unamended + // state) and the next store to the map will make a new dirty copy. + misses int +} + +// readOnly is an immutable struct stored atomically in the Map.read field. +type readOnly struct { + m map[interface{}]*entry + amended bool // true if the dirty map contains some key not in m. +} + +// expunged is an arbitrary pointer that marks entries which have been deleted +// from the dirty map. +var expunged = unsafe.Pointer(new(interface{})) + +// An entry is a slot in the map corresponding to a particular key. +type entry struct { + // p points to the interface{} value stored for the entry. + // + // If p == nil, the entry has been deleted and m.dirty == nil. + // + // If p == expunged, the entry has been deleted, m.dirty != nil, and the entry + // is missing from m.dirty. + // + // Otherwise, the entry is valid and recorded in m.read.m[key] and, if m.dirty + // != nil, in m.dirty[key]. + // + // An entry can be deleted by atomic replacement with nil: when m.dirty is + // next created, it will atomically replace nil with expunged and leave + // m.dirty[key] unset. + // + // An entry's associated value can be updated by atomic replacement, provided + // p != expunged. If p == expunged, an entry's associated value can be updated + // only after first setting m.dirty[key] = e so that lookups using the dirty + // map find the entry. + p unsafe.Pointer // *interface{} +} + +func newEntry(i interface{}) *entry { + return &entry{p: unsafe.Pointer(&i)} +} + +// Load returns the value stored in the map for a key, or nil if no +// value is present. +// The ok result indicates whether value was found in the map. +func (m *Map) Load(key interface{}) (value interface{}, ok bool) { + read, _ := m.read.Load().(readOnly) + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + // Avoid reporting a spurious miss if m.dirty got promoted while we were + // blocked on m.mu. (If further loads of the same key will not miss, it's + // not worth copying the dirty map for this key.) + read, _ = m.read.Load().(readOnly) + e, ok = read.m[key] + if !ok && read.amended { + e, ok = m.dirty[key] + // Regardless of whether the entry was present, record a miss: this key + // will take the slow path until the dirty map is promoted to the read + // map. + m.missLocked() + } + m.mu.Unlock() + } + if !ok { + return nil, false + } + return e.load() +} + +func (e *entry) load() (value interface{}, ok bool) { + p := atomic.LoadPointer(&e.p) + if p == nil || p == expunged { + return nil, false + } + return *(*interface{})(p), true +} + +// Store sets the value for a key. +func (m *Map) Store(key, value interface{}) { + read, _ := m.read.Load().(readOnly) + if e, ok := read.m[key]; ok && e.tryStore(&value) { + return + } + + m.mu.Lock() + read, _ = m.read.Load().(readOnly) + if e, ok := read.m[key]; ok { + if e.unexpungeLocked() { + // The entry was previously expunged, which implies that there is a + // non-nil dirty map and this entry is not in it. + m.dirty[key] = e + } + e.storeLocked(&value) + } else if e, ok := m.dirty[key]; ok { + e.storeLocked(&value) + } else { + if !read.amended { + // We're adding the first new key to the dirty map. + // Make sure it is allocated and mark the read-only map as incomplete. + m.dirtyLocked() + m.read.Store(readOnly{m: read.m, amended: true}) + } + m.dirty[key] = newEntry(value) + } + m.mu.Unlock() +} + +// tryStore stores a value if the entry has not been expunged. +// +// If the entry is expunged, tryStore returns false and leaves the entry +// unchanged. +func (e *entry) tryStore(i *interface{}) bool { + p := atomic.LoadPointer(&e.p) + if p == expunged { + return false + } + for { + if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(i)) { + return true + } + p = atomic.LoadPointer(&e.p) + if p == expunged { + return false + } + } +} + +// unexpungeLocked ensures that the entry is not marked as expunged. +// +// If the entry was previously expunged, it must be added to the dirty map +// before m.mu is unlocked. +func (e *entry) unexpungeLocked() (wasExpunged bool) { + return atomic.CompareAndSwapPointer(&e.p, expunged, nil) +} + +// storeLocked unconditionally stores a value to the entry. +// +// The entry must be known not to be expunged. +func (e *entry) storeLocked(i *interface{}) { + atomic.StorePointer(&e.p, unsafe.Pointer(i)) +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func (m *Map) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) { + // Avoid locking if it's a clean hit. + read, _ := m.read.Load().(readOnly) + if e, ok := read.m[key]; ok { + actual, loaded, ok := e.tryLoadOrStore(value) + if ok { + return actual, loaded + } + } + + m.mu.Lock() + read, _ = m.read.Load().(readOnly) + if e, ok := read.m[key]; ok { + if e.unexpungeLocked() { + m.dirty[key] = e + } + actual, loaded, _ = e.tryLoadOrStore(value) + } else if e, ok := m.dirty[key]; ok { + actual, loaded, _ = e.tryLoadOrStore(value) + m.missLocked() + } else { + if !read.amended { + // We're adding the first new key to the dirty map. + // Make sure it is allocated and mark the read-only map as incomplete. + m.dirtyLocked() + m.read.Store(readOnly{m: read.m, amended: true}) + } + m.dirty[key] = newEntry(value) + actual, loaded = value, false + } + m.mu.Unlock() + + return actual, loaded +} + +// tryLoadOrStore atomically loads or stores a value if the entry is not +// expunged. +// +// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and +// returns with ok==false. +func (e *entry) tryLoadOrStore(i interface{}) (actual interface{}, loaded, ok bool) { + p := atomic.LoadPointer(&e.p) + if p == expunged { + return nil, false, false + } + if p != nil { + return *(*interface{})(p), true, true + } + + // Copy the interface after the first load to make this method more amenable + // to escape analysis: if we hit the "load" path or the entry is expunged, we + // shouldn't bother heap-allocating. + ic := i + for { + if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) { + return i, false, true + } + p = atomic.LoadPointer(&e.p) + if p == expunged { + return nil, false, false + } + if p != nil { + return *(*interface{})(p), true, true + } + } +} + +// Delete deletes the value for a key. +func (m *Map) Delete(key interface{}) { + read, _ := m.read.Load().(readOnly) + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + read, _ = m.read.Load().(readOnly) + e, ok = read.m[key] + if !ok && read.amended { + delete(m.dirty, key) + } + m.mu.Unlock() + } + if ok { + e.delete() + } +} + +func (e *entry) delete() (hadValue bool) { + for { + p := atomic.LoadPointer(&e.p) + if p == nil || p == expunged { + return false + } + if atomic.CompareAndSwapPointer(&e.p, p, nil) { + return true + } + } +} + +// Range calls f sequentially for each key and value present in the map. +// If f returns false, range stops the iteration. +// +// Range does not necessarily correspond to any consistent snapshot of the Map's +// contents: no key will be visited more than once, but if the value for any key +// is stored or deleted concurrently, Range may reflect any mapping for that key +// from any point during the Range call. +// +// Range may be O(N) with the number of elements in the map even if f returns +// false after a constant number of calls. +func (m *Map) Range(f func(key, value interface{}) bool) { + // We need to be able to iterate over all of the keys that were already + // present at the start of the call to Range. + // If read.amended is false, then read.m satisfies that property without + // requiring us to hold m.mu for a long time. + read, _ := m.read.Load().(readOnly) + if read.amended { + // m.dirty contains keys not in read.m. Fortunately, Range is already O(N) + // (assuming the caller does not break out early), so a call to Range + // amortizes an entire copy of the map: we can promote the dirty copy + // immediately! + m.mu.Lock() + read, _ = m.read.Load().(readOnly) + if read.amended { + read = readOnly{m: m.dirty} + m.read.Store(read) + m.dirty = nil + m.misses = 0 + } + m.mu.Unlock() + } + + for k, e := range read.m { + v, ok := e.load() + if !ok { + continue + } + if !f(k, v) { + break + } + } +} + +func (m *Map) missLocked() { + m.misses++ + if m.misses < len(m.dirty) { + return + } + m.read.Store(readOnly{m: m.dirty}) + m.dirty = nil + m.misses = 0 +} + +func (m *Map) dirtyLocked() { + if m.dirty != nil { + return + } + + read, _ := m.read.Load().(readOnly) + m.dirty = make(map[interface{}]*entry, len(read.m)) + for k, e := range read.m { + if !e.tryExpungeLocked() { + m.dirty[k] = e + } + } +} + +func (e *entry) tryExpungeLocked() (isExpunged bool) { + p := atomic.LoadPointer(&e.p) + for p == nil { + if atomic.CompareAndSwapPointer(&e.p, nil, expunged) { + return true + } + p = atomic.LoadPointer(&e.p) + } + return p == expunged +} diff --git a/vendor/golang.org/x/sync/syncmap/map_bench_test.go b/vendor/golang.org/x/sync/syncmap/map_bench_test.go new file mode 100644 index 000000000..b279b4f74 --- /dev/null +++ b/vendor/golang.org/x/sync/syncmap/map_bench_test.go @@ -0,0 +1,216 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package syncmap_test + +import ( + "fmt" + "reflect" + "sync/atomic" + "testing" + + "golang.org/x/sync/syncmap" +) + +type bench struct { + setup func(*testing.B, mapInterface) + perG func(b *testing.B, pb *testing.PB, i int, m mapInterface) +} + +func benchMap(b *testing.B, bench bench) { + for _, m := range [...]mapInterface{&DeepCopyMap{}, &RWMutexMap{}, &syncmap.Map{}} { + b.Run(fmt.Sprintf("%T", m), func(b *testing.B) { + m = reflect.New(reflect.TypeOf(m).Elem()).Interface().(mapInterface) + if bench.setup != nil { + bench.setup(b, m) + } + + b.ResetTimer() + + var i int64 + b.RunParallel(func(pb *testing.PB) { + id := int(atomic.AddInt64(&i, 1) - 1) + bench.perG(b, pb, id*b.N, m) + }) + }) + } +} + +func BenchmarkLoadMostlyHits(b *testing.B) { + const hits, misses = 1023, 1 + + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + for i := 0; i < hits; i++ { + m.LoadOrStore(i, i) + } + // Prime the map to get it into a steady state. + for i := 0; i < hits*2; i++ { + m.Load(i % hits) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + m.Load(i % (hits + misses)) + } + }, + }) +} + +func BenchmarkLoadMostlyMisses(b *testing.B) { + const hits, misses = 1, 1023 + + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + for i := 0; i < hits; i++ { + m.LoadOrStore(i, i) + } + // Prime the map to get it into a steady state. + for i := 0; i < hits*2; i++ { + m.Load(i % hits) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + m.Load(i % (hits + misses)) + } + }, + }) +} + +func BenchmarkLoadOrStoreBalanced(b *testing.B) { + const hits, misses = 128, 128 + + benchMap(b, bench{ + setup: func(b *testing.B, m mapInterface) { + if _, ok := m.(*DeepCopyMap); ok { + b.Skip("DeepCopyMap has quadratic running time.") + } + for i := 0; i < hits; i++ { + m.LoadOrStore(i, i) + } + // Prime the map to get it into a steady state. + for i := 0; i < hits*2; i++ { + m.Load(i % hits) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + j := i % (hits + misses) + if j < hits { + if _, ok := m.LoadOrStore(j, i); !ok { + b.Fatalf("unexpected miss for %v", j) + } + } else { + if v, loaded := m.LoadOrStore(i, i); loaded { + b.Fatalf("failed to store %v: existing value %v", i, v) + } + } + } + }, + }) +} + +func BenchmarkLoadOrStoreUnique(b *testing.B) { + benchMap(b, bench{ + setup: func(b *testing.B, m mapInterface) { + if _, ok := m.(*DeepCopyMap); ok { + b.Skip("DeepCopyMap has quadratic running time.") + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + m.LoadOrStore(i, i) + } + }, + }) +} + +func BenchmarkLoadOrStoreCollision(b *testing.B) { + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + m.LoadOrStore(0, 0) + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + m.LoadOrStore(0, 0) + } + }, + }) +} + +func BenchmarkRange(b *testing.B) { + const mapSize = 1 << 10 + + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + for i := 0; i < mapSize; i++ { + m.Store(i, i) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + m.Range(func(_, _ interface{}) bool { return true }) + } + }, + }) +} + +// BenchmarkAdversarialAlloc tests performance when we store a new value +// immediately whenever the map is promoted to clean and otherwise load a +// unique, missing key. +// +// This forces the Load calls to always acquire the map's mutex. +func BenchmarkAdversarialAlloc(b *testing.B) { + benchMap(b, bench{ + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + var stores, loadsSinceStore int64 + for ; pb.Next(); i++ { + m.Load(i) + if loadsSinceStore++; loadsSinceStore > stores { + m.LoadOrStore(i, stores) + loadsSinceStore = 0 + stores++ + } + } + }, + }) +} + +// BenchmarkAdversarialDelete tests performance when we periodically delete +// one key and add a different one in a large map. +// +// This forces the Load calls to always acquire the map's mutex and periodically +// makes a full copy of the map despite changing only one entry. +func BenchmarkAdversarialDelete(b *testing.B) { + const mapSize = 1 << 10 + + benchMap(b, bench{ + setup: func(_ *testing.B, m mapInterface) { + for i := 0; i < mapSize; i++ { + m.Store(i, i) + } + }, + + perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) { + for ; pb.Next(); i++ { + m.Load(i) + + if i%mapSize == 0 { + m.Range(func(k, _ interface{}) bool { + m.Delete(k) + return false + }) + m.Store(i, i) + } + } + }, + }) +} diff --git a/vendor/golang.org/x/sync/syncmap/map_reference_test.go b/vendor/golang.org/x/sync/syncmap/map_reference_test.go new file mode 100644 index 000000000..923c51b70 --- /dev/null +++ b/vendor/golang.org/x/sync/syncmap/map_reference_test.go @@ -0,0 +1,151 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package syncmap_test + +import ( + "sync" + "sync/atomic" +) + +// This file contains reference map implementations for unit-tests. + +// mapInterface is the interface Map implements. +type mapInterface interface { + Load(interface{}) (interface{}, bool) + Store(key, value interface{}) + LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) + Delete(interface{}) + Range(func(key, value interface{}) (shouldContinue bool)) +} + +// RWMutexMap is an implementation of mapInterface using a sync.RWMutex. +type RWMutexMap struct { + mu sync.RWMutex + dirty map[interface{}]interface{} +} + +func (m *RWMutexMap) Load(key interface{}) (value interface{}, ok bool) { + m.mu.RLock() + value, ok = m.dirty[key] + m.mu.RUnlock() + return +} + +func (m *RWMutexMap) Store(key, value interface{}) { + m.mu.Lock() + if m.dirty == nil { + m.dirty = make(map[interface{}]interface{}) + } + m.dirty[key] = value + m.mu.Unlock() +} + +func (m *RWMutexMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) { + m.mu.Lock() + actual, loaded = m.dirty[key] + if !loaded { + actual = value + if m.dirty == nil { + m.dirty = make(map[interface{}]interface{}) + } + m.dirty[key] = value + } + m.mu.Unlock() + return actual, loaded +} + +func (m *RWMutexMap) Delete(key interface{}) { + m.mu.Lock() + delete(m.dirty, key) + m.mu.Unlock() +} + +func (m *RWMutexMap) Range(f func(key, value interface{}) (shouldContinue bool)) { + m.mu.RLock() + keys := make([]interface{}, 0, len(m.dirty)) + for k := range m.dirty { + keys = append(keys, k) + } + m.mu.RUnlock() + + for _, k := range keys { + v, ok := m.Load(k) + if !ok { + continue + } + if !f(k, v) { + break + } + } +} + +// DeepCopyMap is an implementation of mapInterface using a Mutex and +// atomic.Value. It makes deep copies of the map on every write to avoid +// acquiring the Mutex in Load. +type DeepCopyMap struct { + mu sync.Mutex + clean atomic.Value +} + +func (m *DeepCopyMap) Load(key interface{}) (value interface{}, ok bool) { + clean, _ := m.clean.Load().(map[interface{}]interface{}) + value, ok = clean[key] + return value, ok +} + +func (m *DeepCopyMap) Store(key, value interface{}) { + m.mu.Lock() + dirty := m.dirty() + dirty[key] = value + m.clean.Store(dirty) + m.mu.Unlock() +} + +func (m *DeepCopyMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) { + clean, _ := m.clean.Load().(map[interface{}]interface{}) + actual, loaded = clean[key] + if loaded { + return actual, loaded + } + + m.mu.Lock() + // Reload clean in case it changed while we were waiting on m.mu. + clean, _ = m.clean.Load().(map[interface{}]interface{}) + actual, loaded = clean[key] + if !loaded { + dirty := m.dirty() + dirty[key] = value + actual = value + m.clean.Store(dirty) + } + m.mu.Unlock() + return actual, loaded +} + +func (m *DeepCopyMap) Delete(key interface{}) { + m.mu.Lock() + dirty := m.dirty() + delete(dirty, key) + m.clean.Store(dirty) + m.mu.Unlock() +} + +func (m *DeepCopyMap) Range(f func(key, value interface{}) (shouldContinue bool)) { + clean, _ := m.clean.Load().(map[interface{}]interface{}) + for k, v := range clean { + if !f(k, v) { + break + } + } +} + +func (m *DeepCopyMap) dirty() map[interface{}]interface{} { + clean, _ := m.clean.Load().(map[interface{}]interface{}) + dirty := make(map[interface{}]interface{}, len(clean)+1) + for k, v := range clean { + dirty[k] = v + } + return dirty +} diff --git a/vendor/golang.org/x/sync/syncmap/map_test.go b/vendor/golang.org/x/sync/syncmap/map_test.go new file mode 100644 index 000000000..c883f176f --- /dev/null +++ b/vendor/golang.org/x/sync/syncmap/map_test.go @@ -0,0 +1,172 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package syncmap_test + +import ( + "math/rand" + "reflect" + "runtime" + "sync" + "testing" + "testing/quick" + + "golang.org/x/sync/syncmap" +) + +type mapOp string + +const ( + opLoad = mapOp("Load") + opStore = mapOp("Store") + opLoadOrStore = mapOp("LoadOrStore") + opDelete = mapOp("Delete") +) + +var mapOps = [...]mapOp{opLoad, opStore, opLoadOrStore, opDelete} + +// mapCall is a quick.Generator for calls on mapInterface. +type mapCall struct { + op mapOp + k, v interface{} +} + +func (c mapCall) apply(m mapInterface) (interface{}, bool) { + switch c.op { + case opLoad: + return m.Load(c.k) + case opStore: + m.Store(c.k, c.v) + return nil, false + case opLoadOrStore: + return m.LoadOrStore(c.k, c.v) + case opDelete: + m.Delete(c.k) + return nil, false + default: + panic("invalid mapOp") + } +} + +type mapResult struct { + value interface{} + ok bool +} + +func randValue(r *rand.Rand) interface{} { + b := make([]byte, r.Intn(4)) + for i := range b { + b[i] = 'a' + byte(rand.Intn(26)) + } + return string(b) +} + +func (mapCall) Generate(r *rand.Rand, size int) reflect.Value { + c := mapCall{op: mapOps[rand.Intn(len(mapOps))], k: randValue(r)} + switch c.op { + case opStore, opLoadOrStore: + c.v = randValue(r) + } + return reflect.ValueOf(c) +} + +func applyCalls(m mapInterface, calls []mapCall) (results []mapResult, final map[interface{}]interface{}) { + for _, c := range calls { + v, ok := c.apply(m) + results = append(results, mapResult{v, ok}) + } + + final = make(map[interface{}]interface{}) + m.Range(func(k, v interface{}) bool { + final[k] = v + return true + }) + + return results, final +} + +func applyMap(calls []mapCall) ([]mapResult, map[interface{}]interface{}) { + return applyCalls(new(syncmap.Map), calls) +} + +func applyRWMutexMap(calls []mapCall) ([]mapResult, map[interface{}]interface{}) { + return applyCalls(new(RWMutexMap), calls) +} + +func applyDeepCopyMap(calls []mapCall) ([]mapResult, map[interface{}]interface{}) { + return applyCalls(new(DeepCopyMap), calls) +} + +func TestMapMatchesRWMutex(t *testing.T) { + if err := quick.CheckEqual(applyMap, applyRWMutexMap, nil); err != nil { + t.Error(err) + } +} + +func TestMapMatchesDeepCopy(t *testing.T) { + if err := quick.CheckEqual(applyMap, applyDeepCopyMap, nil); err != nil { + t.Error(err) + } +} + +func TestConcurrentRange(t *testing.T) { + const mapSize = 1 << 10 + + m := new(syncmap.Map) + for n := int64(1); n <= mapSize; n++ { + m.Store(n, int64(n)) + } + + done := make(chan struct{}) + var wg sync.WaitGroup + defer func() { + close(done) + wg.Wait() + }() + for g := int64(runtime.GOMAXPROCS(0)); g > 0; g-- { + r := rand.New(rand.NewSource(g)) + wg.Add(1) + go func(g int64) { + defer wg.Done() + for i := int64(0); ; i++ { + select { + case <-done: + return + default: + } + for n := int64(1); n < mapSize; n++ { + if r.Int63n(mapSize) == 0 { + m.Store(n, n*i*g) + } else { + m.Load(n) + } + } + } + }(g) + } + + iters := 1 << 10 + if testing.Short() { + iters = 16 + } + for n := iters; n > 0; n-- { + seen := make(map[int64]bool, mapSize) + + m.Range(func(ki, vi interface{}) bool { + k, v := ki.(int64), vi.(int64) + if v%k != 0 { + t.Fatalf("while Storing multiples of %v, Range saw value %v", k, v) + } + if seen[k] { + t.Fatalf("Range visited key %v twice", k) + } + seen[k] = true + return true + }) + + if len(seen) != mapSize { + t.Fatalf("Range visited %v elements of %v-element Map", len(seen), mapSize) + } + } +}