package dmapi

import (
	"context"
	"fmt"
	"io"
	"net/http"
	"net/http/httptest"
	"sync/atomic"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func mockContext(sessionID string) context.Context {
	if sessionID == "" {
		sessionID = "xxx"
	}

	return context.WithValue(context.Background(), sessionIDKey, sessionID)
}

func TestClient_login_apikey(t *testing.T) {
	testCases := []struct {
		desc               string
		apiKey             string
		expectedError      bool
		expectedStatusCode int
		expectedAuthSid    string
	}{
		{
			desc:               "correct key",
			apiKey:             correctAPIKey,
			expectedStatusCode: 0,
			expectedAuthSid:    correctAPIKey,
		},
		{
			desc:               "incorrect key",
			apiKey:             incorrectAPIKey,
			expectedStatusCode: 2200,
			expectedError:      true,
		},
		{
			desc:               "server error",
			apiKey:             serverErrorAPIKey,
			expectedStatusCode: -500,
			expectedError:      true,
		},
		{
			desc:               "non-ok status code",
			apiKey:             "333",
			expectedStatusCode: 2202,
			expectedError:      true,
		},
	}

	mux, serverURL := setupTest(t)

	mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
		require.Equal(t, http.MethodPost, r.Method)

		switch r.FormValue("api-key") {
		case correctAPIKey:
			_, _ = io.WriteString(w, "Status-Code: 0\nStatus-Text: OK\nAuth-Sid: 123\n\ncom\nnet")
		case incorrectAPIKey:
			_, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error")
		case serverErrorAPIKey:
			http.NotFound(w, r)
		default:
			_, _ = io.WriteString(w, "Status-Code: 2202\nStatus-Text: OK\n\ncom\nnet")
		}
	})

	for _, test := range testCases {
		t.Run(test.desc, func(t *testing.T) {
			client := NewClient(AuthInfo{APIKey: test.apiKey})
			client.BaseURL = serverURL

			response, err := client.login(context.Background())
			if test.expectedError {
				require.Error(t, err)
			} else {
				require.NoError(t, err)
				require.NotNil(t, response)
				assert.Equal(t, test.expectedStatusCode, response.StatusCode)
				assert.Equal(t, test.expectedAuthSid, response.AuthSid)
			}
		})
	}
}

func TestClient_login_username(t *testing.T) {
	testCases := []struct {
		desc               string
		username           string
		password           string
		expectedError      bool
		expectedStatusCode int
		expectedAuthSid    string
	}{
		{
			desc:               "correct username and password",
			username:           correctUsername,
			password:           "go-acme",
			expectedError:      false,
			expectedStatusCode: 0,
			expectedAuthSid:    correctAPIKey,
		},
		{
			desc:               "incorrect username",
			username:           incorrectUsername,
			password:           "go-acme",
			expectedStatusCode: 2200,
			expectedError:      true,
		},
		{
			desc:               "server error",
			username:           serverErrorUsername,
			password:           "go-acme",
			expectedStatusCode: -500,
			expectedError:      true,
		},
		{
			desc:               "non-ok status code",
			username:           "random",
			password:           "go-acme",
			expectedStatusCode: 2202,
			expectedError:      true,
		},
	}

	mux, serverURL := setupTest(t)

	mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
		require.Equal(t, http.MethodPost, r.Method)

		switch r.FormValue("username") {
		case correctUsername:
			_, _ = io.WriteString(w, "Status-Code: 0\nStatus-Text: OK\nAuth-Sid: 123\n\ncom\nnet")
		case incorrectUsername:
			_, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error")
		case serverErrorUsername:
			http.NotFound(w, r)
		default:
			_, _ = io.WriteString(w, "Status-Code: 2202\nStatus-Text: OK\n\ncom\nnet")
		}
	})

	for _, test := range testCases {
		t.Run(test.desc, func(t *testing.T) {
			client := NewClient(AuthInfo{Username: test.username, Password: test.password})
			client.BaseURL = serverURL

			response, err := client.login(context.Background())
			if test.expectedError {
				require.Error(t, err)
			} else {
				require.NoError(t, err)
				require.NotNil(t, response)
				assert.Equal(t, test.expectedStatusCode, response.StatusCode)
				assert.Equal(t, test.expectedAuthSid, response.AuthSid)
			}
		})
	}
}

func TestClient_logout(t *testing.T) {
	testCases := []struct {
		desc               string
		authSid            string
		expectedError      bool
		expectedStatusCode int
	}{
		{
			desc:               "correct auth-sid",
			authSid:            correctAPIKey,
			expectedStatusCode: 0,
		},
		{
			desc:               "incorrect auth-sid",
			authSid:            incorrectAPIKey,
			expectedStatusCode: 2200,
		},
		{
			desc:          "already logged out",
			authSid:       "",
			expectedError: true,
		},
		{
			desc:          "server error",
			authSid:       serverErrorAPIKey,
			expectedError: true,
		},
	}

	mux, serverURL := setupTest(t)

	mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
		require.Equal(t, http.MethodPost, r.Method)

		switch r.FormValue("auth-sid") {
		case correctAPIKey:
			_, _ = io.WriteString(w, "Status-Code: 0\nStatus-Text: OK\n")
		case incorrectAPIKey:
			_, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error")
		default:
			http.NotFound(w, r)
		}
	})

	for _, test := range testCases {
		t.Run(test.desc, func(t *testing.T) {
			client := NewClient(AuthInfo{APIKey: "12345"})
			client.BaseURL = serverURL
			client.token = &Token{SessionID: test.authSid}

			response, err := client.Logout(mockContext(test.authSid))
			if test.expectedError {
				require.Error(t, err)
			} else {
				require.NoError(t, err)
				require.NotNil(t, response)
				assert.Equal(t, test.expectedStatusCode, response.StatusCode)
			}
		})
	}
}

func TestClient_CreateAuthenticatedContext(t *testing.T) {
	mux := http.NewServeMux()
	server := httptest.NewServer(mux)
	t.Cleanup(server.Close)

	id := atomic.Int32{}
	id.Add(100)

	mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
		require.Equal(t, http.MethodPost, r.Method)

		switch r.FormValue("username") {
		case correctUsername:
			_, _ = fmt.Fprintf(w, "Status-Code: 0\nStatus-Text: OK\nAuth-Sid: %d\n\ncom\nnet", id.Load())
			id.Add(100)

		default:
			_, _ = io.WriteString(w, "Status-Code: 2200\nStatus-Text: Authentication error")
		}
	})

	client := NewClient(AuthInfo{Username: correctUsername, Password: "secret"})
	client.HTTPClient = server.Client()
	client.BaseURL = server.URL

	ctx, err := client.CreateAuthenticatedContext(context.Background())
	require.NoError(t, err)

	assert.Equal(t, "100", getSessionID(ctx))

	// the token is not expired then we use the "cache".
	client.muToken.Lock()
	client.token.SessionID = "cache"
	client.muToken.Unlock()

	ctx, err = client.CreateAuthenticatedContext(context.Background())
	require.NoError(t, err)

	assert.Equal(t, "cache", getSessionID(ctx))

	// force the expiration of the token
	client.muToken.Lock()
	client.token.ExpireAt = time.Now().UTC().Add(-1 * time.Hour)
	client.muToken.Unlock()

	ctx, err = client.CreateAuthenticatedContext(context.Background())
	require.NoError(t, err)

	assert.Equal(t, "200", getSessionID(ctx))
}