package handler

import (
	"encoding/xml"
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/data"
	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test"
	"github.com/stretchr/testify/require"
)

func TestCORSOriginWildcard(t *testing.T) {
	body := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedOrigin>*</AllowedOrigin>
	</CORSRule>
</CORSConfiguration>
`
	bodyNoXmlns := `
<CORSConfiguration>
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedOrigin>*</AllowedOrigin>
	</CORSRule>
</CORSConfiguration>`
	hc := prepareHandlerContextWithMinCache(t)

	bktName := "bucket-for-cors"
	box, _ := createAccessBox(t)
	w, r := prepareTestRequest(hc, bktName, "", nil)
	ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	r.Header.Add(api.AmzACL, "public-read")
	hc.Handler().CreateBucketHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	putBucketCORS(hc, bktName, body)

	getBucketCORS(hc, bktName)

	hc.config.useDefaultXMLNS = true
	putBucketCORS(hc, bktName, bodyNoXmlns)

	getBucketCORS(hc, bktName)
}

func TestPreflight(t *testing.T) {
	body := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedOrigin>http://www.example.com</AllowedOrigin>
		<AllowedHeader>Authorization</AllowedHeader>
		<ExposeHeader>x-amz-*</ExposeHeader>
		<ExposeHeader>X-Amz-*</ExposeHeader>
		<MaxAgeSeconds>600</MaxAgeSeconds>
	</CORSRule>
</CORSConfiguration>
`
	hc := prepareHandlerContext(t)

	bktName := "bucket-preflight-test"
	box, _ := createAccessBox(t)
	w, r := prepareTestRequest(hc, bktName, "", nil)
	ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	hc.Handler().CreateBucketHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body))
	ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	hc.Handler().PutBucketCorsHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	for _, tc := range []struct {
		name           string
		origin         string
		method         string
		headers        string
		expectedStatus int
	}{
		{
			name:           "Valid",
			origin:         "http://www.example.com",
			method:         "GET",
			headers:        "Authorization",
			expectedStatus: http.StatusOK,
		},
		{
			name:           "Empty origin",
			method:         "GET",
			headers:        "Authorization",
			expectedStatus: http.StatusBadRequest,
		},
		{
			name:           "Empty request method",
			origin:         "http://www.example.com",
			headers:        "Authorization",
			expectedStatus: http.StatusBadRequest,
		},
		{
			name:           "Not allowed method",
			origin:         "http://www.example.com",
			method:         "PUT",
			headers:        "Authorization",
			expectedStatus: http.StatusForbidden,
		},
		{
			name:           "Not allowed headers",
			origin:         "http://www.example.com",
			method:         "GET",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusForbidden,
		},
	} {
		t.Run(tc.name, func(t *testing.T) {
			w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
			r.Header.Set(api.Origin, tc.origin)
			r.Header.Set(api.AccessControlRequestMethod, tc.method)
			r.Header.Set(api.AccessControlRequestHeaders, tc.headers)
			hc.Handler().Preflight(w, r)
			assertStatus(t, w, tc.expectedStatus)

			if tc.expectedStatus == http.StatusOK {
				require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin))
				require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods))
				require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders))
				require.Equal(t, "x-amz-*, X-Amz-*", w.Header().Get(api.AccessControlExposeHeaders))
				require.Equal(t, "true", w.Header().Get(api.AccessControlAllowCredentials))
				require.Equal(t, "600", w.Header().Get(api.AccessControlMaxAge))
			}
		})
	}
}

func TestPreflightWildcardOrigin(t *testing.T) {
	body := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedMethod>PUT</AllowedMethod>
		<AllowedOrigin>*</AllowedOrigin>
		<AllowedHeader>*</AllowedHeader>
	</CORSRule>
</CORSConfiguration>
`
	hc := prepareHandlerContext(t)

	bktName := "bucket-preflight-wildcard-test"
	box, _ := createAccessBox(t)
	w, r := prepareTestRequest(hc, bktName, "", nil)
	ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	hc.Handler().CreateBucketHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	putBucketCORS(hc, bktName, body)

	for _, tc := range []struct {
		name           string
		origin         string
		method         string
		headers        string
		expectedStatus int
	}{
		{
			name:           "Valid get",
			origin:         "http://www.example.com",
			method:         "GET",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusOK,
		},
		{
			name:           "Valid put",
			origin:         "http://example.com",
			method:         "PUT",
			headers:        "Authorization, Content-Type",
			expectedStatus: http.StatusOK,
		},
		{
			name:           "Empty origin",
			method:         "GET",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusBadRequest,
		},
		{
			name:           "Empty request method",
			origin:         "http://www.example.com",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusBadRequest,
		},
		{
			name:           "Not allowed method",
			origin:         "http://www.example.com",
			method:         "DELETE",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusForbidden,
		},
	} {
		t.Run(tc.name, func(t *testing.T) {
			w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
			r.Header.Set(api.Origin, tc.origin)
			r.Header.Set(api.AccessControlRequestMethod, tc.method)
			r.Header.Set(api.AccessControlRequestHeaders, tc.headers)
			hc.Handler().Preflight(w, r)
			assertStatus(t, w, tc.expectedStatus)

			if tc.expectedStatus == http.StatusOK {
				require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin))
				require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods))
				require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders))
				require.Empty(t, w.Header().Get(api.AccessControlExposeHeaders))
				require.Empty(t, w.Header().Get(api.AccessControlAllowCredentials))
				require.Equal(t, "0", w.Header().Get(api.AccessControlMaxAge))
			}
		})
	}
}

func TestGetLatestCORSVersion(t *testing.T) {
	bodyTree := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedMethod>PUT</AllowedMethod>
		<AllowedOrigin>*</AllowedOrigin>
		<AllowedHeader>*</AllowedHeader>
	</CORSRule>
</CORSConfiguration>
`
	body := `
	<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
		<CORSRule>
			<AllowedMethod>DELETE</AllowedMethod>
			<AllowedOrigin>*</AllowedOrigin>
			<AllowedHeader>*</AllowedHeader>
		</CORSRule>
	</CORSConfiguration>
	`
	hc := prepareHandlerContextWithMinCache(t)

	bktName := "bucket-get-latest-cors"
	info := createBucket(hc, bktName)

	addCORSToTree(hc, bodyTree, info.BktInfo, info.BktInfo.CID)

	w := getBucketCORS(hc, bktName)
	requireEqualCORS(hc.t, bodyTree, w.Body.String())

	hc.tp.AddCORSObject(info.BktInfo, hc.corsCnrID, body)

	w = getBucketCORS(hc, bktName)
	requireEqualCORS(hc.t, body, w.Body.String())

	hc.tp.AddCORSObject(info.BktInfo, hc.corsCnrID, bodyTree)
	w = getBucketCORS(hc, bktName)
	requireEqualCORS(hc.t, bodyTree, w.Body.String())
}

func TestDeleteCORSVersions(t *testing.T) {
	body := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedMethod>PUT</AllowedMethod>
		<AllowedOrigin>*</AllowedOrigin>
		<AllowedHeader>*</AllowedHeader>
	</CORSRule>
</CORSConfiguration>
`
	newBody := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>HEAD</AllowedMethod>
		<AllowedOrigin>*</AllowedOrigin>
		<AllowedHeader>*</AllowedHeader>
	</CORSRule>
</CORSConfiguration>
`
	hc := prepareHandlerContext(t)

	bktName := "bucket-delete-tree-cors-versions"
	info := createBucket(hc, bktName)

	addCORSToTree(hc, body, info.BktInfo, info.BktInfo.CID)
	addCORSToTree(hc, body, info.BktInfo, hc.corsCnrID)
	require.Len(t, hc.tp.Objects(), 2)

	putBucketCORS(hc, bktName, body)
	require.Len(t, hc.tp.Objects(), 1)
	require.Equal(t, body, string(hc.tp.Objects()[0].Payload()))

	hc.tp.AddCORSObject(info.BktInfo, hc.corsCnrID, body)
	require.Len(t, hc.tp.Objects(), 2)

	putBucketCORS(hc, bktName, newBody)
	require.Len(t, hc.tp.Objects(), 1)
	require.Equal(t, newBody, string(hc.tp.Objects()[0].Payload()))

	addCORSToTree(hc, body, info.BktInfo, info.BktInfo.CID)
	addCORSToTree(hc, body, info.BktInfo, hc.corsCnrID)
	hc.tp.AddCORSObject(info.BktInfo, hc.corsCnrID, body)
	require.Len(t, hc.tp.Objects(), 4)

	deleteBucketCORS(hc, bktName)
	require.Len(t, hc.tp.Objects(), 0)
}

func TestDeleteCORSInDeleteBucket(t *testing.T) {
	body := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedMethod>PUT</AllowedMethod>
		<AllowedOrigin>*</AllowedOrigin>
		<AllowedHeader>*</AllowedHeader>
	</CORSRule>
</CORSConfiguration>
`

	hc := prepareHandlerContext(t)

	bktName := "bucket-delete-cors-in-delete-bucket"
	info := createBucket(hc, bktName)

	addCORSToTree(hc, body, info.BktInfo, hc.corsCnrID)
	addCORSToTree(hc, body, info.BktInfo, info.BktInfo.CID)
	hc.tp.AddCORSObject(info.BktInfo, hc.corsCnrID, body)
	require.Len(t, hc.tp.Objects(), 3)

	hc.owner = info.BktInfo.Owner
	deleteBucket(t, hc, bktName, http.StatusNoContent)
	require.Len(t, hc.tp.Objects(), 1) // CORS object in bucket container is not deleted
}

func addCORSToTree(hc *handlerContext, cors string, bkt *data.BucketInfo, corsCnrID cid.ID) {
	var addr oid.Address
	addr.SetContainer(corsCnrID)
	addr.SetObject(oidtest.ID())

	var obj object.Object
	obj.SetPayload([]byte(cors))
	obj.SetPayloadSize(uint64(len(cors)))

	hc.tp.SetObject(addr, &obj)

	meta := make(map[string]string)
	meta["FileName"] = "bucket-cors"
	meta["OID"] = addr.Object().EncodeToString()
	meta["CID"] = addr.Container().EncodeToString()

	_, err := hc.treeMock.AddNode(hc.context, bkt, "system", 0, meta)
	require.NoError(hc.t, err)
}

func requireEqualCORS(t *testing.T, expected string, actual string) {
	expectedCORS := &data.CORSConfiguration{}
	err := xml.NewDecoder(strings.NewReader(expected)).Decode(expectedCORS)
	require.NoError(t, err)

	actualCORS := &data.CORSConfiguration{}
	err = xml.NewDecoder(strings.NewReader(actual)).Decode(actualCORS)
	require.NoError(t, err)

	require.Equal(t, expectedCORS, actualCORS)
}

func putBucketCORS(hc *handlerContext, bktName string, body string) {
	w, r := prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body))
	box, _ := createAccessBox(hc.t)
	r = r.WithContext(middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box}))
	hc.Handler().PutBucketCorsHandler(w, r)
	assertStatus(hc.t, w, http.StatusOK)
}

func deleteBucketCORS(hc *handlerContext, bktName string) {
	w, r := prepareTestPayloadRequest(hc, bktName, "", nil)
	box, _ := createAccessBox(hc.t)
	r = r.WithContext(middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box}))
	hc.Handler().DeleteBucketCorsHandler(w, r)
	assertStatus(hc.t, w, http.StatusNoContent)
}

func getBucketCORS(hc *handlerContext, bktName string) *httptest.ResponseRecorder {
	w, r := prepareTestPayloadRequest(hc, bktName, "", nil)
	hc.Handler().GetBucketCorsHandler(w, r)
	assertStatus(hc.t, w, http.StatusOK)
	return w
}