From 4d55a62ada22931ce7e54ef133d3472c40e81148 Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 5 May 2024 12:00:25 +0200 Subject: [PATCH] bloblru: add test for GetOrCompute --- internal/bloblru/cache_test.go | 67 ++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/internal/bloblru/cache_test.go b/internal/bloblru/cache_test.go index aa6f4465c..b2becd256 100644 --- a/internal/bloblru/cache_test.go +++ b/internal/bloblru/cache_test.go @@ -1,11 +1,14 @@ package bloblru import ( + "context" + "fmt" "math/rand" "testing" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" + "golang.org/x/sync/errgroup" ) func TestCache(t *testing.T) { @@ -52,6 +55,70 @@ func TestCache(t *testing.T) { rtest.Equals(t, cacheSize, c.free) } +func TestCacheGetOrCompute(t *testing.T) { + var id1, id2 restic.ID + id1[0] = 1 + id2[0] = 2 + + const ( + kiB = 1 << 10 + cacheSize = 64*kiB + 3*overhead + ) + + c := New(cacheSize) + + e := fmt.Errorf("broken") + _, err := c.GetOrCompute(id1, func() ([]byte, error) { + return nil, e + }) + rtest.Equals(t, e, err, "expected error was not returned") + + // fill buffer + data1 := make([]byte, 10*kiB) + blob, err := c.GetOrCompute(id1, func() ([]byte, error) { + return data1, nil + }) + rtest.OK(t, err) + rtest.Equals(t, &data1[0], &blob[0], "wrong buffer returend") + + // now the buffer should be returned without calling the compute function + blob, err = c.GetOrCompute(id1, func() ([]byte, error) { + return nil, e + }) + rtest.OK(t, err) + rtest.Equals(t, &data1[0], &blob[0], "wrong buffer returend") + + // check concurrency + wg, _ := errgroup.WithContext(context.TODO()) + wait := make(chan struct{}) + calls := make(chan struct{}, 10) + + // start a bunch of blocking goroutines + for i := 0; i < 10; i++ { + wg.Go(func() error { + buf, err := c.GetOrCompute(id2, func() ([]byte, error) { + // block to ensure that multiple requests are waiting in parallel + <-wait + calls <- struct{}{} + return make([]byte, 42), nil + }) + if len(buf) != 42 { + return fmt.Errorf("wrong buffer") + } + return err + }) + } + + close(wait) + rtest.OK(t, wg.Wait()) + close(calls) + count := 0 + for range calls { + count++ + } + rtest.Equals(t, 1, count, "expected exactly one call of the compute function") +} + func BenchmarkAdd(b *testing.B) { const ( MiB = 1 << 20