package api import ( "bytes" "context" "crypto/dsa" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/json" "encoding/pem" "fmt" "io/ioutil" "math/big" "net/http" "net/http/httptest" "reflect" "strings" "testing" "time" "github.com/go-chi/chi" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/jose" "golang.org/x/crypto/ssh" ) const ( rootPEM = `-----BEGIN CERTIFICATE----- MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7 qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY /iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/ zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6 yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx -----END CERTIFICATE-----` certPEM = `-----BEGIN CERTIFICATE----- MIIDujCCAqKgAwIBAgIIE31FZVaPXTUwDQYJKoZIhvcNAQEFBQAwSTELMAkGA1UE BhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRl cm5ldCBBdXRob3JpdHkgRzIwHhcNMTQwMTI5MTMyNzQzWhcNMTQwNTI5MDAwMDAw WjBpMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwN TW91bnRhaW4gVmlldzETMBEGA1UECgwKR29vZ2xlIEluYzEYMBYGA1UEAwwPbWFp bC5nb29nbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfRrObuSW5T7q 5CnSEqefEmtH4CCv6+5EckuriNr1CjfVvqzwfAhopXkLrq45EQm8vkmf7W96XJhC 7ZM0dYi1/qOCAU8wggFLMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAa BgNVHREEEzARgg9tYWlsLmdvb2dsZS5jb20wCwYDVR0PBAQDAgeAMGgGCCsGAQUF BwEBBFwwWjArBggrBgEFBQcwAoYfaHR0cDovL3BraS5nb29nbGUuY29tL0dJQUcy LmNydDArBggrBgEFBQcwAYYfaHR0cDovL2NsaWVudHMxLmdvb2dsZS5jb20vb2Nz cDAdBgNVHQ4EFgQUiJxtimAuTfwb+aUtBn5UYKreKvMwDAYDVR0TAQH/BAIwADAf BgNVHSMEGDAWgBRK3QYWG7z2aLV29YG2u2IaulqBLzAXBgNVHSAEEDAOMAwGCisG AQQB1nkCBQEwMAYDVR0fBCkwJzAloCOgIYYfaHR0cDovL3BraS5nb29nbGUuY29t L0dJQUcyLmNybDANBgkqhkiG9w0BAQUFAAOCAQEAH6RYHxHdcGpMpFE3oxDoFnP+ gtuBCHan2yE2GRbJ2Cw8Lw0MmuKqHlf9RSeYfd3BXeKkj1qO6TVKwCh+0HdZk283 TZZyzmEOyclm3UGFYe82P/iDFt+CeQ3NpmBg+GoaVCuWAARJN/KfglbLyyYygcQq 0SgeDh8dRKUiaW3HQSoYvTvdTuqzwK4CXsr3b5/dAOY8uMuG/IAR3FgwTbZ1dtoW RvOTa8hYiU6A475WuZKyEHcwnGYe57u2I2KbMgcKjPniocj4QzgYsVAVKW3IwaOh yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA== -----END CERTIFICATE-----` csrPEM = `-----BEGIN CERTIFICATE REQUEST----- MIIEYjCCAkoCAQAwHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0ZXAuY29tMIICIjAN BgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuCpifZfoZhYNywfpnPa21NezXgtn wrWBFE6xhVzE7YDSIqtIsj8aR7R8zwEymxfv5j5298LUy/XSmItVH31CsKyfcGqN QM0PZr9XY3z5V6qchGMqjzt/jqlYMBHujcxIFBfz4HATxSgKyvHqvw14ESsS2huu 7jowx+XTKbFYgKcXrjBkvOej5FXD3ehkg0jDA2UAJNdfKmrc1BBEaaqOtfh7eyU2 HU7+5gxH8C27IiCAmNj719E0B99Nu2MUw6aLFIM4xAcRga33Avevx6UuXZZIEepe V1sihrkcnDK9Vsxkme5erXzvAoOiRusiC2iIomJHJrdRM5ReEU+N+Tl1Kxq+rk7H /qAq78wVm07M1/GGi9SUMObZS4WuJpM6whlikIAEbv9iV+CK0sv/Jr/AADdGMmQU lwk+Q0ZNE8p4ZuWILv/dtLDtDVBpnrrJ9e8duBtB0lGcG8MdaUCQ346EI4T0Sgx0 hJ+wMq8zYYFfPIZEHC8o9p1ywWN9ySpJ8Zj/5ubmx9v2bY67GbuVFEa8iAp+S00x /Z8nD6/JsoKtexuHyGr3ixWFzlBqXDuugukIDFUOVDCbuGw4Io4/hEMu4Zz0TIFk Uu/wf2z75Tt8EkosKLu2wieKcY7n7Vhog/0tqexqWlWtJH0tvq4djsGoSvA62WPs 0iXXj+aZIARPNhECAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4ICAQA0vyHIndAkIs/I Nnz5yZWCokRjokoKv3Aj4VilyjncL+W0UIPULLU/47ZyoHVSUj2t8gknr9xu/Kd+ g/2z0RiF3CIp8IUH49w/HYWaR95glzVNAAzr8qD9UbUqloLVQW3lObSRGtezhdZO sspw5dC+inhAb1LZhx8PVxB3SAeJ8h11IEBr0s2Hxt9viKKd7YPtIFZkZdOkVx4R if1DMawj1P6fEomf8z7m+dmbUYTqqosbCbRL01mzEga/kF6JyH/OzpNlcsAiyM8e BxPWH6TtPqwmyy4y7j1outmM0RnyUw5A0HmIbWh+rHpXiHVsnNqse0XfzmaxM8+z dxYeDax8aMWZKfvY1Zew+xIxl7DtEy1BpxrZcawumJYt5+LL+bwF/OtL0inQLnw8 zyqydsXNdrpIQJnfmWPld7ThWbQw2FBE70+nFSxHeG2ULnpF3M9xf6ZNAF4gqaNE Q7vMNPBWrJWu+A++vHY61WGET+h4lY3GFr2I8OE4IiHPQi1D7Y0+fwOmStwuRPM4 2rARcJChNdiYBkkuvs4kixKTTjdXhB8RQtuBSrJ0M1tzq2qMbm7F8G01rOg4KlXU 58jHzJwr1K7cx0lpWfGTtc5bseCGtTKmDBXTziw04yl8eE1+ZFOganixGwCtl4Tt DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w== -----END CERTIFICATE REQUEST-----` stepCertPEM = `-----BEGIN CERTIFICATE----- MIIChTCCAiugAwIBAgIRAJ3O5T28Rdj2lr/UPjf+GAUwCgYIKoZIzj0EAwIwJDEi MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0xOTAyMjAyMDE1 NDNaFw0xOTAyMjEyMDE1NDNaMHExCzAJBgNVBAYTAlVTMQswCQYDVQQIEwJDQTEW MBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEcMBoGA1UEChMTU21hbGxzdGVwIExhYnMg SW5jLjEfMB0GA1UEAxMWaW50ZXJuYWwuc21hbGxzdGVwLmNvbTBZMBMGByqGSM49 AgEGCCqGSM49AwEHA0IABC0aKrTNl+gXFuNkXisqX4/foLO3VMt+Kphngziim+fz aJhiS9JU+oFYLTNW6HWGUD8CNzfwrmWlVsAmiJwHKlKjgfAwge0wDgYDVR0PAQH/ BAQDAgWgMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAdBgNVHQ4EFgQU JheKvlZqNv1IcgaC8WOS1Zg0i1QwHwYDVR0jBBgwFoAUu97PaFQPfuyKOeew7Hg4 5WFIAVMwIQYDVR0RBBowGIIWaW50ZXJuYWwuc21hbGxzdGVwLmNvbTBZBgwrBgEE AYKkZMYoQAEESTBHAgEBBBVtYXJpYW5vQHNtYWxsc3RlcC5jb20EK2pPMzdkdERi a3UtUW5hYnM1VlIwWXc2WUZGdjl3ZUExOGRwM2h0dmRFanMwCgYIKoZIzj0EAwID SAAwRQIhAIrn17fP5CBrGtKuhyPiq6eSwryBCf8ki+k17u5a+E/LAiB24Y2E0Put nIHOI54lAqDeF7A0y73fPRVCiJEWmuxz0g== -----END CERTIFICATE-----` pubKey = `{ "use": "sig", "kty": "EC", "kid": "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00", "crv": "P-256", "alg": "ES256", "x": "p9QX4tzjxUrB0fgqRWLKUuPolDtBW681f2Qyh-uVNhk", "y": "CNSEloc4oLDFTX0Vywj0WiqOlh516sFQwCj6WtM8LT8" }` privKey = "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiNEhBYjE0WDQ5OFM4LWxSb29JTnpqZyJ9.RbkJXGzI3kOsaP20KmZs0ELFLgpRddAE49AJHlEblw-uH_gg6SV3QA.M3MArEpHgI171lhm.gBlFySpzK9F7riBJbtLSNkb4nAw_gWokqs1jS-ZK1qxuqTK-9mtX5yILjRnftx9P9uFp5xt7rvv4Mgom1Ed4V9WtIyfNP_Cz3Pme1Eanp5nY68WCe_yG6iSB1RJdMDBUb2qBDZiBdhJim1DRXsOfgedOrNi7GGbppMlD77DEpId118owR5izA-c6Q_hg08hIE3tnMAnebDNQoF9jfEY99_AReVRH8G4hgwZEPCfXMTb3J-lowKGG4vXIbK5knFLh47SgOqG4M2M51SMS-XJ7oBz1Vjoamc90QIqKV51rvZ5m0N_sPFtxzcfV4E9yYH3XVd4O-CG4ydVKfKVyMtQ.mcKFZqBHp_n7Ytj2jz9rvw" ) func parseCertificate(data string) *x509.Certificate { block, _ := pem.Decode([]byte(data)) if block == nil { panic("failed to parse certificate PEM") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { panic("failed to parse certificate: " + err.Error()) } return cert } func parseCertificateRequest(data string) *x509.CertificateRequest { block, _ := pem.Decode([]byte(data)) if block == nil { panic("failed to parse certificate request PEM") } csr, err := x509.ParseCertificateRequest(block.Bytes) if err != nil { panic("failed to parse certificate request: " + err.Error()) } return csr } func TestNewCertificate(t *testing.T) { cert := parseCertificate(rootPEM) if !reflect.DeepEqual(Certificate{Certificate: cert}, NewCertificate(cert)) { t.Errorf("NewCertificate failed, got %v, wants %v", NewCertificate(cert), Certificate{Certificate: cert}) } } func TestCertificate_MarshalJSON(t *testing.T) { type fields struct { Certificate *x509.Certificate } tests := []struct { name string fields fields want []byte wantErr bool }{ {"nil", fields{Certificate: nil}, []byte("null"), false}, {"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false}, {"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"`), false}, {"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := Certificate{ Certificate: tt.fields.Certificate, } got, err := c.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("Certificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Certificate.MarshalJSON() = %s, want %s", got, tt.want) } }) } } func TestCertificate_UnmarshalJSON(t *testing.T) { tests := []struct { name string data []byte wantErr bool }{ {"no data", nil, true}, {"empty string", []byte(`""`), true}, {"incomplete string 1", []byte(`"foobar`), true}, {"incomplete string 2", []byte(`foobar"`), true}, {"invalid string", []byte(`"foobar"`), true}, {"invalid bytes 0", []byte{}, true}, {"invalid bytes 1", []byte{1}, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), true}, {"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true}, {"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false}, {"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var c Certificate if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr { t.Errorf("Certificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && c.Certificate == nil { t.Error("Certificate.UnmarshalJSON() failed, Certificate is nil") } }) } } func TestCertificate_UnmarshalJSON_json(t *testing.T) { tests := []struct { name string data string wantErr bool }{ {"invalid type (null)", `{"crt":null}`, true}, {"invalid type (bool)", `{"crt":true}`, true}, {"invalid type (number)", `{"crt":123}`, true}, {"invalid type (object)", `{"crt":{}}`, true}, {"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, true}, {"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, false}, } type request struct { Cert Certificate `json:"crt"` } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var body request if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr { t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } switch tt.wantErr { case false: if body.Cert.Certificate == nil { t.Error("json.Unmarshal() failed, Certificate is nil") } case true: if body.Cert.Certificate != nil { t.Error("json.Unmarshal() failed, Certificate is not nil") } } }) } } func TestNewCertificateRequest(t *testing.T) { csr := parseCertificateRequest(csrPEM) if !reflect.DeepEqual(CertificateRequest{CertificateRequest: csr}, NewCertificateRequest(csr)) { t.Errorf("NewCertificateRequest failed, got %v, wants %v", NewCertificateRequest(csr), CertificateRequest{CertificateRequest: csr}) } } func TestCertificateRequest_MarshalJSON(t *testing.T) { type fields struct { CertificateRequest *x509.CertificateRequest } tests := []struct { name string fields fields want []byte wantErr bool }{ {"nil", fields{CertificateRequest: nil}, []byte("null"), false}, {"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false}, {"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `\n"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := CertificateRequest{ CertificateRequest: tt.fields.CertificateRequest, } got, err := c.MarshalJSON() if (err != nil) != tt.wantErr { t.Errorf("CertificateRequest.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CertificateRequest.MarshalJSON() = %s, want %s", got, tt.want) } }) } } func TestCertificateRequest_UnmarshalJSON(t *testing.T) { tests := []struct { name string data []byte wantErr bool }{ {"no data", nil, true}, {"empty string", []byte(`""`), true}, {"incomplete string 1", []byte(`"foobar`), true}, {"incomplete string 2", []byte(`foobar"`), true}, {"invalid string", []byte(`"foobar"`), true}, {"invalid bytes 0", []byte{}, true}, {"invalid bytes 1", []byte{1}, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), true}, {"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true}, {"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var c CertificateRequest if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr { t.Errorf("CertificateRequest.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && c.CertificateRequest == nil { t.Error("CertificateRequest.UnmarshalJSON() failed, CertificateRequet is nil") } }) } } func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) { tests := []struct { name string data string wantErr bool }{ {"invalid type (null)", `{"csr":null}`, true}, {"invalid type (bool)", `{"csr":true}`, true}, {"invalid type (number)", `{"csr":123}`, true}, {"invalid type (object)", `{"csr":{}}`, true}, {"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, true}, {"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, false}, } type request struct { CSR CertificateRequest `json:"csr"` } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var body request if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr { t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } switch tt.wantErr { case false: if body.CSR.CertificateRequest == nil { t.Error("json.Unmarshal() failed, CertificateRequest is nil") } case true: if body.CSR.CertificateRequest != nil { t.Error("json.Unmarshal() failed, CertificateRequest is not nil") } } }) } } func TestSignRequest_Validate(t *testing.T) { csr := parseCertificateRequest(csrPEM) bad := parseCertificateRequest(csrPEM) bad.Signature[0]++ type fields struct { CsrPEM CertificateRequest OTT string NotBefore time.Time NotAfter time.Time } tests := []struct { name string fields fields err error }{ {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")}, {"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")}, {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &SignRequest{ CsrPEM: tt.fields.CsrPEM, OTT: tt.fields.OTT, NotAfter: NewTimeDuration(tt.fields.NotAfter), NotBefore: NewTimeDuration(tt.fields.NotBefore), } if err := s.Validate(); err != nil { if assert.NotNil(t, tt.err) { assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { assert.Nil(t, tt.err) } }) } } type mockProvisioner struct { ret1, ret2, ret3 interface{} err error getID func() string getTokenID func(string) (string, error) getName func() string getType func() provisioner.Type getEncryptedKey func() (string, string, bool) init func(provisioner.Config) error authorizeRevoke func(ott string) error authorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) authorizeRenewal func(*x509.Certificate) error } func (m *mockProvisioner) GetID() string { if m.getID != nil { return m.getID() } return m.ret1.(string) } func (m *mockProvisioner) GetTokenID(token string) (string, error) { if m.getTokenID != nil { return m.getTokenID(token) } if m.ret1 == nil { return "", m.err } return m.ret1.(string), m.err } func (m *mockProvisioner) GetName() string { if m.getName != nil { return m.getName() } return m.ret1.(string) } func (m *mockProvisioner) GetType() provisioner.Type { if m.getType != nil { return m.getType() } return m.ret1.(provisioner.Type) } func (m *mockProvisioner) GetEncryptedKey() (string, string, bool) { if m.getEncryptedKey != nil { return m.getEncryptedKey() } return m.ret1.(string), m.ret2.(string), m.ret3.(bool) } func (m *mockProvisioner) Init(c provisioner.Config) error { if m.init != nil { return m.init(c) } return m.err } func (m *mockProvisioner) AuthorizeRevoke(ott string) error { if m.authorizeRevoke != nil { return m.authorizeRevoke(ott) } return m.err } func (m *mockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { if m.authorizeSign != nil { return m.authorizeSign(ctx, ott) } return m.ret1.([]provisioner.SignOption), m.err } func (m *mockProvisioner) AuthorizeRenewal(c *x509.Certificate) error { if m.authorizeRenewal != nil { return m.authorizeRenewal(c) } return m.err } type mockAuthority struct { ret1, ret2 interface{} err error authorizeSign func(ott string) ([]provisioner.SignOption, error) getTLSOptions func() *tlsutil.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByID func(provID string) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) revoke func(*authority.RevokeOptions) error getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error) getSSHKeys func() (*authority.SSHKeys, error) } // TODO: remove once Authorize is deprecated. func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { return m.AuthorizeSign(ott) } func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { if m.authorizeSign != nil { return m.authorizeSign(ott) } return m.ret1.([]provisioner.SignOption), m.err } func (m *mockAuthority) GetTLSOptions() *tlsutil.TLSOptions { if m.getTLSOptions != nil { return m.getTLSOptions() } return m.ret1.(*tlsutil.TLSOptions) } func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { if m.root != nil { return m.root(shasum) } return m.ret1.(*x509.Certificate), m.err } func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) (*x509.Certificate, *x509.Certificate, error) { if m.sign != nil { return m.sign(cr, opts, signOpts...) } return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err } func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { if m.signSSH != nil { return m.signSSH(key, opts, signOpts...) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { if m.signSSHAddUser != nil { return m.signSSHAddUser(key, cert) } return m.ret1.(*ssh.Certificate), m.err } func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) { if m.renew != nil { return m.renew(cert) } return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err } func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { if m.getProvisioners != nil { return m.getProvisioners(nextCursor, limit) } return m.ret1.(provisioner.List), m.ret2.(string), m.err } func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { if m.loadProvisionerByCertificate != nil { return m.loadProvisionerByCertificate(cert) } return m.ret1.(provisioner.Interface), m.err } func (m *mockAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) { if m.loadProvisionerByID != nil { return m.loadProvisionerByID(provID) } return m.ret1.(provisioner.Interface), m.err } func (m *mockAuthority) Revoke(opts *authority.RevokeOptions) error { if m.revoke != nil { return m.revoke(opts) } return m.err } func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { if m.getEncryptedKey != nil { return m.getEncryptedKey(kid) } return m.ret1.(string), m.err } func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { if m.getRoots != nil { return m.getRoots() } return m.ret1.([]*x509.Certificate), m.err } func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { if m.getFederation != nil { return m.getFederation() } return m.ret1.([]*x509.Certificate), m.err } func (m *mockAuthority) GetSSHKeys() (*authority.SSHKeys, error) { if m.getSSHKeys != nil { return m.getSSHKeys() } return m.ret1.(*authority.SSHKeys), m.err } func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority } type args struct { r Router } tests := []struct { name string fields fields args args }{ {"ok", fields{&mockAuthority{}}, args{chi.NewRouter()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &caHandler{ Authority: tt.fields.Authority, } h.Route(tt.args.r) }) } } func Test_caHandler_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", nil) w := httptest.NewRecorder() h := New(&mockAuthority{}).(*caHandler) h.Health(w, req) res := w.Result() if res.StatusCode != 200 { t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Health unexpected error = %v", err) } expected := []byte("{\"status\":\"ok\"}\n") if !bytes.Equal(body, expected) { t.Errorf("caHandler.Health Body = %s, wants %s", body, expected) } } func Test_caHandler_Root(t *testing.T) { tests := []struct { name string root *x509.Certificate err error statusCode int }{ {"ok", parseCertificate(rootPEM), nil, 200}, {"fail", nil, fmt.Errorf("not found"), 404}, } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("sha", "efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36") req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) expected := []byte(`{"ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) w := httptest.NewRecorder() h.Root(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Root unexpected error = %v", err) } if tt.statusCode == 200 { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) } } }) } } func Test_caHandler_Sign(t *testing.T) { csr := parseCertificateRequest(csrPEM) valid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, OTT: "foobarzar", }) if err != nil { t.Fatal(err) } invalid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, OTT: "", }) if err != nil { t.Fatal(err) } expected1 := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`) expected2 := []byte(`{"crt":"` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`) tests := []struct { name string input string certAttrOpts []provisioner.SignOption autherr error cert *x509.Certificate root *x509.Certificate signErr error statusCode int expected []byte }{ {"ok", string(valid), nil, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated, expected1}, {"ok with Provisioner", string(valid), nil, nil, parseCertificate(stepCertPEM), parseCertificate(rootPEM), nil, http.StatusCreated, expected2}, {"json read error", "{", nil, nil, nil, nil, nil, http.StatusBadRequest, nil}, {"validate error", string(invalid), nil, nil, nil, nil, nil, http.StatusBadRequest, nil}, {"authorize error", string(valid), nil, fmt.Errorf("an error"), nil, nil, nil, http.StatusUnauthorized, nil}, {"sign error", string(valid), nil, nil, nil, nil, fmt.Errorf("an error"), http.StatusForbidden, nil}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr }, getTLSOptions: func() *tlsutil.TLSOptions { return nil }, }).(*caHandler) req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input)) w := httptest.NewRecorder() h.Sign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Root unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.expected) { t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.expected) } } }) } } func Test_caHandler_Renew(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } tests := []struct { name string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, {"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *tlsutil.TLSOptions { return nil }, }).(*caHandler) req := httptest.NewRequest("POST", "http://example.com/renew", nil) req.TLS = tt.tls w := httptest.NewRecorder() h.Renew(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Root unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) } } }) } } func Test_caHandler_Provisioners(t *testing.T) { type fields struct { Authority Authority } type args struct { w http.ResponseWriter r *http.Request } req, err := http.NewRequest("GET", "http://example.com/provisioners?cursor=foo&limit=20", nil) if err != nil { t.Fatal(err) } reqLimitFail, err := http.NewRequest("GET", "http://example.com/provisioners?limit=abc", nil) if err != nil { t.Fatal(err) } var key jose.JSONWebKey if err := json.Unmarshal([]byte(pubKey), &key); err != nil { t.Fatal(err) } p := provisioner.List{ &provisioner.JWK{ Type: "JWK", Name: "max", EncryptedKey: "abc", Key: &key, }, &provisioner.JWK{ Type: "JWK", Name: "mariano", EncryptedKey: "def", Key: &key, }, } pr := ProvisionersResponse{ Provisioners: p, } tests := []struct { name string fields fields args args statusCode int }{ {"ok", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200}, {"fail", fields{&mockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, {"limit fail", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400}, } expected, err := json.Marshal(pr) if err != nil { t.Fatal(err) } expectedError400 := []byte(`{"status":400,"message":"Bad Request"}`) expectedError500 := []byte(`{"status":500,"message":"Internal Server Error"}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &caHandler{ Authority: tt.fields.Authority, } h.Provisioners(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Provisioners unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) } } else { switch tt.statusCode { case 400: if !bytes.Equal(bytes.TrimSpace(body), expectedError400) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400) } case 500: if !bytes.Equal(bytes.TrimSpace(body), expectedError500) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500) } default: t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode) } } }) } } func Test_caHandler_ProvisionerKey(t *testing.T) { type fields struct { Authority Authority } type args struct { w http.ResponseWriter r *http.Request } // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("kid", "oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00") req := httptest.NewRequest("GET", "http://example.com/provisioners/oV1p0MJeGQ7qBlK6B-oyfVdBRjh_e7VSK_YSEEqgW00/encrypted-key", nil) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) tests := []struct { name string fields fields args args statusCode int }{ {"ok", fields{&mockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, {"fail", fields{&mockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, } expected := []byte(`{"key":"` + privKey + `"}`) expectedError := []byte(`{"status":404,"message":"Not Found"}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &caHandler{ Authority: tt.fields.Authority, } h.ProvisionerKey(tt.args.w, tt.args.r) rec := tt.args.w.(*httptest.ResponseRecorder) res := rec.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Provisioners unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected) } } else { if !bytes.Equal(bytes.TrimSpace(body), expectedError) { t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError) } } }) } } func Test_caHandler_Roots(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } tests := []struct { name string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/roots", nil) req.TLS = tt.tls w := httptest.NewRecorder() h.Roots(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Roots unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected) } } }) } } func Test_caHandler_Federation(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } tests := []struct { name string tls *tls.ConnectionState cert *x509.Certificate root *x509.Certificate err error statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/federation", nil) req.TLS = tt.tls w := httptest.NewRecorder() h.Federation(w, req) res := w.Result() if res.StatusCode != tt.statusCode { t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("caHandler.Federation unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected) } } }) } } func Test_fmtPublicKey(t *testing.T) { p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatal(err) } rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { t.Fatal(err) } var dsa2048 dsa.PrivateKey if err := dsa.GenerateParameters(&dsa2048.Parameters, rand.Reader, dsa.L2048N256); err != nil { t.Fatal(err) } if err := dsa.GenerateKey(&dsa2048, rand.Reader); err != nil { t.Fatal(err) } type args struct { pub, priv interface{} cert *x509.Certificate } tests := []struct { name string args args want string }{ {"p256", args{p256.Public(), p256, nil}, "ECDSA P-256"}, {"rsa1024", args{rsa1024.Public(), rsa1024, nil}, "RSA 1024"}, {"dsa2048", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.DSA, PublicKey: &dsa2048.PublicKey}}, "DSA 2048"}, {"unknown", args{cert: &x509.Certificate{PublicKeyAlgorithm: x509.ECDSA, PublicKey: []byte("12345678")}}, "ECDSA unknown"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var cert *x509.Certificate if tt.args.cert != nil { cert = tt.args.cert } else { cert = mustCertificate(t, tt.args.pub, tt.args.priv) } if got := fmtPublicKey(cert); got != tt.want { t.Errorf("fmtPublicKey() = %v, want %v", got, tt.want) } }) } } func mustCertificate(t *testing.T, pub, priv interface{}) *x509.Certificate { template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ Organization: []string{"Acme Co"}, }, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } der, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv) if err != nil { t.Fatal(err) } cert, err := x509.ParseCertificate(der) if err != nil { t.Fatal(err) } return cert }