From cb0d083d8d7b33668f874c519cb13867216eb4fa Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Wed, 18 Oct 2023 10:34:10 +0100 Subject: [PATCH] feat: Add context to storagedriver.(Filewriter).Commit() This commit changes storagedriver.Filewriter interface by adding context.Context as an argument to its Commit func. We pass the context appropriately where need be throughout the distribution codebase to all the writers and tests. S3 driver writer unfortunately must maintain the context passed down to it from upstream so it contnues to implement io.Writer and io.Closer interfaces which do not allow accepting the context in any of their funcs. Co-authored-by: Cory Snider Signed-off-by: Milos Gajdos --- registry/storage/blobwriter.go | 2 +- registry/storage/driver/azure/azure.go | 2 +- registry/storage/driver/azure/azure_test.go | 2 +- registry/storage/driver/filesystem/driver.go | 4 +-- registry/storage/driver/gcs/gcs.go | 3 +- registry/storage/driver/gcs/gcs_test.go | 5 ++-- registry/storage/driver/inmemory/driver.go | 2 +- registry/storage/driver/s3-aws/s3.go | 30 ++++++++++--------- registry/storage/driver/storagedriver.go | 2 +- .../storage/driver/testsuites/testsuites.go | 16 +++++----- 10 files changed, 35 insertions(+), 33 deletions(-) diff --git a/registry/storage/blobwriter.go b/registry/storage/blobwriter.go index 48bd2a8ca..70c8b2afe 100644 --- a/registry/storage/blobwriter.go +++ b/registry/storage/blobwriter.go @@ -57,7 +57,7 @@ func (bw *blobWriter) StartedAt() time.Time { func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { dcontext.GetLogger(ctx).Debug("(*blobWriter).Commit") - if err := bw.fileWriter.Commit(); err != nil { + if err := bw.fileWriter.Commit(ctx); err != nil { return distribution.Descriptor{}, err } diff --git a/registry/storage/driver/azure/azure.go b/registry/storage/driver/azure/azure.go index 3e3c05ceb..585c8b432 100644 --- a/registry/storage/driver/azure/azure.go +++ b/registry/storage/driver/azure/azure.go @@ -544,7 +544,7 @@ func (w *writer) Cancel(ctx context.Context) error { return err } -func (w *writer) Commit() error { +func (w *writer) Commit(ctx context.Context) error { if w.closed { return fmt.Errorf("already closed") } else if w.committed { diff --git a/registry/storage/driver/azure/azure_test.go b/registry/storage/driver/azure/azure_test.go index c0d4d2e64..945cc0a90 100644 --- a/registry/storage/driver/azure/azure_test.go +++ b/registry/storage/driver/azure/azure_test.go @@ -120,7 +120,7 @@ func TestCommitAfterMove(t *testing.T) { t.Fatalf("writer.Write: unexpected error: %v", err) } - err = writer.Commit() + err = writer.Commit(ctx) if err != nil { t.Fatalf("writer.Commit: unexpected error: %v", err) } diff --git a/registry/storage/driver/filesystem/driver.go b/registry/storage/driver/filesystem/driver.go index ade7b2838..23033268f 100644 --- a/registry/storage/driver/filesystem/driver.go +++ b/registry/storage/driver/filesystem/driver.go @@ -142,7 +142,7 @@ func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte writer.Cancel(ctx) return err } - return writer.Commit() + return writer.Commit(ctx) } // Reader retrieves an io.ReadCloser for the content stored at "path" with a @@ -397,7 +397,7 @@ func (fw *fileWriter) Cancel(ctx context.Context) error { return os.Remove(fw.file.Name()) } -func (fw *fileWriter) Commit() error { +func (fw *fileWriter) Commit(ctx context.Context) error { if fw.closed { return fmt.Errorf("already closed") } else if fw.committed { diff --git a/registry/storage/driver/gcs/gcs.go b/registry/storage/driver/gcs/gcs.go index 81e1dde6a..5b276d65f 100644 --- a/registry/storage/driver/gcs/gcs.go +++ b/registry/storage/driver/gcs/gcs.go @@ -460,12 +460,11 @@ func putContentsClose(wc *storage.Writer, contents []byte) error { // Commit flushes all content written to this FileWriter and makes it // available for future calls to StorageDriver.GetContent and // StorageDriver.Reader. -func (w *writer) Commit() error { +func (w *writer) Commit(ctx context.Context) error { if err := w.checkClosed(); err != nil { return err } w.closed = true - ctx := context.TODO() // no session started yet just perform a simple upload if w.sessionURI == "" { diff --git a/registry/storage/driver/gcs/gcs_test.go b/registry/storage/driver/gcs/gcs_test.go index d9f0da3ec..65998e7bc 100644 --- a/registry/storage/driver/gcs/gcs_test.go +++ b/registry/storage/driver/gcs/gcs_test.go @@ -4,6 +4,7 @@ package gcs import ( + "context" "fmt" "os" "testing" @@ -122,7 +123,7 @@ func TestCommitEmpty(t *testing.T) { if err != nil { t.Fatalf("driver.Writer: unexpected error: %v", err) } - err = writer.Commit() + err = writer.Commit(context.Background()) if err != nil { t.Fatalf("writer.Commit: unexpected error: %v", err) } @@ -169,7 +170,7 @@ func TestCommit(t *testing.T) { if err != nil { t.Fatalf("writer.Write: unexpected error: %v", err) } - err = writer.Commit() + err = writer.Commit(context.Background()) if err != nil { t.Fatalf("writer.Commit: unexpected error: %v", err) } diff --git a/registry/storage/driver/inmemory/driver.go b/registry/storage/driver/inmemory/driver.go index d9c30f382..4c00ca404 100644 --- a/registry/storage/driver/inmemory/driver.go +++ b/registry/storage/driver/inmemory/driver.go @@ -320,7 +320,7 @@ func (w *writer) Cancel(ctx context.Context) error { return w.d.root.delete(w.f.path()) } -func (w *writer) Commit() error { +func (w *writer) Commit(ctx context.Context) error { if w.closed { return fmt.Errorf("already closed") } else if w.committed { diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go index 22c0a4491..a6428a918 100644 --- a/registry/storage/driver/s3-aws/s3.go +++ b/registry/storage/driver/s3-aws/s3.go @@ -696,7 +696,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto if err != nil { return nil, err } - return d.newWriter(key, *resp.UploadId, nil), nil + return d.newWriter(ctx, key, *resp.UploadId, nil), nil } listMultipartUploadsInput := &s3.ListMultipartUploadsInput{ @@ -743,7 +743,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto } allParts = append(allParts, partsList.Parts...) } - return d.newWriter(key, *multi.UploadId, allParts), nil + return d.newWriter(ctx, key, *multi.UploadId, allParts), nil } // resp.NextUploadIdMarker must have at least one element or we would have returned not found @@ -1338,6 +1338,7 @@ func (b *buffer) Clear() { // cleanly resumed in the future. This is violated if Close is called after less // than a full chunk is written. type writer struct { + ctx context.Context driver *driver key string uploadID string @@ -1350,12 +1351,13 @@ type writer struct { cancelled bool } -func (d *driver) newWriter(key, uploadID string, parts []*s3.Part) storagedriver.FileWriter { +func (d *driver) newWriter(ctx context.Context, key, uploadID string, parts []*s3.Part) storagedriver.FileWriter { var size int64 for _, part := range parts { size += *part.Size } return &writer{ + ctx: ctx, driver: d, key: key, uploadID: uploadID, @@ -1394,7 +1396,7 @@ func (w *writer) Write(p []byte) (int, error) { sort.Sort(completedUploadedParts) - _, err := w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{ + _, err := w.driver.S3.CompleteMultipartUploadWithContext(w.ctx, &s3.CompleteMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1403,7 +1405,7 @@ func (w *writer) Write(p []byte) (int, error) { }, }) if err != nil { - if _, aErr := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ + if _, aErr := w.driver.S3.AbortMultipartUploadWithContext(w.ctx, &s3.AbortMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1413,7 +1415,7 @@ func (w *writer) Write(p []byte) (int, error) { return 0, err } - resp, err := w.driver.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{ + resp, err := w.driver.S3.CreateMultipartUploadWithContext(w.ctx, &s3.CreateMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), ContentType: w.driver.getContentType(), @@ -1429,7 +1431,7 @@ func (w *writer) Write(p []byte) (int, error) { // If the entire written file is smaller than minChunkSize, we need to make // a new part from scratch :double sad face: if w.size < minChunkSize { - resp, err := w.driver.S3.GetObject(&s3.GetObjectInput{ + resp, err := w.driver.S3.GetObjectWithContext(w.ctx, &s3.GetObjectInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), }) @@ -1451,7 +1453,7 @@ func (w *writer) Write(p []byte) (int, error) { } } else { // Otherwise we can use the old file as the new first part - copyPartResp, err := w.driver.S3.UploadPartCopy(&s3.UploadPartCopyInput{ + copyPartResp, err := w.driver.S3.UploadPartCopyWithContext(w.ctx, &s3.UploadPartCopyInput{ Bucket: aws.String(w.driver.Bucket), CopySource: aws.String(w.driver.Bucket + "/" + w.key), Key: aws.String(w.key), @@ -1536,7 +1538,7 @@ func (w *writer) Cancel(ctx context.Context) error { return fmt.Errorf("already committed") } w.cancelled = true - _, err := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ + _, err := w.driver.S3.AbortMultipartUploadWithContext(ctx, &s3.AbortMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1544,7 +1546,7 @@ func (w *writer) Cancel(ctx context.Context) error { return err } -func (w *writer) Commit() error { +func (w *writer) Commit(ctx context.Context) error { if w.closed { return fmt.Errorf("already closed") } else if w.committed { @@ -1576,7 +1578,7 @@ func (w *writer) Commit() error { // Solution: we upload the empty i.e. 0 byte part as a single part and then append it // to the completedUploadedParts slice used to complete the Multipart upload. if len(w.parts) == 0 { - resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{ + resp, err := w.driver.S3.UploadPartWithContext(w.ctx, &s3.UploadPartInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), PartNumber: aws.Int64(1), @@ -1595,7 +1597,7 @@ func (w *writer) Commit() error { sort.Sort(completedUploadedParts) - _, err = w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{ + _, err = w.driver.S3.CompleteMultipartUploadWithContext(w.ctx, &s3.CompleteMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1604,7 +1606,7 @@ func (w *writer) Commit() error { }, }) if err != nil { - if _, aErr := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ + if _, aErr := w.driver.S3.AbortMultipartUploadWithContext(w.ctx, &s3.AbortMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1634,7 +1636,7 @@ func (w *writer) flush() error { partSize := buf.Len() partNumber := aws.Int64(int64(len(w.parts) + 1)) - resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{ + resp, err := w.driver.S3.UploadPartWithContext(w.ctx, &s3.UploadPartInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), PartNumber: partNumber, diff --git a/registry/storage/driver/storagedriver.go b/registry/storage/driver/storagedriver.go index ef81a60e2..f32893491 100644 --- a/registry/storage/driver/storagedriver.go +++ b/registry/storage/driver/storagedriver.go @@ -123,7 +123,7 @@ type FileWriter interface { // Commit flushes all content written to this FileWriter and makes it // available for future calls to StorageDriver.GetContent and // StorageDriver.Reader. - Commit() error + Commit(context.Context) error } // PathRegexp is the regular expression which each file path must match. A diff --git a/registry/storage/driver/testsuites/testsuites.go b/registry/storage/driver/testsuites/testsuites.go index c732f152b..6ad827fd0 100644 --- a/registry/storage/driver/testsuites/testsuites.go +++ b/registry/storage/driver/testsuites/testsuites.go @@ -291,7 +291,7 @@ func (suite *DriverSuite) TestWriteReadLargeStreams(c *check.C) { c.Assert(err, check.IsNil) c.Assert(written, check.Equals, fileSize) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -446,7 +446,7 @@ func (suite *DriverSuite) testContinueStreamAppend(c *check.C, chunkSize int64) c.Assert(err, check.IsNil) c.Assert(nn, check.Equals, int64(len(fullContents[curSize:]))) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -484,7 +484,7 @@ func (suite *DriverSuite) TestWriteZeroByteStreamThenAppend(c *check.C) { c.Assert(err, check.IsNil) // Close the Writer - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -512,7 +512,7 @@ func (suite *DriverSuite) TestWriteZeroByteStreamThenAppend(c *check.C) { c.Assert(nn, check.Equals, int64(len(contentsChunk1))) // Close the AppendWriter - err = awriter.Commit() + err = awriter.Commit(context.Background()) c.Assert(err, check.IsNil) err = awriter.Close() c.Assert(err, check.IsNil) @@ -561,7 +561,7 @@ func (suite *DriverSuite) TestWriteZeroByteContentThenAppend(c *check.C) { c.Assert(nn, check.Equals, int64(len(contentsChunk1))) // Close the AppendWriter - err = awriter.Commit() + err = awriter.Commit(context.Background()) c.Assert(err, check.IsNil) err = awriter.Close() c.Assert(err, check.IsNil) @@ -1156,7 +1156,7 @@ func (suite *DriverSuite) benchmarkStreamFiles(c *check.C, size int64) { c.Assert(err, check.IsNil) c.Assert(written, check.Equals, size) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -1248,7 +1248,7 @@ func (suite *DriverSuite) testFileStreams(c *check.C, size int64) { c.Assert(err, check.IsNil) c.Assert(nn, check.Equals, size) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -1284,7 +1284,7 @@ func (suite *DriverSuite) writeReadCompareStreams(c *check.C, filename string, c c.Assert(err, check.IsNil) c.Assert(nn, check.Equals, int64(len(contents))) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil)