From cb0d083d8d7b33668f874c519cb13867216eb4fa Mon Sep 17 00:00:00 2001 From: Milos Gajdos Date: Wed, 18 Oct 2023 10:34:10 +0100 Subject: [PATCH] feat: Add context to storagedriver.(Filewriter).Commit() This commit changes storagedriver.Filewriter interface by adding context.Context as an argument to its Commit func. We pass the context appropriately where need be throughout the distribution codebase to all the writers and tests. S3 driver writer unfortunately must maintain the context passed down to it from upstream so it contnues to implement io.Writer and io.Closer interfaces which do not allow accepting the context in any of their funcs. Co-authored-by: Cory Snider Signed-off-by: Milos Gajdos --- registry/storage/blobwriter.go | 2 +- registry/storage/driver/azure/azure.go | 2 +- registry/storage/driver/azure/azure_test.go | 2 +- registry/storage/driver/filesystem/driver.go | 4 +-- registry/storage/driver/gcs/gcs.go | 3 +- registry/storage/driver/gcs/gcs_test.go | 5 ++-- registry/storage/driver/inmemory/driver.go | 2 +- registry/storage/driver/s3-aws/s3.go | 30 ++++++++++--------- registry/storage/driver/storagedriver.go | 2 +- .../storage/driver/testsuites/testsuites.go | 16 +++++----- 10 files changed, 35 insertions(+), 33 deletions(-) diff --git a/registry/storage/blobwriter.go b/registry/storage/blobwriter.go index 48bd2a8c..70c8b2af 100644 --- a/registry/storage/blobwriter.go +++ b/registry/storage/blobwriter.go @@ -57,7 +57,7 @@ func (bw *blobWriter) StartedAt() time.Time { func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { dcontext.GetLogger(ctx).Debug("(*blobWriter).Commit") - if err := bw.fileWriter.Commit(); err != nil { + if err := bw.fileWriter.Commit(ctx); err != nil { return distribution.Descriptor{}, err } diff --git a/registry/storage/driver/azure/azure.go b/registry/storage/driver/azure/azure.go index 3e3c05ce..585c8b43 100644 --- a/registry/storage/driver/azure/azure.go +++ b/registry/storage/driver/azure/azure.go @@ -544,7 +544,7 @@ func (w *writer) Cancel(ctx context.Context) error { return err } -func (w *writer) Commit() error { +func (w *writer) Commit(ctx context.Context) error { if w.closed { return fmt.Errorf("already closed") } else if w.committed { diff --git a/registry/storage/driver/azure/azure_test.go b/registry/storage/driver/azure/azure_test.go index c0d4d2e6..945cc0a9 100644 --- a/registry/storage/driver/azure/azure_test.go +++ b/registry/storage/driver/azure/azure_test.go @@ -120,7 +120,7 @@ func TestCommitAfterMove(t *testing.T) { t.Fatalf("writer.Write: unexpected error: %v", err) } - err = writer.Commit() + err = writer.Commit(ctx) if err != nil { t.Fatalf("writer.Commit: unexpected error: %v", err) } diff --git a/registry/storage/driver/filesystem/driver.go b/registry/storage/driver/filesystem/driver.go index ade7b283..23033268 100644 --- a/registry/storage/driver/filesystem/driver.go +++ b/registry/storage/driver/filesystem/driver.go @@ -142,7 +142,7 @@ func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte writer.Cancel(ctx) return err } - return writer.Commit() + return writer.Commit(ctx) } // Reader retrieves an io.ReadCloser for the content stored at "path" with a @@ -397,7 +397,7 @@ func (fw *fileWriter) Cancel(ctx context.Context) error { return os.Remove(fw.file.Name()) } -func (fw *fileWriter) Commit() error { +func (fw *fileWriter) Commit(ctx context.Context) error { if fw.closed { return fmt.Errorf("already closed") } else if fw.committed { diff --git a/registry/storage/driver/gcs/gcs.go b/registry/storage/driver/gcs/gcs.go index 81e1dde6..5b276d65 100644 --- a/registry/storage/driver/gcs/gcs.go +++ b/registry/storage/driver/gcs/gcs.go @@ -460,12 +460,11 @@ func putContentsClose(wc *storage.Writer, contents []byte) error { // Commit flushes all content written to this FileWriter and makes it // available for future calls to StorageDriver.GetContent and // StorageDriver.Reader. -func (w *writer) Commit() error { +func (w *writer) Commit(ctx context.Context) error { if err := w.checkClosed(); err != nil { return err } w.closed = true - ctx := context.TODO() // no session started yet just perform a simple upload if w.sessionURI == "" { diff --git a/registry/storage/driver/gcs/gcs_test.go b/registry/storage/driver/gcs/gcs_test.go index d9f0da3e..65998e7b 100644 --- a/registry/storage/driver/gcs/gcs_test.go +++ b/registry/storage/driver/gcs/gcs_test.go @@ -4,6 +4,7 @@ package gcs import ( + "context" "fmt" "os" "testing" @@ -122,7 +123,7 @@ func TestCommitEmpty(t *testing.T) { if err != nil { t.Fatalf("driver.Writer: unexpected error: %v", err) } - err = writer.Commit() + err = writer.Commit(context.Background()) if err != nil { t.Fatalf("writer.Commit: unexpected error: %v", err) } @@ -169,7 +170,7 @@ func TestCommit(t *testing.T) { if err != nil { t.Fatalf("writer.Write: unexpected error: %v", err) } - err = writer.Commit() + err = writer.Commit(context.Background()) if err != nil { t.Fatalf("writer.Commit: unexpected error: %v", err) } diff --git a/registry/storage/driver/inmemory/driver.go b/registry/storage/driver/inmemory/driver.go index d9c30f38..4c00ca40 100644 --- a/registry/storage/driver/inmemory/driver.go +++ b/registry/storage/driver/inmemory/driver.go @@ -320,7 +320,7 @@ func (w *writer) Cancel(ctx context.Context) error { return w.d.root.delete(w.f.path()) } -func (w *writer) Commit() error { +func (w *writer) Commit(ctx context.Context) error { if w.closed { return fmt.Errorf("already closed") } else if w.committed { diff --git a/registry/storage/driver/s3-aws/s3.go b/registry/storage/driver/s3-aws/s3.go index 22c0a449..a6428a91 100644 --- a/registry/storage/driver/s3-aws/s3.go +++ b/registry/storage/driver/s3-aws/s3.go @@ -696,7 +696,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto if err != nil { return nil, err } - return d.newWriter(key, *resp.UploadId, nil), nil + return d.newWriter(ctx, key, *resp.UploadId, nil), nil } listMultipartUploadsInput := &s3.ListMultipartUploadsInput{ @@ -743,7 +743,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto } allParts = append(allParts, partsList.Parts...) } - return d.newWriter(key, *multi.UploadId, allParts), nil + return d.newWriter(ctx, key, *multi.UploadId, allParts), nil } // resp.NextUploadIdMarker must have at least one element or we would have returned not found @@ -1338,6 +1338,7 @@ func (b *buffer) Clear() { // cleanly resumed in the future. This is violated if Close is called after less // than a full chunk is written. type writer struct { + ctx context.Context driver *driver key string uploadID string @@ -1350,12 +1351,13 @@ type writer struct { cancelled bool } -func (d *driver) newWriter(key, uploadID string, parts []*s3.Part) storagedriver.FileWriter { +func (d *driver) newWriter(ctx context.Context, key, uploadID string, parts []*s3.Part) storagedriver.FileWriter { var size int64 for _, part := range parts { size += *part.Size } return &writer{ + ctx: ctx, driver: d, key: key, uploadID: uploadID, @@ -1394,7 +1396,7 @@ func (w *writer) Write(p []byte) (int, error) { sort.Sort(completedUploadedParts) - _, err := w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{ + _, err := w.driver.S3.CompleteMultipartUploadWithContext(w.ctx, &s3.CompleteMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1403,7 +1405,7 @@ func (w *writer) Write(p []byte) (int, error) { }, }) if err != nil { - if _, aErr := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ + if _, aErr := w.driver.S3.AbortMultipartUploadWithContext(w.ctx, &s3.AbortMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1413,7 +1415,7 @@ func (w *writer) Write(p []byte) (int, error) { return 0, err } - resp, err := w.driver.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{ + resp, err := w.driver.S3.CreateMultipartUploadWithContext(w.ctx, &s3.CreateMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), ContentType: w.driver.getContentType(), @@ -1429,7 +1431,7 @@ func (w *writer) Write(p []byte) (int, error) { // If the entire written file is smaller than minChunkSize, we need to make // a new part from scratch :double sad face: if w.size < minChunkSize { - resp, err := w.driver.S3.GetObject(&s3.GetObjectInput{ + resp, err := w.driver.S3.GetObjectWithContext(w.ctx, &s3.GetObjectInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), }) @@ -1451,7 +1453,7 @@ func (w *writer) Write(p []byte) (int, error) { } } else { // Otherwise we can use the old file as the new first part - copyPartResp, err := w.driver.S3.UploadPartCopy(&s3.UploadPartCopyInput{ + copyPartResp, err := w.driver.S3.UploadPartCopyWithContext(w.ctx, &s3.UploadPartCopyInput{ Bucket: aws.String(w.driver.Bucket), CopySource: aws.String(w.driver.Bucket + "/" + w.key), Key: aws.String(w.key), @@ -1536,7 +1538,7 @@ func (w *writer) Cancel(ctx context.Context) error { return fmt.Errorf("already committed") } w.cancelled = true - _, err := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ + _, err := w.driver.S3.AbortMultipartUploadWithContext(ctx, &s3.AbortMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1544,7 +1546,7 @@ func (w *writer) Cancel(ctx context.Context) error { return err } -func (w *writer) Commit() error { +func (w *writer) Commit(ctx context.Context) error { if w.closed { return fmt.Errorf("already closed") } else if w.committed { @@ -1576,7 +1578,7 @@ func (w *writer) Commit() error { // Solution: we upload the empty i.e. 0 byte part as a single part and then append it // to the completedUploadedParts slice used to complete the Multipart upload. if len(w.parts) == 0 { - resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{ + resp, err := w.driver.S3.UploadPartWithContext(w.ctx, &s3.UploadPartInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), PartNumber: aws.Int64(1), @@ -1595,7 +1597,7 @@ func (w *writer) Commit() error { sort.Sort(completedUploadedParts) - _, err = w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{ + _, err = w.driver.S3.CompleteMultipartUploadWithContext(w.ctx, &s3.CompleteMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1604,7 +1606,7 @@ func (w *writer) Commit() error { }, }) if err != nil { - if _, aErr := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ + if _, aErr := w.driver.S3.AbortMultipartUploadWithContext(w.ctx, &s3.AbortMultipartUploadInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), UploadId: aws.String(w.uploadID), @@ -1634,7 +1636,7 @@ func (w *writer) flush() error { partSize := buf.Len() partNumber := aws.Int64(int64(len(w.parts) + 1)) - resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{ + resp, err := w.driver.S3.UploadPartWithContext(w.ctx, &s3.UploadPartInput{ Bucket: aws.String(w.driver.Bucket), Key: aws.String(w.key), PartNumber: partNumber, diff --git a/registry/storage/driver/storagedriver.go b/registry/storage/driver/storagedriver.go index ef81a60e..f3289349 100644 --- a/registry/storage/driver/storagedriver.go +++ b/registry/storage/driver/storagedriver.go @@ -123,7 +123,7 @@ type FileWriter interface { // Commit flushes all content written to this FileWriter and makes it // available for future calls to StorageDriver.GetContent and // StorageDriver.Reader. - Commit() error + Commit(context.Context) error } // PathRegexp is the regular expression which each file path must match. A diff --git a/registry/storage/driver/testsuites/testsuites.go b/registry/storage/driver/testsuites/testsuites.go index c732f152..6ad827fd 100644 --- a/registry/storage/driver/testsuites/testsuites.go +++ b/registry/storage/driver/testsuites/testsuites.go @@ -291,7 +291,7 @@ func (suite *DriverSuite) TestWriteReadLargeStreams(c *check.C) { c.Assert(err, check.IsNil) c.Assert(written, check.Equals, fileSize) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -446,7 +446,7 @@ func (suite *DriverSuite) testContinueStreamAppend(c *check.C, chunkSize int64) c.Assert(err, check.IsNil) c.Assert(nn, check.Equals, int64(len(fullContents[curSize:]))) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -484,7 +484,7 @@ func (suite *DriverSuite) TestWriteZeroByteStreamThenAppend(c *check.C) { c.Assert(err, check.IsNil) // Close the Writer - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -512,7 +512,7 @@ func (suite *DriverSuite) TestWriteZeroByteStreamThenAppend(c *check.C) { c.Assert(nn, check.Equals, int64(len(contentsChunk1))) // Close the AppendWriter - err = awriter.Commit() + err = awriter.Commit(context.Background()) c.Assert(err, check.IsNil) err = awriter.Close() c.Assert(err, check.IsNil) @@ -561,7 +561,7 @@ func (suite *DriverSuite) TestWriteZeroByteContentThenAppend(c *check.C) { c.Assert(nn, check.Equals, int64(len(contentsChunk1))) // Close the AppendWriter - err = awriter.Commit() + err = awriter.Commit(context.Background()) c.Assert(err, check.IsNil) err = awriter.Close() c.Assert(err, check.IsNil) @@ -1156,7 +1156,7 @@ func (suite *DriverSuite) benchmarkStreamFiles(c *check.C, size int64) { c.Assert(err, check.IsNil) c.Assert(written, check.Equals, size) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -1248,7 +1248,7 @@ func (suite *DriverSuite) testFileStreams(c *check.C, size int64) { c.Assert(err, check.IsNil) c.Assert(nn, check.Equals, size) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil) @@ -1284,7 +1284,7 @@ func (suite *DriverSuite) writeReadCompareStreams(c *check.C, filename string, c c.Assert(err, check.IsNil) c.Assert(nn, check.Equals, int64(len(contents))) - err = writer.Commit() + err = writer.Commit(context.Background()) c.Assert(err, check.IsNil) err = writer.Close() c.Assert(err, check.IsNil)