package sakuracloud

import (
	"encoding/json"
	"fmt"
	"net/http"
	"net/http/httptest"
	"sync"
	"testing"

	"github.com/sacloud/libsacloud/api"
	"github.com/sacloud/libsacloud/sacloud"
	"github.com/stretchr/testify/require"
)

type simpleResponse struct {
	*sacloud.DNS `json:"CommonServiceItem,omitempty"`
}

type apiQuery struct {
	Filter struct {
		Name          string `json:"Name"`
		ProviderClass string `json:"Provider.Class"`
	} `json:"Filter"`
}

func fakeAPIServer(handler func(rw http.ResponseWriter, req *http.Request)) func() {
	mux := http.NewServeMux()
	server := httptest.NewServer(mux)

	mux.HandleFunc("/is1a/api/cloud/1.1/commonserviceitem/", handler)

	backup := api.SakuraCloudAPIRoot
	api.SakuraCloudAPIRoot = server.URL
	return func() {
		api.SakuraCloudAPIRoot = backup
	}
}

func TestDNSProvider_addTXTRecord(t *testing.T) {
	searchResp := &api.SearchDNSResponse{}
	tearDown := fakeAPIServer(func(rw http.ResponseWriter, req *http.Request) {
		switch req.Method {
		case http.MethodGet:
			if len(searchResp.CommonServiceDNSItems) == 0 {
				q := &apiQuery{}
				if err := json.Unmarshal([]byte(req.URL.RawQuery), q); err != nil {
					http.Error(rw, err.Error(), http.StatusServiceUnavailable)
				}

				fakeZone := sacloud.CreateNewDNS(q.Filter.Name)
				fakeZone.ID = 123456789012
				searchResp = &api.SearchDNSResponse{CommonServiceDNSItems: []sacloud.DNS{*fakeZone}}
			}

			if err := json.NewEncoder(rw).Encode(searchResp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}
		case http.MethodPut: // Update
			resp := &simpleResponse{}
			if err := json.NewDecoder(req.Body).Decode(resp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}

			var items []sacloud.DNS
			for _, v := range searchResp.CommonServiceDNSItems {
				if resp.Name == v.Name {
					items = append(items, *resp.DNS)
				} else {
					items = append(items, v)
				}
			}
			searchResp.CommonServiceDNSItems = items

			if err := json.NewEncoder(rw).Encode(searchResp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}
		default:
			http.Error(rw, "OOPS", http.StatusServiceUnavailable)
		}
	})
	defer tearDown()

	config := NewDefaultConfig()
	config.Token = "token1"
	config.Secret = "secret1"

	p, err := NewDNSProviderConfig(config)
	require.NoError(t, err)

	err = p.addTXTRecord("test.example.com", "example.com", "dummyValue", 10)
	require.NoError(t, err)

	updZone, err := p.getHostedZone("example.com")
	require.NoError(t, err)
	require.NotNil(t, updZone)

	require.Len(t, updZone.Settings.DNS.ResourceRecordSets, 1)
}

func TestDNSProvider_cleanupTXTRecord(t *testing.T) {
	searchResp := &api.SearchDNSResponse{}

	tearDown := fakeAPIServer(func(rw http.ResponseWriter, req *http.Request) {
		switch req.Method {
		case http.MethodGet:
			if len(searchResp.CommonServiceDNSItems) == 0 {
				q := &apiQuery{}
				if err := json.Unmarshal([]byte(req.URL.RawQuery), q); err != nil {
					http.Error(rw, err.Error(), http.StatusServiceUnavailable)
				}

				fakeZone := sacloud.CreateNewDNS(q.Filter.Name)
				fakeZone.ID = 123456789012
				fakeZone.CreateNewRecord("test", "TXT", "dummyValue", 10)
				searchResp = &api.SearchDNSResponse{CommonServiceDNSItems: []sacloud.DNS{*fakeZone}}
			}

			if err := json.NewEncoder(rw).Encode(searchResp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}
		case http.MethodPut: // Update
			resp := &simpleResponse{}
			if err := json.NewDecoder(req.Body).Decode(resp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}

			var items []sacloud.DNS
			for _, v := range searchResp.CommonServiceDNSItems {
				if resp.Name == v.Name {
					items = append(items, *resp.DNS)
				} else {
					items = append(items, v)
				}
			}
			searchResp.CommonServiceDNSItems = items

			if err := json.NewEncoder(rw).Encode(searchResp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}
		default:
			http.Error(rw, "OOPS", http.StatusServiceUnavailable)
		}
	})
	defer tearDown()

	config := NewDefaultConfig()
	config.Token = "token2"
	config.Secret = "secret2"

	p, err := NewDNSProviderConfig(config)
	require.NoError(t, err)

	err = p.cleanupTXTRecord("test.example.com", "example.com")
	require.NoError(t, err)

	updZone, err := p.getHostedZone("example.com")
	require.NoError(t, err)
	require.NotNil(t, updZone)

	require.Len(t, updZone.Settings.DNS.ResourceRecordSets, 0)
}

func TestDNSProvider_addTXTRecord_concurrent(t *testing.T) {
	searchResp := &api.SearchDNSResponse{}

	tearDown := fakeAPIServer(func(rw http.ResponseWriter, req *http.Request) {
		switch req.Method {
		case http.MethodGet:
			if len(searchResp.CommonServiceDNSItems) == 0 {
				q := &apiQuery{}
				if err := json.Unmarshal([]byte(req.URL.RawQuery), q); err != nil {
					http.Error(rw, err.Error(), http.StatusServiceUnavailable)
				}

				fakeZone := sacloud.CreateNewDNS(q.Filter.Name)
				fakeZone.ID = 123456789012
				searchResp = &api.SearchDNSResponse{CommonServiceDNSItems: []sacloud.DNS{*fakeZone}}
			}

			if err := json.NewEncoder(rw).Encode(searchResp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}
		case http.MethodPut: // Update
			resp := &simpleResponse{}
			if err := json.NewDecoder(req.Body).Decode(resp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}

			var items []sacloud.DNS
			for _, v := range searchResp.CommonServiceDNSItems {
				if resp.Name == v.Name {
					items = append(items, *resp.DNS)
				} else {
					items = append(items, v)
				}
			}
			searchResp.CommonServiceDNSItems = items

			if err := json.NewEncoder(rw).Encode(searchResp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}
		default:
			http.Error(rw, "OOPS", http.StatusServiceUnavailable)
		}
	})
	defer tearDown()

	dummyRecordCount := 10

	var providers []*DNSProvider
	for i := 0; i < dummyRecordCount; i++ {
		config := NewDefaultConfig()
		config.Token = "token3"
		config.Secret = "secret3"

		p, err := NewDNSProviderConfig(config)
		require.NoError(t, err)

		providers = append(providers, p)
	}

	var wg sync.WaitGroup
	wg.Add(len(providers))

	for i, p := range providers {
		go func(fqdn string, client *DNSProvider) {
			err := client.addTXTRecord(fqdn, "example.com", "dummyValue", 10)
			require.NoError(t, err)
			wg.Done()
		}(fmt.Sprintf("test%d.example.com", i), p)
	}

	wg.Wait()

	updZone, err := providers[0].getHostedZone("example.com")
	require.NoError(t, err)
	require.NotNil(t, updZone)

	require.Len(t, updZone.Settings.DNS.ResourceRecordSets, dummyRecordCount)
}

func TestDNSProvider_cleanupTXTRecord_concurrent(t *testing.T) {
	dummyRecordCount := 10

	baseFakeZone := sacloud.CreateNewDNS("example.com")
	baseFakeZone.ID = 123456789012
	for i := 0; i < dummyRecordCount; i++ {
		baseFakeZone.AddRecord(baseFakeZone.CreateNewRecord(fmt.Sprintf("test%d", i), "TXT", "dummyValue", 10))
	}

	searchResp := &api.SearchDNSResponse{CommonServiceDNSItems: []sacloud.DNS{*baseFakeZone}}

	tearDown := fakeAPIServer(func(rw http.ResponseWriter, req *http.Request) {
		switch req.Method {
		case http.MethodGet:
			if err := json.NewEncoder(rw).Encode(searchResp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}
		case http.MethodPut: // Update
			resp := &simpleResponse{}
			if err := json.NewDecoder(req.Body).Decode(resp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}

			var items []sacloud.DNS
			for _, v := range searchResp.CommonServiceDNSItems {
				if resp.Name == v.Name {
					items = append(items, *resp.DNS)
				} else {
					items = append(items, v)
				}
			}
			searchResp.CommonServiceDNSItems = items

			if err := json.NewEncoder(rw).Encode(searchResp); err != nil {
				http.Error(rw, err.Error(), http.StatusServiceUnavailable)
			}
		default:
			http.Error(rw, "OOPS", http.StatusServiceUnavailable)
		}
	})
	defer tearDown()

	fakeZone := sacloud.CreateNewDNS("example.com")
	fakeZone.ID = 123456789012
	for i := 0; i < dummyRecordCount; i++ {
		fakeZone.AddRecord(fakeZone.CreateNewRecord(fmt.Sprintf("test%d", i), "TXT", "dummyValue", 10))
	}

	var providers []*DNSProvider
	for i := 0; i < dummyRecordCount; i++ {
		config := NewDefaultConfig()
		config.Token = "token4"
		config.Secret = "secret4"

		p, err := NewDNSProviderConfig(config)
		require.NoError(t, err)
		providers = append(providers, p)
	}

	var wg sync.WaitGroup
	wg.Add(len(providers))

	for i, p := range providers {
		go func(fqdn string, client *DNSProvider) {
			err := client.cleanupTXTRecord(fqdn, "example.com")
			require.NoError(t, err)
			wg.Done()
		}(fmt.Sprintf("test%d.example.com", i), p)
	}

	wg.Wait()

	updZone, err := providers[0].getHostedZone("example.com")
	require.NoError(t, err)
	require.NotNil(t, updZone)

	require.Len(t, updZone.Settings.DNS.ResourceRecordSets, 0)
}