diff --git a/docs/storage/driver/s3-aws/s3.go b/docs/storage/driver/s3-aws/s3.go index db61b4e7b..8683f80e1 100644 --- a/docs/storage/driver/s3-aws/s3.go +++ b/docs/storage/driver/s3-aws/s3.go @@ -60,6 +60,7 @@ type DriverParameters struct { Region string RegionEndpoint string Encrypt bool + KeyID string Secure bool ChunkSize int64 RootDirectory string @@ -100,6 +101,7 @@ type driver struct { Bucket string ChunkSize int64 Encrypt bool + KeyID string RootDirectory string StorageClass string } @@ -188,6 +190,11 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { return nil, fmt.Errorf("The secure parameter should be a boolean") } + keyID := parameters["keyid"] + if keyID == nil { + keyID = "" + } + chunkSize := int64(defaultChunkSize) chunkSizeParam := parameters["chunksize"] switch v := chunkSizeParam.(type) { @@ -243,6 +250,7 @@ func FromParameters(parameters map[string]interface{}) (*Driver, error) { region, fmt.Sprint(regionEndpoint), encryptBool, + fmt.Sprint(keyID), secureBool, chunkSize, fmt.Sprint(rootDirectory), @@ -317,6 +325,7 @@ func New(params DriverParameters) (*Driver, error) { Bucket: params.Bucket, ChunkSize: params.ChunkSize, Encrypt: params.Encrypt, + KeyID: params.KeyID, RootDirectory: params.RootDirectory, StorageClass: params.StorageClass, } @@ -353,6 +362,7 @@ func (d *driver) PutContent(ctx context.Context, path string, contents []byte) e ContentType: d.getContentType(), ACL: d.getACL(), ServerSideEncryption: d.getEncryptionMode(), + SSEKMSKeyId: d.getSSEKMSKeyID(), StorageClass: d.getStorageClass(), Body: bytes.NewReader(contents), }) @@ -390,6 +400,7 @@ func (d *driver) Writer(ctx context.Context, path string, append bool) (storaged ContentType: d.getContentType(), ACL: d.getACL(), ServerSideEncryption: d.getEncryptionMode(), + SSEKMSKeyId: d.getSSEKMSKeyID(), StorageClass: d.getStorageClass(), }) if err != nil { @@ -534,6 +545,7 @@ func (d *driver) Move(ctx context.Context, sourcePath string, destPath string) e ContentType: d.getContentType(), ACL: d.getACL(), ServerSideEncryption: d.getEncryptionMode(), + SSEKMSKeyId: d.getSSEKMSKeyID(), StorageClass: d.getStorageClass(), CopySource: aws.String(d.Bucket + "/" + d.s3Path(sourcePath)), }) @@ -645,9 +657,19 @@ func parseError(path string, err error) error { } func (d *driver) getEncryptionMode() *string { - if d.Encrypt { + if !d.Encrypt { + return nil + } + if d.KeyID == "" { return aws.String("AES256") } + return aws.String("aws:kms") +} + +func (d *driver) getSSEKMSKeyID() *string { + if d.KeyID != "" { + return aws.String(d.KeyID) + } return nil } diff --git a/docs/storage/driver/s3-aws/s3_test.go b/docs/storage/driver/s3-aws/s3_test.go index f12297bff..bb64ccf44 100644 --- a/docs/storage/driver/s3-aws/s3_test.go +++ b/docs/storage/driver/s3-aws/s3_test.go @@ -27,6 +27,7 @@ func init() { secretKey := os.Getenv("AWS_SECRET_KEY") bucket := os.Getenv("S3_BUCKET") encrypt := os.Getenv("S3_ENCRYPT") + keyID := os.Getenv("S3_KEY_ID") secure := os.Getenv("S3_SECURE") region := os.Getenv("AWS_REGION") root, err := ioutil.TempDir("", "driver-") @@ -60,6 +61,7 @@ func init() { region, regionEndpoint, encryptBool, + keyID, secureBool, minChunkSize, rootDirectory,