package handler

import (
	"encoding/base64"
	"encoding/xml"
	"fmt"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-http-gw/internal/data"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
	oidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id/test"
	"github.com/stretchr/testify/require"
	"github.com/valyala/fasthttp"
)

func TestPreflight(t *testing.T) {
	hc := prepareHandlerContext(t)

	bktName := "bucket-preflight"
	cnrID, cnr, err := hc.prepareContainer(bktName, acl.Private)
	require.NoError(t, err)
	hc.frostfs.SetContainer(cnrID, cnr)

	var epoch uint64

	t.Run("CORS object", func(t *testing.T) {
		for _, tc := range []struct {
			name            string
			corsConfig      *data.CORSConfiguration
			requestHeaders  map[string]string
			expectedHeaders map[string]string
			status          int
		}{
			{
				name: "no CORS configuration",
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "",
					fasthttp.HeaderAccessControlAllowMethods:     "",
					fasthttp.HeaderAccessControlAllowHeaders:     "",
					fasthttp.HeaderAccessControlExposeHeaders:    "",
					fasthttp.HeaderAccessControlMaxAge:           "",
					fasthttp.HeaderAccessControlAllowCredentials: "",
				},
				requestHeaders: map[string]string{
					fasthttp.HeaderOrigin:                     "http://example.com",
					fasthttp.HeaderAccessControlRequestMethod: "HEAD",
				},
				status: fasthttp.StatusNotFound,
			},
			{
				name: "specific allowed origin",
				corsConfig: &data.CORSConfiguration{
					CORSRules: []data.CORSRule{
						{
							AllowedOrigins: []string{"http://example.com"},
							AllowedMethods: []string{"GET", "HEAD"},
							AllowedHeaders: []string{"Content-Type"},
							ExposeHeaders:  []string{"x-amz-*", "X-Amz-*"},
							MaxAgeSeconds:  900,
						},
					},
				},
				requestHeaders: map[string]string{
					fasthttp.HeaderOrigin:                      "http://example.com",
					fasthttp.HeaderAccessControlRequestMethod:  "HEAD",
					fasthttp.HeaderAccessControlRequestHeaders: "Content-Type",
				},
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "http://example.com",
					fasthttp.HeaderAccessControlAllowMethods:     "GET, HEAD",
					fasthttp.HeaderAccessControlAllowHeaders:     "Content-Type",
					fasthttp.HeaderAccessControlExposeHeaders:    "x-amz-*, X-Amz-*",
					fasthttp.HeaderAccessControlMaxAge:           "900",
					fasthttp.HeaderAccessControlAllowCredentials: "true",
				},
				status: fasthttp.StatusOK,
			},
			{
				name: "wildcard allowed origin",
				corsConfig: &data.CORSConfiguration{
					CORSRules: []data.CORSRule{
						{
							AllowedOrigins: []string{"*"},
							AllowedMethods: []string{"GET", "HEAD"},
							AllowedHeaders: []string{"Content-Type"},
							ExposeHeaders:  []string{"x-amz-*", "X-Amz-*"},
							MaxAgeSeconds:  900,
						},
					},
				},
				requestHeaders: map[string]string{
					fasthttp.HeaderOrigin:                     "http://example.com",
					fasthttp.HeaderAccessControlRequestMethod: "HEAD",
				},
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "http://example.com",
					fasthttp.HeaderAccessControlAllowMethods:     "GET, HEAD",
					fasthttp.HeaderAccessControlAllowHeaders:     "",
					fasthttp.HeaderAccessControlExposeHeaders:    "x-amz-*, X-Amz-*",
					fasthttp.HeaderAccessControlMaxAge:           "900",
					fasthttp.HeaderAccessControlAllowCredentials: "",
				},
				status: fasthttp.StatusOK,
			},
			{
				name: "not allowed header",
				corsConfig: &data.CORSConfiguration{
					CORSRules: []data.CORSRule{
						{
							AllowedOrigins: []string{"*"},
							AllowedMethods: []string{"GET", "HEAD"},
							AllowedHeaders: []string{"Content-Type"},
						},
					},
				},
				requestHeaders: map[string]string{
					fasthttp.HeaderOrigin:                      "http://example.com",
					fasthttp.HeaderAccessControlRequestMethod:  "GET",
					fasthttp.HeaderAccessControlRequestHeaders: "Authorization",
				},
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "",
					fasthttp.HeaderAccessControlAllowMethods:     "",
					fasthttp.HeaderAccessControlAllowHeaders:     "",
					fasthttp.HeaderAccessControlExposeHeaders:    "",
					fasthttp.HeaderAccessControlMaxAge:           "",
					fasthttp.HeaderAccessControlAllowCredentials: "",
				},
				status: fasthttp.StatusForbidden,
			},
			{
				name: "empty Origin header",
				corsConfig: &data.CORSConfiguration{
					CORSRules: []data.CORSRule{
						{
							AllowedOrigins: []string{"*"},
							AllowedMethods: []string{"GET", "HEAD"},
						},
					},
				},
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "",
					fasthttp.HeaderAccessControlAllowMethods:     "",
					fasthttp.HeaderAccessControlAllowHeaders:     "",
					fasthttp.HeaderAccessControlExposeHeaders:    "",
					fasthttp.HeaderAccessControlMaxAge:           "",
					fasthttp.HeaderAccessControlAllowCredentials: "",
				},
				status: fasthttp.StatusBadRequest,
			},
			{
				name: "empty Access-Control-Request-Method header",
				corsConfig: &data.CORSConfiguration{
					CORSRules: []data.CORSRule{
						{
							AllowedOrigins: []string{"*"},
							AllowedMethods: []string{"GET", "HEAD"},
						},
					},
				},
				requestHeaders: map[string]string{
					fasthttp.HeaderOrigin: "http://example.com",
				},
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "",
					fasthttp.HeaderAccessControlAllowMethods:     "",
					fasthttp.HeaderAccessControlAllowHeaders:     "",
					fasthttp.HeaderAccessControlExposeHeaders:    "",
					fasthttp.HeaderAccessControlMaxAge:           "",
					fasthttp.HeaderAccessControlAllowCredentials: "",
				},
				status: fasthttp.StatusBadRequest,
			},
		} {
			t.Run(tc.name, func(t *testing.T) {
				if tc.corsConfig != nil {
					epoch++
					setCORSObject(t, hc, cnrID, tc.corsConfig, epoch)
				}

				r := prepareCORSRequest(t, bktName, tc.requestHeaders)
				hc.Handler().Preflight(r)

				require.Equal(t, tc.status, r.Response.StatusCode())
				for k, v := range tc.expectedHeaders {
					require.Equal(t, v, string(r.Response.Header.Peek(k)))
				}
			})
		}
	})

	t.Run("CORS config", func(t *testing.T) {
		hc.cfg.cors = &data.CORSRule{
			AllowedOrigins:     []string{"*"},
			AllowedMethods:     []string{"GET", "HEAD"},
			AllowedHeaders:     []string{"Content-Type", "Content-Encoding"},
			ExposeHeaders:      []string{"x-amz-*", "X-Amz-*"},
			MaxAgeSeconds:      900,
			AllowedCredentials: true,
		}

		r := prepareCORSRequest(t, bktName, map[string]string{
			fasthttp.HeaderOrigin:                     "http://example.com",
			fasthttp.HeaderAccessControlRequestMethod: "GET",
		})
		hc.Handler().Preflight(r)

		require.Equal(t, fasthttp.StatusOK, r.Response.StatusCode())
		require.Equal(t, "900", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlMaxAge)))
		require.Equal(t, "*", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlAllowOrigin)))
		require.Equal(t, "GET, HEAD", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlAllowMethods)))
		require.Equal(t, "Content-Type, Content-Encoding", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlAllowHeaders)))
		require.Equal(t, "x-amz-*, X-Amz-*", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlExposeHeaders)))
		require.Equal(t, "true", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlAllowCredentials)))
	})
}

func TestSetCORSHeaders(t *testing.T) {
	hc := prepareHandlerContext(t)

	bktName := "bucket-set-cors-headers"
	cnrID, cnr, err := hc.prepareContainer(bktName, acl.Private)
	require.NoError(t, err)
	hc.frostfs.SetContainer(cnrID, cnr)

	var epoch uint64

	t.Run("CORS object", func(t *testing.T) {
		for _, tc := range []struct {
			name            string
			corsConfig      *data.CORSConfiguration
			requestHeaders  map[string]string
			expectedHeaders map[string]string
		}{
			{
				name: "empty Origin header",
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "",
					fasthttp.HeaderAccessControlAllowMethods:     "",
					fasthttp.HeaderVary:                          "",
					fasthttp.HeaderAccessControlAllowCredentials: "",
				},
			},
			{
				name: "no CORS configuration",
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "",
					fasthttp.HeaderAccessControlAllowMethods:     "",
					fasthttp.HeaderVary:                          "",
					fasthttp.HeaderAccessControlAllowCredentials: "",
				},
				requestHeaders: map[string]string{
					fasthttp.HeaderOrigin: "http://example.com",
				},
			},
			{
				name: "specific allowed origin",
				corsConfig: &data.CORSConfiguration{
					CORSRules: []data.CORSRule{
						{
							AllowedOrigins: []string{"http://example.com"},
							AllowedMethods: []string{"GET", "HEAD"},
						},
					},
				},
				requestHeaders: map[string]string{
					fasthttp.HeaderOrigin: "http://example.com",
				},
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "http://example.com",
					fasthttp.HeaderAccessControlAllowMethods:     "GET, HEAD",
					fasthttp.HeaderVary:                          fasthttp.HeaderOrigin,
					fasthttp.HeaderAccessControlAllowCredentials: "true",
				},
			},
			{
				name: "wildcard allowed origin, with credentials",
				corsConfig: &data.CORSConfiguration{
					CORSRules: []data.CORSRule{
						{
							AllowedOrigins: []string{"*"},
							AllowedMethods: []string{"GET", "HEAD"},
						},
					},
				},
				requestHeaders: func() map[string]string {
					tkn := new(bearer.Token)
					err = tkn.Sign(hc.key.PrivateKey)
					require.NoError(t, err)

					t64 := base64.StdEncoding.EncodeToString(tkn.Marshal())
					require.NotEmpty(t, t64)

					return map[string]string{
						fasthttp.HeaderOrigin:        "http://example.com",
						fasthttp.HeaderAuthorization: "Bearer " + t64,
					}
				}(),
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "http://example.com",
					fasthttp.HeaderAccessControlAllowMethods:     "GET, HEAD",
					fasthttp.HeaderVary:                          fasthttp.HeaderOrigin,
					fasthttp.HeaderAccessControlAllowCredentials: "true",
				},
			},
			{
				name: "wildcard allowed origin, without credentials",
				corsConfig: &data.CORSConfiguration{
					CORSRules: []data.CORSRule{
						{
							AllowedOrigins: []string{"*"},
							AllowedMethods: []string{"GET", "HEAD"},
						},
					},
				},
				requestHeaders: map[string]string{
					fasthttp.HeaderOrigin: "http://example.com",
				},
				expectedHeaders: map[string]string{
					fasthttp.HeaderAccessControlAllowOrigin:      "*",
					fasthttp.HeaderAccessControlAllowMethods:     "GET, HEAD",
					fasthttp.HeaderVary:                          "",
					fasthttp.HeaderAccessControlAllowCredentials: "",
				},
			},
		} {
			t.Run(tc.name, func(t *testing.T) {
				epoch++
				setCORSObject(t, hc, cnrID, tc.corsConfig, epoch)
				r := prepareCORSRequest(t, bktName, tc.requestHeaders)
				hc.Handler().SetCORSHeaders(r)

				require.Equal(t, fasthttp.StatusOK, r.Response.StatusCode())
				for k, v := range tc.expectedHeaders {
					require.Equal(t, v, string(r.Response.Header.Peek(k)))
				}
			})
		}
	})

	t.Run("CORS config", func(t *testing.T) {
		hc.cfg.cors = &data.CORSRule{
			AllowedOrigins:     []string{"*"},
			AllowedMethods:     []string{"GET", "HEAD"},
			AllowedHeaders:     []string{"Content-Type", "Content-Encoding"},
			ExposeHeaders:      []string{"x-amz-*", "X-Amz-*"},
			MaxAgeSeconds:      900,
			AllowedCredentials: true,
		}

		r := prepareCORSRequest(t, bktName, map[string]string{fasthttp.HeaderOrigin: "http://example.com"})
		hc.Handler().SetCORSHeaders(r)

		require.Equal(t, "900", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlMaxAge)))
		require.Equal(t, "*", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlAllowOrigin)))
		require.Equal(t, "GET, HEAD", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlAllowMethods)))
		require.Equal(t, "Content-Type, Content-Encoding", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlAllowHeaders)))
		require.Equal(t, "x-amz-*, X-Amz-*", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlExposeHeaders)))
		require.Equal(t, "true", string(r.Response.Header.Peek(fasthttp.HeaderAccessControlAllowCredentials)))
	})
}

func TestCheckSubslice(t *testing.T) {
	for _, tc := range []struct {
		name     string
		allowed  []string
		actual   []string
		expected bool
	}{
		{
			name:     "empty allowed slice",
			allowed:  []string{},
			actual:   []string{"str1", "str2", "str3"},
			expected: false,
		},
		{
			name:     "empty actual slice",
			allowed:  []string{"str1", "str2", "str3"},
			actual:   []string{},
			expected: true,
		},
		{
			name:     "allowed wildcard",
			allowed:  []string{"str", "*"},
			actual:   []string{"str1", "str2", "str3"},
			expected: true,
		},
		{
			name:     "similar allowed and actual",
			allowed:  []string{"str1", "str2", "str3"},
			actual:   []string{"str1", "str2", "str3"},
			expected: true,
		},
		{
			name:     "allowed actual",
			allowed:  []string{"str", "str1", "str2", "str4"},
			actual:   []string{"str1", "str2"},
			expected: true,
		},
		{
			name:     "not allowed actual",
			allowed:  []string{"str", "str1", "str2", "str4"},
			actual:   []string{"str1", "str5"},
			expected: false,
		},
	} {
		t.Run(tc.name, func(t *testing.T) {
			require.Equal(t, tc.expected, checkSubslice(tc.allowed, tc.actual))
		})
	}
}

func setCORSObject(t *testing.T, hc *handlerContext, cnrID cid.ID, corsConfig *data.CORSConfiguration, epoch uint64) {
	payload, err := xml.Marshal(corsConfig)
	require.NoError(t, err)

	a := object.NewAttribute()
	a.SetKey(object.AttributeFilePath)
	a.SetValue(fmt.Sprintf(corsFilePathTemplate, cnrID))

	objID := oidtest.ID()
	obj := object.New()
	obj.SetAttributes(*a)
	obj.SetOwnerID(hc.owner)
	obj.SetPayload(payload)
	obj.SetPayloadSize(uint64(len(payload)))
	obj.SetContainerID(hc.corsCnr)
	obj.SetID(objID)
	obj.SetCreationEpoch(epoch)

	var addr oid.Address
	addr.SetObject(objID)
	addr.SetContainer(hc.corsCnr)

	hc.frostfs.SetObject(addr, obj)
}