feat: Add context to storagedriver.(Filewriter).Commit()

This commit changes storagedriver.Filewriter interface
by adding context.Context as an argument to its Commit
func.

We pass the context appropriately where need be throughout
the distribution codebase to all the writers and tests.

S3 driver writer unfortunately must maintain the context
passed down to it from upstream so it contnues to
implement io.Writer and io.Closer interfaces which do not
allow accepting the context in any of their funcs.

Co-authored-by: Cory Snider <corhere@gmail.com>
Signed-off-by: Milos Gajdos <milosthegajdos@gmail.com>
This commit is contained in:
Milos Gajdos 2023-10-18 10:34:10 +01:00
parent 915ad2d5a6
commit cb0d083d8d
No known key found for this signature in database
10 changed files with 35 additions and 33 deletions

View file

@ -57,7 +57,7 @@ func (bw *blobWriter) StartedAt() time.Time {
func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) { func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
dcontext.GetLogger(ctx).Debug("(*blobWriter).Commit") dcontext.GetLogger(ctx).Debug("(*blobWriter).Commit")
if err := bw.fileWriter.Commit(); err != nil { if err := bw.fileWriter.Commit(ctx); err != nil {
return distribution.Descriptor{}, err return distribution.Descriptor{}, err
} }

View file

@ -544,7 +544,7 @@ func (w *writer) Cancel(ctx context.Context) error {
return err return err
} }
func (w *writer) Commit() error { func (w *writer) Commit(ctx context.Context) error {
if w.closed { if w.closed {
return fmt.Errorf("already closed") return fmt.Errorf("already closed")
} else if w.committed { } else if w.committed {

View file

@ -120,7 +120,7 @@ func TestCommitAfterMove(t *testing.T) {
t.Fatalf("writer.Write: unexpected error: %v", err) t.Fatalf("writer.Write: unexpected error: %v", err)
} }
err = writer.Commit() err = writer.Commit(ctx)
if err != nil { if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err) t.Fatalf("writer.Commit: unexpected error: %v", err)
} }

View file

@ -142,7 +142,7 @@ func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte
writer.Cancel(ctx) writer.Cancel(ctx)
return err return err
} }
return writer.Commit() return writer.Commit(ctx)
} }
// Reader retrieves an io.ReadCloser for the content stored at "path" with a // Reader retrieves an io.ReadCloser for the content stored at "path" with a
@ -397,7 +397,7 @@ func (fw *fileWriter) Cancel(ctx context.Context) error {
return os.Remove(fw.file.Name()) return os.Remove(fw.file.Name())
} }
func (fw *fileWriter) Commit() error { func (fw *fileWriter) Commit(ctx context.Context) error {
if fw.closed { if fw.closed {
return fmt.Errorf("already closed") return fmt.Errorf("already closed")
} else if fw.committed { } else if fw.committed {

View file

@ -460,12 +460,11 @@ func putContentsClose(wc *storage.Writer, contents []byte) error {
// Commit flushes all content written to this FileWriter and makes it // Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and // available for future calls to StorageDriver.GetContent and
// StorageDriver.Reader. // StorageDriver.Reader.
func (w *writer) Commit() error { func (w *writer) Commit(ctx context.Context) error {
if err := w.checkClosed(); err != nil { if err := w.checkClosed(); err != nil {
return err return err
} }
w.closed = true w.closed = true
ctx := context.TODO()
// no session started yet just perform a simple upload // no session started yet just perform a simple upload
if w.sessionURI == "" { if w.sessionURI == "" {

View file

@ -4,6 +4,7 @@
package gcs package gcs
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"testing" "testing"
@ -122,7 +123,7 @@ func TestCommitEmpty(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("driver.Writer: unexpected error: %v", err) t.Fatalf("driver.Writer: unexpected error: %v", err)
} }
err = writer.Commit() err = writer.Commit(context.Background())
if err != nil { if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err) t.Fatalf("writer.Commit: unexpected error: %v", err)
} }
@ -169,7 +170,7 @@ func TestCommit(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("writer.Write: unexpected error: %v", err) t.Fatalf("writer.Write: unexpected error: %v", err)
} }
err = writer.Commit() err = writer.Commit(context.Background())
if err != nil { if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err) t.Fatalf("writer.Commit: unexpected error: %v", err)
} }

View file

@ -320,7 +320,7 @@ func (w *writer) Cancel(ctx context.Context) error {
return w.d.root.delete(w.f.path()) return w.d.root.delete(w.f.path())
} }
func (w *writer) Commit() error { func (w *writer) Commit(ctx context.Context) error {
if w.closed { if w.closed {
return fmt.Errorf("already closed") return fmt.Errorf("already closed")
} else if w.committed { } else if w.committed {

View file

@ -696,7 +696,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto
if err != nil { if err != nil {
return nil, err return nil, err
} }
return d.newWriter(key, *resp.UploadId, nil), nil return d.newWriter(ctx, key, *resp.UploadId, nil), nil
} }
listMultipartUploadsInput := &s3.ListMultipartUploadsInput{ listMultipartUploadsInput := &s3.ListMultipartUploadsInput{
@ -743,7 +743,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto
} }
allParts = append(allParts, partsList.Parts...) allParts = append(allParts, partsList.Parts...)
} }
return d.newWriter(key, *multi.UploadId, allParts), nil return d.newWriter(ctx, key, *multi.UploadId, allParts), nil
} }
// resp.NextUploadIdMarker must have at least one element or we would have returned not found // resp.NextUploadIdMarker must have at least one element or we would have returned not found
@ -1338,6 +1338,7 @@ func (b *buffer) Clear() {
// cleanly resumed in the future. This is violated if Close is called after less // cleanly resumed in the future. This is violated if Close is called after less
// than a full chunk is written. // than a full chunk is written.
type writer struct { type writer struct {
ctx context.Context
driver *driver driver *driver
key string key string
uploadID string uploadID string
@ -1350,12 +1351,13 @@ type writer struct {
cancelled bool cancelled bool
} }
func (d *driver) newWriter(key, uploadID string, parts []*s3.Part) storagedriver.FileWriter { func (d *driver) newWriter(ctx context.Context, key, uploadID string, parts []*s3.Part) storagedriver.FileWriter {
var size int64 var size int64
for _, part := range parts { for _, part := range parts {
size += *part.Size size += *part.Size
} }
return &writer{ return &writer{
ctx: ctx,
driver: d, driver: d,
key: key, key: key,
uploadID: uploadID, uploadID: uploadID,
@ -1394,7 +1396,7 @@ func (w *writer) Write(p []byte) (int, error) {
sort.Sort(completedUploadedParts) sort.Sort(completedUploadedParts)
_, err := w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{ _, err := w.driver.S3.CompleteMultipartUploadWithContext(w.ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key), Key: aws.String(w.key),
UploadId: aws.String(w.uploadID), UploadId: aws.String(w.uploadID),
@ -1403,7 +1405,7 @@ func (w *writer) Write(p []byte) (int, error) {
}, },
}) })
if err != nil { if err != nil {
if _, aErr := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ if _, aErr := w.driver.S3.AbortMultipartUploadWithContext(w.ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key), Key: aws.String(w.key),
UploadId: aws.String(w.uploadID), UploadId: aws.String(w.uploadID),
@ -1413,7 +1415,7 @@ func (w *writer) Write(p []byte) (int, error) {
return 0, err return 0, err
} }
resp, err := w.driver.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{ resp, err := w.driver.S3.CreateMultipartUploadWithContext(w.ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key), Key: aws.String(w.key),
ContentType: w.driver.getContentType(), ContentType: w.driver.getContentType(),
@ -1429,7 +1431,7 @@ func (w *writer) Write(p []byte) (int, error) {
// If the entire written file is smaller than minChunkSize, we need to make // If the entire written file is smaller than minChunkSize, we need to make
// a new part from scratch :double sad face: // a new part from scratch :double sad face:
if w.size < minChunkSize { if w.size < minChunkSize {
resp, err := w.driver.S3.GetObject(&s3.GetObjectInput{ resp, err := w.driver.S3.GetObjectWithContext(w.ctx, &s3.GetObjectInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key), Key: aws.String(w.key),
}) })
@ -1451,7 +1453,7 @@ func (w *writer) Write(p []byte) (int, error) {
} }
} else { } else {
// Otherwise we can use the old file as the new first part // Otherwise we can use the old file as the new first part
copyPartResp, err := w.driver.S3.UploadPartCopy(&s3.UploadPartCopyInput{ copyPartResp, err := w.driver.S3.UploadPartCopyWithContext(w.ctx, &s3.UploadPartCopyInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
CopySource: aws.String(w.driver.Bucket + "/" + w.key), CopySource: aws.String(w.driver.Bucket + "/" + w.key),
Key: aws.String(w.key), Key: aws.String(w.key),
@ -1536,7 +1538,7 @@ func (w *writer) Cancel(ctx context.Context) error {
return fmt.Errorf("already committed") return fmt.Errorf("already committed")
} }
w.cancelled = true w.cancelled = true
_, err := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ _, err := w.driver.S3.AbortMultipartUploadWithContext(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key), Key: aws.String(w.key),
UploadId: aws.String(w.uploadID), UploadId: aws.String(w.uploadID),
@ -1544,7 +1546,7 @@ func (w *writer) Cancel(ctx context.Context) error {
return err return err
} }
func (w *writer) Commit() error { func (w *writer) Commit(ctx context.Context) error {
if w.closed { if w.closed {
return fmt.Errorf("already closed") return fmt.Errorf("already closed")
} else if w.committed { } else if w.committed {
@ -1576,7 +1578,7 @@ func (w *writer) Commit() error {
// Solution: we upload the empty i.e. 0 byte part as a single part and then append it // Solution: we upload the empty i.e. 0 byte part as a single part and then append it
// to the completedUploadedParts slice used to complete the Multipart upload. // to the completedUploadedParts slice used to complete the Multipart upload.
if len(w.parts) == 0 { if len(w.parts) == 0 {
resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{ resp, err := w.driver.S3.UploadPartWithContext(w.ctx, &s3.UploadPartInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key), Key: aws.String(w.key),
PartNumber: aws.Int64(1), PartNumber: aws.Int64(1),
@ -1595,7 +1597,7 @@ func (w *writer) Commit() error {
sort.Sort(completedUploadedParts) sort.Sort(completedUploadedParts)
_, err = w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{ _, err = w.driver.S3.CompleteMultipartUploadWithContext(w.ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key), Key: aws.String(w.key),
UploadId: aws.String(w.uploadID), UploadId: aws.String(w.uploadID),
@ -1604,7 +1606,7 @@ func (w *writer) Commit() error {
}, },
}) })
if err != nil { if err != nil {
if _, aErr := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{ if _, aErr := w.driver.S3.AbortMultipartUploadWithContext(w.ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key), Key: aws.String(w.key),
UploadId: aws.String(w.uploadID), UploadId: aws.String(w.uploadID),
@ -1634,7 +1636,7 @@ func (w *writer) flush() error {
partSize := buf.Len() partSize := buf.Len()
partNumber := aws.Int64(int64(len(w.parts) + 1)) partNumber := aws.Int64(int64(len(w.parts) + 1))
resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{ resp, err := w.driver.S3.UploadPartWithContext(w.ctx, &s3.UploadPartInput{
Bucket: aws.String(w.driver.Bucket), Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key), Key: aws.String(w.key),
PartNumber: partNumber, PartNumber: partNumber,

View file

@ -123,7 +123,7 @@ type FileWriter interface {
// Commit flushes all content written to this FileWriter and makes it // Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and // available for future calls to StorageDriver.GetContent and
// StorageDriver.Reader. // StorageDriver.Reader.
Commit() error Commit(context.Context) error
} }
// PathRegexp is the regular expression which each file path must match. A // PathRegexp is the regular expression which each file path must match. A

View file

@ -291,7 +291,7 @@ func (suite *DriverSuite) TestWriteReadLargeStreams(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(written, check.Equals, fileSize) c.Assert(written, check.Equals, fileSize)
err = writer.Commit() err = writer.Commit(context.Background())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = writer.Close() err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -446,7 +446,7 @@ func (suite *DriverSuite) testContinueStreamAppend(c *check.C, chunkSize int64)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(fullContents[curSize:]))) c.Assert(nn, check.Equals, int64(len(fullContents[curSize:])))
err = writer.Commit() err = writer.Commit(context.Background())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = writer.Close() err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -484,7 +484,7 @@ func (suite *DriverSuite) TestWriteZeroByteStreamThenAppend(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// Close the Writer // Close the Writer
err = writer.Commit() err = writer.Commit(context.Background())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = writer.Close() err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -512,7 +512,7 @@ func (suite *DriverSuite) TestWriteZeroByteStreamThenAppend(c *check.C) {
c.Assert(nn, check.Equals, int64(len(contentsChunk1))) c.Assert(nn, check.Equals, int64(len(contentsChunk1)))
// Close the AppendWriter // Close the AppendWriter
err = awriter.Commit() err = awriter.Commit(context.Background())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = awriter.Close() err = awriter.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -561,7 +561,7 @@ func (suite *DriverSuite) TestWriteZeroByteContentThenAppend(c *check.C) {
c.Assert(nn, check.Equals, int64(len(contentsChunk1))) c.Assert(nn, check.Equals, int64(len(contentsChunk1)))
// Close the AppendWriter // Close the AppendWriter
err = awriter.Commit() err = awriter.Commit(context.Background())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = awriter.Close() err = awriter.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -1156,7 +1156,7 @@ func (suite *DriverSuite) benchmarkStreamFiles(c *check.C, size int64) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(written, check.Equals, size) c.Assert(written, check.Equals, size)
err = writer.Commit() err = writer.Commit(context.Background())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = writer.Close() err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -1248,7 +1248,7 @@ func (suite *DriverSuite) testFileStreams(c *check.C, size int64) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, size) c.Assert(nn, check.Equals, size)
err = writer.Commit() err = writer.Commit(context.Background())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = writer.Close() err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -1284,7 +1284,7 @@ func (suite *DriverSuite) writeReadCompareStreams(c *check.C, filename string, c
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contents))) c.Assert(nn, check.Equals, int64(len(contents)))
err = writer.Commit() err = writer.Commit(context.Background())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = writer.Close() err = writer.Close()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)