package route53

import (
	"context"
	"os"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	awsconfig "github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/credentials"
	"github.com/aws/aws-sdk-go-v2/service/route53"
	"github.com/go-acme/lego/v4/platform/tester"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

const envDomain = "R53_DOMAIN"

var envTest = tester.NewEnvTest(
	EnvAccessKeyID,
	EnvSecretAccessKey,
	EnvRegion,
	EnvHostedZoneID,
	EnvMaxRetries,
	EnvTTL,
	EnvPropagationTimeout,
	EnvPollingInterval).
	WithDomain(envDomain).
	WithLiveTestRequirements(EnvAccessKeyID, EnvSecretAccessKey, EnvRegion, envDomain)

type endpointResolverMock struct {
	endpoint string
}

func (e endpointResolverMock) ResolveEndpoint(_, _ string, _ ...interface{}) (aws.Endpoint, error) {
	return aws.Endpoint{URL: e.endpoint}, nil
}

func makeTestProvider(t *testing.T, serverURL string) *DNSProvider {
	t.Helper()

	cfg := aws.Config{
		Credentials:                 credentials.NewStaticCredentialsProvider("abc", "123", " "),
		Region:                      "mock-region",
		EndpointResolverWithOptions: endpointResolverMock{endpoint: serverURL},
		RetryMaxAttempts:            1,
	}

	return &DNSProvider{
		client: route53.NewFromConfig(cfg),
		config: NewDefaultConfig(),
	}
}

func Test_loadCredentials_FromEnv(t *testing.T) {
	defer envTest.RestoreEnv()
	envTest.ClearEnv()

	_ = os.Setenv(EnvAccessKeyID, "123")
	_ = os.Setenv(EnvSecretAccessKey, "456")
	_ = os.Setenv(EnvRegion, "us-east-1")

	ctx := context.Background()

	cfg, err := awsconfig.LoadDefaultConfig(ctx)
	require.NoError(t, err)

	value, err := cfg.Credentials.Retrieve(ctx)
	require.NoError(t, err, "Expected credentials to be set from environment")

	expected := aws.Credentials{
		AccessKeyID:     "123",
		SecretAccessKey: "456",
		SessionToken:    "",
		Source:          "EnvConfigCredentials",
	}

	assert.Equal(t, expected, value)
}

func Test_loadRegion_FromEnv(t *testing.T) {
	defer envTest.RestoreEnv()
	envTest.ClearEnv()

	_ = os.Setenv(EnvRegion, "foo")

	cfg, err := awsconfig.LoadDefaultConfig(context.Background())
	require.NoError(t, err)

	assert.Equal(t, "foo", cfg.Region, "Region")
}

func Test_getHostedZoneID_FromEnv(t *testing.T) {
	defer envTest.RestoreEnv()
	envTest.ClearEnv()

	expectedZoneID := "zoneID"

	_ = os.Setenv(EnvHostedZoneID, expectedZoneID)

	provider, err := NewDNSProvider()
	require.NoError(t, err)

	hostedZoneID, err := provider.getHostedZoneID(context.Background(), "whatever")
	require.NoError(t, err, "HostedZoneID")

	assert.Equal(t, expectedZoneID, hostedZoneID)
}

func TestNewDefaultConfig(t *testing.T) {
	defer envTest.RestoreEnv()

	testCases := []struct {
		desc     string
		envVars  map[string]string
		expected *Config
	}{
		{
			desc: "default configuration",
			expected: &Config{
				MaxRetries:         5,
				TTL:                10,
				PropagationTimeout: 2 * time.Minute,
				PollingInterval:    4 * time.Second,
			},
		},
		{
			desc: "",
			envVars: map[string]string{
				EnvMaxRetries:         "10",
				EnvTTL:                "99",
				EnvPropagationTimeout: "60",
				EnvPollingInterval:    "60",
				EnvHostedZoneID:       "abc123",
			},
			expected: &Config{
				MaxRetries:         10,
				TTL:                99,
				PropagationTimeout: 60 * time.Second,
				PollingInterval:    60 * time.Second,
				HostedZoneID:       "abc123",
			},
		},
	}

	for _, test := range testCases {
		t.Run(test.desc, func(t *testing.T) {
			envTest.ClearEnv()
			for key, value := range test.envVars {
				_ = os.Setenv(key, value)
			}

			config := NewDefaultConfig()

			assert.Equal(t, test.expected, config)
		})
	}
}

func TestDNSProvider_Present(t *testing.T) {
	mockResponses := MockResponseMap{
		"/2013-04-01/hostedzonesbyname":        {StatusCode: 200, Body: ListHostedZonesByNameResponse},
		"/2013-04-01/hostedzone/ABCDEFG/rrset": {StatusCode: 200, Body: ChangeResourceRecordSetsResponse},
		"/2013-04-01/change/123456":            {StatusCode: 200, Body: GetChangeResponse},
		"/2013-04-01/hostedzone/ABCDEFG/rrset?name=_acme-challenge.example.com.&type=TXT": {
			StatusCode: 200,
			Body:       "",
		},
	}

	serverURL := setupTest(t, mockResponses)

	defer envTest.RestoreEnv()
	envTest.ClearEnv()
	provider := makeTestProvider(t, serverURL)

	domain := "example.com"
	keyAuth := "123456d=="

	err := provider.Present(domain, "", keyAuth)
	require.NoError(t, err, "Expected Present to return no error")
}

func Test_createAWSConfig(t *testing.T) {
	testCases := []struct {
		desc             string
		env              map[string]string
		config           *Config
		wantCreds        aws.Credentials
		wantDefaultChain bool
		wantRegion       string
		wantErr          string
	}{
		{
			desc:    "config is nil",
			wantErr: "config is nil",
		},
		{
			desc:    "session token without access key id or secret access key",
			config:  &Config{SessionToken: "foo"},
			wantErr: "SessionToken must be supplied with AccessKeyID and SecretAccessKey",
		},
		{
			desc:    "access key id without secret access key",
			config:  &Config{AccessKeyID: "foo"},
			wantErr: "AccessKeyID and SecretAccessKey must be supplied together",
		},
		{
			desc:    "access key id without secret access key",
			config:  &Config{SecretAccessKey: "foo"},
			wantErr: "AccessKeyID and SecretAccessKey must be supplied together",
		},
		{
			desc:             "credentials from default chain",
			config:           &Config{},
			wantDefaultChain: true,
		},
		{
			desc: "static credentials",
			config: &Config{
				AccessKeyID:     "one",
				SecretAccessKey: "two",
			},
			wantCreds: aws.Credentials{
				AccessKeyID:     "one",
				SecretAccessKey: "two",
				SessionToken:    "",
				Source:          credentials.StaticCredentialsName,
			},
		},
		{
			desc: "static credentials with session token",
			config: &Config{
				AccessKeyID:     "one",
				SecretAccessKey: "two",
				SessionToken:    "three",
			},
			wantCreds: aws.Credentials{
				AccessKeyID:     "one",
				SecretAccessKey: "two",
				SessionToken:    "three",
				Source:          credentials.StaticCredentialsName,
			},
		},
		{
			desc:   "region from env",
			config: &Config{},
			env: map[string]string{
				"AWS_REGION": "foo",
			},
			wantDefaultChain: true,
			wantRegion:       "foo",
		},
		{
			desc: "static region",
			config: &Config{
				Region: "one",
			},
			env: map[string]string{
				"AWS_REGION": "foo",
			},
			wantDefaultChain: true,
			wantRegion:       "one",
		},
	}

	for _, test := range testCases {
		t.Run(test.desc, func(t *testing.T) {
			defer envTest.RestoreEnv()
			envTest.ClearEnv()

			envTest.Apply(test.env)

			ctx := context.Background()

			cfg, err := createAWSConfig(ctx, test.config)
			requireErr(t, err, test.wantErr)

			if err != nil {
				return
			}

			gotCreds, err := cfg.Credentials.Retrieve(ctx)

			if test.wantDefaultChain {
				assert.NotEqual(t, credentials.StaticCredentialsName, gotCreds.Source)
			} else {
				require.NoError(t, err)
				assert.Equal(t, test.wantCreds, gotCreds)
			}

			if test.wantRegion != "" {
				assert.Equal(t, test.wantRegion, cfg.Region)
			}
		})
	}
}

func requireErr(t *testing.T, err error, wantErr string) {
	t.Helper()

	switch {
	case err != nil && wantErr == "":
		// force the assertion error.
		require.NoError(t, err)

	case err == nil && wantErr != "":
		// force the assertion error.
		require.EqualError(t, err, wantErr)

	case err != nil && wantErr != "":
		require.EqualError(t, err, wantErr)
	}
}