619 lines
21 KiB
Go
619 lines
21 KiB
Go
|
package api
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"encoding/json"
|
||
|
"encoding/pem"
|
||
|
"fmt"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/go-chi/chi"
|
||
|
"github.com/smallstep/cli/crypto/tlsutil"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
rootPEM = `-----BEGIN CERTIFICATE-----
|
||
|
MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT
|
||
|
MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i
|
||
|
YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG
|
||
|
EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy
|
||
|
bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||
|
AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP
|
||
|
VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv
|
||
|
h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE
|
||
|
ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ
|
||
|
EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC
|
||
|
DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7
|
||
|
qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD
|
||
|
VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g
|
||
|
K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI
|
||
|
KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n
|
||
|
ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB
|
||
|
BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY
|
||
|
/iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/
|
||
|
zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza
|
||
|
HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto
|
||
|
WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6
|
||
|
yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx
|
||
|
-----END CERTIFICATE-----`
|
||
|
|
||
|
certPEM = `-----BEGIN CERTIFICATE-----
|
||
|
MIIDujCCAqKgAwIBAgIIE31FZVaPXTUwDQYJKoZIhvcNAQEFBQAwSTELMAkGA1UE
|
||
|
BhMCVVMxEzARBgNVBAoTCkdvb2dsZSBJbmMxJTAjBgNVBAMTHEdvb2dsZSBJbnRl
|
||
|
cm5ldCBBdXRob3JpdHkgRzIwHhcNMTQwMTI5MTMyNzQzWhcNMTQwNTI5MDAwMDAw
|
||
|
WjBpMQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwN
|
||
|
TW91bnRhaW4gVmlldzETMBEGA1UECgwKR29vZ2xlIEluYzEYMBYGA1UEAwwPbWFp
|
||
|
bC5nb29nbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEfRrObuSW5T7q
|
||
|
5CnSEqefEmtH4CCv6+5EckuriNr1CjfVvqzwfAhopXkLrq45EQm8vkmf7W96XJhC
|
||
|
7ZM0dYi1/qOCAU8wggFLMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAa
|
||
|
BgNVHREEEzARgg9tYWlsLmdvb2dsZS5jb20wCwYDVR0PBAQDAgeAMGgGCCsGAQUF
|
||
|
BwEBBFwwWjArBggrBgEFBQcwAoYfaHR0cDovL3BraS5nb29nbGUuY29tL0dJQUcy
|
||
|
LmNydDArBggrBgEFBQcwAYYfaHR0cDovL2NsaWVudHMxLmdvb2dsZS5jb20vb2Nz
|
||
|
cDAdBgNVHQ4EFgQUiJxtimAuTfwb+aUtBn5UYKreKvMwDAYDVR0TAQH/BAIwADAf
|
||
|
BgNVHSMEGDAWgBRK3QYWG7z2aLV29YG2u2IaulqBLzAXBgNVHSAEEDAOMAwGCisG
|
||
|
AQQB1nkCBQEwMAYDVR0fBCkwJzAloCOgIYYfaHR0cDovL3BraS5nb29nbGUuY29t
|
||
|
L0dJQUcyLmNybDANBgkqhkiG9w0BAQUFAAOCAQEAH6RYHxHdcGpMpFE3oxDoFnP+
|
||
|
gtuBCHan2yE2GRbJ2Cw8Lw0MmuKqHlf9RSeYfd3BXeKkj1qO6TVKwCh+0HdZk283
|
||
|
TZZyzmEOyclm3UGFYe82P/iDFt+CeQ3NpmBg+GoaVCuWAARJN/KfglbLyyYygcQq
|
||
|
0SgeDh8dRKUiaW3HQSoYvTvdTuqzwK4CXsr3b5/dAOY8uMuG/IAR3FgwTbZ1dtoW
|
||
|
RvOTa8hYiU6A475WuZKyEHcwnGYe57u2I2KbMgcKjPniocj4QzgYsVAVKW3IwaOh
|
||
|
yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
||
|
-----END CERTIFICATE-----`
|
||
|
|
||
|
csrPEM = `-----BEGIN CERTIFICATE REQUEST-----
|
||
|
MIIEYjCCAkoCAQAwHTEbMBkGA1UEAxMSdGVzdC5zbWFsbHN0ZXAuY29tMIICIjAN
|
||
|
BgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAuCpifZfoZhYNywfpnPa21NezXgtn
|
||
|
wrWBFE6xhVzE7YDSIqtIsj8aR7R8zwEymxfv5j5298LUy/XSmItVH31CsKyfcGqN
|
||
|
QM0PZr9XY3z5V6qchGMqjzt/jqlYMBHujcxIFBfz4HATxSgKyvHqvw14ESsS2huu
|
||
|
7jowx+XTKbFYgKcXrjBkvOej5FXD3ehkg0jDA2UAJNdfKmrc1BBEaaqOtfh7eyU2
|
||
|
HU7+5gxH8C27IiCAmNj719E0B99Nu2MUw6aLFIM4xAcRga33Avevx6UuXZZIEepe
|
||
|
V1sihrkcnDK9Vsxkme5erXzvAoOiRusiC2iIomJHJrdRM5ReEU+N+Tl1Kxq+rk7H
|
||
|
/qAq78wVm07M1/GGi9SUMObZS4WuJpM6whlikIAEbv9iV+CK0sv/Jr/AADdGMmQU
|
||
|
lwk+Q0ZNE8p4ZuWILv/dtLDtDVBpnrrJ9e8duBtB0lGcG8MdaUCQ346EI4T0Sgx0
|
||
|
hJ+wMq8zYYFfPIZEHC8o9p1ywWN9ySpJ8Zj/5ubmx9v2bY67GbuVFEa8iAp+S00x
|
||
|
/Z8nD6/JsoKtexuHyGr3ixWFzlBqXDuugukIDFUOVDCbuGw4Io4/hEMu4Zz0TIFk
|
||
|
Uu/wf2z75Tt8EkosKLu2wieKcY7n7Vhog/0tqexqWlWtJH0tvq4djsGoSvA62WPs
|
||
|
0iXXj+aZIARPNhECAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4ICAQA0vyHIndAkIs/I
|
||
|
Nnz5yZWCokRjokoKv3Aj4VilyjncL+W0UIPULLU/47ZyoHVSUj2t8gknr9xu/Kd+
|
||
|
g/2z0RiF3CIp8IUH49w/HYWaR95glzVNAAzr8qD9UbUqloLVQW3lObSRGtezhdZO
|
||
|
sspw5dC+inhAb1LZhx8PVxB3SAeJ8h11IEBr0s2Hxt9viKKd7YPtIFZkZdOkVx4R
|
||
|
if1DMawj1P6fEomf8z7m+dmbUYTqqosbCbRL01mzEga/kF6JyH/OzpNlcsAiyM8e
|
||
|
BxPWH6TtPqwmyy4y7j1outmM0RnyUw5A0HmIbWh+rHpXiHVsnNqse0XfzmaxM8+z
|
||
|
dxYeDax8aMWZKfvY1Zew+xIxl7DtEy1BpxrZcawumJYt5+LL+bwF/OtL0inQLnw8
|
||
|
zyqydsXNdrpIQJnfmWPld7ThWbQw2FBE70+nFSxHeG2ULnpF3M9xf6ZNAF4gqaNE
|
||
|
Q7vMNPBWrJWu+A++vHY61WGET+h4lY3GFr2I8OE4IiHPQi1D7Y0+fwOmStwuRPM4
|
||
|
2rARcJChNdiYBkkuvs4kixKTTjdXhB8RQtuBSrJ0M1tzq2qMbm7F8G01rOg4KlXU
|
||
|
58jHzJwr1K7cx0lpWfGTtc5bseCGtTKmDBXTziw04yl8eE1+ZFOganixGwCtl4Tt
|
||
|
DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w==
|
||
|
-----END CERTIFICATE REQUEST-----`
|
||
|
)
|
||
|
|
||
|
func parseCertificate(data string) *x509.Certificate {
|
||
|
block, _ := pem.Decode([]byte(data))
|
||
|
if block == nil {
|
||
|
panic("failed to parse certificate PEM")
|
||
|
}
|
||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||
|
if err != nil {
|
||
|
panic("failed to parse certificate: " + err.Error())
|
||
|
}
|
||
|
return cert
|
||
|
}
|
||
|
|
||
|
func parseCertificateRequest(data string) *x509.CertificateRequest {
|
||
|
block, _ := pem.Decode([]byte(csrPEM))
|
||
|
if block == nil {
|
||
|
panic("failed to parse certificate request PEM")
|
||
|
}
|
||
|
csr, err := x509.ParseCertificateRequest(block.Bytes)
|
||
|
if err != nil {
|
||
|
panic("failed to parse certificate request: " + err.Error())
|
||
|
}
|
||
|
return csr
|
||
|
}
|
||
|
|
||
|
func TestNewCertificate(t *testing.T) {
|
||
|
cert := parseCertificate(rootPEM)
|
||
|
if !reflect.DeepEqual(Certificate{Certificate: cert}, NewCertificate(cert)) {
|
||
|
t.Errorf("NewCertificate failed, got %v, wants %v", NewCertificate(cert), Certificate{Certificate: cert})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCertificate_MarshalJSON(t *testing.T) {
|
||
|
type fields struct {
|
||
|
Certificate *x509.Certificate
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
fields fields
|
||
|
want []byte
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{"nil", fields{Certificate: nil}, []byte("null"), false},
|
||
|
{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
|
||
|
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"`), false},
|
||
|
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`), false},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
c := Certificate{
|
||
|
Certificate: tt.fields.Certificate,
|
||
|
}
|
||
|
got, err := c.MarshalJSON()
|
||
|
if (err != nil) != tt.wantErr {
|
||
|
t.Errorf("Certificate.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
return
|
||
|
}
|
||
|
if !reflect.DeepEqual(got, tt.want) {
|
||
|
t.Errorf("Certificate.MarshalJSON() = %s, want %s", got, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCertificate_UnmarshalJSON(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
data []byte
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{"no data", nil, true},
|
||
|
{"empty string", []byte(`""`), true},
|
||
|
{"incomplete string 1", []byte(`"foobar`), true}, {"incomplete string 2", []byte(`foobar"`), true},
|
||
|
{"invalid string", []byte(`"foobar"`), true},
|
||
|
{"invalid bytes 0", []byte{}, true}, {"invalid bytes 1", []byte{1}, true},
|
||
|
{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), true},
|
||
|
{"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true},
|
||
|
{"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false},
|
||
|
{"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), false},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
var c Certificate
|
||
|
if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr {
|
||
|
t.Errorf("Certificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
}
|
||
|
if !tt.wantErr && c.Certificate == nil {
|
||
|
t.Error("Certificate.UnmarshalJSON() failed, Certificate is nil")
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCertificate_UnmarshalJSON_json(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
data string
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{"invalid type (null)", `{"crt":null}`, true},
|
||
|
{"invalid type (bool)", `{"crt":true}`, true},
|
||
|
{"invalid type (number)", `{"crt":123}`, true},
|
||
|
{"invalid type (object)", `{"crt":{}}`, true},
|
||
|
{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, true},
|
||
|
{"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, false},
|
||
|
}
|
||
|
|
||
|
type request struct {
|
||
|
Cert Certificate `json:"crt"`
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
var body request
|
||
|
if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr {
|
||
|
t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
}
|
||
|
|
||
|
switch tt.wantErr {
|
||
|
case false:
|
||
|
if body.Cert.Certificate == nil {
|
||
|
t.Error("json.Unmarshal() failed, Certificate is nil")
|
||
|
}
|
||
|
case true:
|
||
|
if body.Cert.Certificate != nil {
|
||
|
t.Error("json.Unmarshal() failed, Certificate is not nil")
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
func TestNewCertificateRequest(t *testing.T) {
|
||
|
csr := parseCertificateRequest(csrPEM)
|
||
|
if !reflect.DeepEqual(CertificateRequest{CertificateRequest: csr}, NewCertificateRequest(csr)) {
|
||
|
t.Errorf("NewCertificateRequest failed, got %v, wants %v", NewCertificateRequest(csr), CertificateRequest{CertificateRequest: csr})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCertificateRequest_MarshalJSON(t *testing.T) {
|
||
|
type fields struct {
|
||
|
CertificateRequest *x509.CertificateRequest
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
fields fields
|
||
|
want []byte
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{"nil", fields{CertificateRequest: nil}, []byte("null"), false},
|
||
|
{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
|
||
|
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `\n"`), false},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
c := CertificateRequest{
|
||
|
CertificateRequest: tt.fields.CertificateRequest,
|
||
|
}
|
||
|
got, err := c.MarshalJSON()
|
||
|
if (err != nil) != tt.wantErr {
|
||
|
t.Errorf("CertificateRequest.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
return
|
||
|
}
|
||
|
if !reflect.DeepEqual(got, tt.want) {
|
||
|
t.Errorf("CertificateRequest.MarshalJSON() = %s, want %s", got, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCertificateRequest_UnmarshalJSON(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
data []byte
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{"no data", nil, true},
|
||
|
{"empty string", []byte(`""`), true},
|
||
|
{"incomplete string 1", []byte(`"foobar`), true}, {"incomplete string 2", []byte(`foobar"`), true},
|
||
|
{"invalid string", []byte(`"foobar"`), true},
|
||
|
{"invalid bytes 0", []byte{}, true}, {"invalid bytes 1", []byte{1}, true},
|
||
|
{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), true},
|
||
|
{"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true},
|
||
|
{"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
var c CertificateRequest
|
||
|
if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr {
|
||
|
t.Errorf("CertificateRequest.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
}
|
||
|
if !tt.wantErr && c.CertificateRequest == nil {
|
||
|
t.Error("CertificateRequest.UnmarshalJSON() failed, CertificateRequet is nil")
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
data string
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{"invalid type (null)", `{"csr":null}`, true},
|
||
|
{"invalid type (bool)", `{"csr":true}`, true},
|
||
|
{"invalid type (number)", `{"csr":123}`, true},
|
||
|
{"invalid type (object)", `{"csr":{}}`, true},
|
||
|
{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, true},
|
||
|
{"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, false},
|
||
|
}
|
||
|
|
||
|
type request struct {
|
||
|
CSR CertificateRequest `json:"csr"`
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
var body request
|
||
|
if err := json.Unmarshal([]byte(tt.data), &body); (err != nil) != tt.wantErr {
|
||
|
t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
}
|
||
|
|
||
|
switch tt.wantErr {
|
||
|
case false:
|
||
|
if body.CSR.CertificateRequest == nil {
|
||
|
t.Error("json.Unmarshal() failed, CertificateRequest is nil")
|
||
|
}
|
||
|
case true:
|
||
|
if body.CSR.CertificateRequest != nil {
|
||
|
t.Error("json.Unmarshal() failed, CertificateRequest is not nil")
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestSignRequest_Validate(t *testing.T) {
|
||
|
now := time.Now()
|
||
|
csr := parseCertificateRequest(csrPEM)
|
||
|
bad := parseCertificateRequest(csrPEM)
|
||
|
bad.Signature[0]++
|
||
|
type fields struct {
|
||
|
CsrPEM CertificateRequest
|
||
|
OTT string
|
||
|
NotBefore time.Time
|
||
|
NotAfter time.Time
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
fields fields
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{"ok", fields{CertificateRequest{csr}, "foobarzar", time.Time{}, time.Time{}}, false},
|
||
|
{"ok 5m", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(5 * time.Minute)}, false},
|
||
|
{"ok 24h", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(24 * time.Hour)}, false},
|
||
|
{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, true},
|
||
|
{"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, true},
|
||
|
{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, true},
|
||
|
{"notAfter < now", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(-5 * time.Minute)}, true},
|
||
|
{"notAfter < notBefore", fields{CertificateRequest{csr}, "foobarzar", now.Add(5 * time.Minute), now.Add(4 * time.Minute)}, true},
|
||
|
{"too short", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(4 * time.Minute)}, true},
|
||
|
{"too long", fields{CertificateRequest{csr}, "foobarzar", now, now.Add(24 * time.Hour).Add(1 * time.Minute)}, true},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
s := &SignRequest{
|
||
|
CsrPEM: tt.fields.CsrPEM,
|
||
|
OTT: tt.fields.OTT,
|
||
|
NotAfter: tt.fields.NotAfter,
|
||
|
NotBefore: tt.fields.NotBefore,
|
||
|
}
|
||
|
if err := s.Validate(); (err != nil) != tt.wantErr {
|
||
|
t.Errorf("SignRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type mockAuthority struct {
|
||
|
ret1, ret2 interface{}
|
||
|
err error
|
||
|
authorize func(ott string) ([]Claim, error)
|
||
|
getTLSOptions func() *tlsutil.TLSOptions
|
||
|
root func(shasum string) (*x509.Certificate, error)
|
||
|
sign func(cr *x509.CertificateRequest, opts SignOptions, claims ...Claim) (*x509.Certificate, *x509.Certificate, error)
|
||
|
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||
|
}
|
||
|
|
||
|
func (m *mockAuthority) Authorize(ott string) ([]Claim, error) {
|
||
|
if m.authorize != nil {
|
||
|
return m.authorize(ott)
|
||
|
}
|
||
|
return m.ret1.([]Claim), m.err
|
||
|
}
|
||
|
|
||
|
func (m *mockAuthority) GetTLSOptions() *tlsutil.TLSOptions {
|
||
|
if m.getTLSOptions != nil {
|
||
|
return m.getTLSOptions()
|
||
|
}
|
||
|
return m.ret1.(*tlsutil.TLSOptions)
|
||
|
}
|
||
|
|
||
|
func (m *mockAuthority) GetMinDuration() time.Duration {
|
||
|
return minCertDuration
|
||
|
}
|
||
|
|
||
|
func (m *mockAuthority) GetMaxDuration() time.Duration {
|
||
|
return maxCertDuration
|
||
|
}
|
||
|
|
||
|
func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
|
||
|
if m.root != nil {
|
||
|
return m.root(shasum)
|
||
|
}
|
||
|
return m.ret1.(*x509.Certificate), m.err
|
||
|
}
|
||
|
|
||
|
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts SignOptions, claims ...Claim) (*x509.Certificate, *x509.Certificate, error) {
|
||
|
if m.sign != nil {
|
||
|
return m.sign(cr, opts, claims...)
|
||
|
}
|
||
|
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
||
|
}
|
||
|
|
||
|
func (m *mockAuthority) Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) {
|
||
|
if m.renew != nil {
|
||
|
return m.renew(cert)
|
||
|
}
|
||
|
return m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate), m.err
|
||
|
}
|
||
|
|
||
|
func Test_caHandler_Health(t *testing.T) {
|
||
|
req := httptest.NewRequest("GET", "http://example.com/health", nil)
|
||
|
w := httptest.NewRecorder()
|
||
|
h := New(&mockAuthority{}).(*caHandler)
|
||
|
h.Health(w, req)
|
||
|
|
||
|
res := w.Result()
|
||
|
if res.StatusCode != 200 {
|
||
|
t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode)
|
||
|
}
|
||
|
|
||
|
body, err := ioutil.ReadAll(res.Body)
|
||
|
res.Body.Close()
|
||
|
if err != nil {
|
||
|
t.Errorf("caHandler.Health unexpected error = %v", err)
|
||
|
}
|
||
|
expected := []byte("{\"status\":\"ok\"}\n")
|
||
|
if !bytes.Equal(body, expected) {
|
||
|
t.Errorf("caHandler.Health Body = %s, wants %s", body, expected)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_caHandler_Root(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
root *x509.Certificate
|
||
|
err error
|
||
|
statusCode int
|
||
|
}{
|
||
|
{"ok", parseCertificate(rootPEM), nil, 200},
|
||
|
{"fail", nil, fmt.Errorf("not found"), 404},
|
||
|
}
|
||
|
|
||
|
// Request with chi context
|
||
|
chiCtx := chi.NewRouteContext()
|
||
|
chiCtx.URLParams.Add("sha", "efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36")
|
||
|
req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil)
|
||
|
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
|
||
|
|
||
|
expected := []byte(`{"ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`)
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
|
||
|
w := httptest.NewRecorder()
|
||
|
h.Root(w, req)
|
||
|
res := w.Result()
|
||
|
|
||
|
if res.StatusCode != tt.statusCode {
|
||
|
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||
|
}
|
||
|
|
||
|
body, err := ioutil.ReadAll(res.Body)
|
||
|
res.Body.Close()
|
||
|
if err != nil {
|
||
|
t.Errorf("caHandler.Root unexpected error = %v", err)
|
||
|
}
|
||
|
if tt.statusCode == 200 {
|
||
|
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||
|
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_caHandler_Sign(t *testing.T) {
|
||
|
csr := parseCertificateRequest(csrPEM)
|
||
|
valid, err := json.Marshal(SignRequest{
|
||
|
CsrPEM: CertificateRequest{csr},
|
||
|
OTT: "foobarzar",
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
invalid, err := json.Marshal(SignRequest{
|
||
|
CsrPEM: CertificateRequest{csr},
|
||
|
OTT: "",
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
input string
|
||
|
claims []Claim
|
||
|
autherr error
|
||
|
cert *x509.Certificate
|
||
|
root *x509.Certificate
|
||
|
signErr error
|
||
|
statusCode int
|
||
|
}{
|
||
|
{"ok", string(valid), nil, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||
|
{"json read error", "{", nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||
|
{"validate error", string(invalid), nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||
|
{"authorize error", string(valid), nil, fmt.Errorf("an error"), nil, nil, nil, http.StatusUnauthorized},
|
||
|
{"sign error", string(valid), nil, nil, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||
|
}
|
||
|
|
||
|
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`)
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
h := New(&mockAuthority{
|
||
|
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
||
|
authorize: func(ott string) ([]Claim, error) {
|
||
|
return tt.claims, tt.autherr
|
||
|
},
|
||
|
getTLSOptions: func() *tlsutil.TLSOptions {
|
||
|
return nil
|
||
|
},
|
||
|
}).(*caHandler)
|
||
|
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
|
||
|
w := httptest.NewRecorder()
|
||
|
h.Sign(w, req)
|
||
|
res := w.Result()
|
||
|
|
||
|
if res.StatusCode != tt.statusCode {
|
||
|
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||
|
}
|
||
|
|
||
|
body, err := ioutil.ReadAll(res.Body)
|
||
|
res.Body.Close()
|
||
|
if err != nil {
|
||
|
t.Errorf("caHandler.Root unexpected error = %v", err)
|
||
|
}
|
||
|
if tt.statusCode < http.StatusBadRequest {
|
||
|
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||
|
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Test_caHandler_Renew(t *testing.T) {
|
||
|
cs := &tls.ConnectionState{
|
||
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
tls *tls.ConnectionState
|
||
|
cert *x509.Certificate
|
||
|
root *x509.Certificate
|
||
|
err error
|
||
|
statusCode int
|
||
|
}{
|
||
|
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||
|
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
|
||
|
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
|
||
|
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||
|
}
|
||
|
|
||
|
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`)
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
h := New(&mockAuthority{
|
||
|
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||
|
getTLSOptions: func() *tlsutil.TLSOptions {
|
||
|
return nil
|
||
|
},
|
||
|
}).(*caHandler)
|
||
|
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
||
|
req.TLS = tt.tls
|
||
|
w := httptest.NewRecorder()
|
||
|
h.Renew(w, req)
|
||
|
res := w.Result()
|
||
|
|
||
|
if res.StatusCode != tt.statusCode {
|
||
|
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||
|
}
|
||
|
|
||
|
body, err := ioutil.ReadAll(res.Body)
|
||
|
res.Body.Close()
|
||
|
if err != nil {
|
||
|
t.Errorf("caHandler.Root unexpected error = %v", err)
|
||
|
}
|
||
|
if tt.statusCode < http.StatusBadRequest {
|
||
|
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||
|
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|