package api

import (
	"crypto/dsa" //nolint


const (
	rootPEM = `-----BEGIN CERTIFICATE-----

	certPEM = `-----BEGIN CERTIFICATE-----


	stepCertPEM = `-----BEGIN CERTIFICATE-----

	pubKey = `{
	"use": "sig",
	"kty": "EC",
	"kid": "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00",
	"crv": "P-256",
	"alg": "ES256",
	"x": "p9QX4tzjxUrB0fgqRWLKUuPolDtBW681f2Qyh-uVNhk",
	"y": "CNSEloc4oLDFTX0Vywj0WiqOlh516sFQwCj6WtM8LT8"

	privKey = "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiNEhBYjE0WDQ5OFM4LWxSb29JTnpqZyJ9.RbkJXGzI3kOsaP20KmZs0ELFLgpRddAE49AJHlEblw-uH_gg6SV3QA.M3MArEpHgI171lhm.gBlFySpzK9F7riBJbtLSNkb4nAw_gWokqs1jS-ZK1qxuqTK-9mtX5yILjRnftx9P9uFp5xt7rvv4Mgom1Ed4V9WtIyfNP_Cz3Pme1Eanp5nY68WCe_yG6iSB1RJdMDBUb2qBDZiBdhJim1DRXsOfgedOrNi7GGbppMlD77DEpId118owR5izA-c6Q_hg08hIE3tnMAnebDNQoF9jfEY99_AReVRH8G4hgwZEPCfXMTb3J-lowKGG4vXIbK5knFLh47SgOqG4M2M51SMS-XJ7oBz1Vjoamc90QIqKV51rvZ5m0N_sPFtxzcfV4E9yYH3XVd4O-CG4ydVKfKVyMtQ.mcKFZqBHp_n7Ytj2jz9rvw"

func parseCertificate(data string) *x509.Certificate {
	block, _ := pem.Decode([]byte(data))
	if block == nil {
		panic("failed to parse certificate PEM")
	cert, err := x509.ParseCertificate(block.Bytes)
	if err != nil {
		panic("failed to parse certificate: " + err.Error())
	return cert

func parseCertificateRequest(data string) *x509.CertificateRequest {
	block, _ := pem.Decode([]byte(data))
	if block == nil {
		panic("failed to parse certificate request PEM")
	csr, err := x509.ParseCertificateRequest(block.Bytes)
	if err != nil {
		panic("failed to parse certificate request: " + err.Error())
	return csr

func TestNewCertificate(t *testing.T) {
	cert := parseCertificate(rootPEM)
	if !reflect.DeepEqual(Certificate{Certificate: cert}, NewCertificate(cert)) {
		t.Errorf("NewCertificate failed, got %v, wants %v", NewCertificate(cert), Certificate{Certificate: cert})

func TestCertificate_MarshalJSON(t *testing.T) {
	type fields struct {
		Certificate *x509.Certificate
	tests := []struct {
		name    string
		fields  fields
		want    []byte
		wantErr bool
		{"nil", fields{Certificate: nil}, []byte("null"), false},
		{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
		{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"`), false},
		{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`), false},

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			c := Certificate{
				Certificate: tt.fields.Certificate,
			got, err := c.MarshalJSON()
			if (err != nil) != tt.wantErr {
				t.Errorf("Certificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("Certificate.MarshalJSON() = %s, want %s", got, tt.want)

func TestCertificate_UnmarshalJSON(t *testing.T) {
	tests := []struct {
		name     string
		data     []byte
		wantCert bool
		wantErr  bool
		{"no data", nil, false, true},
		{"incomplete string 1", []byte(`"foobar`), false, true}, {"incomplete string 2", []byte(`foobar"`), false, true},
		{"invalid string", []byte(`"foobar"`), false, true},
		{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
		{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true},
		{"invalid type", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), false, true},
		{"empty string", []byte(`""`), false, false},
		{"json null", []byte(`null`), false, false},
		{"valid root", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), true, false},
		{"valid cert", []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"`), true, false},
	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			var c Certificate
			if err := c.UnmarshalJSON(; (err != nil) != tt.wantErr {
				t.Errorf("Certificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
			if tt.wantCert && c.Certificate == nil {
				t.Error("Certificate.UnmarshalJSON() failed, Certificate is nil")

func TestCertificate_UnmarshalJSON_json(t *testing.T) {
	tests := []struct {
		name     string
		data     string
		wantCert bool
		wantErr  bool
		{"invalid type (bool)", `{"crt":true}`, false, true},
		{"invalid type (number)", `{"crt":123}`, false, true},
		{"invalid type (object)", `{"crt":{}}`, false, true},
		{"empty crt (null)", `{"crt":null}`, false, false},
		{"empty crt (string)", `{"crt":""}`, false, false},
		{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true},
		{"valid crt", `{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"}`, true, false},

	type request struct {
		Cert Certificate `json:"crt"`

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			var body request
			if err := json.Unmarshal([]byte(, &body); (err != nil) != tt.wantErr {
				t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr)

			switch tt.wantCert {
			case true:
				if body.Cert.Certificate == nil {
					t.Error("json.Unmarshal() failed, Certificate is nil")
			case false:
				if body.Cert.Certificate != nil {
					t.Error("json.Unmarshal() failed, Certificate is not nil")
func TestNewCertificateRequest(t *testing.T) {
	csr := parseCertificateRequest(csrPEM)
	if !reflect.DeepEqual(CertificateRequest{CertificateRequest: csr}, NewCertificateRequest(csr)) {
		t.Errorf("NewCertificateRequest failed, got %v, wants %v", NewCertificateRequest(csr), CertificateRequest{CertificateRequest: csr})

func TestCertificateRequest_MarshalJSON(t *testing.T) {
	type fields struct {
		CertificateRequest *x509.CertificateRequest
	tests := []struct {
		name    string
		fields  fields
		want    []byte
		wantErr bool
		{"nil", fields{CertificateRequest: nil}, []byte("null"), false},
		{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
		{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `\n"`), false},

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			c := CertificateRequest{
				CertificateRequest: tt.fields.CertificateRequest,
			got, err := c.MarshalJSON()
			if (err != nil) != tt.wantErr {
				t.Errorf("CertificateRequest.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("CertificateRequest.MarshalJSON() = %s, want %s", got, tt.want)

func TestCertificateRequest_UnmarshalJSON(t *testing.T) {
	tests := []struct {
		name     string
		data     []byte
		wantCert bool
		wantErr  bool
		{"no data", nil, false, true},
		{"incomplete string 1", []byte(`"foobar`), false, true}, {"incomplete string 2", []byte(`foobar"`), false, true},
		{"invalid string", []byte(`"foobar"`), false, true},
		{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
		{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true},
		{"invalid type", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), false, true},
		{"empty string", []byte(`""`), false, false},
		{"json null", []byte(`null`), false, false},
		{"valid csr", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), true, false},
	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			var c CertificateRequest
			if err := c.UnmarshalJSON(; (err != nil) != tt.wantErr {
				t.Errorf("CertificateRequest.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
			if tt.wantCert && c.CertificateRequest == nil {
				t.Error("CertificateRequest.UnmarshalJSON() failed, CertificateRequet is nil")

func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) {
	tests := []struct {
		name     string
		data     string
		wantCert bool
		wantErr  bool
		{"invalid type (bool)", `{"csr":true}`, false, true},
		{"invalid type (number)", `{"csr":123}`, false, true},
		{"invalid type (object)", `{"csr":{}}`, false, true},
		{"empty csr (null)", `{"csr":null}`, false, false},
		{"empty csr (string)", `{"csr":""}`, false, false},
		{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true},
		{"valid csr", `{"csr":"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"}`, true, false},

	type request struct {
		CSR CertificateRequest `json:"csr"`

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			var body request
			if err := json.Unmarshal([]byte(, &body); (err != nil) != tt.wantErr {
				t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr)

			switch tt.wantCert {
			case true:
				if body.CSR.CertificateRequest == nil {
					t.Error("json.Unmarshal() failed, CertificateRequest is nil")
			case false:
				if body.CSR.CertificateRequest != nil {
					t.Error("json.Unmarshal() failed, CertificateRequest is not nil")

func TestSignRequest_Validate(t *testing.T) {
	csr := parseCertificateRequest(csrPEM)
	bad := parseCertificateRequest(csrPEM)
	type fields struct {
		CsrPEM    CertificateRequest
		OTT       string
		NotBefore time.Time
		NotAfter  time.Time
	tests := []struct {
		name   string
		fields fields
		err    error
		{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")},
		{"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")},
		{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")},
	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			s := &SignRequest{
				CsrPEM:    tt.fields.CsrPEM,
				OTT:       tt.fields.OTT,
				NotAfter:  NewTimeDuration(tt.fields.NotAfter),
				NotBefore: NewTimeDuration(tt.fields.NotBefore),
			if err := s.Validate(); err != nil {
				if assert.NotNil(t, tt.err) {
					assert.HasPrefix(t, err.Error(), tt.err.Error())
			} else {
				assert.Nil(t, tt.err)

type mockProvisioner struct {
	ret1, ret2, ret3   interface{}
	err                error
	getID              func() string
	getIDForToken      func() string
	getTokenID         func(string) (string, error)
	getName            func() string
	getType            func() provisioner.Type
	getEncryptedKey    func() (string, string, bool)
	init               func(provisioner.Config) error
	authorizeRenew     func(ctx context.Context, cert *x509.Certificate) error
	authorizeRevoke    func(ctx context.Context, token string) error
	authorizeSign      func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
	authorizeRenewal   func(*x509.Certificate) error
	authorizeSSHSign   func(ctx context.Context, token string) ([]provisioner.SignOption, error)
	authorizeSSHRevoke func(ctx context.Context, token string) error
	authorizeSSHRenew  func(ctx context.Context, token string) (*ssh.Certificate, error)
	authorizeSSHRekey  func(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error)

func (m *mockProvisioner) GetID() string {
	if m.getID != nil {
		return m.getID()
	return m.ret1.(string)

func (m *mockProvisioner) GetIDForToken() string {
	if m.getIDForToken != nil {
		return m.getIDForToken()
	return m.ret1.(string)

func (m *mockProvisioner) GetTokenID(token string) (string, error) {
	if m.getTokenID != nil {
		return m.getTokenID(token)
	if m.ret1 == nil {
		return "", m.err
	return m.ret1.(string), m.err

func (m *mockProvisioner) GetName() string {
	if m.getName != nil {
		return m.getName()
	return m.ret1.(string)

func (m *mockProvisioner) GetType() provisioner.Type {
	if m.getType != nil {
		return m.getType()
	return m.ret1.(provisioner.Type)

func (m *mockProvisioner) GetEncryptedKey() (string, string, bool) {
	if m.getEncryptedKey != nil {
		return m.getEncryptedKey()
	return m.ret1.(string), m.ret2.(string), m.ret3.(bool)

func (m *mockProvisioner) Init(c provisioner.Config) error {
	if m.init != nil {
		return m.init(c)
	return m.err

func (m *mockProvisioner) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
	if m.authorizeRenew != nil {
		return m.authorizeRenew(ctx, cert)
	return m.err

func (m *mockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error {
	if m.authorizeRevoke != nil {
		return m.authorizeRevoke(ctx, token)
	return m.err

func (m *mockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
	if m.authorizeSign != nil {
		return m.authorizeSign(ctx, ott)
	return m.ret1.([]provisioner.SignOption), m.err

func (m *mockProvisioner) AuthorizeRenewal(c *x509.Certificate) error {
	if m.authorizeRenewal != nil {
		return m.authorizeRenewal(c)
	return m.err

func (m *mockProvisioner) AuthorizeSSHSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
	if m.authorizeSSHSign != nil {
		return m.authorizeSSHSign(ctx, token)
	return m.ret1.([]provisioner.SignOption), m.err
func (m *mockProvisioner) AuthorizeSSHRevoke(ctx context.Context, token string) error {
	if m.authorizeSSHRevoke != nil {
		return m.authorizeSSHRevoke(ctx, token)
	return m.err
func (m *mockProvisioner) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
	if m.authorizeSSHRenew != nil {
		return m.authorizeSSHRenew(ctx, token)
	return m.ret1.(*ssh.Certificate), m.err
func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []provisioner.SignOption, error) {
	if m.authorizeSSHRekey != nil {
		return m.authorizeSSHRekey(ctx, token)
	return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err

type mockAuthority struct {
	ret1, ret2                   interface{}
	err                          error
	authorizeSign                func(ott string) ([]provisioner.SignOption, error)
	getTLSOptions                func() *authority.TLSOptions
	root                         func(shasum string) (*x509.Certificate, error)
	sign                         func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
	renew                        func(cert *x509.Certificate) ([]*x509.Certificate, error)
	rekey                        func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
	loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
	loadProvisionerByName        func(name string) (provisioner.Interface, error)
	getProvisioners              func(nextCursor string, limit int) (provisioner.List, string, error)
	revoke                       func(context.Context, *authority.RevokeOptions) error
	getEncryptedKey              func(kid string) (string, error)
	getRoots                     func() ([]*x509.Certificate, error)
	getFederation                func() ([]*x509.Certificate, error)
	signSSH                      func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
	signSSHAddUser               func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
	renewSSH                     func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
	rekeySSH                     func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
	getSSHHosts                  func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error)
	getSSHRoots                  func(ctx context.Context) (*authority.SSHKeys, error)
	getSSHFederation             func(ctx context.Context) (*authority.SSHKeys, error)
	getSSHConfig                 func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
	checkSSHHost                 func(ctx context.Context, principal, token string) (bool, error)
	getSSHBastion                func(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
	version                      func() authority.Version

// TODO: remove once Authorize is deprecated.
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
	return m.AuthorizeSign(ott)

func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
	if m.authorizeSign != nil {
		return m.authorizeSign(ott)
	return m.ret1.([]provisioner.SignOption), m.err

func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions {
	if m.getTLSOptions != nil {
		return m.getTLSOptions()
	return m.ret1.(*authority.TLSOptions)

func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
	if m.root != nil {
		return m.root(shasum)
	return m.ret1.(*x509.Certificate), m.err

func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
	if m.sign != nil {
		return m.sign(cr, opts, signOpts...)
	return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err

func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
	if m.renew != nil {
		return m.renew(cert)
	return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err

func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
	if m.rekey != nil {
		return m.rekey(oldcert, pk)
	return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err

func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) {
	if m.getProvisioners != nil {
		return m.getProvisioners(nextCursor, limit)
	return m.ret1.(provisioner.List), m.ret2.(string), m.err

func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) {
	if m.loadProvisionerByCertificate != nil {
		return m.loadProvisionerByCertificate(cert)
	return m.ret1.(provisioner.Interface), m.err

func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) {
	if m.loadProvisionerByName != nil {
		return m.loadProvisionerByName(name)
	return m.ret1.(provisioner.Interface), m.err

func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error {
	if m.revoke != nil {
		return m.revoke(ctx, opts)
	return m.err

func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
	if m.getEncryptedKey != nil {
		return m.getEncryptedKey(kid)
	return m.ret1.(string), m.err

func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) {
	if m.getRoots != nil {
		return m.getRoots()
	return m.ret1.([]*x509.Certificate), m.err

func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) {
	if m.getFederation != nil {
		return m.getFederation()
	return m.ret1.([]*x509.Certificate), m.err

func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
	if m.signSSH != nil {
		return m.signSSH(ctx, key, opts, signOpts...)
	return m.ret1.(*ssh.Certificate), m.err

func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
	if m.signSSHAddUser != nil {
		return m.signSSHAddUser(ctx, key, cert)
	return m.ret1.(*ssh.Certificate), m.err

func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) {
	if m.renewSSH != nil {
		return m.renewSSH(ctx, cert)
	return m.ret1.(*ssh.Certificate), m.err

func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
	if m.rekeySSH != nil {
		return m.rekeySSH(ctx, cert, key, signOpts...)
	return m.ret1.(*ssh.Certificate), m.err

func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) {
	if m.getSSHHosts != nil {
		return m.getSSHHosts(ctx, cert)
	return m.ret1.([]authority.Host), m.err

func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) {
	if m.getSSHRoots != nil {
		return m.getSSHRoots(ctx)
	return m.ret1.(*authority.SSHKeys), m.err

func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) {
	if m.getSSHFederation != nil {
		return m.getSSHFederation(ctx)
	return m.ret1.(*authority.SSHKeys), m.err

func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
	if m.getSSHConfig != nil {
		return m.getSSHConfig(ctx, typ, data)
	return m.ret1.([]templates.Output), m.err

func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) {
	if m.checkSSHHost != nil {
		return m.checkSSHHost(ctx, principal, token)
	return m.ret1.(bool), m.err

func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
	if m.getSSHBastion != nil {
		return m.getSSHBastion(ctx, user, hostname)
	return m.ret1.(*authority.Bastion), m.err

func (m *mockAuthority) Version() authority.Version {
	if m.version != nil {
		return m.version()
	return m.ret1.(authority.Version)

func Test_caHandler_Route(t *testing.T) {
	type fields struct {
		Authority Authority
	type args struct {
		r Router
	tests := []struct {
		name   string
		fields fields
		args   args
		{"ok", fields{&mockAuthority{}}, args{chi.NewRouter()}},
	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			h := &caHandler{
				Authority: tt.fields.Authority,

func Test_caHandler_Health(t *testing.T) {
	req := httptest.NewRequest("GET", "", nil)
	w := httptest.NewRecorder()
	h := New(&mockAuthority{}).(*caHandler)
	h.Health(w, req)

	res := w.Result()
	if res.StatusCode != 200 {
		t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode)

	body, err := ioutil.ReadAll(res.Body)
	if err != nil {
		t.Errorf("caHandler.Health unexpected error = %v", err)
	expected := []byte("{\"status\":\"ok\"}\n")
	if !bytes.Equal(body, expected) {
		t.Errorf("caHandler.Health Body = %s, wants %s", body, expected)

func Test_caHandler_Root(t *testing.T) {
	tests := []struct {
		name       string
		root       *x509.Certificate
		err        error
		statusCode int
		{"ok", parseCertificate(rootPEM), nil, 200},
		{"fail", nil, fmt.Errorf("not found"), 404},

	// Request with chi context
	chiCtx := chi.NewRouteContext()
	chiCtx.URLParams.Add("sha", "efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36")
	req := httptest.NewRequest("GET", "", nil)
	req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))

	expected := []byte(`{"ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"}`)

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
			w := httptest.NewRecorder()
			h.Root(w, req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)

			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				t.Errorf("caHandler.Root unexpected error = %v", err)
			if tt.statusCode == 200 {
				if !bytes.Equal(bytes.TrimSpace(body), expected) {
					t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)

func Test_caHandler_Sign(t *testing.T) {
	csr := parseCertificateRequest(csrPEM)
	valid, err := json.Marshal(SignRequest{
		CsrPEM: CertificateRequest{csr},
		OTT:    "foobarzar",
	if err != nil {
	invalid, err := json.Marshal(SignRequest{
		CsrPEM: CertificateRequest{csr},
		OTT:    "",
	if err != nil {

	expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
	expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)

	tests := []struct {
		name         string
		input        string
		certAttrOpts []provisioner.SignOption
		autherr      error
		cert         *x509.Certificate
		root         *x509.Certificate
		signErr      error
		statusCode   int
		expected     []byte
		{"ok", string(valid), nil, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated, expected1},
		{"ok with Provisioner", string(valid), nil, nil, parseCertificate(stepCertPEM), parseCertificate(rootPEM), nil, http.StatusCreated, expected2},
		{"json read error", "{", nil, nil, nil, nil, nil, http.StatusBadRequest, nil},
		{"validate error", string(invalid), nil, nil, nil, nil, nil, http.StatusBadRequest, nil},
		{"authorize error", string(valid), nil, fmt.Errorf("an error"), nil, nil, nil, http.StatusUnauthorized, nil},
		{"sign error", string(valid), nil, nil, nil, nil, fmt.Errorf("an error"), http.StatusForbidden, nil},

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			h := New(&mockAuthority{
				ret1: tt.cert, ret2: tt.root, err: tt.signErr,
				authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
					return tt.certAttrOpts, tt.autherr
				getTLSOptions: func() *authority.TLSOptions {
					return nil
			req := httptest.NewRequest("POST", "", strings.NewReader(tt.input))
			w := httptest.NewRecorder()
			h.Sign(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)

			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				t.Errorf("caHandler.Root unexpected error = %v", err)
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), tt.expected) {
					t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.expected)

func Test_caHandler_Renew(t *testing.T) {
	cs := &tls.ConnectionState{
		PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
	tests := []struct {
		name       string
		tls        *tls.ConnectionState
		cert       *x509.Certificate
		root       *x509.Certificate
		err        error
		statusCode int
		{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
		{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
		{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
		{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},

	expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			h := New(&mockAuthority{
				ret1: tt.cert, ret2: tt.root, err: tt.err,
				getTLSOptions: func() *authority.TLSOptions {
					return nil
			req := httptest.NewRequest("POST", "", nil)
			req.TLS = tt.tls
			w := httptest.NewRecorder()
			h.Renew(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)

			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				t.Errorf("caHandler.Renew unexpected error = %v", err)
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), expected) {
					t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)

func Test_caHandler_Rekey(t *testing.T) {
	cs := &tls.ConnectionState{
		PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
	csr := parseCertificateRequest(csrPEM)
	valid, err := json.Marshal(RekeyRequest{
		CsrPEM: CertificateRequest{csr},
	if err != nil {
	tests := []struct {
		name       string
		input      string
		tls        *tls.ConnectionState
		cert       *x509.Certificate
		root       *x509.Certificate
		err        error
		statusCode int
		{"ok", string(valid), cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
		{"no tls", string(valid), nil, nil, nil, nil, http.StatusBadRequest},
		{"no peer certificates", string(valid), &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
		{"rekey error", string(valid), cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
		{"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest},

	expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			h := New(&mockAuthority{
				ret1: tt.cert, ret2: tt.root, err: tt.err,
				getTLSOptions: func() *authority.TLSOptions {
					return nil
			req := httptest.NewRequest("POST", "", strings.NewReader(tt.input))
			req.TLS = tt.tls
			w := httptest.NewRecorder()
			h.Rekey(logging.NewResponseLogger(w), req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)

			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				t.Errorf("caHandler.Rekey unexpected error = %v", err)
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), expected) {
					t.Errorf("caHandler.Rekey Body = %s, wants %s", body, expected)

func Test_caHandler_Provisioners(t *testing.T) {
	type fields struct {
		Authority Authority
	type args struct {
		w http.ResponseWriter
		r *http.Request

	req, err := http.NewRequest("GET", "", nil)
	if err != nil {

	reqLimitFail, err := http.NewRequest("GET", "", nil)
	if err != nil {

	var key jose.JSONWebKey
	if err := json.Unmarshal([]byte(pubKey), &key); err != nil {

	p := provisioner.List{
			Type:         "JWK",
			Name:         "max",
			EncryptedKey: "abc",
			Key:          &key,
			Type:         "JWK",
			Name:         "mariano",
			EncryptedKey: "def",
			Key:          &key,
	pr := ProvisionersResponse{
		Provisioners: p,

	tests := []struct {
		name       string
		fields     fields
		args       args
		statusCode int
		{"ok", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200},
		{"fail", fields{&mockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500},
		{"limit fail", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400},

	expected, err := json.Marshal(pr)
	if err != nil {

	expectedError400 := errs.BadRequest("force")
	expectedError400Bytes, err := json.Marshal(expectedError400)
	assert.FatalError(t, err)
	expectedError500 := errs.InternalServer("force")
	expectedError500Bytes, err := json.Marshal(expectedError500)
	assert.FatalError(t, err)
	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			h := &caHandler{
				Authority: tt.fields.Authority,
			h.Provisioners(tt.args.w, tt.args.r)

			rec := tt.args.w.(*httptest.ResponseRecorder)
			res := rec.Result()
			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				t.Errorf("caHandler.Provisioners unexpected error = %v", err)
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), expected) {
					t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected)
			} else {
				switch tt.statusCode {
				case 400:
					if !bytes.Equal(bytes.TrimSpace(body), expectedError400Bytes) {
						t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400Bytes)
				case 500:
					if !bytes.Equal(bytes.TrimSpace(body), expectedError500Bytes) {
						t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500Bytes)
					t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode)


func Test_caHandler_ProvisionerKey(t *testing.T) {
	type fields struct {
		Authority Authority
	type args struct {
		w http.ResponseWriter
		r *http.Request

	// Request with chi context
	chiCtx := chi.NewRouteContext()
	chiCtx.URLParams.Add("kid", "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00")
	req := httptest.NewRequest("GET", "", nil)
	req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))

	tests := []struct {
		name       string
		fields     fields
		args       args
		statusCode int
		{"ok", fields{&mockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200},
		{"fail", fields{&mockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404},

	expected := []byte(`{"key":"` + privKey + `"}`)
	expectedError404 := errs.NotFound("force")
	expectedError404Bytes, err := json.Marshal(expectedError404)
	assert.FatalError(t, err)

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			h := &caHandler{
				Authority: tt.fields.Authority,
			h.ProvisionerKey(tt.args.w, tt.args.r)

			rec := tt.args.w.(*httptest.ResponseRecorder)
			res := rec.Result()
			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				t.Errorf("caHandler.Provisioners unexpected error = %v", err)
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), expected) {
					t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected)
			} else {
				if !bytes.Equal(bytes.TrimSpace(body), expectedError404Bytes) {
					t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError404Bytes)

func Test_caHandler_Roots(t *testing.T) {
	cs := &tls.ConnectionState{
		PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
	tests := []struct {
		name       string
		tls        *tls.ConnectionState
		cert       *x509.Certificate
		root       *x509.Certificate
		err        error
		statusCode int
		{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
		{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
		{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},

	expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
			req := httptest.NewRequest("GET", "", nil)
			req.TLS = tt.tls
			w := httptest.NewRecorder()
			h.Roots(w, req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)

			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				t.Errorf("caHandler.Roots unexpected error = %v", err)
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), expected) {
					t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected)

func Test_caHandler_Federation(t *testing.T) {
	cs := &tls.ConnectionState{
		PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
	tests := []struct {
		name       string
		tls        *tls.ConnectionState
		cert       *x509.Certificate
		root       *x509.Certificate
		err        error
		statusCode int
		{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
		{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
		{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},

	expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)

	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
			req := httptest.NewRequest("GET", "", nil)
			req.TLS = tt.tls
			w := httptest.NewRecorder()
			h.Federation(w, req)
			res := w.Result()

			if res.StatusCode != tt.statusCode {
				t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)

			body, err := ioutil.ReadAll(res.Body)
			if err != nil {
				t.Errorf("caHandler.Federation unexpected error = %v", err)
			if tt.statusCode < http.StatusBadRequest {
				if !bytes.Equal(bytes.TrimSpace(body), expected) {
					t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected)

func Test_fmtPublicKey(t *testing.T) {
	p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
	rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
	if err != nil {
	edPub, edPriv, err := ed25519.GenerateKey(rand.Reader)
	if err != nil {
	var dsa2048 dsa.PrivateKey
	if err := dsa.GenerateParameters(&dsa2048.Parameters, rand.Reader, dsa.L2048N256); err != nil {
	if err := dsa.GenerateKey(&dsa2048, rand.Reader); err != nil {

	type args struct {
		pub, priv interface{}
		cert      *x509.Certificate
	tests := []struct {
		name string
		args args
		want string
		{"p256", args{p256.Public(), p256, nil}, "ECDSA P-256"},
		{"rsa1024", args{rsa1024.Public(), rsa1024, nil}, "RSA 1024"},
		{"ed25519", args{edPub, edPriv, nil}, "Ed25519"},
		{"dsa2048", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.DSA, PublicKey: &dsa2048.PublicKey}}, "DSA 2048"},
		{"unknown", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.ECDSA, PublicKey: []byte("12345678")}}, "ECDSA unknown"},
	for _, tt := range tests {
		t.Run(, func(t *testing.T) {
			var cert *x509.Certificate
			if tt.args.cert != nil {
				cert = tt.args.cert
			} else {
				cert = mustCertificate(t,, tt.args.priv)
			if got := fmtPublicKey(cert); got != tt.want {
				t.Errorf("fmtPublicKey() = %v, want %v", got, tt.want)

func mustCertificate(t *testing.T, pub, priv interface{}) *x509.Certificate {
	template := x509.Certificate{
		SerialNumber: big.NewInt(1),
		Subject: pkix.Name{
			Organization: []string{"Acme Co"},
		NotBefore: time.Now(),
		NotAfter:  time.Now().Add(24 * time.Hour),

		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,

	der, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv)
	if err != nil {

	cert, err := x509.ParseCertificate(der)
	if err != nil {
	return cert