package internal

import (
	"fmt"
	"io"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func setupTest(t *testing.T) (*http.ServeMux, *Client) {
	t.Helper()

	mux := http.NewServeMux()
	server := httptest.NewServer(mux)
	t.Cleanup(server.Close)

	client := &Client{
		token:      "secret",
		endpoint:   server.URL,
		httpClient: server.Client(),
	}

	return mux, client
}

func TestClient_GetDomainID(t *testing.T) {
	type expected struct {
		domainID string
		error    bool
	}

	testCases := []struct {
		desc       string
		domainName string
		handler    http.HandlerFunc
		expected   expected
	}{
		{
			desc:       "success",
			domainName: "domain1.com.",
			handler: func(rw http.ResponseWriter, req *http.Request) {
				if req.Method != http.MethodGet {
					http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
					return
				}

				content := `
{
    "domains":[
      {
        "id": "09494b72-b65b-4297-9efb-187f65a0553e",
        "name": "domain1.com.",
        "ttl": 3600,
        "serial": 1351800668,
        "email": "nsadmin@example.org",
        "gslb": 0,
        "created_at": "2012-11-01T20:11:08.000000",
        "updated_at": null,
        "description": "memo"
      },
      {
        "id": "cf661142-e577-40b5-b3eb-75795cdc0cd7",
        "name": "domain2.com.",
        "ttl": 7200,
        "serial": 1351800670,
        "email": "nsadmin2@example.org",
        "gslb": 1,
        "created_at": "2012-11-01T20:11:08.000000",
        "updated_at": "2012-12-01T20:11:08.000000",
        "description": "memomemo"
      }
    ]
}
`
				_, err := fmt.Fprint(rw, content)
				if err != nil {
					http.Error(rw, err.Error(), http.StatusInternalServerError)
					return
				}
			},
			expected: expected{domainID: "09494b72-b65b-4297-9efb-187f65a0553e"},
		},
		{
			desc:       "non existing domain",
			domainName: "domain1.com.",
			handler: func(rw http.ResponseWriter, req *http.Request) {
				if req.Method != http.MethodGet {
					http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
					return
				}

				_, err := fmt.Fprint(rw, "{}")
				if err != nil {
					http.Error(rw, err.Error(), http.StatusInternalServerError)
					return
				}
			},
			expected: expected{error: true},
		},
		{
			desc:       "marshaling error",
			domainName: "domain1.com.",
			handler: func(rw http.ResponseWriter, req *http.Request) {
				if req.Method != http.MethodGet {
					http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
					return
				}

				_, err := fmt.Fprint(rw, "[]")
				if err != nil {
					http.Error(rw, err.Error(), http.StatusInternalServerError)
					return
				}
			},
			expected: expected{error: true},
		},
	}

	for _, test := range testCases {
		t.Run(test.desc, func(t *testing.T) {
			mux, client := setupTest(t)

			mux.Handle("/v1/domains", test.handler)

			domainID, err := client.GetDomainID(test.domainName)

			if test.expected.error {
				require.Error(t, err)
			} else {
				require.NoError(t, err)
				assert.Equal(t, test.expected.domainID, domainID)
			}
		})
	}
}

func TestClient_CreateRecord(t *testing.T) {
	testCases := []struct {
		desc        string
		handler     http.HandlerFunc
		expectError bool
	}{
		{
			desc: "success",
			handler: func(rw http.ResponseWriter, req *http.Request) {
				if req.Method != http.MethodPost {
					http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
					return
				}

				raw, err := io.ReadAll(req.Body)
				if err != nil {
					http.Error(rw, err.Error(), http.StatusBadRequest)
					return
				}
				defer req.Body.Close()

				if string(raw) != `{"name":"lego.com.","type":"TXT","data":"txtTXTtxt","ttl":300}` {
					http.Error(rw, fmt.Sprintf("invalid request body: %s", string(raw)), http.StatusBadRequest)
					return
				}
			},
		},
		{
			desc: "bad request",
			handler: func(rw http.ResponseWriter, req *http.Request) {
				if req.Method != http.MethodPost {
					http.Error(rw, fmt.Sprintf("%s: %s", http.StatusText(http.StatusMethodNotAllowed), req.Method), http.StatusMethodNotAllowed)
					return
				}

				http.Error(rw, "OOPS", http.StatusBadRequest)
			},
			expectError: true,
		},
	}

	for _, test := range testCases {
		t.Run(test.desc, func(t *testing.T) {
			mux, client := setupTest(t)

			mux.Handle("/v1/domains/lego/records", test.handler)

			domainID := "lego"

			record := Record{
				Name: "lego.com.",
				Type: "TXT",
				Data: "txtTXTtxt",
				TTL:  300,
			}

			err := client.CreateRecord(domainID, record)

			if test.expectError {
				require.Error(t, err)
			} else {
				require.NoError(t, err)
			}
		})
	}
}