package handler

import (
	"bytes"
	"crypto/rand"
	"crypto/tls"
	"fmt"
	"io"
	"net/http"
	"net/http/httptest"
	"net/url"
	"strconv"
	"strings"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/layer"
	"github.com/stretchr/testify/require"
)

const (
	aes256Key       = "MTIzNDU2Nzg5MHF3ZXJ0eXVpb3Bhc2RmZ2hqa2x6eGM="
	aes256KeyMD5    = "NtkH/y2maPit+yUkhq4Q7A=="
	partNumberQuery = "partNumber"
	uploadIDQuery   = "uploadId"
)

func TestSimpleGetEncrypted(t *testing.T) {
	tc := prepareHandlerContext(t)

	bktName, objName := "bucket-for-sse-c", "object-to-encrypt"
	bktInfo := createTestBucket(tc, bktName)

	content := "content"
	putEncryptedObject(t, tc, bktName, objName, content)

	objInfo, err := tc.Layer().GetObjectInfo(tc.Context(), &layer.HeadObjectParams{BktInfo: bktInfo, Object: objName})
	require.NoError(t, err)
	obj, err := tc.MockedPool().ReadObject(tc.Context(), layer.PrmObjectRead{Container: bktInfo.CID, Object: objInfo.ID})
	require.NoError(t, err)
	encryptedContent, err := io.ReadAll(obj.Payload)
	require.NoError(t, err)
	require.NotEqual(t, content, string(encryptedContent))

	response, _ := getEncryptedObject(tc, bktName, objName)
	require.Equal(t, content, string(response))
}

func TestGetEncryptedRange(t *testing.T) {
	tc := prepareHandlerContext(t)

	bktName, objName := "bucket-for-sse-c", "object-to-encrypt"
	createTestBucket(tc, bktName)

	var sb strings.Builder
	for i := 0; i < 1<<16+11; i++ {
		switch i {
		case 0:
			sb.Write([]byte("b"))
		case 1<<16 - 2:
			sb.Write([]byte("c"))
		case 1<<16 - 1:
			sb.Write([]byte("d"))
		case 1 << 16:
			sb.Write([]byte("e"))
		case 1<<16 + 1:
			sb.Write([]byte("f"))
		case 1<<16 + 10:
			sb.Write([]byte("g"))
		default:
			sb.Write([]byte("a"))
		}
	}

	content := sb.String()
	putEncryptedObject(t, tc, bktName, objName, content)

	full := getEncryptedObjectRange(t, tc, bktName, objName, 0, sb.Len()-1)
	require.Equalf(t, content, string(full), "expected len: %d, actual len: %d", len(content), len(full))

	beginning := getEncryptedObjectRange(t, tc, bktName, objName, 0, 3)
	require.Equal(t, content[:4], string(beginning))

	middle := getEncryptedObjectRange(t, tc, bktName, objName, 1<<16-3, 1<<16+2)
	require.Equal(t, "acdefa", string(middle))

	end := getEncryptedObjectRange(t, tc, bktName, objName, 1<<16+2, len(content)-1)
	require.Equal(t, "aaaaaaaag", string(end))
}

func TestS3EncryptionSSECMultipartUpload(t *testing.T) {
	tc := prepareHandlerContext(t)
	bktName, objName := "bucket-for-sse-c-multipart-s3-tests", "multipart_enc"
	createTestBucket(tc, bktName)

	objLen := 30 * 1024 * 1024
	partSize := objLen / 6
	headerMetaKey := api.MetadataPrefix + "foo"
	headers := map[string]string{
		headerMetaKey:   "bar",
		api.ContentType: "text/plain",
	}

	data := multipartUploadEncrypted(tc, bktName, objName, headers, objLen, partSize)
	require.Equal(t, objLen, len(data))

	resData, resHeader := getEncryptedObject(tc, bktName, objName)
	equalDataSlices(t, data, resData)
	require.Equal(t, headers[api.ContentType], resHeader.Get(api.ContentType))
	require.Equal(t, headers[headerMetaKey], resHeader[headerMetaKey][0])
	require.Equal(t, strconv.Itoa(objLen), resHeader.Get(api.ContentLength))

	checkContentUsingRangeEnc(tc, bktName, objName, data, 1000000)
	checkContentUsingRangeEnc(tc, bktName, objName, data, 10000000)
}

func TestMultipartUploadGetRange(t *testing.T) {
	hc := prepareHandlerContext(t)
	bktName, objName := "bucket-for-multipart-s3-tests", "multipart_obj"
	createTestBucket(hc, bktName)

	objLen := 30 * 1024 * 1024
	partSize := objLen / 6
	headerMetaKey := api.MetadataPrefix + "foo"
	headers := map[string]string{
		headerMetaKey:   "bar",
		api.ContentType: "text/plain",
	}

	data := multipartUpload(hc, bktName, objName, headers, objLen, partSize)
	require.Equal(t, objLen, len(data))

	resData, resHeader := getObject(hc, bktName, objName)
	equalDataSlices(t, data, resData)
	require.Equal(t, headers[api.ContentType], resHeader.Get(api.ContentType))
	require.Equal(t, headers[headerMetaKey], resHeader[headerMetaKey][0])
	require.Equal(t, strconv.Itoa(objLen), resHeader.Get(api.ContentLength))

	checkContentUsingRange(hc, bktName, objName, data, 1000000)
	checkContentUsingRange(hc, bktName, objName, data, 10000000)
}

func equalDataSlices(t *testing.T, expected, actual []byte) {
	require.Equal(t, len(expected), len(actual), "sizes don't match")

	if bytes.Equal(expected, actual) {
		return
	}

	for i := 0; i < len(expected); i++ {
		if expected[i] != actual[i] {
			require.Equalf(t, expected[i], actual[i], "differ start with '%d' position, length: %d", i, len(expected))
		}
	}
}

func checkContentUsingRangeEnc(hc *handlerContext, bktName, objName string, data []byte, step int) {
	checkContentUsingRangeBase(hc, bktName, objName, data, step, true)
}

func checkContentUsingRange(hc *handlerContext, bktName, objName string, data []byte, step int) {
	checkContentUsingRangeBase(hc, bktName, objName, data, step, false)
}

func checkContentUsingRangeBase(hc *handlerContext, bktName, objName string, data []byte, step int, encrypted bool) {
	var off, toRead, end int

	for off < len(data) {
		toRead = len(data) - off
		if toRead > step {
			toRead = step
		}
		end = off + toRead - 1

		var rangeData []byte
		if encrypted {
			rangeData = getEncryptedObjectRange(hc.t, hc, bktName, objName, off, end)
		} else {
			rangeData = getObjectRange(hc.t, hc, bktName, objName, off, end)
		}

		equalDataSlices(hc.t, data[off:end+1], rangeData)

		off += step
	}
}

func multipartUploadEncrypted(hc *handlerContext, bktName, objName string, headers map[string]string, objLen, partsSize int) (objData []byte) {
	multipartInfo := createMultipartUploadEncrypted(hc, bktName, objName, headers)

	var sum, currentPart int
	var etags []string
	adjustedSize := partsSize

	for sum < objLen {
		currentPart++

		sum += partsSize
		if sum > objLen {
			adjustedSize = objLen - sum
		}

		etag, data := uploadPartEncrypted(hc, bktName, objName, multipartInfo.UploadID, currentPart, adjustedSize)
		etags = append(etags, etag)
		objData = append(objData, data...)
	}

	completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, etags)
	return
}

func multipartUpload(hc *handlerContext, bktName, objName string, headers map[string]string, objLen, partsSize int) (objData []byte) {
	multipartInfo := createMultipartUpload(hc, bktName, objName, headers)

	var sum, currentPart int
	var etags []string
	adjustedSize := partsSize

	for sum < objLen {
		currentPart++

		sum += partsSize
		if sum > objLen {
			adjustedSize = objLen - sum
		}

		etag, data := uploadPart(hc, bktName, objName, multipartInfo.UploadID, currentPart, adjustedSize)
		etags = append(etags, etag)
		objData = append(objData, data...)
	}

	completeMultipartUpload(hc, bktName, objName, multipartInfo.UploadID, etags)
	return
}

func createMultipartUploadEncrypted(hc *handlerContext, bktName, objName string, headers map[string]string) *InitiateMultipartUploadResponse {
	return createMultipartUploadBase(hc, bktName, objName, true, headers)
}

func createMultipartUpload(hc *handlerContext, bktName, objName string, headers map[string]string) *InitiateMultipartUploadResponse {
	return createMultipartUploadBase(hc, bktName, objName, false, headers)
}

func createMultipartUploadBase(hc *handlerContext, bktName, objName string, encrypted bool, headers map[string]string) *InitiateMultipartUploadResponse {
	w, r := prepareTestRequest(hc, bktName, objName, nil)
	if encrypted {
		setEncryptHeaders(r)
	}
	setHeaders(r, headers)
	hc.Handler().CreateMultipartUploadHandler(w, r)
	multipartInitInfo := &InitiateMultipartUploadResponse{}
	readResponse(hc.t, w, http.StatusOK, multipartInitInfo)

	return multipartInitInfo
}

func completeMultipartUpload(hc *handlerContext, bktName, objName, uploadID string, partsETags []string) {
	w := completeMultipartUploadBase(hc, bktName, objName, uploadID, partsETags)
	assertStatus(hc.t, w, http.StatusOK)
}

func completeMultipartUploadBase(hc *handlerContext, bktName, objName, uploadID string, partsETags []string) *httptest.ResponseRecorder {
	query := make(url.Values)
	query.Set(uploadIDQuery, uploadID)
	complete := &CompleteMultipartUpload{
		Parts: []*layer.CompletedPart{},
	}
	for i, tag := range partsETags {
		complete.Parts = append(complete.Parts, &layer.CompletedPart{
			ETag:       tag,
			PartNumber: i + 1,
		})
	}

	w, r := prepareTestFullRequest(hc, bktName, objName, query, complete)
	hc.Handler().CompleteMultipartUploadHandler(w, r)

	return w
}

func uploadPartEncrypted(hc *handlerContext, bktName, objName, uploadID string, num, size int) (string, []byte) {
	return uploadPartBase(hc, bktName, objName, true, uploadID, num, size)
}

func uploadPart(hc *handlerContext, bktName, objName, uploadID string, num, size int) (string, []byte) {
	return uploadPartBase(hc, bktName, objName, false, uploadID, num, size)
}

func uploadPartBase(hc *handlerContext, bktName, objName string, encrypted bool, uploadID string, num, size int) (string, []byte) {
	partBody := make([]byte, size)
	_, err := rand.Read(partBody)
	require.NoError(hc.t, err)

	query := make(url.Values)
	query.Set(uploadIDQuery, uploadID)
	query.Set(partNumberQuery, strconv.Itoa(num))

	w, r := prepareTestRequestWithQuery(hc, bktName, objName, query, partBody)
	if encrypted {
		setEncryptHeaders(r)
	}
	hc.Handler().UploadPartHandler(w, r)
	assertStatus(hc.t, w, http.StatusOK)

	return w.Header().Get(api.ETag), partBody
}

func TestMultipartEncrypted(t *testing.T) {
	partSize := 5*1048576 + 1<<16 - 5 // 5MB (min part size) + 64kb (cipher block size) - 5 (to check corner range)

	hc := prepareHandlerContext(t)

	bktName, objName := "bucket-for-sse-c-multipart", "object-to-encrypt-multipart"
	createTestBucket(hc, bktName)

	multipartInitInfo := createMultipartUploadEncrypted(hc, bktName, objName, map[string]string{})
	part1ETag, part1 := uploadPartEncrypted(hc, bktName, objName, multipartInitInfo.UploadID, 1, partSize)
	part2ETag, part2 := uploadPartEncrypted(hc, bktName, objName, multipartInitInfo.UploadID, 2, 5)
	completeMultipartUpload(hc, bktName, objName, multipartInitInfo.UploadID, []string{part1ETag, part2ETag})

	res, _ := getEncryptedObject(hc, bktName, objName)
	require.Equal(t, len(part1)+len(part2), len(res))
	require.Equal(t, append(part1, part2...), res)

	part2Range := getEncryptedObjectRange(t, hc, bktName, objName, len(part1), len(part1)+len(part2)-1)
	require.Equal(t, part2[0:], part2Range)
}

func putEncryptedObject(t *testing.T, tc *handlerContext, bktName, objName, content string) {
	body := bytes.NewReader([]byte(content))
	w, r := prepareTestPayloadRequest(tc, bktName, objName, body)
	setEncryptHeaders(r)
	tc.Handler().PutObjectHandler(w, r)
	assertStatus(t, w, http.StatusOK)
}

func getEncryptedObject(hc *handlerContext, bktName, objName string) ([]byte, http.Header) {
	w, r := prepareTestRequest(hc, bktName, objName, nil)
	setEncryptHeaders(r)
	return getObjectBase(hc, w, r)
}

func getObject(hc *handlerContext, bktName, objName string) ([]byte, http.Header) {
	w, r := prepareTestRequest(hc, bktName, objName, nil)
	return getObjectBase(hc, w, r)
}

func getObjectBase(hc *handlerContext, w *httptest.ResponseRecorder, r *http.Request) ([]byte, http.Header) {
	hc.Handler().GetObjectHandler(w, r)
	assertStatus(hc.t, w, http.StatusOK)
	content, err := io.ReadAll(w.Result().Body)
	require.NoError(hc.t, err)
	return content, w.Header()
}

func getEncryptedObjectRange(t *testing.T, tc *handlerContext, bktName, objName string, start, end int) []byte {
	w, r := prepareTestRequest(tc, bktName, objName, nil)
	setEncryptHeaders(r)
	r.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
	tc.Handler().GetObjectHandler(w, r)
	assertStatus(t, w, http.StatusPartialContent)
	content, err := io.ReadAll(w.Result().Body)
	require.NoError(t, err)
	return content
}

func setEncryptHeaders(r *http.Request) {
	r.TLS = &tls.ConnectionState{}
	r.Header.Set(api.AmzServerSideEncryptionCustomerAlgorithm, layer.AESEncryptionAlgorithm)
	r.Header.Set(api.AmzServerSideEncryptionCustomerKey, aes256Key)
	r.Header.Set(api.AmzServerSideEncryptionCustomerKeyMD5, aes256KeyMD5)
}

func setHeaders(r *http.Request, header map[string]string) {
	for key, val := range header {
		r.Header.Set(key, val)
	}
}