package ca

import (
	"context"
	"crypto/tls"
	"io"
	"net"
	"net/http"
	"net/http/httptest"
	"reflect"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/pkg/errors"
	"github.com/smallstep/certificates/api"
	"github.com/smallstep/certificates/authority"
	"github.com/smallstep/certificates/errs"
	"go.step.sm/crypto/jose"
	"go.step.sm/crypto/randutil"
)

func newLocalListener() net.Listener {
	l, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
			panic(errors.Wrap(err, "failed to listen on a port"))
		}
	}
	return l
}

func setMinCertDuration(d time.Duration) func() {
	tmp := minCertDuration
	minCertDuration = 1 * time.Second
	return func() {
		minCertDuration = tmp
	}
}

func startCABootstrapServer() *httptest.Server {
	config, err := authority.LoadConfiguration("testdata/ca.json")
	if err != nil {
		panic(err)
	}
	srv := httptest.NewUnstartedServer(nil)
	config.Address = srv.Listener.Addr().String()
	ca, err := New(config)
	if err != nil {
		panic(err)
	}
	srv.Config.Handler = ca.srv.Handler
	srv.TLS = ca.srv.TLSConfig
	srv.StartTLS()
	// Force the use of GetCertificate on IPs
	srv.TLS.Certificates = nil
	return srv
}

func startCAServer(configFile string) (*CA, string, error) {
	config, err := authority.LoadConfiguration(configFile)
	if err != nil {
		return nil, "", err
	}
	listener := newLocalListener()
	config.Address = listener.Addr().String()
	caURL := "https://" + listener.Addr().String()
	ca, err := New(config)
	if err != nil {
		return nil, "", err
	}
	go func() {
		ca.srv.Serve(listener)
	}()
	return ca, caURL, nil
}

func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.URL.Path == "/version" {
			api.JSON(w, api.VersionResponse{
				Version:                     "test",
				RequireClientAuthentication: true,
			})
			return
		}

		for _, s := range nonAuthenticatedPaths {
			if strings.HasPrefix(r.URL.Path, s) || strings.HasPrefix(r.URL.Path, "/1.0"+s) {
				next.ServeHTTP(w, r)
			}
		}
		isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0
		if !isMTLS {
			api.WriteError(w, errs.Unauthorized("missing peer certificate"))
		} else {
			next.ServeHTTP(w, r)
		}
	})
}

func generateBootstrapToken(ca, subject, sha string) string {
	now := time.Now()
	jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password")))
	if err != nil {
		panic(err)
	}
	opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID)
	sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts)
	if err != nil {
		panic(err)
	}
	id, err := randutil.ASCII(64)
	if err != nil {
		panic(err)
	}
	cl := struct {
		SHA string `json:"sha"`
		jose.Claims
		SANS []string `json:"sans"`
	}{
		SHA: sha,
		Claims: jose.Claims{
			ID:        id,
			Subject:   subject,
			Issuer:    "mariano",
			NotBefore: jose.NewNumericDate(now),
			Expiry:    jose.NewNumericDate(now.Add(time.Minute)),
			Audience:  []string{ca + "/sign"},
		},
		SANS: []string{subject},
	}
	raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
	if err != nil {
		panic(err)
	}
	return raw
}

func TestBootstrap(t *testing.T) {
	srv := startCABootstrapServer()
	defer srv.Close()
	token := generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt"))
	if err != nil {
		t.Fatal(err)
	}

	type args struct {
		token string
	}
	tests := []struct {
		name    string
		args    args
		want    *Client
		wantErr bool
	}{
		{"ok", args{token}, client, false},
		{"token err", args{"badtoken"}, nil, true},
		{"bad claims", args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.foo.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"}, nil, true},
		{"bad sha", args{generateBootstrapToken(srv.URL, "subject", "")}, nil, true},
		{"bad aud", args{generateBootstrapToken("", "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := Bootstrap(tt.args.token)
			if (err != nil) != tt.wantErr {
				t.Errorf("Bootstrap() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if tt.wantErr {
				if !reflect.DeepEqual(got, tt.want) {
					t.Errorf("Bootstrap() = %v, want %v", got, tt.want)
				}
			} else {
				if got == nil {
					t.Error("Bootstrap() = nil, want not nil")
				} else {
					if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) {
						t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint)
					}
					gotTR := got.client.GetTransport().(*http.Transport)
					wantTR := tt.want.client.GetTransport().(*http.Transport)
					if !equalPools(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) {
						t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs)
					}
				}
			}
		})
	}
}

func TestBootstrapServerWithoutMTLS(t *testing.T) {
	srv := startCABootstrapServer()
	defer srv.Close()
	token := func() string {
		return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	}

	mtlsServer := startCABootstrapServer()
	next := mtlsServer.Config.Handler
	mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign")
	defer mtlsServer.Close()
	mtlsToken := func() string {
		return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	}

	type args struct {
		ctx   context.Context
		token string
		base  *http.Server
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{"ok", args{context.Background(), token(), &http.Server{}}, false},
		{"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false},
		{"fail", args{context.Background(), "bad-token", &http.Server{}}, true},
		{"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base, VerifyClientCertIfGiven())
			if (err != nil) != tt.wantErr {
				t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if tt.wantErr {
				if got != nil {
					t.Errorf("BootstrapServer() = %v, want nil", got)
				}
			} else {
				expected := &http.Server{
					TLSConfig: got.TLSConfig,
				}
				if !reflect.DeepEqual(got, expected) {
					t.Errorf("BootstrapServer() = %v, want %v", got, expected)
				}
				if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
					t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig)
				}
			}
		})
	}
}

func TestBootstrapServerWithMTLS(t *testing.T) {
	srv := startCABootstrapServer()
	defer srv.Close()
	token := func() string {
		return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	}

	mtlsServer := startCABootstrapServer()
	next := mtlsServer.Config.Handler
	mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign")
	defer mtlsServer.Close()
	mtlsToken := func() string {
		return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	}

	type args struct {
		ctx   context.Context
		token string
		base  *http.Server
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{"ok", args{context.Background(), token(), &http.Server{}}, false},
		{"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false},
		{"fail", args{context.Background(), "bad-token", &http.Server{}}, true},
		{"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base)
			if (err != nil) != tt.wantErr {
				t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if tt.wantErr {
				if got != nil {
					t.Errorf("BootstrapServer() = %v, want nil", got)
				}
			} else {
				expected := &http.Server{
					TLSConfig: got.TLSConfig,
				}
				if !reflect.DeepEqual(got, expected) {
					t.Errorf("BootstrapServer() = %v, want %v", got, expected)
				}
				if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
					t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig)
				}
			}
		})
	}
}

func TestBootstrapClient(t *testing.T) {
	srv := startCABootstrapServer()
	defer srv.Close()
	token := func() string {
		return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	}

	mtlsServer := startCABootstrapServer()
	next := mtlsServer.Config.Handler
	mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign")
	defer mtlsServer.Close()
	mtlsToken := func() string {
		return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	}

	type args struct {
		ctx   context.Context
		token string
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{"ok", args{context.Background(), token()}, false},
		{"ok mtls", args{context.Background(), mtlsToken()}, false},
		{"fail", args{context.Background(), "bad-token"}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := BootstrapClient(tt.args.ctx, tt.args.token)
			if (err != nil) != tt.wantErr {
				t.Errorf("BootstrapClient() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if tt.wantErr {
				if got != nil {
					t.Errorf("BootstrapClient() = %v, want nil", got)
				}
			} else {
				tlsConfig := got.Transport.(*http.Transport).TLSClientConfig
				if tlsConfig == nil || tlsConfig.ClientCAs != nil || tlsConfig.GetClientCertificate == nil || tlsConfig.RootCAs == nil || tlsConfig.GetCertificate != nil {
					t.Errorf("BootstrapClient() invalid Transport = %#v", tlsConfig)
				}
				resp, err := got.Post(srv.URL+"/renew", "application/json", http.NoBody)
				if err != nil {
					t.Errorf("BootstrapClient() failed renewing certificate")
					return
				}
				var renewal api.SignResponse
				if err := readJSON(resp.Body, &renewal); err != nil {
					t.Errorf("BootstrapClient() error reading response: %v", err)
					return
				}
				if renewal.CaPEM.Certificate == nil || renewal.ServerPEM.Certificate == nil || len(renewal.CertChainPEM) == 0 {
					t.Errorf("BootstrapClient() invalid renewal response: %v", renewal)
				}
			}
		})
	}
}

func TestBootstrapClientServerRotation(t *testing.T) {
	reset := setMinCertDuration(1 * time.Second)
	defer reset()

	// Configuration with current root
	config, err := authority.LoadConfiguration("testdata/rotate-ca-0.json")
	if err != nil {
		t.Fatal(err)
	}

	// Get local address
	listener := newLocalListener()
	config.Address = listener.Addr().String()
	caURL := "https://" + listener.Addr().String()

	// Start CA server
	ca, err := New(config)
	if err != nil {
		t.Fatal(err)
	}
	go func() {
		ca.srv.Serve(listener)
	}()
	defer ca.Stop()
	time.Sleep(1 * time.Second)

	// Create bootstrap server
	token := generateBootstrapToken(caURL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	server, err := BootstrapServer(context.Background(), token, &http.Server{
		Addr: ":0",
		Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
			w.Write([]byte("ok"))
		}),
	}, RequireAndVerifyClientCert())
	if err != nil {
		t.Fatal(err)
	}
	listener = newLocalListener()
	srvURL := "https://" + listener.Addr().String()
	go func() {
		server.ServeTLS(listener, "", "")
	}()
	defer server.Close()
	time.Sleep(1 * time.Second)

	// Create bootstrap client
	token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	client, err := BootstrapClient(context.Background(), token)
	if err != nil {
		t.Errorf("BootstrapClient() error = %v", err)
		return
	}

	// doTest does a request that requires mTLS
	doTest := func(client *http.Client) error {
		// test with ca
		resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
		if err != nil {
			return errors.Wrap(err, "client.Post() failed")
		}
		var renew api.SignResponse
		if err := readJSON(resp.Body, &renew); err != nil {
			return errors.Wrap(err, "client.Post() error reading response")
		}
		if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil || len(renew.CertChainPEM) == 0 {
			return errors.New("client.Post() unexpected response found")
		}
		// test with bootstrap server
		resp, err = client.Get(srvURL)
		if err != nil {
			return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
		}
		defer resp.Body.Close()
		b, err := io.ReadAll(resp.Body)
		if err != nil {
			return errors.Wrap(err, "client.Get() error reading response")
		}
		if string(b) != "ok" {
			return errors.New("client.Get() unexpected response found")
		}
		return nil
	}

	// Test with default root
	if err := doTest(client); err != nil {
		t.Errorf("Test with rotate-ca-0.json failed: %v", err)
	}

	// wait for renew
	time.Sleep(5 * time.Second)

	// Reload with configuration with current and future root
	ca.opts.configFile = "testdata/rotate-ca-1.json"
	if err := doReload(ca); err != nil {
		t.Errorf("ca.Reload() error = %v", err)
		return
	}
	if err := doTest(client); err != nil {
		t.Errorf("Test with rotate-ca-1.json failed: %v", err)
	}

	// wait for renew
	time.Sleep(5 * time.Second)

	// Reload with new and old root
	ca.opts.configFile = "testdata/rotate-ca-2.json"
	if err := doReload(ca); err != nil {
		t.Errorf("ca.Reload() error = %v", err)
		return
	}
	if err := doTest(client); err != nil {
		t.Errorf("Test with rotate-ca-2.json failed: %v", err)
	}

	// wait for renew
	time.Sleep(5 * time.Second)

	// Reload with pnly the new root
	ca.opts.configFile = "testdata/rotate-ca-3.json"
	if err := doReload(ca); err != nil {
		t.Errorf("ca.Reload() error = %v", err)
		return
	}
	if err := doTest(client); err != nil {
		t.Errorf("Test with rotate-ca-3.json failed: %v", err)
	}
}

func TestBootstrapClientServerFederation(t *testing.T) {
	reset := setMinCertDuration(1 * time.Second)
	defer reset()

	ca1, caURL1, err := startCAServer("testdata/ca.json")
	if err != nil {
		t.Fatal(err)
	}
	defer ca1.Stop()

	ca2, caURL2, err := startCAServer("testdata/federated-ca.json")
	if err != nil {
		t.Fatal(err)
	}
	defer ca2.Stop()

	// Create bootstrap server
	token := generateBootstrapToken(caURL1, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	server, err := BootstrapServer(context.Background(), token, &http.Server{
		Addr: ":0",
		Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
			w.Write([]byte("ok"))
		}),
	}, RequireAndVerifyClientCert(), AddFederationToClientCAs())
	if err != nil {
		t.Fatal(err)
	}
	listener := newLocalListener()
	srvURL := "https://" + listener.Addr().String()
	go func() {
		server.ServeTLS(listener, "", "")
	}()
	defer server.Close()

	// Create bootstrap client
	token = generateBootstrapToken(caURL2, "client", "c86f74bb7eb2eabef45c4f7fc6c146359ed3a5bbad416b31da5dce8093bcbffd")
	client, err := BootstrapClient(context.Background(), token, AddFederationToRootCAs())
	if err != nil {
		t.Errorf("BootstrapClient() error = %v", err)
		return
	}

	// doTest does a request that requires mTLS
	doTest := func(client *http.Client) error {
		// test with ca
		resp, err := client.Post(caURL2+"/renew", "application/json", http.NoBody)
		if err != nil {
			return errors.Wrap(err, "client.Post() failed")
		}
		var renew api.SignResponse
		if err := readJSON(resp.Body, &renew); err != nil {
			return errors.Wrap(err, "client.Post() error reading response")
		}
		if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil || len(renew.CertChainPEM) == 0 {
			return errors.New("client.Post() unexpected response found")
		}
		// test with bootstrap server
		resp, err = client.Get(srvURL)
		if err != nil {
			return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
		}
		defer resp.Body.Close()
		b, err := io.ReadAll(resp.Body)
		if err != nil {
			return errors.Wrap(err, "client.Get() error reading response")
		}
		if string(b) != "ok" {
			return errors.New("client.Get() unexpected response found")
		}
		return nil
	}

	// Test with default root
	if err := doTest(client); err != nil {
		t.Errorf("Test with rotate-ca-0.json failed: %v", err)
	}
}

// doReload uses the reload implementation but overwrites the new address with
// the one being used.
func doReload(ca *CA) error {
	config, err := authority.LoadConfiguration(ca.opts.configFile)
	if err != nil {
		return errors.Wrap(err, "error reloading ca")
	}

	newCA, err := New(config,
		WithPassword(ca.opts.password),
		WithConfigFile(ca.opts.configFile),
		WithDatabase(ca.auth.GetDatabase()))
	if err != nil {
		return errors.Wrap(err, "error reloading ca")
	}
	// Use same address in new server
	newCA.srv.Addr = ca.srv.Addr
	return ca.srv.Reload(newCA.srv)
}

func TestBootstrapListener(t *testing.T) {
	srv := startCABootstrapServer()
	defer srv.Close()
	token := func() string {
		return generateBootstrapToken(srv.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	}

	mtlsServer := startCABootstrapServer()
	next := mtlsServer.Config.Handler
	mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign")
	defer mtlsServer.Close()
	mtlsToken := func() string {
		return generateBootstrapToken(mtlsServer.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
	}

	type args struct {
		token string
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{"ok", args{token()}, false},
		{"ok mtls", args{mtlsToken()}, false},
		{"fail", args{"bad-token"}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			inner := newLocalListener()
			defer inner.Close()
			lis, err := BootstrapListener(context.Background(), tt.args.token, inner)
			if (err != nil) != tt.wantErr {
				t.Errorf("BootstrapListener() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if tt.wantErr {
				if lis != nil {
					t.Errorf("BootstrapListener() = %v, want nil", lis)
				}
				return
			}
			wg := new(sync.WaitGroup)
			wg.Add(1)
			go func() {
				http.Serve(lis, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					w.Write([]byte("ok"))
				}))
				wg.Done()
			}()
			defer wg.Wait()
			defer lis.Close()

			client, err := BootstrapClient(context.Background(), token())
			if err != nil {
				t.Errorf("BootstrapClient() error = %v", err)
				return
			}
			resp, err := client.Get("https://" + lis.Addr().String())
			if err != nil {
				t.Errorf("client.Get() error = %v", err)
				return
			}
			defer resp.Body.Close()
			b, err := io.ReadAll(resp.Body)
			if err != nil {
				t.Errorf("io.ReadAll() error = %v", err)
				return
			}
			if string(b) != "ok" {
				t.Errorf("client.Get() = %s, want ok", string(b))
				return
			}
		})
	}
}