package provisioner

import (
	"crypto/hmac"
	"crypto/sha256"
	"crypto/tls"
	"encoding/base64"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/pkg/errors"
	"github.com/smallstep/assert"
	"github.com/smallstep/certificates/webhook"
	"go.step.sm/crypto/x509util"
	"go.step.sm/linkedca"
)

func TestWebhookController_isCertTypeOK(t *testing.T) {
	type test struct {
		wc   *WebhookController
		wh   *Webhook
		want bool
	}
	tests := map[string]test{
		"all/all": {
			wc:   &WebhookController{certType: linkedca.Webhook_ALL},
			wh:   &Webhook{CertType: linkedca.Webhook_ALL.String()},
			want: true,
		},
		"all/x509": {
			wc:   &WebhookController{certType: linkedca.Webhook_ALL},
			wh:   &Webhook{CertType: linkedca.Webhook_X509.String()},
			want: true,
		},
		"all/ssh": {
			wc:   &WebhookController{certType: linkedca.Webhook_ALL},
			wh:   &Webhook{CertType: linkedca.Webhook_SSH.String()},
			want: true,
		},
		`all/""`: {
			wc:   &WebhookController{certType: linkedca.Webhook_ALL},
			wh:   &Webhook{},
			want: true,
		},
		"x509/all": {
			wc:   &WebhookController{certType: linkedca.Webhook_X509},
			wh:   &Webhook{CertType: linkedca.Webhook_ALL.String()},
			want: true,
		},
		"x509/x509": {
			wc:   &WebhookController{certType: linkedca.Webhook_X509},
			wh:   &Webhook{CertType: linkedca.Webhook_X509.String()},
			want: true,
		},
		"x509/ssh": {
			wc:   &WebhookController{certType: linkedca.Webhook_X509},
			wh:   &Webhook{CertType: linkedca.Webhook_SSH.String()},
			want: false,
		},
		`x509/""`: {
			wc:   &WebhookController{certType: linkedca.Webhook_X509},
			wh:   &Webhook{},
			want: true,
		},
		"ssh/all": {
			wc:   &WebhookController{certType: linkedca.Webhook_SSH},
			wh:   &Webhook{CertType: linkedca.Webhook_ALL.String()},
			want: true,
		},
		"ssh/x509": {
			wc:   &WebhookController{certType: linkedca.Webhook_SSH},
			wh:   &Webhook{CertType: linkedca.Webhook_X509.String()},
			want: false,
		},
		"ssh/ssh": {
			wc:   &WebhookController{certType: linkedca.Webhook_SSH},
			wh:   &Webhook{CertType: linkedca.Webhook_SSH.String()},
			want: true,
		},
		`ssh/""`: {
			wc:   &WebhookController{certType: linkedca.Webhook_SSH},
			wh:   &Webhook{},
			want: true,
		},
	}
	for name, test := range tests {
		t.Run(name, func(t *testing.T) {
			assert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh))
		})
	}
}

func TestWebhookController_Enrich(t *testing.T) {
	type test struct {
		ctl                *WebhookController
		req                *webhook.RequestBody
		responses          []*webhook.ResponseBody
		expectErr          bool
		expectTemplateData any
	}
	tests := map[string]test{
		"ok/no enriching webhooks": {
			ctl: &WebhookController{
				client:       http.DefaultClient,
				webhooks:     []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
				TemplateData: nil,
			},
			req:                &webhook.RequestBody{},
			responses:          nil,
			expectErr:          false,
			expectTemplateData: nil,
		},
		"ok/one webhook": {
			ctl: &WebhookController{
				client:       http.DefaultClient,
				webhooks:     []*Webhook{{Name: "people", Kind: "ENRICHING"}},
				TemplateData: x509util.TemplateData{},
			},
			req:                &webhook.RequestBody{},
			responses:          []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
			expectErr:          false,
			expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}},
		},
		"ok/two webhooks": {
			ctl: &WebhookController{
				client: http.DefaultClient,
				webhooks: []*Webhook{
					{Name: "people", Kind: "ENRICHING"},
					{Name: "devices", Kind: "ENRICHING"},
				},
				TemplateData: x509util.TemplateData{},
			},
			req: &webhook.RequestBody{},
			responses: []*webhook.ResponseBody{
				{Allow: true, Data: map[string]any{"role": "bar"}},
				{Allow: true, Data: map[string]any{"serial": "123"}},
			},
			expectErr: false,
			expectTemplateData: x509util.TemplateData{
				"Webhooks": map[string]any{
					"devices": map[string]any{"serial": "123"},
					"people":  map[string]any{"role": "bar"},
				},
			},
		},
		"ok/x509 only": {
			ctl: &WebhookController{
				client: http.DefaultClient,
				webhooks: []*Webhook{
					{Name: "people", Kind: "ENRICHING", CertType: linkedca.Webhook_SSH.String()},
					{Name: "devices", Kind: "ENRICHING"},
				},
				TemplateData: x509util.TemplateData{},
				certType:     linkedca.Webhook_X509,
			},
			req: &webhook.RequestBody{},
			responses: []*webhook.ResponseBody{
				{Allow: true, Data: map[string]any{"role": "bar"}},
				{Allow: true, Data: map[string]any{"serial": "123"}},
			},
			expectErr: false,
			expectTemplateData: x509util.TemplateData{
				"Webhooks": map[string]any{
					"devices": map[string]any{"serial": "123"},
				},
			},
		},
		"deny": {
			ctl: &WebhookController{
				client:       http.DefaultClient,
				webhooks:     []*Webhook{{Name: "people", Kind: "ENRICHING"}},
				TemplateData: x509util.TemplateData{},
			},
			req:                &webhook.RequestBody{},
			responses:          []*webhook.ResponseBody{{Allow: false}},
			expectErr:          true,
			expectTemplateData: x509util.TemplateData{},
		},
	}
	for name, test := range tests {
		t.Run(name, func(t *testing.T) {
			for i, wh := range test.ctl.webhooks {
				var j = i
				ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					err := json.NewEncoder(w).Encode(test.responses[j])
					assert.FatalError(t, err)
				}))
				// nolint: gocritic // defer in loop isn't a memory leak
				defer ts.Close()
				wh.URL = ts.URL
			}

			err := test.ctl.Enrich(test.req)
			if (err != nil) != test.expectErr {
				t.Fatalf("Got err %v, want %v", err, test.expectErr)
			}
			assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData)
		})
	}
}

func TestWebhookController_Authorize(t *testing.T) {
	type test struct {
		ctl       *WebhookController
		req       *webhook.RequestBody
		responses []*webhook.ResponseBody
		expectErr bool
	}
	tests := map[string]test{
		"ok/no enriching webhooks": {
			ctl: &WebhookController{
				client:   http.DefaultClient,
				webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
			},
			req:       &webhook.RequestBody{},
			responses: nil,
			expectErr: false,
		},
		"ok": {
			ctl: &WebhookController{
				client:   http.DefaultClient,
				webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
			},
			req:       &webhook.RequestBody{},
			responses: []*webhook.ResponseBody{{Allow: true}},
			expectErr: false,
		},
		"ok/ssh only": {
			ctl: &WebhookController{
				client:   http.DefaultClient,
				webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}},
				certType: linkedca.Webhook_SSH,
			},
			req:       &webhook.RequestBody{},
			responses: []*webhook.ResponseBody{{Allow: false}},
			expectErr: false,
		},
		"deny": {
			ctl: &WebhookController{
				client:   http.DefaultClient,
				webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
			},
			req:       &webhook.RequestBody{},
			responses: []*webhook.ResponseBody{{Allow: false}},
			expectErr: true,
		},
	}
	for name, test := range tests {
		t.Run(name, func(t *testing.T) {
			for i, wh := range test.ctl.webhooks {
				var j = i
				ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					err := json.NewEncoder(w).Encode(test.responses[j])
					assert.FatalError(t, err)
				}))
				// nolint: gocritic // defer in loop isn't a memory leak
				defer ts.Close()
				wh.URL = ts.URL
			}

			err := test.ctl.Authorize(test.req)
			if (err != nil) != test.expectErr {
				t.Fatalf("Got err %v, want %v", err, test.expectErr)
			}
		})
	}
}

func TestWebhook_Do(t *testing.T) {
	csr := parseCertificateRequest(t, "testdata/certs/ecdsa.csr")
	type test struct {
		webhook         Webhook
		dataArg         any
		webhookResponse webhook.ResponseBody
		expectPath      string
		errStatusCode   int
		serverErrMsg    string
		expectErr       error
		// expectToken     any
	}
	tests := map[string]test{
		"ok": {
			webhook: Webhook{
				ID:     "abc123",
				Secret: "c2VjcmV0Cg==",
			},
			webhookResponse: webhook.ResponseBody{
				Data: map[string]interface{}{"role": "dba"},
			},
		},
		"ok/bearer": {
			webhook: Webhook{
				ID:          "abc123",
				Secret:      "c2VjcmV0Cg==",
				BearerToken: "mytoken",
			},
			webhookResponse: webhook.ResponseBody{
				Data: map[string]interface{}{"role": "dba"},
			},
		},
		"ok/basic": {
			webhook: Webhook{
				ID:     "abc123",
				Secret: "c2VjcmV0Cg==",
				BasicAuth: struct {
					Username string
					Password string
				}{
					Username: "myuser",
					Password: "mypass",
				},
			},
			webhookResponse: webhook.ResponseBody{
				Data: map[string]interface{}{"role": "dba"},
			},
		},
		"ok/templated-url": {
			webhook: Webhook{
				ID: "abc123",
				// scheme, host, port will come from test server
				URL:    "/users/{{ .username }}?region={{ .region }}",
				Secret: "c2VjcmV0Cg==",
			},
			dataArg: map[string]interface{}{"username": "areed", "region": "central"},
			webhookResponse: webhook.ResponseBody{
				Data: map[string]interface{}{"role": "dba"},
			},
			expectPath: "/users/areed?region=central",
		},
		/*
			"ok/token from ssh template": {
				webhook: Webhook{
					ID:     "abc123",
					Secret: "c2VjcmV0Cg==",
				},
				webhookResponse: webhook.ResponseBody{
					Data: map[string]interface{}{"role": "dba"},
				},
				dataArg:     sshutil.TemplateData{sshutil.TokenKey: "token"},
				expectToken: "token",
			},
			"ok/token from x509 template": {
				webhook: Webhook{
					ID:     "abc123",
					Secret: "c2VjcmV0Cg==",
				},
				webhookResponse: webhook.ResponseBody{
					Data: map[string]interface{}{"role": "dba"},
				},
				dataArg:     x509util.TemplateData{sshutil.TokenKey: "token"},
				expectToken: "token",
			},
		*/
		"ok/allow": {
			webhook: Webhook{
				ID:     "abc123",
				Secret: "c2VjcmV0Cg==",
			},
			webhookResponse: webhook.ResponseBody{
				Allow: true,
			},
		},
		"fail/404": {
			webhook: Webhook{
				ID:     "abc123",
				Secret: "c2VjcmV0Cg==",
			},
			webhookResponse: webhook.ResponseBody{
				Data: map[string]interface{}{"role": "dba"},
			},
			errStatusCode: 404,
			serverErrMsg:  "item not found",
			expectErr:     errors.New("Webhook server responded with 404"),
		},
	}
	for name, tc := range tests {
		t.Run(name, func(t *testing.T) {
			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				id := r.Header.Get("X-Smallstep-Webhook-ID")
				assert.Equals(t, tc.webhook.ID, id)

				sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature"))
				assert.FatalError(t, err)

				body, err := io.ReadAll(r.Body)
				assert.FatalError(t, err)

				secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret)
				assert.FatalError(t, err)
				mac := hmac.New(sha256.New, secret).Sum(body)
				assert.True(t, hmac.Equal(sig, mac))

				switch {
				case tc.webhook.BearerToken != "":
					ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken)
					assert.Equals(t, ah, r.Header.Get("Authorization"))
				case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "":
					whReq, err := http.NewRequest("", "", http.NoBody)
					assert.FatalError(t, err)
					whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password)
					ah := whReq.Header.Get("Authorization")
					assert.Equals(t, ah, whReq.Header.Get("Authorization"))
				default:
					assert.Equals(t, "", r.Header.Get("Authorization"))
				}

				if tc.expectPath != "" {
					assert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery)
				}

				if tc.errStatusCode != 0 {
					http.Error(w, tc.serverErrMsg, tc.errStatusCode)
					return
				}

				reqBody := new(webhook.RequestBody)
				err = json.Unmarshal(body, reqBody)
				assert.FatalError(t, err)
				// assert.Equals(t, tc.expectToken, reqBody.Token)

				err = json.NewEncoder(w).Encode(tc.webhookResponse)
				assert.FatalError(t, err)
			}))
			defer ts.Close()

			tc.webhook.URL = ts.URL + tc.webhook.URL

			reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr))
			assert.FatalError(t, err)
			got, err := tc.webhook.Do(http.DefaultClient, reqBody, tc.dataArg)
			if tc.expectErr != nil {
				assert.Equals(t, tc.expectErr.Error(), err.Error())
				return
			}
			assert.FatalError(t, err)

			assert.Equals(t, got, &tc.webhookResponse)
		})
	}

	t.Run("disableTLSClientAuth", func(t *testing.T) {
		ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			w.Write([]byte("{}"))
		}))
		ts.TLS.ClientAuth = tls.RequireAnyClientCert
		wh := Webhook{
			URL: ts.URL,
		}
		cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key")
		assert.FatalError(t, err)
		transport := http.DefaultTransport.(*http.Transport).Clone()
		transport.TLSClientConfig = &tls.Config{
			InsecureSkipVerify: true,
			Certificates:       []tls.Certificate{cert},
		}
		client := &http.Client{
			Transport: transport,
		}
		reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr))
		assert.FatalError(t, err)
		_, err = wh.Do(client, reqBody, nil)
		assert.FatalError(t, err)

		wh.DisableTLSClientAuth = true
		_, err = wh.Do(client, reqBody, nil)
		assert.Error(t, err)
	})
}