feat: Add context to storagedriver.(Filewriter).Commit()

This commit changes storagedriver.Filewriter interface
by adding context.Context as an argument to its Commit
func.

We pass the context appropriately where need be throughout
the distribution codebase to all the writers and tests.

S3 driver writer unfortunately must maintain the context
passed down to it from upstream so it contnues to
implement io.Writer and io.Closer interfaces which do not
allow accepting the context in any of their funcs.

Co-authored-by: Cory Snider <corhere@gmail.com>
Signed-off-by: Milos Gajdos <milosthegajdos@gmail.com>
This commit is contained in:
Milos Gajdos 2023-10-18 10:34:10 +01:00
parent 915ad2d5a6
commit cb0d083d8d
No known key found for this signature in database
10 changed files with 35 additions and 33 deletions

View file

@ -57,7 +57,7 @@ func (bw *blobWriter) StartedAt() time.Time {
func (bw *blobWriter) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
dcontext.GetLogger(ctx).Debug("(*blobWriter).Commit")
if err := bw.fileWriter.Commit(); err != nil {
if err := bw.fileWriter.Commit(ctx); err != nil {
return distribution.Descriptor{}, err
}

View file

@ -544,7 +544,7 @@ func (w *writer) Cancel(ctx context.Context) error {
return err
}
func (w *writer) Commit() error {
func (w *writer) Commit(ctx context.Context) error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {

View file

@ -120,7 +120,7 @@ func TestCommitAfterMove(t *testing.T) {
t.Fatalf("writer.Write: unexpected error: %v", err)
}
err = writer.Commit()
err = writer.Commit(ctx)
if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err)
}

View file

@ -142,7 +142,7 @@ func (d *driver) PutContent(ctx context.Context, subPath string, contents []byte
writer.Cancel(ctx)
return err
}
return writer.Commit()
return writer.Commit(ctx)
}
// Reader retrieves an io.ReadCloser for the content stored at "path" with a
@ -397,7 +397,7 @@ func (fw *fileWriter) Cancel(ctx context.Context) error {
return os.Remove(fw.file.Name())
}
func (fw *fileWriter) Commit() error {
func (fw *fileWriter) Commit(ctx context.Context) error {
if fw.closed {
return fmt.Errorf("already closed")
} else if fw.committed {

View file

@ -460,12 +460,11 @@ func putContentsClose(wc *storage.Writer, contents []byte) error {
// Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and
// StorageDriver.Reader.
func (w *writer) Commit() error {
func (w *writer) Commit(ctx context.Context) error {
if err := w.checkClosed(); err != nil {
return err
}
w.closed = true
ctx := context.TODO()
// no session started yet just perform a simple upload
if w.sessionURI == "" {

View file

@ -4,6 +4,7 @@
package gcs
import (
"context"
"fmt"
"os"
"testing"
@ -122,7 +123,7 @@ func TestCommitEmpty(t *testing.T) {
if err != nil {
t.Fatalf("driver.Writer: unexpected error: %v", err)
}
err = writer.Commit()
err = writer.Commit(context.Background())
if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err)
}
@ -169,7 +170,7 @@ func TestCommit(t *testing.T) {
if err != nil {
t.Fatalf("writer.Write: unexpected error: %v", err)
}
err = writer.Commit()
err = writer.Commit(context.Background())
if err != nil {
t.Fatalf("writer.Commit: unexpected error: %v", err)
}

View file

@ -320,7 +320,7 @@ func (w *writer) Cancel(ctx context.Context) error {
return w.d.root.delete(w.f.path())
}
func (w *writer) Commit() error {
func (w *writer) Commit(ctx context.Context) error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {

View file

@ -696,7 +696,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto
if err != nil {
return nil, err
}
return d.newWriter(key, *resp.UploadId, nil), nil
return d.newWriter(ctx, key, *resp.UploadId, nil), nil
}
listMultipartUploadsInput := &s3.ListMultipartUploadsInput{
@ -743,7 +743,7 @@ func (d *driver) Writer(ctx context.Context, path string, appendParam bool) (sto
}
allParts = append(allParts, partsList.Parts...)
}
return d.newWriter(key, *multi.UploadId, allParts), nil
return d.newWriter(ctx, key, *multi.UploadId, allParts), nil
}
// resp.NextUploadIdMarker must have at least one element or we would have returned not found
@ -1338,6 +1338,7 @@ func (b *buffer) Clear() {
// cleanly resumed in the future. This is violated if Close is called after less
// than a full chunk is written.
type writer struct {
ctx context.Context
driver *driver
key string
uploadID string
@ -1350,12 +1351,13 @@ type writer struct {
cancelled bool
}
func (d *driver) newWriter(key, uploadID string, parts []*s3.Part) storagedriver.FileWriter {
func (d *driver) newWriter(ctx context.Context, key, uploadID string, parts []*s3.Part) storagedriver.FileWriter {
var size int64
for _, part := range parts {
size += *part.Size
}
return &writer{
ctx: ctx,
driver: d,
key: key,
uploadID: uploadID,
@ -1394,7 +1396,7 @@ func (w *writer) Write(p []byte) (int, error) {
sort.Sort(completedUploadedParts)
_, err := w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
_, err := w.driver.S3.CompleteMultipartUploadWithContext(w.ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
@ -1403,7 +1405,7 @@ func (w *writer) Write(p []byte) (int, error) {
},
})
if err != nil {
if _, aErr := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
if _, aErr := w.driver.S3.AbortMultipartUploadWithContext(w.ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
@ -1413,7 +1415,7 @@ func (w *writer) Write(p []byte) (int, error) {
return 0, err
}
resp, err := w.driver.S3.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
resp, err := w.driver.S3.CreateMultipartUploadWithContext(w.ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
ContentType: w.driver.getContentType(),
@ -1429,7 +1431,7 @@ func (w *writer) Write(p []byte) (int, error) {
// If the entire written file is smaller than minChunkSize, we need to make
// a new part from scratch :double sad face:
if w.size < minChunkSize {
resp, err := w.driver.S3.GetObject(&s3.GetObjectInput{
resp, err := w.driver.S3.GetObjectWithContext(w.ctx, &s3.GetObjectInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
})
@ -1451,7 +1453,7 @@ func (w *writer) Write(p []byte) (int, error) {
}
} else {
// Otherwise we can use the old file as the new first part
copyPartResp, err := w.driver.S3.UploadPartCopy(&s3.UploadPartCopyInput{
copyPartResp, err := w.driver.S3.UploadPartCopyWithContext(w.ctx, &s3.UploadPartCopyInput{
Bucket: aws.String(w.driver.Bucket),
CopySource: aws.String(w.driver.Bucket + "/" + w.key),
Key: aws.String(w.key),
@ -1536,7 +1538,7 @@ func (w *writer) Cancel(ctx context.Context) error {
return fmt.Errorf("already committed")
}
w.cancelled = true
_, err := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
_, err := w.driver.S3.AbortMultipartUploadWithContext(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
@ -1544,7 +1546,7 @@ func (w *writer) Cancel(ctx context.Context) error {
return err
}
func (w *writer) Commit() error {
func (w *writer) Commit(ctx context.Context) error {
if w.closed {
return fmt.Errorf("already closed")
} else if w.committed {
@ -1576,7 +1578,7 @@ func (w *writer) Commit() error {
// Solution: we upload the empty i.e. 0 byte part as a single part and then append it
// to the completedUploadedParts slice used to complete the Multipart upload.
if len(w.parts) == 0 {
resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{
resp, err := w.driver.S3.UploadPartWithContext(w.ctx, &s3.UploadPartInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
PartNumber: aws.Int64(1),
@ -1595,7 +1597,7 @@ func (w *writer) Commit() error {
sort.Sort(completedUploadedParts)
_, err = w.driver.S3.CompleteMultipartUpload(&s3.CompleteMultipartUploadInput{
_, err = w.driver.S3.CompleteMultipartUploadWithContext(w.ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
@ -1604,7 +1606,7 @@ func (w *writer) Commit() error {
},
})
if err != nil {
if _, aErr := w.driver.S3.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
if _, aErr := w.driver.S3.AbortMultipartUploadWithContext(w.ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
UploadId: aws.String(w.uploadID),
@ -1634,7 +1636,7 @@ func (w *writer) flush() error {
partSize := buf.Len()
partNumber := aws.Int64(int64(len(w.parts) + 1))
resp, err := w.driver.S3.UploadPart(&s3.UploadPartInput{
resp, err := w.driver.S3.UploadPartWithContext(w.ctx, &s3.UploadPartInput{
Bucket: aws.String(w.driver.Bucket),
Key: aws.String(w.key),
PartNumber: partNumber,

View file

@ -123,7 +123,7 @@ type FileWriter interface {
// Commit flushes all content written to this FileWriter and makes it
// available for future calls to StorageDriver.GetContent and
// StorageDriver.Reader.
Commit() error
Commit(context.Context) error
}
// PathRegexp is the regular expression which each file path must match. A

View file

@ -291,7 +291,7 @@ func (suite *DriverSuite) TestWriteReadLargeStreams(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(written, check.Equals, fileSize)
err = writer.Commit()
err = writer.Commit(context.Background())
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
@ -446,7 +446,7 @@ func (suite *DriverSuite) testContinueStreamAppend(c *check.C, chunkSize int64)
c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(fullContents[curSize:])))
err = writer.Commit()
err = writer.Commit(context.Background())
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
@ -484,7 +484,7 @@ func (suite *DriverSuite) TestWriteZeroByteStreamThenAppend(c *check.C) {
c.Assert(err, check.IsNil)
// Close the Writer
err = writer.Commit()
err = writer.Commit(context.Background())
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
@ -512,7 +512,7 @@ func (suite *DriverSuite) TestWriteZeroByteStreamThenAppend(c *check.C) {
c.Assert(nn, check.Equals, int64(len(contentsChunk1)))
// Close the AppendWriter
err = awriter.Commit()
err = awriter.Commit(context.Background())
c.Assert(err, check.IsNil)
err = awriter.Close()
c.Assert(err, check.IsNil)
@ -561,7 +561,7 @@ func (suite *DriverSuite) TestWriteZeroByteContentThenAppend(c *check.C) {
c.Assert(nn, check.Equals, int64(len(contentsChunk1)))
// Close the AppendWriter
err = awriter.Commit()
err = awriter.Commit(context.Background())
c.Assert(err, check.IsNil)
err = awriter.Close()
c.Assert(err, check.IsNil)
@ -1156,7 +1156,7 @@ func (suite *DriverSuite) benchmarkStreamFiles(c *check.C, size int64) {
c.Assert(err, check.IsNil)
c.Assert(written, check.Equals, size)
err = writer.Commit()
err = writer.Commit(context.Background())
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
@ -1248,7 +1248,7 @@ func (suite *DriverSuite) testFileStreams(c *check.C, size int64) {
c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, size)
err = writer.Commit()
err = writer.Commit(context.Background())
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)
@ -1284,7 +1284,7 @@ func (suite *DriverSuite) writeReadCompareStreams(c *check.C, filename string, c
c.Assert(err, check.IsNil)
c.Assert(nn, check.Equals, int64(len(contents)))
err = writer.Commit()
err = writer.Commit(context.Background())
c.Assert(err, check.IsNil)
err = writer.Close()
c.Assert(err, check.IsNil)