package internal

import (
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"net/http/httptest"
	"net/url"
	"reflect"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

const (
	testToken          = "test"
	testRecordContent  = "acme"
	testRecordContent2 = "foo"
	testTTL            = 10
)

func setupTest(t *testing.T) (*Client, *http.ServeMux) {
	t.Helper()

	mux := http.NewServeMux()
	server := httptest.NewServer(mux)
	t.Cleanup(server.Close)

	client := NewClient(testToken)
	client.baseURL, _ = url.Parse(server.URL)

	return client, mux
}

func TestClient_GetZone(t *testing.T) {
	client, mux := setupTest(t)

	expected := Zone{Name: "example.com"}

	mux.Handle("/v2/zones/example.com", validationHandler{
		method: http.MethodGet,
		next:   handleJSONResponse(expected),
	})

	zone, err := client.GetZone(context.Background(), "example.com")
	require.NoError(t, err)

	assert.Equal(t, expected, zone)
}

func TestClient_GetZone_error(t *testing.T) {
	client, mux := setupTest(t)

	mux.Handle("/v2/zones/example.com", validationHandler{
		method: http.MethodGet,
		next:   handleAPIError(),
	})

	_, err := client.GetZone(context.Background(), "example.com")
	require.Error(t, err)
}

func TestClient_GetRRSet(t *testing.T) {
	client, mux := setupTest(t)

	expected := RRSet{
		TTL: testTTL,
		Records: []Records{
			{Content: []string{testRecordContent}},
		},
	}

	mux.Handle("/v2/zones/example.com/foo.example.com/TXT", validationHandler{
		method: http.MethodGet,
		next:   handleJSONResponse(expected),
	})

	rrSet, err := client.GetRRSet(context.Background(), "example.com", "foo.example.com")
	require.NoError(t, err)

	assert.Equal(t, expected, rrSet)
}

func TestClient_GetRRSet_error(t *testing.T) {
	client, mux := setupTest(t)

	mux.Handle("/v2/zones/example.com/foo.example.com/TXT", validationHandler{
		method: http.MethodGet,
		next:   handleAPIError(),
	})

	_, err := client.GetRRSet(context.Background(), "example.com", "foo.example.com")
	require.Error(t, err)
}

func TestClient_DeleteRRSet(t *testing.T) {
	client, mux := setupTest(t)

	mux.Handle("/v2/zones/test.example.com/my.test.example.com/"+txtRecordType,
		validationHandler{method: http.MethodDelete})

	err := client.DeleteRRSet(context.Background(), "test.example.com", "my.test.example.com.")
	require.NoError(t, err)
}

func TestClient_DeleteRRSet_error(t *testing.T) {
	client, mux := setupTest(t)

	mux.Handle("/v2/zones/test.example.com/my.test.example.com/"+txtRecordType, validationHandler{
		method: http.MethodDelete,
		next:   handleAPIError(),
	})

	err := client.DeleteRRSet(context.Background(), "test.example.com", "my.test.example.com.")
	require.NoError(t, err)
}

func TestClient_AddRRSet(t *testing.T) {
	testCases := []struct {
		desc          string
		zone          string
		recordName    string
		value         string
		handledDomain string
		handlers      map[string]http.Handler
		wantErr       bool
	}{
		{
			desc:       "success add",
			zone:       "test.example.com",
			recordName: "my.test.example.com",
			value:      testRecordContent,
			handlers: map[string]http.Handler{
				// createRRSet
				"/v2/zones/test.example.com/my.test.example.com/" + txtRecordType: validationHandler{
					method: http.MethodPost,
					next:   handleAddRRSet([]Records{{Content: []string{testRecordContent}}}),
				},
			},
		},
		{
			desc:       "success update",
			zone:       "test.example.com",
			recordName: "my.test.example.com",
			value:      testRecordContent,
			handlers: map[string]http.Handler{
				"/v2/zones/test.example.com/my.test.example.com/" + txtRecordType: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
					switch req.Method {
					case http.MethodGet: // GetRRSet
						data := RRSet{
							TTL:     testTTL,
							Records: []Records{{Content: []string{testRecordContent2}}},
						}
						handleJSONResponse(data).ServeHTTP(rw, req)
					case http.MethodPut: // updateRRSet
						expected := []Records{
							{Content: []string{testRecordContent}},
							{Content: []string{testRecordContent2}},
						}
						handleAddRRSet(expected).ServeHTTP(rw, req)
					default:
						http.Error(rw, "wrong method", http.StatusMethodNotAllowed)
					}
				}),
			},
		},
		{
			desc:       "not in the zone",
			zone:       "test.example.com",
			recordName: "notfound.example.com",
			value:      testRecordContent,
			wantErr:    true,
		},
	}

	for _, test := range testCases {
		t.Run(test.desc, func(t *testing.T) {
			cl, mux := setupTest(t)

			for pattern, handler := range test.handlers {
				mux.Handle(pattern, handler)
			}

			err := cl.AddRRSet(context.Background(), test.zone, test.recordName, test.value, testTTL)
			if test.wantErr {
				require.Error(t, err)
				return
			}

			require.NoError(t, err)
		})
	}
}

type validationHandler struct {
	method string
	next   http.Handler
}

func (v validationHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
	if req.Header.Get(authorizationHeader) != fmt.Sprintf("%s %s", tokenTypeHeader, testToken) {
		rw.WriteHeader(http.StatusForbidden)
		_ = json.NewEncoder(rw).Encode(APIError{Message: "token up for parsing was not passed through the context"})
		return
	}

	if req.Method != v.method {
		http.Error(rw, "wrong method", http.StatusMethodNotAllowed)
		return
	}

	if v.next != nil {
		v.next.ServeHTTP(rw, req)
	}
}

func handleAPIError() http.HandlerFunc {
	return func(rw http.ResponseWriter, req *http.Request) {
		rw.WriteHeader(http.StatusInternalServerError)
		_ = json.NewEncoder(rw).Encode(APIError{Message: "oops"})
	}
}

func handleJSONResponse(data interface{}) http.HandlerFunc {
	return func(rw http.ResponseWriter, req *http.Request) {
		err := json.NewEncoder(rw).Encode(data)
		if err != nil {
			http.Error(rw, err.Error(), http.StatusInternalServerError)
			return
		}
	}
}

func handleAddRRSet(expected []Records) http.HandlerFunc {
	return func(rw http.ResponseWriter, req *http.Request) {
		body := RRSet{}

		err := json.NewDecoder(req.Body).Decode(&body)
		if err != nil {
			http.Error(rw, err.Error(), http.StatusInternalServerError)
			return
		}

		if body.TTL != testTTL {
			http.Error(rw, "wrong ttl", http.StatusInternalServerError)
			return
		}

		if !reflect.DeepEqual(body.Records, expected) {
			http.Error(rw, "wrong resource records", http.StatusInternalServerError)
			return
		}
	}
}