package api

import (
	"bytes"
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/x509"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"reflect"
	"strings"
	"testing"
	"time"

	"github.com/smallstep/assert"
	"github.com/smallstep/certificates/authority"
	"github.com/smallstep/certificates/authority/provisioner"
	"github.com/smallstep/certificates/logging"
	"github.com/smallstep/certificates/sshutil"
	"github.com/smallstep/certificates/templates"
	"golang.org/x/crypto/ssh"
)

var (
	sshSignerKey = mustKey()
	sshUserKey   = mustKey()
	sshHostKey   = mustKey()
)

func mustKey() *ecdsa.PrivateKey {
	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		panic(err)
	}
	return priv
}

func signSSHCertificate(cert *ssh.Certificate) error {
	signerKey, err := ssh.NewPublicKey(sshSignerKey.Public())
	if err != nil {
		return err
	}
	signer, err := ssh.NewSignerFromSigner(sshSignerKey)
	if err != nil {
		return err
	}
	cert.SignatureKey = signerKey
	data := cert.Marshal()
	data = data[:len(data)-4]
	sig, err := signer.Sign(rand.Reader, data)
	if err != nil {
		return err
	}
	cert.Signature = sig
	return nil
}

func getSignedUserCertificate() (*ssh.Certificate, error) {
	key, err := ssh.NewPublicKey(sshUserKey.Public())
	if err != nil {
		return nil, err
	}
	t := time.Now()
	cert := &ssh.Certificate{
		Nonce:           []byte("1234567890"),
		Key:             key,
		Serial:          1234567890,
		CertType:        ssh.UserCert,
		KeyId:           "user@localhost",
		ValidPrincipals: []string{"user"},
		ValidAfter:      uint64(t.Unix()),
		ValidBefore:     uint64(t.Add(time.Hour).Unix()),
		Permissions: ssh.Permissions{
			CriticalOptions: map[string]string{},
			Extensions: map[string]string{
				"permit-X11-forwarding":   "",
				"permit-agent-forwarding": "",
				"permit-port-forwarding":  "",
				"permit-pty":              "",
				"permit-user-rc":          "",
			},
		},
		Reserved: []byte{},
	}
	if err := signSSHCertificate(cert); err != nil {
		return nil, err
	}
	return cert, nil
}

func getSignedHostCertificate() (*ssh.Certificate, error) {
	key, err := ssh.NewPublicKey(sshHostKey.Public())
	if err != nil {
		return nil, err
	}
	t := time.Now()
	cert := &ssh.Certificate{
		Nonce:           []byte("1234567890"),
		Key:             key,
		Serial:          1234567890,
		CertType:        ssh.UserCert,
		KeyId:           "internal.smallstep.com",
		ValidPrincipals: []string{"internal.smallstep.com"},
		ValidAfter:      uint64(t.Unix()),
		ValidBefore:     uint64(t.Add(time.Hour).Unix()),
		Permissions: ssh.Permissions{
			CriticalOptions: map[string]string{},
			Extensions:      map[string]string{},
		},
		Reserved: []byte{},
	}
	if err := signSSHCertificate(cert); err != nil {
		return nil, err
	}
	return cert, nil
}

func TestSSHCertificate_MarshalJSON(t *testing.T) {
	user, err := getSignedUserCertificate()
	assert.FatalError(t, err)
	host, err := getSignedHostCertificate()
	assert.FatalError(t, err)
	userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
	hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())

	type fields struct {
		Certificate *ssh.Certificate
	}
	tests := []struct {
		name    string
		fields  fields
		want    []byte
		wantErr bool
	}{
		{"nil", fields{Certificate: nil}, []byte("null"), false},
		{"user", fields{Certificate: user}, []byte(`"` + userB64 + `"`), false},
		{"user", fields{Certificate: host}, []byte(`"` + hostB64 + `"`), false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			c := SSHCertificate{
				Certificate: tt.fields.Certificate,
			}
			got, err := c.MarshalJSON()
			if (err != nil) != tt.wantErr {
				t.Errorf("SSHCertificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("SSHCertificate.MarshalJSON() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestSSHCertificate_UnmarshalJSON(t *testing.T) {
	user, err := getSignedUserCertificate()
	assert.FatalError(t, err)
	host, err := getSignedHostCertificate()
	assert.FatalError(t, err)
	userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
	hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())
	keyB64 := base64.StdEncoding.EncodeToString(user.Key.Marshal())

	type args struct {
		data []byte
	}
	tests := []struct {
		name    string
		args    args
		want    *ssh.Certificate
		wantErr bool
	}{
		{"null", args{[]byte(`null`)}, nil, false},
		{"empty", args{[]byte(`""`)}, nil, false},
		{"user", args{[]byte(`"` + userB64 + `"`)}, user, false},
		{"host", args{[]byte(`"` + hostB64 + `"`)}, host, false},
		{"bad-string", args{[]byte(userB64)}, nil, true},
		{"bad-base64", args{[]byte(`"this-is-not-base64"`)}, nil, true},
		{"bad-key", args{[]byte(`"bm90LWEta2V5"`)}, nil, true},
		{"bat-cert", args{[]byte(`"` + keyB64 + `"`)}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			c := &SSHCertificate{}
			if err := c.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
				t.Errorf("SSHCertificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
			}
			if !reflect.DeepEqual(tt.want, c.Certificate) {
				t.Errorf("SSHCertificate.UnmarshalJSON() got = %v, want %v\n", c.Certificate, tt.want)
			}
		})
	}
}

func TestSignSSHRequest_Validate(t *testing.T) {
	csr := parseCertificateRequest(csrPEM)
	badCSR := parseCertificateRequest(csrPEM)
	badCSR.SignatureAlgorithm = x509.SHA1WithRSA

	type fields struct {
		PublicKey        []byte
		OTT              string
		CertType         string
		Principals       []string
		ValidAfter       TimeDuration
		ValidBefore      TimeDuration
		AddUserPublicKey []byte
		KeyID            string
		IdentityCSR      CertificateRequest
	}
	tests := []struct {
		name    string
		fields  fields
		wantErr bool
	}{
		{"ok-empty", fields{[]byte("Zm9v"), "ott", "", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false},
		{"ok-user", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false},
		{"ok-host", fields{[]byte("Zm9v"), "ott", "host", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, false},
		{"ok-keyID", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{}}, false},
		{"ok-identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: csr}}, false},
		{"key", fields{nil, "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
		{"key", fields{[]byte(""), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
		{"type", fields{[]byte("Zm9v"), "ott", "foo", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
		{"ott", fields{[]byte("Zm9v"), "", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "", CertificateRequest{}}, true},
		{"identityCSR", fields{[]byte("Zm9v"), "ott", "user", []string{"user"}, TimeDuration{}, TimeDuration{}, nil, "key-id", CertificateRequest{CertificateRequest: badCSR}}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			s := &SSHSignRequest{
				PublicKey:        tt.fields.PublicKey,
				OTT:              tt.fields.OTT,
				CertType:         tt.fields.CertType,
				Principals:       tt.fields.Principals,
				ValidAfter:       tt.fields.ValidAfter,
				ValidBefore:      tt.fields.ValidBefore,
				AddUserPublicKey: tt.fields.AddUserPublicKey,
				KeyID:            tt.fields.KeyID,
				IdentityCSR:      tt.fields.IdentityCSR,
			}
			if err := s.Validate(); (err != nil) != tt.wantErr {
				t.Errorf("SignSSHRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func Test_caHandler_SSHSign(t *testing.T) {
	user, err := getSignedUserCertificate()
	assert.FatalError(t, err)
	host, err := getSignedHostCertificate()
	assert.FatalError(t, err)

	userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
	hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())

	userReq, err := json.Marshal(SSHSignRequest{
		PublicKey: user.Key.Marshal(),
		OTT:       "ott",
	})
	assert.FatalError(t, err)
	hostReq, err := json.Marshal(SSHSignRequest{
		PublicKey: host.Key.Marshal(),
		OTT:       "ott",
	})
	assert.FatalError(t, err)
	userAddReq, err := json.Marshal(SSHSignRequest{
		PublicKey:        user.Key.Marshal(),
		OTT:              "ott",
		AddUserPublicKey: user.Key.Marshal(),
	})
	assert.FatalError(t, err)
	userIdentityReq, err := json.Marshal(SSHSignRequest{
		PublicKey:   user.Key.Marshal(),
		OTT:         "ott",
		IdentityCSR: CertificateRequest{parseCertificateRequest(csrPEM)},
	})
	assert.FatalError(t, err)
	identityCerts := []*x509.Certificate{
		parseCertificate(certPEM),
	}
	identityCertsPEM := []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`)

	tests := []struct {
		name         string
		req          []byte
		authErr      error
		signCert     *ssh.Certificate
		signErr      error
		addUserCert  *ssh.Certificate
		addUserErr   error
		tlsSignCerts []*x509.Certificate
		tlsSignErr   error
		body         []byte
		statusCode   int
	}{
		{"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated},
		{"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated},
		{"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated},
		{"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":"%s","identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated},
		{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
		{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
		{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
		{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
		{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized},
		{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden},
		{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
		{"fail-user-identity", userIdentityReq, nil, user, nil, user, nil, nil, fmt.Errorf("an-error"), nil, http.StatusForbidden},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			h := New(&mockAuthority{
				authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
					return []provisioner.SignOption{}, tt.authErr
				},
				signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
					return tt.signCert, tt.signErr
				},
				signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
					return tt.addUserCert, tt.addUserErr
				},
				sign: func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
					return tt.tlsSignCerts, tt.tlsSignErr
				},
			}).(*caHandler)

			req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
			w := httptest.NewRecorder()
			h.SSHSign(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
			}

			body, err := ioutil.ReadAll(res.Body)
			res.Body.Close()
			if err != nil {
				t.Errorf("caHandler.SignSSH unexpected error = %v", err)
			}
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
					t.Errorf("caHandler.SignSSH Body = %s, wants %s", body, tt.body)
				}
			}
		})
	}
}

func Test_caHandler_SSHRoots(t *testing.T) {
	user, err := ssh.NewPublicKey(sshUserKey.Public())
	assert.FatalError(t, err)
	userB64 := base64.StdEncoding.EncodeToString(user.Marshal())

	host, err := ssh.NewPublicKey(sshHostKey.Public())
	assert.FatalError(t, err)
	hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())

	tests := []struct {
		name       string
		keys       *authority.SSHKeys
		keysErr    error
		body       []byte
		statusCode int
	}{
		{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK},
		{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
		{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK},
		{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK},
		{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
		{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			h := New(&mockAuthority{
				getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
					return tt.keys, tt.keysErr
				},
			}).(*caHandler)

			req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
			w := httptest.NewRecorder()
			h.SSHRoots(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
			}

			body, err := ioutil.ReadAll(res.Body)
			res.Body.Close()
			if err != nil {
				t.Errorf("caHandler.SSHRoots unexpected error = %v", err)
			}
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
					t.Errorf("caHandler.SSHRoots Body = %s, wants %s", body, tt.body)
				}
			}
		})
	}
}

func Test_caHandler_SSHFederation(t *testing.T) {
	user, err := ssh.NewPublicKey(sshUserKey.Public())
	assert.FatalError(t, err)
	userB64 := base64.StdEncoding.EncodeToString(user.Marshal())

	host, err := ssh.NewPublicKey(sshHostKey.Public())
	assert.FatalError(t, err)
	hostB64 := base64.StdEncoding.EncodeToString(host.Marshal())

	tests := []struct {
		name       string
		keys       *authority.SSHKeys
		keysErr    error
		body       []byte
		statusCode int
	}{
		{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK},
		{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
		{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK},
		{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK},
		{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
		{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			h := New(&mockAuthority{
				getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
					return tt.keys, tt.keysErr
				},
			}).(*caHandler)

			req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
			w := httptest.NewRecorder()
			h.SSHFederation(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
			}

			body, err := ioutil.ReadAll(res.Body)
			res.Body.Close()
			if err != nil {
				t.Errorf("caHandler.SSHFederation unexpected error = %v", err)
			}
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
					t.Errorf("caHandler.SSHFederation Body = %s, wants %s", body, tt.body)
				}
			}
		})
	}
}

func Test_caHandler_SSHConfig(t *testing.T) {
	userOutput := []templates.Output{
		{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
		{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
	}
	hostOutput := []templates.Output{
		{Name: "sshd_config.tpl", Type: templates.Snippet, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("TrustedUserCAKeys /etc/ssh/ca.pub")},
		{Name: "ca.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/ca.pub", Content: []byte("ecdsa-sha2-nistp256 AAAA...=")},
	}
	userJSON, err := json.Marshal(userOutput)
	assert.FatalError(t, err)
	hostJSON, err := json.Marshal(hostOutput)
	assert.FatalError(t, err)

	tests := []struct {
		name       string
		req        string
		output     []templates.Output
		err        error
		body       []byte
		statusCode int
	}{
		{"user", `{"type":"user"}`, userOutput, nil, []byte(fmt.Sprintf(`{"userTemplates":%s}`, userJSON)), http.StatusOK},
		{"host", `{"type":"host"}`, hostOutput, nil, []byte(fmt.Sprintf(`{"hostTemplates":%s}`, hostJSON)), http.StatusOK},
		{"noType", `{}`, userOutput, nil, []byte(fmt.Sprintf(`{"userTemplates":%s}`, userJSON)), http.StatusOK},
		{"badType", `{"type":"bad"}`, userOutput, nil, nil, http.StatusBadRequest},
		{"badData", `{"type":"user","data":{"bad"}}`, userOutput, nil, nil, http.StatusBadRequest},
		{"error", `{"type": "user"}`, nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			h := New(&mockAuthority{
				getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
					return tt.output, tt.err
				},
			}).(*caHandler)

			req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
			w := httptest.NewRecorder()
			h.SSHConfig(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
			}

			body, err := ioutil.ReadAll(res.Body)
			res.Body.Close()
			if err != nil {
				t.Errorf("caHandler.SSHConfig unexpected error = %v", err)
			}
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
					t.Errorf("caHandler.SSHConfig Body = %s, wants %s", body, tt.body)
				}
			}
		})
	}
}

func Test_caHandler_SSHCheckHost(t *testing.T) {
	tests := []struct {
		name       string
		req        string
		exists     bool
		err        error
		body       []byte
		statusCode int
	}{
		{"true", `{"type":"host","principal":"foo.example.com"}`, true, nil, []byte(`{"exists":true}`), http.StatusOK},
		{"false", `{"type":"host","principal":"bar.example.com"}`, false, nil, []byte(`{"exists":false}`), http.StatusOK},
		{"badType", `{"type":"user","principal":"bar.example.com"}`, false, nil, nil, http.StatusBadRequest},
		{"badPrincipal", `{"type":"host","principal":""}`, false, nil, nil, http.StatusBadRequest},
		{"badRequest", `{"foo"}`, false, nil, nil, http.StatusBadRequest},
		{"error", `{"type":"host","principal":"foo.example.com"}`, false, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			h := New(&mockAuthority{
				checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
					return tt.exists, tt.err
				},
			}).(*caHandler)

			req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
			w := httptest.NewRecorder()
			h.SSHCheckHost(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
			}

			body, err := ioutil.ReadAll(res.Body)
			res.Body.Close()
			if err != nil {
				t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err)
			}
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
					t.Errorf("caHandler.SSHCheckHost Body = %s, wants %s", body, tt.body)
				}
			}
		})
	}
}

func Test_caHandler_SSHGetHosts(t *testing.T) {
	hosts := []sshutil.Host{
		{HostID: "1", HostTags: []sshutil.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
		{HostID: "2", HostTags: []sshutil.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
	}
	hostsJSON, err := json.Marshal(hosts)
	assert.FatalError(t, err)

	tests := []struct {
		name       string
		hosts      []sshutil.Host
		err        error
		body       []byte
		statusCode int
	}{
		{"ok", hosts, nil, []byte(fmt.Sprintf(`{"hosts":%s}`, hostsJSON)), http.StatusOK},
		{"empty (array)", []sshutil.Host{}, nil, []byte(`{"hosts":[]}`), http.StatusOK},
		{"empty (nil)", nil, nil, []byte(`{"hosts":null}`), http.StatusOK},
		{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			h := New(&mockAuthority{
				getSSHHosts: func(context.Context, *x509.Certificate) ([]sshutil.Host, error) {
					return tt.hosts, tt.err
				},
			}).(*caHandler)

			req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
			w := httptest.NewRecorder()
			h.SSHGetHosts(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
			}

			body, err := ioutil.ReadAll(res.Body)
			res.Body.Close()
			if err != nil {
				t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err)
			}
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
					t.Errorf("caHandler.SSHGetHosts Body = %s, wants %s", body, tt.body)
				}
			}
		})
	}
}

func Test_caHandler_SSHBastion(t *testing.T) {
	bastion := &authority.Bastion{
		Hostname: "bastion.local",
	}
	bastionPort := &authority.Bastion{
		Hostname: "bastion.local",
		Port:     "2222",
	}

	tests := []struct {
		name       string
		bastion    *authority.Bastion
		bastionErr error
		req        []byte
		body       []byte
		statusCode int
	}{
		{"ok", bastion, nil, []byte(`{"hostname":"host.local"}`), []byte(`{"hostname":"host.local","bastion":{"hostname":"bastion.local"}}`), http.StatusOK},
		{"ok", bastionPort, nil, []byte(`{"hostname":"host.local","user":"user"}`), []byte(`{"hostname":"host.local","bastion":{"hostname":"bastion.local","port":"2222"}}`), http.StatusOK},
		{"empty", nil, nil, []byte(`{"hostname":"host.local"}`), []byte(`{"hostname":"host.local"}`), http.StatusOK},
		{"bad json", bastion, nil, []byte(`bad json`), nil, http.StatusBadRequest},
		{"bad request", bastion, nil, []byte(`{"hostname": ""}`), nil, http.StatusBadRequest},
		{"error", nil, fmt.Errorf("an error"), []byte(`{"hostname":"host.local"}`), nil, http.StatusInternalServerError},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			h := New(&mockAuthority{
				getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
					return tt.bastion, tt.bastionErr
				},
			}).(*caHandler)

			req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
			w := httptest.NewRecorder()
			h.SSHBastion(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
			}

			body, err := ioutil.ReadAll(res.Body)
			res.Body.Close()
			if err != nil {
				t.Errorf("caHandler.SSHBastion unexpected error = %v", err)
			}
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), tt.body) {
					t.Errorf("caHandler.SSHBastion Body = %s, wants %s", body, tt.body)
				}
			}
		})
	}
}

func TestSSHPublicKey_MarshalJSON(t *testing.T) {
	key, err := ssh.NewPublicKey(sshUserKey.Public())
	assert.FatalError(t, err)
	keyB64 := base64.StdEncoding.EncodeToString(key.Marshal())

	tests := []struct {
		name      string
		publicKey *SSHPublicKey
		want      []byte
		wantErr   bool
	}{
		{"ok", &SSHPublicKey{PublicKey: key}, []byte(`"` + keyB64 + `"`), false},
		{"null", nil, []byte("null"), false},
		{"null", &SSHPublicKey{PublicKey: nil}, []byte("null"), false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := tt.publicKey.MarshalJSON()
			if (err != nil) != tt.wantErr {
				t.Errorf("SSHPublicKey.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("SSHPublicKey.MarshalJSON() = %s, want %s", got, tt.want)
			}
		})
	}
}

func TestSSHPublicKey_UnmarshalJSON(t *testing.T) {
	key, err := ssh.NewPublicKey(sshUserKey.Public())
	assert.FatalError(t, err)
	keyB64 := base64.StdEncoding.EncodeToString(key.Marshal())

	type args struct {
		data []byte
	}
	tests := []struct {
		name    string
		args    args
		want    *SSHPublicKey
		wantErr bool
	}{
		{"ok", args{[]byte(`"` + keyB64 + `"`)}, &SSHPublicKey{PublicKey: key}, false},
		{"empty", args{[]byte(`""`)}, &SSHPublicKey{}, false},
		{"null", args{[]byte(`null`)}, &SSHPublicKey{}, false},
		{"noString", args{[]byte("123")}, &SSHPublicKey{}, true},
		{"badB64", args{[]byte(`"bad"`)}, &SSHPublicKey{}, true},
		{"badKey", args{[]byte(`"Zm9vYmFyCg=="`)}, &SSHPublicKey{}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			p := &SSHPublicKey{}
			if err := p.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
				t.Errorf("SSHPublicKey.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
			}
			if !reflect.DeepEqual(p, tt.want) {
				t.Errorf("SSHPublicKey.UnmarshalJSON() = %v, want %v", p, tt.want)
			}
		})
	}
}