From bd03af2febc5e223e8edf4fade1638d43720329a Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 5 May 2024 11:37:35 +0200 Subject: [PATCH] dump: add GetOrCompute to bloblru cache --- internal/bloblru/cache.go | 48 +++++++++++++++++++++++++++++++++++++-- internal/dump/common.go | 14 ++++-------- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/internal/bloblru/cache.go b/internal/bloblru/cache.go index 302ecc769..4477e37a9 100644 --- a/internal/bloblru/cache.go +++ b/internal/bloblru/cache.go @@ -20,13 +20,15 @@ type Cache struct { c *simplelru.LRU[restic.ID, []byte] free, size int // Current and max capacity, in bytes. + inProgress map[restic.ID]chan struct{} } // New constructs a blob cache that stores at most size bytes worth of blobs. func New(size int) *Cache { c := &Cache{ - free: size, - size: size, + free: size, + size: size, + inProgress: make(map[restic.ID]chan struct{}), } // NewLRU wants us to specify some max. number of entries, else it errors. @@ -85,6 +87,48 @@ func (c *Cache) Get(id restic.ID) ([]byte, bool) { return blob, ok } +func (c *Cache) GetOrCompute(id restic.ID, compute func() ([]byte, error)) ([]byte, error) { + // check if already cached + blob, ok := c.Get(id) + if ok { + return blob, nil + } + + // check for parallel download or start our own + finish := make(chan struct{}) + c.mu.Lock() + waitForResult, isDownloading := c.inProgress[id] + if !isDownloading { + c.inProgress[id] = finish + + // remove progress channel once finished here + defer func() { + c.mu.Lock() + delete(c.inProgress, id) + c.mu.Unlock() + close(finish) + }() + } + c.mu.Unlock() + + if isDownloading { + // wait for result of parallel download + <-waitForResult + blob, ok := c.Get(id) + if ok { + return blob, nil + } + } + + // download it + blob, err := compute() + if err == nil { + c.Add(id, blob) + } + + return blob, err +} + func (c *Cache) evict(key restic.ID, blob []byte) { debug.Log("bloblru.Cache: evict %v, %d bytes", key, cap(blob)) c.free += cap(blob) + overhead diff --git a/internal/dump/common.go b/internal/dump/common.go index 116762b5a..62145ba9c 100644 --- a/internal/dump/common.go +++ b/internal/dump/common.go @@ -143,15 +143,11 @@ func (d *Dumper) writeNode(ctx context.Context, w io.Writer, node *restic.Node) for i := uint(0); i < d.repo.Connections(); i++ { wg.Go(func() error { for task := range loaderCh { - var err error - blob, ok := d.cache.Get(task.id) - if !ok { - blob, err = d.repo.LoadBlob(ctx, restic.DataBlob, task.id, nil) - if err != nil { - return err - } - - d.cache.Add(task.id, blob) + blob, err := d.cache.GetOrCompute(task.id, func() ([]byte, error) { + return d.repo.LoadBlob(ctx, restic.DataBlob, task.id, nil) + }) + if err != nil { + return err } select {