diff --git a/internal/archiver/tree_saver_test.go b/internal/archiver/tree_saver_test.go index 67ef21111..053b3e249 100644 --- a/internal/archiver/tree_saver_test.go +++ b/internal/archiver/tree_saver_test.go @@ -11,31 +11,38 @@ import ( "golang.org/x/sync/errgroup" ) -func newFutureBlobWithResponse() FutureBlob { +func treeSaveHelper(ctx context.Context, t restic.BlobType, buf *Buffer) FutureBlob { ch := make(chan SaveBlobResponse, 1) ch <- SaveBlobResponse{ id: restic.NewRandomID(), known: false, - length: 123, - sizeInRepo: 123, + length: len(buf.Data), + sizeInRepo: len(buf.Data), } return FutureBlob{ch: ch} } -func TestTreeSaver(t *testing.T) { +func setupTreeSaver() (context.Context, context.CancelFunc, *TreeSaver, func() error) { ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - wg, ctx := errgroup.WithContext(ctx) - saveFn := func(ctx context.Context, t restic.BlobType, buf *Buffer) FutureBlob { - return newFutureBlobWithResponse() - } errFn := func(snPath string, err error) error { return err } - b := NewTreeSaver(ctx, wg, uint(runtime.NumCPU()), saveFn, errFn) + b := NewTreeSaver(ctx, wg, uint(runtime.NumCPU()), treeSaveHelper, errFn) + + shutdown := func() error { + b.TriggerShutdown() + return wg.Wait() + } + + return ctx, cancel, b, shutdown +} + +func TestTreeSaver(t *testing.T) { + ctx, cancel, b, shutdown := setupTreeSaver() + defer cancel() var results []FutureNode @@ -52,9 +59,7 @@ func TestTreeSaver(t *testing.T) { tree.take(ctx) } - b.TriggerShutdown() - - err := wg.Wait() + err := shutdown() if err != nil { t.Fatal(err) } @@ -76,20 +81,9 @@ func TestTreeSaverError(t *testing.T) { for _, test := range tests { t.Run("", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel, b, shutdown := setupTreeSaver() defer cancel() - wg, ctx := errgroup.WithContext(ctx) - - saveFn := func(ctx context.Context, tpe restic.BlobType, buf *Buffer) FutureBlob { - return newFutureBlobWithResponse() - } - errFn := func(snPath string, err error) error { - return err - } - - b := NewTreeSaver(ctx, wg, uint(runtime.NumCPU()), saveFn, errFn) - var results []FutureNode for i := 0; i < test.trees; i++ { @@ -115,13 +109,10 @@ func TestTreeSaverError(t *testing.T) { tree.take(ctx) } - b.TriggerShutdown() - - err := wg.Wait() + err := shutdown() if err == nil { t.Errorf("expected error not found") } - if err != errTest { t.Fatalf("unexpected error found: %v", err) }