package middleware

import (
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestHTTPResponseCarrierSetGet(t *testing.T) {
	const (
		testKey1   = "Key"
		testValue1 = "Value"
	)

	respCarrier := httpResponseCarrier{}
	respCarrier.resp = httptest.NewRecorder()

	actual := respCarrier.Get(testKey1)
	require.Equal(t, "", actual)

	respCarrier.Set(testKey1, testValue1)
	actual = respCarrier.Get(testKey1)
	require.Equal(t, testValue1, actual)
}

func TestHTTPResponseCarrierKeys(t *testing.T) {
	const (
		testKey1   = "Key1"
		testKey2   = "Key2"
		testKey3   = "Key3"
		testValue1 = "Value1"
		testValue2 = "Value2"
		testValue3 = "Value3"
	)

	respCarrier := httpResponseCarrier{}
	respCarrier.resp = httptest.NewRecorder()

	actual := respCarrier.Keys()
	require.Equal(t, 0, len(actual))

	respCarrier.Set(testKey1, testValue1)
	respCarrier.Set(testKey2, testValue2)
	respCarrier.Set(testKey3, testValue3)

	actual = respCarrier.Keys()
	require.Equal(t, 3, len(actual))
	require.Contains(t, actual, testKey1)
	require.Contains(t, actual, testKey2)
	require.Contains(t, actual, testKey3)
}

func TestHTTPRequestCarrierSet(t *testing.T) {
	const (
		testKey   = "Key"
		testValue = "Value"
	)

	reqCarrier := httpRequestCarrier{}
	reqCarrier.req = httptest.NewRequest(http.MethodGet, "/test", nil)
	reqCarrier.req.Response = httptest.NewRecorder().Result()

	actual := reqCarrier.req.Response.Header.Get(testKey)
	require.Equal(t, "", actual)

	reqCarrier.Set(testKey, testValue)
	actual = reqCarrier.req.Response.Header.Get(testKey)
	require.Contains(t, testValue, actual)
}

func TestHTTPRequestCarrierGet(t *testing.T) {
	const (
		testKey   = "Key"
		testValue = "Value"
	)

	reqCarrier := httpRequestCarrier{}
	reqCarrier.req = httptest.NewRequest(http.MethodGet, "/test", nil)

	actual := reqCarrier.Get(testKey)
	require.Equal(t, "", actual)

	reqCarrier.req.Header.Set(testKey, testValue)
	actual = reqCarrier.Get(testKey)
	require.Equal(t, testValue, actual)
}

func TestHTTPRequestCarrierKeys(t *testing.T) {
	const (
		testKey1   = "Key1"
		testKey2   = "Key2"
		testKey3   = "Key3"
		testValue1 = "Value1"
		testValue2 = "Value2"
		testValue3 = "Value3"
	)

	reqCarrier := httpRequestCarrier{}
	reqCarrier.req = httptest.NewRequest(http.MethodGet, "/test", nil)

	actual := reqCarrier.Keys()
	require.Equal(t, 0, len(actual))

	reqCarrier.req.Header.Set(testKey1, testValue1)
	reqCarrier.req.Header.Set(testKey2, testValue2)
	reqCarrier.req.Header.Set(testKey3, testValue3)

	actual = reqCarrier.Keys()
	require.Equal(t, 3, len(actual))
	require.Contains(t, actual, testKey1)
	require.Contains(t, actual, testKey2)
	require.Contains(t, actual, testKey3)
}