diff --git a/api/api_test.go b/api/api_test.go index e68eb7db..98d612ab 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -29,6 +29,7 @@ import ( "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/sshutil" "github.com/smallstep/certificates/templates" "github.com/smallstep/cli/crypto/tlsutil" "github.com/smallstep/cli/jose" @@ -207,19 +208,21 @@ func TestCertificate_MarshalJSON(t *testing.T) { func TestCertificate_UnmarshalJSON(t *testing.T) { tests := []struct { - name string - data []byte - wantErr bool + name string + data []byte + wantCert bool + wantErr bool }{ - {"no data", nil, true}, - {"empty string", []byte(`""`), true}, - {"incomplete string 1", []byte(`"foobar`), true}, {"incomplete string 2", []byte(`foobar"`), true}, - {"invalid string", []byte(`"foobar"`), true}, - {"invalid bytes 0", []byte{}, true}, {"invalid bytes 1", []byte{1}, true}, - {"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), true}, - {"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true}, - {"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false}, - {"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), false}, + {"no data", nil, false, true}, + {"incomplete string 1", []byte(`"foobar`), false, true}, {"incomplete string 2", []byte(`foobar"`), false, true}, + {"invalid string", []byte(`"foobar"`), false, true}, + {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, + {"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true}, + {"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false, true}, + {"empty string", []byte(`""`), false, false}, + {"json null", []byte(`null`), false, false}, + {"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true, false}, + {"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -227,7 +230,7 @@ func TestCertificate_UnmarshalJSON(t *testing.T) { if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr { t.Errorf("Certificate.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } - if !tt.wantErr && c.Certificate == nil { + if tt.wantCert && c.Certificate == nil { t.Error("Certificate.UnmarshalJSON() failed, Certificate is nil") } }) @@ -236,16 +239,18 @@ func TestCertificate_UnmarshalJSON(t *testing.T) { func TestCertificate_UnmarshalJSON_json(t *testing.T) { tests := []struct { - name string - data string - wantErr bool + name string + data string + wantCert bool + wantErr bool }{ - {"invalid type (null)", `{"crt":null}`, true}, - {"invalid type (bool)", `{"crt":true}`, true}, - {"invalid type (number)", `{"crt":123}`, true}, - {"invalid type (object)", `{"crt":{}}`, true}, - {"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, true}, - {"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, false}, + {"invalid type (bool)", `{"crt":true}`, false, true}, + {"invalid type (number)", `{"crt":123}`, false, true}, + {"invalid type (object)", `{"crt":{}}`, false, true}, + {"empty crt (null)", `{"crt":null}`, false, false}, + {"empty crt (string)", `{"crt":""}`, false, false}, + {"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true}, + {"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, true, false}, } type request struct { @@ -259,12 +264,12 @@ func TestCertificate_UnmarshalJSON_json(t *testing.T) { t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } - switch tt.wantErr { - case false: + switch tt.wantCert { + case true: if body.Cert.Certificate == nil { t.Error("json.Unmarshal() failed, Certificate is nil") } - case true: + case false: if body.Cert.Certificate != nil { t.Error("json.Unmarshal() failed, Certificate is not nil") } @@ -313,18 +318,20 @@ func TestCertificateRequest_MarshalJSON(t *testing.T) { func TestCertificateRequest_UnmarshalJSON(t *testing.T) { tests := []struct { - name string - data []byte - wantErr bool + name string + data []byte + wantCert bool + wantErr bool }{ - {"no data", nil, true}, - {"empty string", []byte(`""`), true}, - {"incomplete string 1", []byte(`"foobar`), true}, {"incomplete string 2", []byte(`foobar"`), true}, - {"invalid string", []byte(`"foobar"`), true}, - {"invalid bytes 0", []byte{}, true}, {"invalid bytes 1", []byte{1}, true}, - {"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), true}, - {"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true}, - {"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false}, + {"no data", nil, false, true}, + {"incomplete string 1", []byte(`"foobar`), false, true}, {"incomplete string 2", []byte(`foobar"`), false, true}, + {"invalid string", []byte(`"foobar"`), false, true}, + {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, + {"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true}, + {"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false, true}, + {"empty string", []byte(`""`), false, false}, + {"json null", []byte(`null`), false, false}, + {"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -332,7 +339,7 @@ func TestCertificateRequest_UnmarshalJSON(t *testing.T) { if err := c.UnmarshalJSON(tt.data); (err != nil) != tt.wantErr { t.Errorf("CertificateRequest.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) } - if !tt.wantErr && c.CertificateRequest == nil { + if tt.wantCert && c.CertificateRequest == nil { t.Error("CertificateRequest.UnmarshalJSON() failed, CertificateRequet is nil") } }) @@ -341,16 +348,18 @@ func TestCertificateRequest_UnmarshalJSON(t *testing.T) { func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) { tests := []struct { - name string - data string - wantErr bool + name string + data string + wantCert bool + wantErr bool }{ - {"invalid type (null)", `{"csr":null}`, true}, - {"invalid type (bool)", `{"csr":true}`, true}, - {"invalid type (number)", `{"csr":123}`, true}, - {"invalid type (object)", `{"csr":{}}`, true}, - {"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, true}, - {"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, false}, + {"invalid type (bool)", `{"csr":true}`, false, true}, + {"invalid type (number)", `{"csr":123}`, false, true}, + {"invalid type (object)", `{"csr":{}}`, false, true}, + {"empty csr (null)", `{"csr":null}`, false, false}, + {"empty csr (string)", `{"csr":""}`, false, false}, + {"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true}, + {"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, true, false}, } type request struct { @@ -364,12 +373,12 @@ func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) { t.Errorf("json.Unmarshal() error = %v, wantErr %v", err, tt.wantErr) } - switch tt.wantErr { - case false: + switch tt.wantCert { + case true: if body.CSR.CertificateRequest == nil { t.Error("json.Unmarshal() failed, CertificateRequest is nil") } - case true: + case false: if body.CSR.CertificateRequest != nil { t.Error("json.Unmarshal() failed, CertificateRequest is not nil") } @@ -552,12 +561,13 @@ type mockAuthority struct { getFederation func() ([]*x509.Certificate, error) renewSSH func(cert *ssh.Certificate) (*ssh.Certificate, error) rekeySSH func(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - getSSHHosts func() ([]string, error) + getSSHHosts func(*x509.Certificate) ([]sshutil.Host, error) getSSHRoots func() (*authority.SSHKeys, error) getSSHFederation func() (*authority.SSHKeys, error) getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error) checkSSHHost func(principal string) (bool, error) getSSHBastion func(user string, hostname string) (*authority.Bastion, error) + version func() authority.Version } // TODO: remove once Authorize is deprecated. @@ -677,11 +687,11 @@ func (m *mockAuthority) RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signO return m.ret1.(*ssh.Certificate), m.err } -func (m *mockAuthority) GetSSHHosts() ([]string, error) { +func (m *mockAuthority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) { if m.getSSHHosts != nil { - return m.getSSHHosts() + return m.getSSHHosts(cert) } - return m.ret1.([]string), m.err + return m.ret1.([]sshutil.Host), m.err } func (m *mockAuthority) GetSSHRoots() (*authority.SSHKeys, error) { @@ -719,6 +729,13 @@ func (m *mockAuthority) GetSSHBastion(user string, hostname string) (*authority. return m.ret1.(*authority.Bastion), m.err } +func (m *mockAuthority) Version() authority.Version { + if m.version != nil { + return m.version() + } + return m.ret1.(authority.Version) +} + func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority