vendor: update all dependencies

This commit is contained in:
Nick Craig-Wood 2017-07-23 08:51:42 +01:00
parent 0b6fba34a3
commit eb87cf6f12
2008 changed files with 352633 additions and 1004750 deletions

File diff suppressed because it is too large Load diff

View file

@ -119,8 +119,8 @@ func TestNoPopulateLocationConstraintIfProvided(t *testing.T) {
t.Fatalf("expect no error, got %v", err)
}
v, _ := awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")
if v := len(v); v != 0 {
t.Errorf("expect no values, got %d", v)
if l := len(v); l != 0 {
t.Errorf("expect no values, got %d", l)
}
}
@ -133,7 +133,7 @@ func TestNoPopulateLocationConstraintIfClassic(t *testing.T) {
t.Fatalf("expect no error, got %v", err)
}
v, _ := awsutil.ValuesAtPath(req.Params, "CreateBucketConfiguration.LocationConstraint")
if v := len(v); v != 0 {
t.Errorf("expect no values, got %d", v)
if l := len(v); l != 0 {
t.Errorf("expect no values, got %d", l)
}
}

View file

@ -44,3 +44,21 @@ func defaultInitRequestFn(r *request.Request) {
r.Handlers.Unmarshal.PushFront(copyMultipartStatusOKUnmarhsalError)
}
}
// bucketGetter is an accessor interface to grab the "Bucket" field from
// an S3 type.
type bucketGetter interface {
getBucket() string
}
// sseCustomerKeyGetter is an accessor interface to grab the "SSECustomerKey"
// field from an S3 type.
type sseCustomerKeyGetter interface {
getSSECustomerKey() string
}
// copySourceSSECustomerKeyGetter is an accessor interface to grab the
// "CopySourceSSECustomerKey" field from an S3 type.
type copySourceSSECustomerKeyGetter interface {
getCopySourceSSECustomerKey() string
}

File diff suppressed because it is too large Load diff

View file

@ -8,7 +8,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
@ -113,15 +112,9 @@ func updateEndpointForAccelerate(r *request.Request) {
// Attempts to retrieve the bucket name from the request input parameters.
// If no bucket is found, or the field is empty "", false will be returned.
func bucketNameFromReqParams(params interface{}) (string, bool) {
b, _ := awsutil.ValuesAtPath(params, "Bucket")
if len(b) == 0 {
return "", false
}
if bucket, ok := b[0].(*string); ok {
if bucketStr := aws.StringValue(bucket); bucketStr != "" {
return bucketStr, true
}
if iface, ok := params.(bucketGetter); ok {
b := iface.getBucket()
return b, len(b) > 0
}
return "", false

View file

@ -0,0 +1,500 @@
package s3manager
import (
"bytes"
"fmt"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
const (
// DefaultBatchSize is the batch size we initialize when constructing a batch delete client.
// This value is used when calling DeleteObjects. This represents how many objects to delete
// per DeleteObjects call.
DefaultBatchSize = 100
)
// BatchError will contain the key and bucket of the object that failed to
// either upload or download.
type BatchError struct {
Errors Errors
code string
message string
}
// Errors is a typed alias for a slice of errors to satisfy the error
// interface.
type Errors []Error
func (errs Errors) Error() string {
buf := bytes.NewBuffer(nil)
for i, err := range errs {
buf.WriteString(err.Error())
if i+1 < len(errs) {
buf.WriteString("\n")
}
}
return buf.String()
}
// Error will contain the original error, bucket, and key of the operation that failed
// during batch operations.
type Error struct {
OrigErr error
Bucket *string
Key *string
}
func newError(err error, bucket, key *string) Error {
return Error{
err,
bucket,
key,
}
}
func (err *Error) Error() string {
return fmt.Sprintf("failed to upload %q to %q:\n%s", err.Key, err.Bucket, err.OrigErr.Error())
}
// NewBatchError will return a BatchError that satisfies the awserr.Error interface.
func NewBatchError(code, message string, err []Error) awserr.Error {
return &BatchError{
Errors: err,
code: code,
message: message,
}
}
// Code will return the code associated with the batch error.
func (err *BatchError) Code() string {
return err.code
}
// Message will return the message associated with the batch error.
func (err *BatchError) Message() string {
return err.message
}
func (err *BatchError) Error() string {
return awserr.SprintError(err.Code(), err.Message(), "", err.Errors)
}
// OrigErr will return the original error. Which, in this case, will always be nil
// for batched operations.
func (err *BatchError) OrigErr() error {
return err.Errors
}
// BatchDeleteIterator is an interface that uses the scanner pattern to
// iterate through what needs to be deleted.
type BatchDeleteIterator interface {
Next() bool
Err() error
DeleteObject() BatchDeleteObject
}
// DeleteListIterator is an alternative iterator for the BatchDelete client. This will
// iterate through a list of objects and delete the objects.
//
// Example:
// iter := &s3manager.DeleteListIterator{
// Client: svc,
// Input: &s3.ListObjectsInput{
// Bucket: aws.String("bucket"),
// MaxKeys: aws.Int64(5),
// },
// Paginator: request.Pagination{
// NewRequest: func() (*request.Request, error) {
// var inCpy *ListObjectsInput
// if input != nil {
// tmp := *input
// inCpy = &tmp
// }
// req, _ := c.ListObjectsRequest(inCpy)
// return req, nil
// },
// },
// }
//
// batcher := s3manager.NewBatchDeleteWithClient(svc)
// if err := batcher.Delete(aws.BackgroundContext(), iter); err != nil {
// return err
// }
type DeleteListIterator struct {
Bucket *string
Paginator request.Pagination
objects []*s3.Object
}
// NewDeleteListIterator will return a new DeleteListIterator.
func NewDeleteListIterator(svc s3iface.S3API, input *s3.ListObjectsInput, opts ...func(*DeleteListIterator)) BatchDeleteIterator {
iter := &DeleteListIterator{
Bucket: input.Bucket,
Paginator: request.Pagination{
NewRequest: func() (*request.Request, error) {
var inCpy *s3.ListObjectsInput
if input != nil {
tmp := *input
inCpy = &tmp
}
req, _ := svc.ListObjectsRequest(inCpy)
return req, nil
},
},
}
for _, opt := range opts {
opt(iter)
}
return iter
}
// Next will use the S3API client to iterate through a list of objects.
func (iter *DeleteListIterator) Next() bool {
if len(iter.objects) > 0 {
iter.objects = iter.objects[1:]
}
if len(iter.objects) == 0 && iter.Paginator.Next() {
iter.objects = iter.Paginator.Page().(*s3.ListObjectsOutput).Contents
}
return len(iter.objects) > 0
}
// Err will return the last known error from Next.
func (iter *DeleteListIterator) Err() error {
return iter.Paginator.Err()
}
// DeleteObject will return the current object to be deleted.
func (iter *DeleteListIterator) DeleteObject() BatchDeleteObject {
return BatchDeleteObject{
Object: &s3.DeleteObjectInput{
Bucket: iter.Bucket,
Key: iter.objects[0].Key,
},
}
}
// BatchDelete will use the s3 package's service client to perform a batch
// delete.
type BatchDelete struct {
Client s3iface.S3API
BatchSize int
}
// NewBatchDeleteWithClient will return a new delete client that can delete a batched amount of
// objects.
//
// Example:
// batcher := s3manager.NewBatchDeleteWithClient(client, size)
//
// objects := []BatchDeleteObject{
// {
// Object: &s3.DeleteObjectInput {
// Key: aws.String("key"),
// Bucket: aws.String("bucket"),
// },
// },
// }
//
// if err := batcher.Delete(&s3manager.DeleteObjectsIterator{
// Objects: objects,
// }); err != nil {
// return err
// }
func NewBatchDeleteWithClient(client s3iface.S3API, options ...func(*BatchDelete)) *BatchDelete {
svc := &BatchDelete{
Client: client,
BatchSize: DefaultBatchSize,
}
for _, opt := range options {
opt(svc)
}
return svc
}
// NewBatchDelete will return a new delete client that can delete a batched amount of
// objects.
//
// Example:
// batcher := s3manager.NewBatchDelete(sess, size)
//
// objects := []BatchDeleteObject{
// {
// Object: &s3.DeleteObjectInput {
// Key: aws.String("key"),
// Bucket: aws.String("bucket"),
// },
// },
// }
//
// if err := batcher.Delete(&s3manager.DeleteObjectsIterator{
// Objects: objects,
// }); err != nil {
// return err
// }
func NewBatchDelete(c client.ConfigProvider, options ...func(*BatchDelete)) *BatchDelete {
client := s3.New(c)
return NewBatchDeleteWithClient(client, options...)
}
// BatchDeleteObject is a wrapper object for calling the batch delete operation.
type BatchDeleteObject struct {
Object *s3.DeleteObjectInput
// After will run after each iteration during the batch process. This function will
// be executed whether or not the request was successful.
After func() error
}
// DeleteObjectsIterator is an interface that uses the scanner pattern to iterate
// through a series of objects to be deleted.
type DeleteObjectsIterator struct {
Objects []BatchDeleteObject
index int
inc bool
}
// Next will increment the default iterator's index and and ensure that there
// is another object to iterator to.
func (iter *DeleteObjectsIterator) Next() bool {
if iter.inc {
iter.index++
} else {
iter.inc = true
}
return iter.index < len(iter.Objects)
}
// Err will return an error. Since this is just used to satisfy the BatchDeleteIterator interface
// this will only return nil.
func (iter *DeleteObjectsIterator) Err() error {
return nil
}
// DeleteObject will return the BatchDeleteObject at the current batched index.
func (iter *DeleteObjectsIterator) DeleteObject() BatchDeleteObject {
object := iter.Objects[iter.index]
return object
}
// Delete will use the iterator to queue up objects that need to be deleted.
// Once the batch size is met, this will call the deleteBatch function.
func (d *BatchDelete) Delete(ctx aws.Context, iter BatchDeleteIterator) error {
var errs []Error
objects := []BatchDeleteObject{}
var input *s3.DeleteObjectsInput
for iter.Next() {
o := iter.DeleteObject()
if input == nil {
input = initDeleteObjectsInput(o.Object)
}
parity := hasParity(input, o)
if parity {
input.Delete.Objects = append(input.Delete.Objects, &s3.ObjectIdentifier{
Key: o.Object.Key,
VersionId: o.Object.VersionId,
})
objects = append(objects, o)
}
if len(input.Delete.Objects) == d.BatchSize || !parity {
if err := deleteBatch(d, input, objects); err != nil {
errs = append(errs, err...)
}
objects = objects[:0]
input = nil
if !parity {
objects = append(objects, o)
input = initDeleteObjectsInput(o.Object)
input.Delete.Objects = append(input.Delete.Objects, &s3.ObjectIdentifier{
Key: o.Object.Key,
VersionId: o.Object.VersionId,
})
}
}
}
if input != nil && len(input.Delete.Objects) > 0 {
if err := deleteBatch(d, input, objects); err != nil {
errs = append(errs, err...)
}
}
if len(errs) > 0 {
return NewBatchError("BatchedDeleteIncomplete", "some objects have failed to be deleted.", errs)
}
return nil
}
func initDeleteObjectsInput(o *s3.DeleteObjectInput) *s3.DeleteObjectsInput {
return &s3.DeleteObjectsInput{
Bucket: o.Bucket,
MFA: o.MFA,
RequestPayer: o.RequestPayer,
Delete: &s3.Delete{},
}
}
// deleteBatch will delete a batch of items in the objects parameters.
func deleteBatch(d *BatchDelete, input *s3.DeleteObjectsInput, objects []BatchDeleteObject) []Error {
errs := []Error{}
if result, err := d.Client.DeleteObjects(input); err != nil {
for i := 0; i < len(input.Delete.Objects); i++ {
errs = append(errs, newError(err, input.Bucket, input.Delete.Objects[i].Key))
}
} else if len(result.Errors) > 0 {
for i := 0; i < len(result.Errors); i++ {
errs = append(errs, newError(err, input.Bucket, result.Errors[i].Key))
}
}
for _, object := range objects {
if object.After == nil {
continue
}
if err := object.After(); err != nil {
errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
}
}
return errs
}
func hasParity(o1 *s3.DeleteObjectsInput, o2 BatchDeleteObject) bool {
if o1.Bucket != nil && o2.Object.Bucket != nil {
if *o1.Bucket != *o2.Object.Bucket {
return false
}
} else if o1.Bucket != o2.Object.Bucket {
return false
}
if o1.MFA != nil && o2.Object.MFA != nil {
if *o1.MFA != *o2.Object.MFA {
return false
}
} else if o1.MFA != o2.Object.MFA {
return false
}
if o1.RequestPayer != nil && o2.Object.RequestPayer != nil {
if *o1.RequestPayer != *o2.Object.RequestPayer {
return false
}
} else if o1.RequestPayer != o2.Object.RequestPayer {
return false
}
return true
}
// BatchDownloadIterator is an interface that uses the scanner pattern to iterate
// through a series of objects to be downloaded.
type BatchDownloadIterator interface {
Next() bool
Err() error
DownloadObject() BatchDownloadObject
}
// BatchDownloadObject contains all necessary information to run a batch operation once.
type BatchDownloadObject struct {
Object *s3.GetObjectInput
Writer io.WriterAt
// After will run after each iteration during the batch process. This function will
// be executed whether or not the request was successful.
After func() error
}
// DownloadObjectsIterator implements the BatchDownloadIterator interface and allows for batched
// download of objects.
type DownloadObjectsIterator struct {
Objects []BatchDownloadObject
index int
inc bool
}
// Next will increment the default iterator's index and and ensure that there
// is another object to iterator to.
func (batcher *DownloadObjectsIterator) Next() bool {
if batcher.inc {
batcher.index++
} else {
batcher.inc = true
}
return batcher.index < len(batcher.Objects)
}
// DownloadObject will return the BatchDownloadObject at the current batched index.
func (batcher *DownloadObjectsIterator) DownloadObject() BatchDownloadObject {
object := batcher.Objects[batcher.index]
return object
}
// Err will return an error. Since this is just used to satisfy the BatchDeleteIterator interface
// this will only return nil.
func (batcher *DownloadObjectsIterator) Err() error {
return nil
}
// BatchUploadIterator is an interface that uses the scanner pattern to
// iterate through what needs to be uploaded.
type BatchUploadIterator interface {
Next() bool
Err() error
UploadObject() BatchUploadObject
}
// UploadObjectsIterator implements the BatchUploadIterator interface and allows for batched
// upload of objects.
type UploadObjectsIterator struct {
Objects []BatchUploadObject
index int
inc bool
}
// Next will increment the default iterator's index and and ensure that there
// is another object to iterator to.
func (batcher *UploadObjectsIterator) Next() bool {
if batcher.inc {
batcher.index++
} else {
batcher.inc = true
}
return batcher.index < len(batcher.Objects)
}
// Err will return an error. Since this is just used to satisfy the BatchUploadIterator interface
// this will only return nil.
func (batcher *UploadObjectsIterator) Err() error {
return nil
}
// UploadObject will return the BatchUploadObject at the current batched index.
func (batcher *UploadObjectsIterator) UploadObject() BatchUploadObject {
object := batcher.Objects[batcher.index]
return object
}
// BatchUploadObject contains all necessary information to run a batch operation once.
type BatchUploadObject struct {
Object *UploadInput
// After will run after each iteration during the batch process. This function will
// be executed whether or not the request was successful.
After func() error
}

View file

@ -0,0 +1,975 @@
package s3manager
import (
"bytes"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
)
func TestHasParity(t *testing.T) {
cases := []struct {
o1 *s3.DeleteObjectsInput
o2 BatchDeleteObject
expected bool
}{
{
&s3.DeleteObjectsInput{},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{},
},
true,
},
{
&s3.DeleteObjectsInput{
Bucket: aws.String("foo"),
},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{
Bucket: aws.String("bar"),
},
},
false,
},
{
&s3.DeleteObjectsInput{},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{
Bucket: aws.String("foo"),
},
},
false,
},
{
&s3.DeleteObjectsInput{
Bucket: aws.String("foo"),
},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{},
},
false,
},
{
&s3.DeleteObjectsInput{
MFA: aws.String("foo"),
},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{
MFA: aws.String("bar"),
},
},
false,
},
{
&s3.DeleteObjectsInput{},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{
MFA: aws.String("foo"),
},
},
false,
},
{
&s3.DeleteObjectsInput{
MFA: aws.String("foo"),
},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{},
},
false,
},
{
&s3.DeleteObjectsInput{
RequestPayer: aws.String("foo"),
},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{
RequestPayer: aws.String("bar"),
},
},
false,
},
{
&s3.DeleteObjectsInput{},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{
RequestPayer: aws.String("foo"),
},
},
false,
},
{
&s3.DeleteObjectsInput{
RequestPayer: aws.String("foo"),
},
BatchDeleteObject{
Object: &s3.DeleteObjectInput{},
},
false,
},
}
for i, c := range cases {
if result := hasParity(c.o1, c.o2); result != c.expected {
t.Errorf("Case %d: expected %t, but received %t\n", i, c.expected, result)
}
}
}
func TestBatchDelete(t *testing.T) {
cases := []struct {
objects []BatchDeleteObject
size int
expected int
}{
{
[]BatchDeleteObject{
{
Object: &s3.DeleteObjectInput{
Key: aws.String("1"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("2"),
Bucket: aws.String("bucket2"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("3"),
Bucket: aws.String("bucket3"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("4"),
Bucket: aws.String("bucket4"),
},
},
},
1,
4,
},
{
[]BatchDeleteObject{
{
Object: &s3.DeleteObjectInput{
Key: aws.String("1"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("2"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("3"),
Bucket: aws.String("bucket3"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("4"),
Bucket: aws.String("bucket3"),
},
},
},
1,
4,
},
{
[]BatchDeleteObject{
{
Object: &s3.DeleteObjectInput{
Key: aws.String("1"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("2"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("3"),
Bucket: aws.String("bucket3"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("4"),
Bucket: aws.String("bucket3"),
},
},
},
4,
2,
},
{
[]BatchDeleteObject{
{
Object: &s3.DeleteObjectInput{
Key: aws.String("1"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("2"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("3"),
Bucket: aws.String("bucket3"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("4"),
Bucket: aws.String("bucket3"),
},
},
},
10,
2,
},
{
[]BatchDeleteObject{
{
Object: &s3.DeleteObjectInput{
Key: aws.String("1"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("2"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("3"),
Bucket: aws.String("bucket1"),
},
},
{
Object: &s3.DeleteObjectInput{
Key: aws.String("4"),
Bucket: aws.String("bucket3"),
},
},
},
2,
3,
},
}
count := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
count++
}))
svc := &mockS3Client{S3: buildS3SvcClient(server.URL)}
for i, c := range cases {
batcher := BatchDelete{
Client: svc,
BatchSize: c.size,
}
if err := batcher.Delete(aws.BackgroundContext(), &DeleteObjectsIterator{Objects: c.objects}); err != nil {
panic(err)
}
if count != c.expected {
t.Errorf("Case %d: expected %d, but received %d", i, c.expected, count)
}
count = 0
}
}
type mockS3Client struct {
*s3.S3
index int
objects []*s3.ListObjectsOutput
}
func (client *mockS3Client) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
object := client.objects[client.index]
client.index++
return object, nil
}
func TestBatchDeleteList(t *testing.T) {
count := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
count++
}))
objects := []*s3.ListObjectsOutput{
{
Contents: []*s3.Object{
{
Key: aws.String("1"),
},
},
NextMarker: aws.String("marker"),
IsTruncated: aws.Bool(true),
},
{
Contents: []*s3.Object{
{
Key: aws.String("2"),
},
},
NextMarker: aws.String("marker"),
IsTruncated: aws.Bool(true),
},
{
Contents: []*s3.Object{
{
Key: aws.String("3"),
},
},
IsTruncated: aws.Bool(false),
},
}
svc := &mockS3Client{S3: buildS3SvcClient(server.URL), objects: objects}
batcher := BatchDelete{
Client: svc,
BatchSize: 1,
}
input := &s3.ListObjectsInput{
Bucket: aws.String("bucket"),
}
iter := &DeleteListIterator{
Bucket: input.Bucket,
Paginator: request.Pagination{
NewRequest: func() (*request.Request, error) {
var inCpy *s3.ListObjectsInput
if input != nil {
tmp := *input
inCpy = &tmp
}
req, _ := svc.ListObjectsRequest(inCpy)
req.Handlers.Clear()
output, _ := svc.ListObjects(inCpy)
req.Data = output
return req, nil
},
},
}
if err := batcher.Delete(aws.BackgroundContext(), iter); err != nil {
t.Error(err)
}
if count != len(objects) {
t.Errorf("Expected %d, but received %d", len(objects), count)
}
}
func buildS3SvcClient(u string) *s3.S3 {
return s3.New(unit.Session, &aws.Config{
Endpoint: aws.String(u),
S3ForcePathStyle: aws.Bool(true),
DisableSSL: aws.Bool(true),
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
})
}
func TestBatchDeleteList_EmptyListObjects(t *testing.T) {
count := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
count++
}))
svc := &mockS3Client{S3: buildS3SvcClient(server.URL)}
batcher := BatchDelete{
Client: svc,
}
input := &s3.ListObjectsInput{
Bucket: aws.String("bucket"),
}
// Test DeleteListIterator in the case when the ListObjectsRequest responds
// with an empty listing.
// We need a new iterator with a fresh Pagination since
// Pagination.HasNextPage() is always true the first time Pagination.Next()
// called on it
iter := &DeleteListIterator{
Bucket: input.Bucket,
Paginator: request.Pagination{
NewRequest: func() (*request.Request, error) {
req, _ := svc.ListObjectsRequest(input)
// Simulate empty listing
req.Data = &s3.ListObjectsOutput{Contents: []*s3.Object{}}
return req, nil
},
},
}
if err := batcher.Delete(aws.BackgroundContext(), iter); err != nil {
t.Error(err)
}
if count != 1 {
t.Errorf("expect count to be 1, got %d", count)
}
}
func TestBatchDownload(t *testing.T) {
count := 0
expected := []struct {
bucket, key string
}{
{
key: "1",
bucket: "bucket1",
},
{
key: "2",
bucket: "bucket2",
},
{
key: "3",
bucket: "bucket3",
},
{
key: "4",
bucket: "bucket4",
},
}
received := []struct {
bucket, key string
}{}
payload := []string{
"1",
"2",
"3",
"4",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
urlParts := strings.Split(r.URL.String(), "/")
received = append(received, struct{ bucket, key string }{urlParts[1], urlParts[2]})
w.Write([]byte(payload[count]))
count++
}))
svc := NewDownloaderWithClient(buildS3SvcClient(server.URL))
objects := []BatchDownloadObject{
{
Object: &s3.GetObjectInput{
Key: aws.String("1"),
Bucket: aws.String("bucket1"),
},
Writer: aws.NewWriteAtBuffer(make([]byte, 128)),
},
{
Object: &s3.GetObjectInput{
Key: aws.String("2"),
Bucket: aws.String("bucket2"),
},
Writer: aws.NewWriteAtBuffer(make([]byte, 128)),
},
{
Object: &s3.GetObjectInput{
Key: aws.String("3"),
Bucket: aws.String("bucket3"),
},
Writer: aws.NewWriteAtBuffer(make([]byte, 128)),
},
{
Object: &s3.GetObjectInput{
Key: aws.String("4"),
Bucket: aws.String("bucket4"),
},
Writer: aws.NewWriteAtBuffer(make([]byte, 128)),
},
}
iter := &DownloadObjectsIterator{Objects: objects}
if err := svc.DownloadWithIterator(aws.BackgroundContext(), iter); err != nil {
panic(err)
}
if count != len(objects) {
t.Errorf("Expected %d, but received %d", len(objects), count)
}
if len(expected) != len(received) {
t.Errorf("Expected %d, but received %d", len(expected), len(received))
}
for i := 0; i < len(expected); i++ {
if expected[i].key != received[i].key {
t.Errorf("Expected %q, but received %q", expected[i].key, received[i].key)
}
if expected[i].bucket != received[i].bucket {
t.Errorf("Expected %q, but received %q", expected[i].bucket, received[i].bucket)
}
}
for i, p := range payload {
b := iter.Objects[i].Writer.(*aws.WriteAtBuffer).Bytes()
b = bytes.Trim(b, "\x00")
if string(b) != p {
t.Errorf("Expected %q, but received %q", p, b)
}
}
}
func TestBatchUpload(t *testing.T) {
count := 0
expected := []struct {
bucket, key string
reqBody string
}{
{
key: "1",
bucket: "bucket1",
reqBody: "1",
},
{
key: "2",
bucket: "bucket2",
reqBody: "2",
},
{
key: "3",
bucket: "bucket3",
reqBody: "3",
},
{
key: "4",
bucket: "bucket4",
reqBody: "4",
},
}
received := []struct {
bucket, key, reqBody string
}{}
payload := []string{
"a",
"b",
"c",
"d",
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
urlParts := strings.Split(r.URL.String(), "/")
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Error(err)
}
received = append(received, struct{ bucket, key, reqBody string }{urlParts[1], urlParts[2], string(b)})
w.Write([]byte(payload[count]))
count++
}))
svc := NewUploaderWithClient(buildS3SvcClient(server.URL))
objects := []BatchUploadObject{
{
Object: &UploadInput{
Key: aws.String("1"),
Bucket: aws.String("bucket1"),
Body: bytes.NewBuffer([]byte("1")),
},
},
{
Object: &UploadInput{
Key: aws.String("2"),
Bucket: aws.String("bucket2"),
Body: bytes.NewBuffer([]byte("2")),
},
},
{
Object: &UploadInput{
Key: aws.String("3"),
Bucket: aws.String("bucket3"),
Body: bytes.NewBuffer([]byte("3")),
},
},
{
Object: &UploadInput{
Key: aws.String("4"),
Bucket: aws.String("bucket4"),
Body: bytes.NewBuffer([]byte("4")),
},
},
}
iter := &UploadObjectsIterator{Objects: objects}
if err := svc.UploadWithIterator(aws.BackgroundContext(), iter); err != nil {
panic(err)
}
if count != len(objects) {
t.Errorf("Expected %d, but received %d", len(objects), count)
}
if len(expected) != len(received) {
t.Errorf("Expected %d, but received %d", len(expected), len(received))
}
for i := 0; i < len(expected); i++ {
if expected[i].key != received[i].key {
t.Errorf("Expected %q, but received %q", expected[i].key, received[i].key)
}
if expected[i].bucket != received[i].bucket {
t.Errorf("Expected %q, but received %q", expected[i].bucket, received[i].bucket)
}
if expected[i].reqBody != received[i].reqBody {
t.Errorf("Expected %q, but received %q", expected[i].reqBody, received[i].reqBody)
}
}
}
type mockClient struct {
s3iface.S3API
Put func() (*s3.PutObjectOutput, error)
Get func() (*s3.GetObjectOutput, error)
List func() (*s3.ListObjectsOutput, error)
responses []response
}
type response struct {
out interface{}
err error
}
func (client *mockClient) PutObject(input *s3.PutObjectInput) (*s3.PutObjectOutput, error) {
return client.Put()
}
func (client *mockClient) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) {
req, _ := client.S3API.PutObjectRequest(input)
req.Handlers.Clear()
req.Data, req.Error = client.Put()
return req, req.Data.(*s3.PutObjectOutput)
}
func (client *mockClient) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
return client.List()
}
func (client *mockClient) ListObjectsRequest(input *s3.ListObjectsInput) (*request.Request, *s3.ListObjectsOutput) {
req, _ := client.S3API.ListObjectsRequest(input)
req.Handlers.Clear()
req.Data, req.Error = client.List()
return req, req.Data.(*s3.ListObjectsOutput)
}
func TestBatchError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
index := 0
responses := []response{
{
&s3.PutObjectOutput{},
errors.New("Foo"),
},
{
&s3.PutObjectOutput{},
nil,
},
{
&s3.PutObjectOutput{},
nil,
},
{
&s3.PutObjectOutput{},
errors.New("Bar"),
},
}
svc := &mockClient{
S3API: buildS3SvcClient(server.URL),
Put: func() (*s3.PutObjectOutput, error) {
resp := responses[index]
index++
return resp.out.(*s3.PutObjectOutput), resp.err
},
List: func() (*s3.ListObjectsOutput, error) {
resp := responses[index]
index++
return resp.out.(*s3.ListObjectsOutput), resp.err
},
}
uploader := NewUploaderWithClient(svc)
objects := []BatchUploadObject{
{
Object: &UploadInput{
Key: aws.String("1"),
Bucket: aws.String("bucket1"),
Body: bytes.NewBuffer([]byte("1")),
},
},
{
Object: &UploadInput{
Key: aws.String("2"),
Bucket: aws.String("bucket2"),
Body: bytes.NewBuffer([]byte("2")),
},
},
{
Object: &UploadInput{
Key: aws.String("3"),
Bucket: aws.String("bucket3"),
Body: bytes.NewBuffer([]byte("3")),
},
},
{
Object: &UploadInput{
Key: aws.String("4"),
Bucket: aws.String("bucket4"),
Body: bytes.NewBuffer([]byte("4")),
},
},
}
iter := &UploadObjectsIterator{Objects: objects}
if err := uploader.UploadWithIterator(aws.BackgroundContext(), iter); err != nil {
if bErr, ok := err.(*BatchError); !ok {
t.Error("Expected BatchError, but received other")
} else {
if len(bErr.Errors) != 2 {
t.Errorf("Expected 2 errors, but received %d", len(bErr.Errors))
}
expected := []struct {
bucket, key string
}{
{
"bucket1",
"1",
},
{
"bucket4",
"4",
},
}
for i, expect := range expected {
if *bErr.Errors[i].Bucket != expect.bucket {
t.Errorf("Case %d: Invalid bucket expected %s, but received %s", i, expect.bucket, *bErr.Errors[i].Bucket)
}
if *bErr.Errors[i].Key != expect.key {
t.Errorf("Case %d: Invalid key expected %s, but received %s", i, expect.key, *bErr.Errors[i].Key)
}
}
}
} else {
t.Error("Expected error, but received nil")
}
if index != len(objects) {
t.Errorf("Expected %d, but received %d", len(objects), index)
}
}
type testAfterDeleteIter struct {
afterDelete bool
afterDownload bool
afterUpload bool
next bool
}
func (iter *testAfterDeleteIter) Next() bool {
next := !iter.next
iter.next = !iter.next
return next
}
func (iter *testAfterDeleteIter) Err() error {
return nil
}
func (iter *testAfterDeleteIter) DeleteObject() BatchDeleteObject {
return BatchDeleteObject{
Object: &s3.DeleteObjectInput{
Bucket: aws.String("foo"),
Key: aws.String("foo"),
},
After: func() error {
iter.afterDelete = true
return nil
},
}
}
type testAfterDownloadIter struct {
afterDownload bool
afterUpload bool
next bool
}
func (iter *testAfterDownloadIter) Next() bool {
next := !iter.next
iter.next = !iter.next
return next
}
func (iter *testAfterDownloadIter) Err() error {
return nil
}
func (iter *testAfterDownloadIter) DownloadObject() BatchDownloadObject {
return BatchDownloadObject{
Object: &s3.GetObjectInput{
Bucket: aws.String("foo"),
Key: aws.String("foo"),
},
Writer: aws.NewWriteAtBuffer([]byte{}),
After: func() error {
iter.afterDownload = true
return nil
},
}
}
type testAfterUploadIter struct {
afterUpload bool
next bool
}
func (iter *testAfterUploadIter) Next() bool {
next := !iter.next
iter.next = !iter.next
return next
}
func (iter *testAfterUploadIter) Err() error {
return nil
}
func (iter *testAfterUploadIter) UploadObject() BatchUploadObject {
return BatchUploadObject{
Object: &UploadInput{
Bucket: aws.String("foo"),
Key: aws.String("foo"),
Body: strings.NewReader("bar"),
},
After: func() error {
iter.afterUpload = true
return nil
},
}
}
func TestAfter(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
index := 0
responses := []response{
{
&s3.PutObjectOutput{},
nil,
},
{
&s3.GetObjectOutput{},
nil,
},
{
&s3.DeleteObjectOutput{},
nil,
},
}
svc := &mockClient{
S3API: buildS3SvcClient(server.URL),
Put: func() (*s3.PutObjectOutput, error) {
resp := responses[index]
index++
return resp.out.(*s3.PutObjectOutput), resp.err
},
Get: func() (*s3.GetObjectOutput, error) {
resp := responses[index]
index++
return resp.out.(*s3.GetObjectOutput), resp.err
},
List: func() (*s3.ListObjectsOutput, error) {
resp := responses[index]
index++
return resp.out.(*s3.ListObjectsOutput), resp.err
},
}
uploader := NewUploaderWithClient(svc)
downloader := NewDownloaderWithClient(svc)
deleter := NewBatchDeleteWithClient(svc)
deleteIter := &testAfterDeleteIter{}
downloadIter := &testAfterDownloadIter{}
uploadIter := &testAfterUploadIter{}
if err := uploader.UploadWithIterator(aws.BackgroundContext(), uploadIter); err != nil {
t.Error(err)
}
if err := downloader.DownloadWithIterator(aws.BackgroundContext(), downloadIter); err != nil {
t.Error(err)
}
if err := deleter.Delete(aws.BackgroundContext(), deleteIter); err != nil {
t.Error(err)
}
if !deleteIter.afterDelete {
t.Error("Expected 'afterDelete' to be true, but received false")
}
if !downloadIter.afterDownload {
t.Error("Expected 'afterDownload' to be true, but received false")
}
if !uploadIter.afterUpload {
t.Error("Expected 'afterUpload' to be true, but received false")
}
}

View file

@ -31,11 +31,15 @@ const DefaultDownloadConcurrency = 5
type Downloader struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
// if this value is set to zero, the DefaultDownloadPartSize value will be used.
//
// PartSize is ignored if the Range input parameter is provided.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
// If this is set to zero, the DefaultDownloadConcurrency value will be used.
//
// Concurrency is ignored if the Range input parameter is provided.
Concurrency int
// An S3 client to use when performing downloads.
@ -130,6 +134,10 @@ type maxRetrier interface {
//
// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
//
// If the GetObjectInput's Range value is provided that will cause the downloader
// to perform a single GetObjectInput request for that object's range. This will
// caused the part size, and concurrency configurations to be ignored.
func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
return d.DownloadWithContext(aws.BackgroundContext(), w, input, options...)
}
@ -153,6 +161,10 @@ func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ..
// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
//
// It is safe to call this method concurrently across goroutines.
//
// If the GetObjectInput's Range value is provided that will cause the downloader
// to perform a single GetObjectInput request for that object's range. This will
// caused the part size, and concurrency configurations to be ignored.
func (d Downloader) DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
impl := downloader{w: w, in: input, cfg: d, ctx: ctx}
@ -177,6 +189,66 @@ func (d Downloader) DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s
return impl.download()
}
// DownloadWithIterator will download a batched amount of objects in S3 and writes them
// to the io.WriterAt specificed in the iterator.
//
// Example:
// svc := s3manager.NewDownloader(session)
//
// fooFile, err := os.Open("/tmp/foo.file")
// if err != nil {
// return err
// }
//
// barFile, err := os.Open("/tmp/bar.file")
// if err != nil {
// return err
// }
//
// objects := []s3manager.BatchDownloadObject {
// {
// Input: &s3.GetObjectInput {
// Bucket: aws.String("bucket"),
// Key: aws.String("foo"),
// },
// Writer: fooFile,
// },
// {
// Input: &s3.GetObjectInput {
// Bucket: aws.String("bucket"),
// Key: aws.String("bar"),
// },
// Writer: barFile,
// },
// }
//
// iter := &s3manager.DownloadObjectsIterator{Objects: objects}
// if err := svc.DownloadWithIterator(aws.BackgroundContext(), iter); err != nil {
// return err
// }
func (d Downloader) DownloadWithIterator(ctx aws.Context, iter BatchDownloadIterator, opts ...func(*Downloader)) error {
var errs []Error
for iter.Next() {
object := iter.DownloadObject()
if _, err := d.DownloadWithContext(ctx, object.Writer, object.Object, opts...); err != nil {
errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
}
if object.After == nil {
continue
}
if err := object.After(); err != nil {
errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
}
}
if len(errs) > 0 {
return NewBatchError("BatchedDownloadIncomplete", "some objects have failed to download.", errs)
}
return nil
}
// downloader is the implementation structure used internally by Downloader.
type downloader struct {
ctx aws.Context
@ -199,6 +271,14 @@ type downloader struct {
// download performs the implementation of the object download across ranged
// GETs.
func (d *downloader) download() (n int64, err error) {
// If range is specified fall back to single download of that range
// this enables the functionality of ranged gets with the downloader but
// at the cost of no multipart downloads.
if rng := aws.StringValue(d.in.Range); len(rng) > 0 {
d.downloadRange(rng)
return d.written, d.err
}
// Spin off first worker to check additional header information
d.getChunk()
@ -285,14 +365,32 @@ func (d *downloader) getChunk() {
}
}
// downloadChunk downloads the chunk froom s3
// downloadRange downloads an Object given the passed in Byte-Range value.
// The chunk used down download the range will be configured for that range.
func (d *downloader) downloadRange(rng string) {
if d.getErr() != nil {
return
}
chunk := dlchunk{w: d.w, start: d.pos}
// Ranges specified will short circuit the multipart download
chunk.withRange = rng
if err := d.downloadChunk(chunk); err != nil {
d.setErr(err)
}
// Update the position based on the amount of data received.
d.pos = d.written
}
// downloadChunk downloads the chunk from s3
func (d *downloader) downloadChunk(chunk dlchunk) error {
in := &s3.GetObjectInput{}
awsutil.Copy(in, d.in)
// Get the next byte range of data
rng := fmt.Sprintf("bytes=%d-%d", chunk.start, chunk.start+chunk.size-1)
in.Range = &rng
in.Range = aws.String(chunk.ByteRange())
var n int64
var err error
@ -417,12 +515,18 @@ type dlchunk struct {
start int64
size int64
cur int64
// specifies the byte range the chunk should be downloaded with.
withRange string
}
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
// position to its end (or EOF).
//
// If a range is specified on the dlchunk the size will be ignored when writing.
// as the total size may not of be known ahead of time.
func (c *dlchunk) Write(p []byte) (n int, err error) {
if c.cur >= c.size {
if c.cur >= c.size && len(c.withRange) == 0 {
return 0, io.EOF
}
@ -431,3 +535,13 @@ func (c *dlchunk) Write(p []byte) (n int, err error) {
return
}
// ByteRange returns a HTTP Byte-Range header value that should be used by the
// client to request the chunk's range.
func (c *dlchunk) ByteRange() string {
if len(c.withRange) != 0 {
return c.withRange
}
return fmt.Sprintf("bytes=%d-%d", c.start, c.start+c.size-1)
}

View file

@ -6,6 +6,7 @@ import (
"io"
"io/ioutil"
"net/http"
"reflect"
"regexp"
"strconv"
"strings"
@ -13,8 +14,6 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
@ -199,16 +198,30 @@ func TestDownloadOrder(t *testing.T) {
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf12MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"}, *ranges)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(len(buf12MB)), n; e != a {
t.Errorf("expect %d buffer length, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectRngs := []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"}
if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v ranges, got %v", e, a)
}
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
if count != 0 {
t.Errorf("expect 0 count, got %d", count)
}
}
func TestDownloadZero(t *testing.T) {
@ -221,10 +234,21 @@ func TestDownloadZero(t *testing.T) {
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(0), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-5242879"}, *ranges)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if n != 0 {
t.Errorf("expect 0 bytes read, got %d", n)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectRngs := []string{"bytes=0-5242879"}
if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v ranges, got %v", e, a)
}
}
func TestDownloadSetPartSize(t *testing.T) {
@ -240,11 +264,24 @@ func TestDownloadSetPartSize(t *testing.T) {
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject", "GetObject", "GetObject"}, *names)
assert.Equal(t, []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"}, *ranges)
assert.Equal(t, []byte{1, 2, 3}, w.Bytes())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(3), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectRngs := []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"}
if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v ranges, got %v", e, a)
}
expectBytes := []byte{1, 2, 3}
if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
t.Errorf("expect %v bytes, got %v", e, a)
}
}
func TestDownloadError(t *testing.T) {
@ -269,10 +306,24 @@ func TestDownloadError(t *testing.T) {
Key: aws.String("key"),
})
assert.NotNil(t, err)
assert.Equal(t, int64(1), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
assert.Equal(t, []byte{1}, w.Bytes())
if err == nil {
t.Fatalf("expect error, got none")
}
aerr := err.(awserr.Error)
if e, a := "BadRequest", aerr.Code(); e != a {
t.Errorf("expect %s error code, got %s", e, a)
}
if e, a := int64(1), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectBytes := []byte{1}
if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
t.Errorf("expect %v bytes, got %v", e, a)
}
}
func TestDownloadNonChunk(t *testing.T) {
@ -287,15 +338,24 @@ func TestDownloadNonChunk(t *testing.T) {
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject"}, *names)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(len(buf2MB)), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
if count != 0 {
t.Errorf("expect 0 count, got %d", count)
}
}
func TestDownloadNoContentRangeLength(t *testing.T) {
@ -310,15 +370,24 @@ func TestDownloadNoContentRangeLength(t *testing.T) {
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(len(buf2MB)), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
if count != 0 {
t.Errorf("expect 0 count, got %d", count)
}
}
func TestDownloadContentRangeTotalAny(t *testing.T) {
@ -333,15 +402,24 @@ func TestDownloadContentRangeTotalAny(t *testing.T) {
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(len(buf2MB)), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(len(buf2MB)), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
count := 0
for _, b := range w.Bytes() {
count += int(b)
}
assert.Equal(t, 0, count)
if count != 0 {
t.Errorf("expect 0 count, got %d", count)
}
}
func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) {
@ -360,10 +438,19 @@ func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) {
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject", "GetObject"}, *names)
assert.Equal(t, []byte("123"), w.Bytes())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(3), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject", "GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
if e, a := "123", string(w.Bytes()); e != a {
t.Errorf("expect %q response, got %q", e, a)
}
}
func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) {
@ -381,10 +468,19 @@ func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) {
Key: aws.String("key"),
})
assert.Nil(t, err)
assert.Equal(t, int64(3), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []byte("abc"), w.Bytes())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(3), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
if e, a := "abc", string(w.Bytes()); e != a {
t.Errorf("expect %q response, got %q", e, a)
}
}
func TestDownloadPartBodyRetry_FailRetry(t *testing.T) {
@ -402,10 +498,22 @@ func TestDownloadPartBodyRetry_FailRetry(t *testing.T) {
Key: aws.String("key"),
})
assert.Error(t, err)
assert.Equal(t, int64(2), n)
assert.Equal(t, []string{"GetObject"}, *names)
assert.Equal(t, []byte("ab"), w.Bytes())
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := "unexpected EOF", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %q error message to be in %q", e, a)
}
if e, a := int64(2), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
if e, a := "ab", string(w.Bytes()); e != a {
t.Errorf("expect %q response, got %q", e, a)
}
}
func TestDownloadWithContextCanceled(t *testing.T) {
@ -435,6 +543,41 @@ func TestDownloadWithContextCanceled(t *testing.T) {
}
}
func TestDownload_WithRange(t *testing.T) {
s, names, ranges := dlLoggingSvc([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
d := s3manager.NewDownloaderWithClient(s, func(d *s3manager.Downloader) {
d.Concurrency = 10 // should be ignored
d.PartSize = 1 // should be ignored
})
w := &aws.WriteAtBuffer{}
n, err := d.Download(w, &s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
Range: aws.String("bytes=2-6"),
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := int64(5), n; e != a {
t.Errorf("expect %d bytes read, got %d", e, a)
}
expectCalls := []string{"GetObject"}
if e, a := expectCalls, *names; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v API calls, got %v", e, a)
}
expectRngs := []string{"bytes=2-6"}
if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v ranges, got %v", e, a)
}
expectBytes := []byte{2, 3, 4, 5, 6}
if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) {
t.Errorf("expect %v bytes, got %v", e, a)
}
}
func TestDownload_WithFailure(t *testing.T) {
svc := s3.New(unit.Session)
svc.Handlers.Send.Clear()

View file

@ -215,7 +215,7 @@ func WithUploaderRequestOptions(opts ...request.Option) func(*Uploader) {
type Uploader struct {
// The buffer size (in bytes) to use when buffering data into chunks and
// sending them as parts to S3. The minimum allowed part size is 5MB, and
// if this value is set to zero, the DefaultPartSize value will be used.
// if this value is set to zero, the DefaultUploadPartSize value will be used.
PartSize int64
// The number of goroutines to spin up in parallel when sending parts.
@ -373,6 +373,61 @@ func (u Uploader) UploadWithContext(ctx aws.Context, input *UploadInput, opts ..
return i.upload()
}
// UploadWithIterator will upload a batched amount of objects to S3. This operation uses
// the iterator pattern to know which object to upload next. Since this is an interface this
// allows for custom defined functionality.
//
// Example:
// svc:= s3manager.NewUploader(sess)
//
// objects := []BatchUploadObject{
// {
// Object: &s3manager.UploadInput {
// Key: aws.String("key"),
// Bucket: aws.String("bucket"),
// },
// },
// }
//
// iter := &s3managee.UploadObjectsIterator{Objects: objects}
// if err := svc.UploadWithIterator(aws.BackgroundContext(), iter); err != nil {
// return err
// }
func (u Uploader) UploadWithIterator(ctx aws.Context, iter BatchUploadIterator, opts ...func(*Uploader)) error {
var errs []Error
for iter.Next() {
object := iter.UploadObject()
if _, err := u.UploadWithContext(ctx, object.Object, opts...); err != nil {
s3Err := Error{
OrigErr: err,
Bucket: object.Object.Bucket,
Key: object.Object.Key,
}
errs = append(errs, s3Err)
}
if object.After == nil {
continue
}
if err := object.After(); err != nil {
s3Err := Error{
OrigErr: err,
Bucket: object.Object.Bucket,
Key: object.Object.Key,
}
errs = append(errs, s3Err)
}
}
if len(errs) > 0 {
return NewBatchError("BatchedUploadIncomplete", "some objects have failed to upload.", errs)
}
return nil
}
// internal structure to manage an upload to S3.
type uploader struct {
ctx aws.Context

View file

@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"regexp"
"sort"
"strings"
"sync"
@ -21,7 +22,6 @@ import (
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/stretchr/testify/assert"
)
var emptyList = []string{}
@ -111,32 +111,77 @@ func TestUploadOrderMulti(t *testing.T) {
ContentType: aws.String("content/type"),
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
assert.Equal(t, "https://location", resp.Location)
assert.Equal(t, "UPLOAD-ID", resp.UploadID)
assert.Equal(t, aws.String("VERSION-ID"), resp.VersionID)
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
expected := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}
if !reflect.DeepEqual(expected, *ops) {
t.Errorf("Expected %v, but received %v", expected, *ops)
}
if "https://location" != resp.Location {
t.Errorf("Expected %q, but received %q", "https://location", resp.Location)
}
if "UPLOAD-ID" != resp.UploadID {
t.Errorf("Expected %q, but received %q", "UPLOAD-ID", resp.UploadID)
}
if "VERSION-ID" != *resp.VersionID {
t.Errorf("Expected %q, but received %q", "VERSION-ID", resp.VersionID)
}
// Validate input values
// UploadPart
assert.Equal(t, "UPLOAD-ID", val((*args)[1], "UploadId"))
assert.Equal(t, "UPLOAD-ID", val((*args)[2], "UploadId"))
assert.Equal(t, "UPLOAD-ID", val((*args)[3], "UploadId"))
for i := 1; i < 5; i++ {
v := val((*args)[i], "UploadId")
if "UPLOAD-ID" != v {
t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v)
}
}
// CompleteMultipartUpload
assert.Equal(t, "UPLOAD-ID", val((*args)[4], "UploadId"))
assert.Equal(t, int64(1), val((*args)[4], "MultipartUpload.Parts[0].PartNumber"))
assert.Equal(t, int64(2), val((*args)[4], "MultipartUpload.Parts[1].PartNumber"))
assert.Equal(t, int64(3), val((*args)[4], "MultipartUpload.Parts[2].PartNumber"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[0].ETag"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[1].ETag"))
assert.Regexp(t, `^ETAG\d+$`, val((*args)[4], "MultipartUpload.Parts[2].ETag"))
v := val((*args)[4], "UploadId")
if "UPLOAD-ID" != v {
t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v)
}
for i := 0; i < 3; i++ {
e := val((*args)[4], fmt.Sprintf("MultipartUpload.Parts[%d].PartNumber", i))
if int64(i+1) != e.(int64) {
t.Errorf("Expected %d, but received %d", i+1, e)
}
}
vals := []string{
val((*args)[4], "MultipartUpload.Parts[0].ETag").(string),
val((*args)[4], "MultipartUpload.Parts[1].ETag").(string),
val((*args)[4], "MultipartUpload.Parts[2].ETag").(string),
}
for _, a := range vals {
if matched, err := regexp.MatchString(`^ETAG\d+$`, a); !matched || err != nil {
t.Errorf("Failed regexp expression `^ETAG\\d+$`")
}
}
// Custom headers
assert.Equal(t, "aws:kms", val((*args)[0], "ServerSideEncryption"))
assert.Equal(t, "KmsId", val((*args)[0], "SSEKMSKeyId"))
assert.Equal(t, "content/type", val((*args)[0], "ContentType"))
e := val((*args)[0], "ServerSideEncryption")
if e != "aws:kms" {
t.Errorf("Expected %q, but received %q", "aws:kms", e)
}
e = val((*args)[0], "SSEKMSKeyId")
if e != "KmsId" {
t.Errorf("Expected %q, but received %q", "KmsId", e)
}
e = val((*args)[0], "ContentType")
if e != "content/type" {
t.Errorf("Expected %q, but received %q", "content/type", e)
}
}
func TestUploadOrderMultiDifferentPartSize(t *testing.T) {
@ -151,12 +196,22 @@ func TestUploadOrderMultiDifferentPartSize(t *testing.T) {
Body: bytes.NewReader(buf12MB),
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
vals := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}
if !reflect.DeepEqual(vals, *ops) {
t.Errorf("Expected %v, but received %v", vals, *ops)
}
// Part lengths
assert.Equal(t, 1024*1024*7, buflen(val((*args)[1], "Body")))
assert.Equal(t, 1024*1024*5, buflen(val((*args)[2], "Body")))
if len := buflen(val((*args)[1], "Body")); 1024*1024*7 != len {
t.Errorf("Expected %d, but received %d", 1024*1024*7, len)
}
if len := buflen(val((*args)[2], "Body")); 1024*1024*5 != len {
t.Errorf("Expected %d, but received %d", 1024*1024*5, len)
}
}
func TestUploadIncreasePartSize(t *testing.T) {
@ -171,13 +226,27 @@ func TestUploadIncreasePartSize(t *testing.T) {
Body: bytes.NewReader(buf12MB),
})
assert.NoError(t, err)
assert.Equal(t, int64(s3manager.DefaultDownloadPartSize), mgr.PartSize)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if int64(s3manager.DefaultDownloadPartSize) != mgr.PartSize {
t.Errorf("Expected %d, but received %d", s3manager.DefaultDownloadPartSize, mgr.PartSize)
}
vals := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}
if !reflect.DeepEqual(vals, *ops) {
t.Errorf("Expected %v, but received %v", vals, *ops)
}
// Part lengths
assert.Equal(t, (1024*1024*6)+1, buflen(val((*args)[1], "Body")))
assert.Equal(t, (1024*1024*6)-1, buflen(val((*args)[2], "Body")))
if len := buflen(val((*args)[1], "Body")); (1024*1024*6)+1 != len {
t.Errorf("Expected %d, but received %d", (1024*1024*6)+1, len)
}
if len := buflen(val((*args)[2], "Body")); (1024*1024*6)-1 != len {
t.Errorf("Expected %d, but received %d", (1024*1024*6)-1, len)
}
}
func TestUploadFailIfPartSizeTooSmall(t *testing.T) {
@ -190,12 +259,22 @@ func TestUploadFailIfPartSizeTooSmall(t *testing.T) {
Body: bytes.NewReader(buf12MB),
})
assert.Nil(t, resp)
assert.NotNil(t, err)
if resp != nil {
t.Errorf("Expected response to be nil, but received %v", resp)
}
if err == nil {
t.Errorf("Expected error, but received nil")
}
aerr := err.(awserr.Error)
assert.Equal(t, "ConfigError", aerr.Code())
assert.Contains(t, aerr.Message(), "part size must be at least")
if "ConfigError" != aerr.Code() {
t.Errorf("Expected %q, but received %q", "ConfigError", aerr.Code())
}
if strings.Contains("part size must be at least", aerr.Message()) {
t.Errorf("Expected string to contain %q, but received %q", "part size must be at least", aerr.Message())
}
}
func TestUploadOrderSingle(t *testing.T) {
@ -210,14 +289,37 @@ func TestUploadOrderSingle(t *testing.T) {
ContentType: aws.String("content/type"),
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, aws.String("VERSION-ID"), resp.VersionID)
assert.Equal(t, "", resp.UploadID)
assert.Equal(t, "aws:kms", val((*args)[0], "ServerSideEncryption"))
assert.Equal(t, "KmsId", val((*args)[0], "SSEKMSKeyId"))
assert.Equal(t, "content/type", val((*args)[0], "ContentType"))
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if vals := []string{"PutObject"}; !reflect.DeepEqual(vals, *ops) {
t.Errorf("Expected %v, but received %v", vals, *ops)
}
if len(resp.Location) == 0 {
t.Error("Expected Location to not be empty")
}
if e := "VERSION-ID"; e != *resp.VersionID {
t.Errorf("Expected %q, but received %q", e, resp.VersionID)
}
if len(resp.UploadID) > 0 {
t.Errorf("Expected empty string, but received %q", resp.UploadID)
}
if e, a := "aws:kms", val((*args)[0], "ServerSideEncryption").(string); e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if e, a := "KmsId", val((*args)[0], "SSEKMSKeyId").(string); e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if e, a := "content/type", val((*args)[0], "ContentType").(string); e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
}
func TestUploadOrderSingleFailure(t *testing.T) {
@ -232,9 +334,17 @@ func TestUploadOrderSingleFailure(t *testing.T) {
Body: bytes.NewReader(buf2MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.Nil(t, resp)
if err == nil {
t.Error("Expected error, but receievd nil")
}
if vals := []string{"PutObject"}; !reflect.DeepEqual(vals, *ops) {
t.Errorf("Expected %v, but received %v", vals, *ops)
}
if resp != nil {
t.Errorf("Expected response to be nil, but received %v", resp)
}
}
func TestUploadOrderZero(t *testing.T) {
@ -246,11 +356,25 @@ func TestUploadOrderZero(t *testing.T) {
Body: bytes.NewReader(make([]byte, 0)),
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
assert.Equal(t, 0, buflen(val((*args)[0], "Body")))
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if vals := []string{"PutObject"}; !reflect.DeepEqual(vals, *ops) {
t.Errorf("Expected %v, but received %v", vals, *ops)
}
if len(resp.Location) == 0 {
t.Error("Expected Location to not be empty")
}
if len(resp.UploadID) > 0 {
t.Errorf("Expected empty string, but received %q", resp.UploadID)
}
if e, a := 0, buflen(val((*args)[0], "Body")); e != a {
t.Errorf("Expected %d, but received %d", e, a)
}
}
func TestUploadOrderMultiFailure(t *testing.T) {
@ -273,8 +397,13 @@ func TestUploadOrderMultiFailure(t *testing.T) {
Body: bytes.NewReader(buf12MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *ops)
if err == nil {
t.Error("Expected error, but receievd nil")
}
if e, a := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but received %v", e, a)
}
}
func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
@ -295,9 +424,14 @@ func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
Body: bytes.NewReader(buf12MB),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart",
"UploadPart", "CompleteMultipartUpload", "AbortMultipartUpload"}, *ops)
if err == nil {
t.Error("Expected error, but receievd nil")
}
if e, a := []string{"CreateMultipartUpload", "UploadPart", "UploadPart",
"UploadPart", "CompleteMultipartUpload", "AbortMultipartUpload"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but received %v", e, a)
}
}
func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
@ -316,8 +450,13 @@ func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload"}, *ops)
if err == nil {
t.Error("Expected error, but receievd nil")
}
if e, a := []string{"CreateMultipartUpload"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but received %v", e, a)
}
}
func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
@ -341,8 +480,13 @@ func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
assert.Error(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *ops)
if err == nil {
t.Error("Expected error, but receievd nil")
}
if e, a := []string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but received %v", e, a)
}
}
type failreader struct {
@ -367,9 +511,17 @@ func TestUploadOrderReadFail1(t *testing.T) {
Body: &failreader{times: 1},
})
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).Code())
assert.EqualError(t, err.(awserr.Error).OrigErr(), "random failure")
assert.Equal(t, []string{}, *ops)
if e, a := "ReadRequestBody", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if e, a := err.(awserr.Error).OrigErr().Error(), "random failure"; e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if e, a := []string{}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but received %v", e, a)
}
}
func TestUploadOrderReadFail2(t *testing.T) {
@ -383,10 +535,21 @@ func TestUploadOrderReadFail2(t *testing.T) {
Body: &failreader{times: 2},
})
assert.Equal(t, "MultipartUpload", err.(awserr.Error).Code())
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).OrigErr().(awserr.Error).Code())
assert.Contains(t, err.(awserr.Error).OrigErr().Error(), "random failure")
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
if e, a := "MultipartUpload", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if e, a := "ReadRequestBody", err.(awserr.Error).OrigErr().(awserr.Error).Code(); e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if errStr := err.(awserr.Error).OrigErr().Error(); !strings.Contains(errStr, "random failure") {
t.Errorf("Expected error to contains 'random failure', but was %q", errStr)
}
if e, a := []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but receievd %v", e, a)
}
}
type sizedReader struct {
@ -421,8 +584,13 @@ func TestUploadOrderMultiBufferedReader(t *testing.T) {
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if e, a := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but receievd %v", e, a)
}
// Part lengths
parts := []int{
@ -431,7 +599,10 @@ func TestUploadOrderMultiBufferedReader(t *testing.T) {
buflen(val((*args)[3], "Body")),
}
sort.Ints(parts)
assert.Equal(t, []int{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts)
if e, a := []int{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but receievd %v", e, a)
}
}
func TestUploadOrderMultiBufferedReaderPartial(t *testing.T) {
@ -443,8 +614,13 @@ func TestUploadOrderMultiBufferedReaderPartial(t *testing.T) {
Body: &sizedReader{size: 1024 * 1024 * 12, err: io.EOF},
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if e, a := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but receievd %v", e, a)
}
// Part lengths
parts := []int{
@ -453,7 +629,10 @@ func TestUploadOrderMultiBufferedReaderPartial(t *testing.T) {
buflen(val((*args)[3], "Body")),
}
sort.Ints(parts)
assert.Equal(t, []int{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts)
if e, a := []int{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but receievd %v", e, a)
}
}
// TestUploadOrderMultiBufferedReaderEOF tests the edge case where the
@ -467,8 +646,13 @@ func TestUploadOrderMultiBufferedReaderEOF(t *testing.T) {
Body: &sizedReader{size: 1024 * 1024 * 10, err: io.EOF},
})
assert.NoError(t, err)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops)
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if e, a := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but receievd %v", e, a)
}
// Part lengths
parts := []int{
@ -476,7 +660,10 @@ func TestUploadOrderMultiBufferedReaderEOF(t *testing.T) {
buflen(val((*args)[2], "Body")),
}
sort.Ints(parts)
assert.Equal(t, []int{1024 * 1024 * 5, 1024 * 1024 * 5}, parts)
if e, a := []int{1024 * 1024 * 5, 1024 * 1024 * 5}, parts; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but receievd %v", e, a)
}
}
func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
@ -491,14 +678,30 @@ func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.Error(t, err)
assert.Nil(t, resp)
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
if err == nil {
t.Error("Expected an error, but received nil")
}
if resp != nil {
t.Errorf("Expected nil, but receievd %v", resp)
}
if e, a := []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but receievd %v", e, a)
}
aerr := err.(awserr.Error)
assert.Equal(t, "MultipartUpload", aerr.Code())
assert.Equal(t, "TotalPartsExceeded", aerr.OrigErr().(awserr.Error).Code())
assert.Contains(t, aerr.Error(), "configured MaxUploadParts (2)")
if e, a := "MultipartUpload", aerr.Code(); e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if e, a := "TotalPartsExceeded", aerr.OrigErr().(awserr.Error).Code(); e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if !strings.Contains(aerr.Error(), "configured MaxUploadParts (2)") {
t.Errorf("Expected error to contain 'configured MaxUploadParts (2)', but receievd %q", aerr.Error())
}
}
func TestUploadOrderSingleBufferedReader(t *testing.T) {
@ -510,10 +713,21 @@ func TestUploadOrderSingleBufferedReader(t *testing.T) {
Body: &sizedReader{size: 1024 * 1024 * 2},
})
assert.NoError(t, err)
assert.Equal(t, []string{"PutObject"}, *ops)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if e, a := []string{"PutObject"}, *ops; !reflect.DeepEqual(e, a) {
t.Errorf("Expected %v, but received %v", e, a)
}
if len(resp.Location) == 0 {
t.Error("Expected a value in Location but received empty string")
}
if len(resp.UploadID) > 0 {
t.Errorf("Expected empty string but received %q", resp.UploadID)
}
}
func TestUploadZeroLenObject(t *testing.T) {
@ -531,10 +745,20 @@ func TestUploadZeroLenObject(t *testing.T) {
Body: strings.NewReader(""),
})
assert.NoError(t, err)
assert.True(t, requestMade)
assert.NotEqual(t, "", resp.Location)
assert.Equal(t, "", resp.UploadID)
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if !requestMade {
t.Error("Expected request to have been made, but was not")
}
if len(resp.Location) == 0 {
t.Error("Expected a non-empty string value for Location")
}
if len(resp.UploadID) > 0 {
t.Errorf("Expected empty string, but received %q", resp.UploadID)
}
}
func TestUploadInputS3PutObjectInputPairity(t *testing.T) {
@ -550,8 +774,14 @@ func TestUploadInputS3PutObjectInputPairity(t *testing.T) {
bOnly = append(bOnly, k)
}
}
assert.Empty(t, aOnly, "s3.PutObjectInput")
assert.Empty(t, bOnly, "s3Manager.UploadInput")
if len(aOnly) > 0 {
t.Errorf("Expected empty array, but received %v", aOnly)
}
if len(bOnly) > 0 {
t.Errorf("Expected empty array, but received %v", bOnly)
}
}
type testIncompleteReader struct {
@ -582,13 +812,26 @@ func TestUploadUnexpectedEOF(t *testing.T) {
},
})
assert.Error(t, err)
assert.Equal(t, "CreateMultipartUpload", (*ops)[0])
assert.Equal(t, "UploadPart", (*ops)[1])
assert.Equal(t, "AbortMultipartUpload", (*ops)[len(*ops)-1])
if err == nil {
t.Error("Expected error, but received none")
}
if e, a := "CreateMultipartUpload", (*ops)[0]; e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if e, a := "UploadPart", (*ops)[1]; e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
if e, a := "AbortMultipartUpload", (*ops)[len(*ops)-1]; e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
// Part lengths
assert.Equal(t, 1024*1024*5, buflen(val((*args)[1], "Body")))
if e, a := 1024*1024*5, buflen(val((*args)[1], "Body")); e != a {
t.Errorf("Expected %d, but received %d", e, a)
}
}
func compareStructType(a, b reflect.Type) map[string]int {
@ -668,8 +911,13 @@ func TestReaderAt(t *testing.T) {
Body: &fooReaderAt{},
})
assert.NoError(t, err)
assert.Equal(t, contentLen, "12")
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if e, a := "12", contentLen; e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
}
func TestSSE(t *testing.T) {

View file

@ -5,17 +5,27 @@ import (
"encoding/base64"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request"
)
var errSSERequiresSSL = awserr.New("ConfigError", "cannot send SSE keys over HTTP.", nil)
func validateSSERequiresSSL(r *request.Request) {
if r.HTTPRequest.URL.Scheme != "https" {
p, _ := awsutil.ValuesAtPath(r.Params, "SSECustomerKey||CopySourceSSECustomerKey")
if len(p) > 0 {
if r.HTTPRequest.URL.Scheme == "https" {
return
}
if iface, ok := r.Params.(sseCustomerKeyGetter); ok {
if len(iface.getSSECustomerKey()) > 0 {
r.Error = errSSERequiresSSL
return
}
}
if iface, ok := r.Params.(copySourceSSECustomerKeyGetter); ok {
if len(iface.getCopySourceSSECustomerKey()) > 0 {
r.Error = errSSERequiresSSL
return
}
}
}