package handler

import (
	"net/http"
	"strings"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api"
	"git.frostfs.info/TrueCloudLab/frostfs-s3-gw/api/middleware"
	"github.com/stretchr/testify/require"
)

func TestCORSOriginWildcard(t *testing.T) {
	body := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedOrigin>*</AllowedOrigin>
	</CORSRule>
</CORSConfiguration>
`
	hc := prepareHandlerContext(t)

	bktName := "bucket-for-cors"
	box, _ := createAccessBox(t)
	w, r := prepareTestRequest(hc, bktName, "", nil)
	ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	r.Header.Add(api.AmzACL, "public-read")
	hc.Handler().CreateBucketHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body))
	ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	hc.Handler().PutBucketCorsHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
	hc.Handler().GetBucketCorsHandler(w, r)
	assertStatus(t, w, http.StatusOK)
}

func TestPreflight(t *testing.T) {
	body := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedOrigin>http://www.example.com</AllowedOrigin>
		<AllowedHeader>Authorization</AllowedHeader>
		<ExposeHeader>x-amz-*</ExposeHeader>
		<ExposeHeader>X-Amz-*</ExposeHeader>
		<MaxAgeSeconds>600</MaxAgeSeconds>
	</CORSRule>
</CORSConfiguration>
`
	hc := prepareHandlerContext(t)

	bktName := "bucket-preflight-test"
	box, _ := createAccessBox(t)
	w, r := prepareTestRequest(hc, bktName, "", nil)
	ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	hc.Handler().CreateBucketHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body))
	ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	hc.Handler().PutBucketCorsHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	for _, tc := range []struct {
		name           string
		origin         string
		method         string
		headers        string
		expectedStatus int
	}{
		{
			name:           "Valid",
			origin:         "http://www.example.com",
			method:         "GET",
			headers:        "Authorization",
			expectedStatus: http.StatusOK,
		},
		{
			name:           "Empty origin",
			method:         "GET",
			headers:        "Authorization",
			expectedStatus: http.StatusBadRequest,
		},
		{
			name:           "Empty request method",
			origin:         "http://www.example.com",
			headers:        "Authorization",
			expectedStatus: http.StatusBadRequest,
		},
		{
			name:           "Not allowed method",
			origin:         "http://www.example.com",
			method:         "PUT",
			headers:        "Authorization",
			expectedStatus: http.StatusForbidden,
		},
		{
			name:           "Not allowed headers",
			origin:         "http://www.example.com",
			method:         "GET",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusForbidden,
		},
	} {
		t.Run(tc.name, func(t *testing.T) {
			w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
			r.Header.Set(api.Origin, tc.origin)
			r.Header.Set(api.AccessControlRequestMethod, tc.method)
			r.Header.Set(api.AccessControlRequestHeaders, tc.headers)
			hc.Handler().Preflight(w, r)
			assertStatus(t, w, tc.expectedStatus)

			if tc.expectedStatus == http.StatusOK {
				require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin))
				require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods))
				require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders))
				require.Equal(t, "x-amz-*, X-Amz-*", w.Header().Get(api.AccessControlExposeHeaders))
				require.Equal(t, "true", w.Header().Get(api.AccessControlAllowCredentials))
				require.Equal(t, "600", w.Header().Get(api.AccessControlMaxAge))
			}
		})
	}
}

func TestPreflightWildcardOrigin(t *testing.T) {
	body := `
<CORSConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
	<CORSRule>
		<AllowedMethod>GET</AllowedMethod>
		<AllowedMethod>PUT</AllowedMethod>
		<AllowedOrigin>*</AllowedOrigin>
		<AllowedHeader>*</AllowedHeader>
	</CORSRule>
</CORSConfiguration>
`
	hc := prepareHandlerContext(t)

	bktName := "bucket-preflight-wildcard-test"
	box, _ := createAccessBox(t)
	w, r := prepareTestRequest(hc, bktName, "", nil)
	ctx := middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	hc.Handler().CreateBucketHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	w, r = prepareTestPayloadRequest(hc, bktName, "", strings.NewReader(body))
	ctx = middleware.SetBox(r.Context(), &middleware.Box{AccessBox: box})
	r = r.WithContext(ctx)
	hc.Handler().PutBucketCorsHandler(w, r)
	assertStatus(t, w, http.StatusOK)

	for _, tc := range []struct {
		name           string
		origin         string
		method         string
		headers        string
		expectedStatus int
	}{
		{
			name:           "Valid get",
			origin:         "http://www.example.com",
			method:         "GET",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusOK,
		},
		{
			name:           "Valid put",
			origin:         "http://example.com",
			method:         "PUT",
			headers:        "Authorization, Content-Type",
			expectedStatus: http.StatusOK,
		},
		{
			name:           "Empty origin",
			method:         "GET",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusBadRequest,
		},
		{
			name:           "Empty request method",
			origin:         "http://www.example.com",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusBadRequest,
		},
		{
			name:           "Not allowed method",
			origin:         "http://www.example.com",
			method:         "DELETE",
			headers:        "Authorization, Last-Modified",
			expectedStatus: http.StatusForbidden,
		},
	} {
		t.Run(tc.name, func(t *testing.T) {
			w, r = prepareTestPayloadRequest(hc, bktName, "", nil)
			r.Header.Set(api.Origin, tc.origin)
			r.Header.Set(api.AccessControlRequestMethod, tc.method)
			r.Header.Set(api.AccessControlRequestHeaders, tc.headers)
			hc.Handler().Preflight(w, r)
			assertStatus(t, w, tc.expectedStatus)

			if tc.expectedStatus == http.StatusOK {
				require.Equal(t, tc.origin, w.Header().Get(api.AccessControlAllowOrigin))
				require.Equal(t, tc.method, w.Header().Get(api.AccessControlAllowMethods))
				require.Equal(t, tc.headers, w.Header().Get(api.AccessControlAllowHeaders))
				require.Empty(t, w.Header().Get(api.AccessControlExposeHeaders))
				require.Empty(t, w.Header().Get(api.AccessControlAllowCredentials))
				require.Equal(t, "0", w.Header().Get(api.AccessControlMaxAge))
			}
		})
	}
}