frostfs-s3-gw/api/handler/encryption_test.go

379 lines
12 KiB
Go
Raw Permalink Normal View History

package handler
import (
"bytes"
"crypto/rand"
"crypto/tls"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"testing"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
"github.com/stretchr/testify/require"
)
const (
aes256Key = "MTIzNDU2Nzg5MHF3ZXJ0eXVpb3Bhc2RmZ2hqa2x6eGM="
aes256KeyMD5 = "NtkH/y2maPit+yUkhq4Q7A=="
partNumberQuery = "partNumber"
uploadIDQuery = "uploadId"
)
func TestSimpleGetEncrypted(t *testing.T) {
tc := prepareHandlerContext(t)
bktName, objName := "bucket-for-sse-c", "object-to-encrypt"
bktInfo := createTestBucket(tc, bktName)
content := "content"
putEncryptedObject(t, tc, bktName, objName, content)
objInfo, err := tc.Layer().GetObjectInfo(tc.Context(), &layer.HeadObjectParams{BktInfo: bktInfo, Object: objName})
require.NoError(t, err)
obj, err := tc.MockedPool().ReadObject(tc.Context(), layer.PrmObjectRead{Container: bktInfo.CID, Object: objInfo.ID})
require.NoError(t, err)
encryptedContent, err := io.ReadAll(obj.Payload)
require.NoError(t, err)
require.NotEqual(t, content, string(encryptedContent))
response, _ := getEncryptedObject(tc, bktName, objName)
require.Equal(t, content, string(response))
}
func TestGetEncryptedRange(t *testing.T) {
tc := prepareHandlerContext(t)
bktName, objName := "bucket-for-sse-c", "object-to-encrypt"
createTestBucket(tc, bktName)
var sb strings.Builder
for i := 0; i < 1<<16+11; i++ {
switch i {
case 0:
sb.Write([]byte("b"))
case 1<<16 - 2:
sb.Write([]byte("c"))
case 1<<16 - 1:
sb.Write([]byte("d"))
case 1 << 16:
sb.Write([]byte("e"))
case 1<<16 + 1:
sb.Write([]byte("f"))
case 1<<16 + 10:
sb.Write([]byte("g"))
default:
sb.Write([]byte("a"))
}
}
content := sb.String()
putEncryptedObject(t, tc, bktName, objName, content)
full := getEncryptedObjectRange(t, tc, bktName, objName, 0, sb.Len()-1)
require.Equalf(t, content, string(full), "expected len: %d, actual len: %d", len(content), len(full))
beginning := getEncryptedObjectRange(t, tc, bktName, objName, 0, 3)
require.Equal(t, content[:4], string(beginning))
middle := getEncryptedObjectRange(t, tc, bktName, objName, 1<<16-3, 1<<16+2)
require.Equal(t, "acdefa", string(middle))
end := getEncryptedObjectRange(t, tc, bktName, objName, 1<<16+2, len(content)-1)
require.Equal(t, "aaaaaaaag", string(end))
}
func TestS3EncryptionSSECMultipartUpload(t *testing.T) {
tc := prepareHandlerContext(t)
bktName, objName := "bucket-for-sse-c-multipart-s3-tests", "multipart_enc"
createTestBucket(tc, bktName)
objLen := 30 * 1024 * 1024
partSize := objLen / 6
headerMetaKey := api.MetadataPrefix + "foo"
headers := map[string]string{
headerMetaKey: "bar",
api.ContentType: "text/plain",
}
data := multipartUploadEncrypted(tc, bktName, objName, headers, objLen, partSize)
require.Equal(t, objLen, len(data))
resData, resHeader := getEncryptedObject(tc, bktName, objName)
equalDataSlices(t, data, resData)
require.Equal(t, headers[api.ContentType], resHeader.Get(api.ContentType))
require.Equal(t, headers[headerMetaKey], resHeader[headerMetaKey][0])
require.Equal(t, strconv.Itoa(objLen), resHeader.Get(api.ContentLength))
checkContentUsingRangeEnc(tc, bktName, objName, data, 1000000)
checkContentUsingRangeEnc(tc, bktName, objName, data, 10000000)
}
func TestMultipartUploadGetRange(t *testing.T) {
hc := prepareHandlerContext(t)
bktName, objName := "bucket-for-multipart-s3-tests", "multipart_obj"
createTestBucket(hc, bktName)
objLen := 30 * 1024 * 1024
partSize := objLen / 6
headerMetaKey := api.MetadataPrefix + "foo"
headers := map[string]string{
headerMetaKey: "bar",
api.ContentType: "text/plain",
}
data := multipartUpload(hc, bktName, objName, headers, objLen, partSize)
require.Equal(t, objLen, len(data))
resData, resHeader := getObject(hc, bktName, objName)
equalDataSlices(t, data, resData)
require.Equal(t, headers[api.ContentType], resHeader.Get(api.ContentType))
require.Equal(t, headers[headerMetaKey], resHeader[headerMetaKey][0])
require.Equal(t, strconv.Itoa(objLen), resHeader.Get(api.ContentLength))
checkContentUsingRange(hc, bktName, objName, data, 1000000)
checkContentUsingRange(hc, bktName, objName, data, 10000000)
}
func equalDataSlices(t *testing.T, expected, actual []byte) {
require.Equal(t, len(expected), len(actual), "sizes don't match")
if bytes.Equal(expected, actual) {
return
}
for i := 0; i < len(expected); i++ {
if expected[i] != actual[i] {
require.Equalf(t, expected[i], actual[i], "differ start with '%d' position, length: %d", i, len(expected))
}
}
}
func checkContentUsingRangeEnc(hc *handlerContext, bktName, objName string, data []byte, step int) {
checkContentUsingRangeBase(hc, bktName, objName, data, step, true)
}
func checkContentUsingRange(hc *handlerContext, bktName, objName string, data []byte, step int) {
checkContentUsingRangeBase(hc, bktName, objName, data, step, false)
}
func checkContentUsingRangeBase(hc *handlerContext, bktName, objName string, data []byte, step int, encrypted bool) {
var off, toRead, end int
for off < len(data) {
toRead = len(data) - off
if toRead > step {
toRead = step
}
end = off + toRead - 1
var rangeData []byte
if encrypted {
rangeData = getEncryptedObjectRange(hc.t, hc, bktName, objName, off, end)
} else {
rangeData = getObjectRange(hc.t, hc, bktName, objName, off, end)
}
equalDataSlices(hc.t, data[off:end+1], rangeData)
off += step
}
}
func multipartUploadEncrypted(hc *handlerContext, bktName, objName string, headers map[string]string, objLen, partsSize int) (objData []byte) {
multipartInfo := createMultipartUploadEncrypted(hc, bktName, objName, headers)
var sum, currentPart int
var etags []string
adjustedSize := partsSize
for sum < objLen {
currentPart++
sum += partsSize
if sum > objLen {
adjustedSize = objLen - sum
}
etag, data := uploadPartEncrypted(hc, bktName, objName, multipartInfo.UploadID, currentPart, adjustedSize)
etags = append(etags, etag)
objData = append(objData, data...)
}
completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, etags)
return
}
func multipartUpload(hc *handlerContext, bktName, objName string, headers map[string]string, objLen, partsSize int) (objData []byte) {
multipartInfo := createMultipartUpload(hc, bktName, objName, headers)
var sum, currentPart int
var etags []string
adjustedSize := partsSize
for sum < objLen {
currentPart++
sum += partsSize
if sum > objLen {
adjustedSize = objLen - sum
}
etag, data := uploadPart(hc, bktName, objName, multipartInfo.UploadID, currentPart, adjustedSize)
etags = append(etags, etag)
objData = append(objData, data...)
}
completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, etags)
return
}
func createMultipartUploadEncrypted(hc *handlerContext, bktName, objName string, headers map[string]string) *InitiateMultipartUploadResponse {
return createMultipartUploadBase(hc, bktName, objName, true, headers)
}
func createMultipartUpload(hc *handlerContext, bktName, objName string, headers map[string]string) *InitiateMultipartUploadResponse {
return createMultipartUploadBase(hc, bktName, objName, false, headers)
}
func createMultipartUploadBase(hc *handlerContext, bktName, objName string, encrypted bool, headers map[string]string) *InitiateMultipartUploadResponse {
w, r := prepareTestRequest(hc, bktName, objName, nil)
if encrypted {
setEncryptHeaders(r)
}
setHeaders(r, headers)
hc.Handler().CreateMultipartUploadHandler(w, r)
multipartInitInfo := &InitiateMultipartUploadResponse{}
readResponse(hc.t, w, http.StatusOK, multipartInitInfo)
return multipartInitInfo
}
func completeMultipartUpload(hc *handlerContext, bktName, objName, uploadID string, partsETags []string) {
w := completeMultipartUploadBase(hc, bktName, objName, uploadID, partsETags)
assertStatus(hc.t, w, http.StatusOK)
}
func completeMultipartUploadBase(hc *handlerContext, bktName, objName, uploadID string, partsETags []string) *httptest.ResponseRecorder {
query := make(url.Values)
query.Set(uploadIDQuery, uploadID)
complete := &CompleteMultipartUpload{
Parts: []*layer.CompletedPart{},
}
for i, tag := range partsETags {
complete.Parts = append(complete.Parts, &layer.CompletedPart{
ETag: tag,
PartNumber: i + 1,
})
}
w, r := prepareTestFullRequest(hc, bktName, objName, query, complete)
hc.Handler().CompleteMultipartUploadHandler(w, r)
return w
}
func uploadPartEncrypted(hc *handlerContext, bktName, objName, uploadID string, num, size int) (string, []byte) {
return uploadPartBase(hc, bktName, objName, true, uploadID, num, size)
}
func uploadPart(hc *handlerContext, bktName, objName, uploadID string, num, size int) (string, []byte) {
return uploadPartBase(hc, bktName, objName, false, uploadID, num, size)
}
func uploadPartBase(hc *handlerContext, bktName, objName string, encrypted bool, uploadID string, num, size int) (string, []byte) {
partBody := make([]byte, size)
_, err := rand.Read(partBody)
require.NoError(hc.t, err)
query := make(url.Values)
query.Set(uploadIDQuery, uploadID)
query.Set(partNumberQuery, strconv.Itoa(num))
w, r := prepareTestRequestWithQuery(hc, bktName, objName, query, partBody)
if encrypted {
setEncryptHeaders(r)
}
hc.Handler().UploadPartHandler(w, r)
assertStatus(hc.t, w, http.StatusOK)
return w.Header().Get(api.ETag), partBody
}
func TestMultipartEncrypted(t *testing.T) {
partSize := 5*1048576 + 1<<16 - 5 // 5MB (min part size) + 64kb (cipher block size) - 5 (to check corner range)
hc := prepareHandlerContext(t)
bktName, objName := "bucket-for-sse-c-multipart", "object-to-encrypt-multipart"
createTestBucket(hc, bktName)
multipartInitInfo := createMultipartUploadEncrypted(hc, bktName, objName, map[string]string{})
part1ETag, part1 := uploadPartEncrypted(hc, bktName, objName, multipartInitInfo.UploadID, 1, partSize)
part2ETag, part2 := uploadPartEncrypted(hc, bktName, objName, multipartInitInfo.UploadID, 2, 5)
completeMultipartUpload(hc, bktName, objName, multipartInitInfo.UploadID, []string{part1ETag, part2ETag})
res, _ := getEncryptedObject(hc, bktName, objName)
require.Equal(t, len(part1)+len(part2), len(res))
require.Equal(t, append(part1, part2...), res)
part2Range := getEncryptedObjectRange(t, hc, bktName, objName, len(part1), len(part1)+len(part2)-1)
require.Equal(t, part2[0:], part2Range)
}
func putEncryptedObject(t *testing.T, tc *handlerContext, bktName, objName, content string) {
body := bytes.NewReader([]byte(content))
w, r := prepareTestPayloadRequest(tc, bktName, objName, body)
setEncryptHeaders(r)
tc.Handler().PutObjectHandler(w, r)
assertStatus(t, w, http.StatusOK)
}
func getEncryptedObject(hc *handlerContext, bktName, objName string) ([]byte, http.Header) {
w, r := prepareTestRequest(hc, bktName, objName, nil)
setEncryptHeaders(r)
return getObjectBase(hc, w, r)
}
func getObject(hc *handlerContext, bktName, objName string) ([]byte, http.Header) {
w, r := prepareTestRequest(hc, bktName, objName, nil)
return getObjectBase(hc, w, r)
}
func getObjectBase(hc *handlerContext, w *httptest.ResponseRecorder, r *http.Request) ([]byte, http.Header) {
hc.Handler().GetObjectHandler(w, r)
assertStatus(hc.t, w, http.StatusOK)
content, err := io.ReadAll(w.Result().Body)
require.NoError(hc.t, err)
return content, w.Header()
}
func getEncryptedObjectRange(t *testing.T, tc *handlerContext, bktName, objName string, start, end int) []byte {
w, r := prepareTestRequest(tc, bktName, objName, nil)
setEncryptHeaders(r)
r.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
tc.Handler().GetObjectHandler(w, r)
assertStatus(t, w, http.StatusPartialContent)
content, err := io.ReadAll(w.Result().Body)
require.NoError(t, err)
return content
}
func setEncryptHeaders(r *http.Request) {
r.TLS = &tls.ConnectionState{}
r.Header.Set(api.AmzServerSideEncryptionCustomerAlgorithm, layer.AESEncryptionAlgorithm)
r.Header.Set(api.AmzServerSideEncryptionCustomerKey, aes256Key)
r.Header.Set(api.AmzServerSideEncryptionCustomerKeyMD5, aes256KeyMD5)
}
func setHeaders(r *http.Request, header map[string]string) {
for key, val := range header {
r.Header.Set(key, val)
}
}