diff --git a/cas/stepcas/stepcas.go b/cas/stepcas/stepcas.go index 1befdc35..86a995ef 100644 --- a/cas/stepcas/stepcas.go +++ b/cas/stepcas/stepcas.go @@ -4,6 +4,7 @@ import ( "context" "crypto/x509" "net/url" + "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" @@ -70,25 +71,11 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 return nil, errors.New("createCertificateRequest `lifetime` cannot be 0") } - token, err := s.signToken(req.CSR.Subject.CommonName, req.CSR.DNSNames) + cert, chain, err := s.createCertificate(req.CSR, req.Lifetime) if err != nil { return nil, err } - resp, err := s.client.Sign(&api.SignRequest{ - CsrPEM: api.CertificateRequest{CertificateRequest: req.CSR}, - OTT: token, - }) - if err != nil { - return nil, err - } - - var chain []*x509.Certificate - cert := resp.CertChainPEM[0].Certificate - for _, c := range resp.CertChainPEM[1:] { - chain = append(chain, c.Certificate) - } - return &apiv1.CreateCertificateResponse{ Certificate: cert, CertificateChain: chain, @@ -98,30 +85,16 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.RenewCertificateResponse, error) { switch { case req.CSR == nil: - return nil, errors.New("createCertificateRequest `template` cannot be nil") + return nil, errors.New("renewCertificateRequest `template` cannot be nil") case req.Lifetime == 0: - return nil, errors.New("createCertificateRequest `lifetime` cannot be 0") + return nil, errors.New("renewCertificateRequest `lifetime` cannot be 0") } - token, err := s.signToken(req.CSR.Subject.CommonName, req.CSR.DNSNames) + cert, chain, err := s.createCertificate(req.CSR, req.Lifetime) if err != nil { return nil, err } - resp, err := s.client.Sign(&api.SignRequest{ - CsrPEM: api.CertificateRequest{CertificateRequest: req.CSR}, - OTT: token, - }) - if err != nil { - return nil, err - } - - var chain []*x509.Certificate - cert := resp.CertChainPEM[0].Certificate - for _, c := range resp.CertChainPEM[1:] { - chain = append(chain, c.Certificate) - } - return &apiv1.RenewCertificateResponse{ Certificate: cert, CertificateChain: chain, @@ -173,6 +146,48 @@ func (s *StepCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequ }, nil } +func (s *StepCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) { + sans := make([]string, 0, len(cr.DNSNames)+len(cr.EmailAddresses)+len(cr.IPAddresses)+len(cr.URIs)) + for _, s := range cr.DNSNames { + sans = append(sans, s) + } + for _, s := range cr.EmailAddresses { + sans = append(sans, s) + } + for _, ip := range cr.IPAddresses { + sans = append(sans, ip.String()) + } + for _, u := range cr.URIs { + sans = append(sans, u.String()) + } + + commonName := cr.Subject.CommonName + if commonName == "" && len(sans) > 0 { + commonName = sans[0] + } + + token, err := s.signToken(commonName, sans) + if err != nil { + return nil, nil, err + } + + resp, err := s.client.Sign(&api.SignRequest{ + CsrPEM: api.CertificateRequest{CertificateRequest: cr}, + OTT: token, + }) + if err != nil { + return nil, nil, err + } + + var chain []*x509.Certificate + cert := resp.CertChainPEM[0].Certificate + for _, c := range resp.CertChainPEM[1:] { + chain = append(chain, c.Certificate) + } + + return cert, chain, nil +} + func (s *StepCAS) signToken(subject string, sans []string) (string, error) { if s.x5c != nil { return s.x5c.SignToken(subject, sans) diff --git a/cas/stepcas/stepcas_test.go b/cas/stepcas/stepcas_test.go index c10fb5ca..e954c9ff 100644 --- a/cas/stepcas/stepcas_test.go +++ b/cas/stepcas/stepcas_test.go @@ -174,14 +174,15 @@ func TestMain(m *testing.M) { // Final certificate. var err error - testCrt, testKey = mustSignCertificate("Test Certificate", []string{"doe.org"}, x509util.DefaultLeafTemplate, testIssCrt, testIssKey) - testCR, err = x509util.CreateCertificateRequest("Test Certificate", []string{"doe.org"}, testKey) + sans := []string{"doe.org", "jane@doe.org", "127.0.0.1", "::1", "localhost", "uuid:f81d4fae-7dec-11d0-a765-00a0c91e6bf6;name=value"} + testCrt, testKey = mustSignCertificate("Test Certificate", sans, x509util.DefaultLeafTemplate, testIssCrt, testIssKey) + testCR, err = x509util.CreateCertificateRequest("Test Certificate", sans, testKey) if err != nil { panic(err) } // CR used in errors. - testFailCR, err = x509util.CreateCertificateRequest("Test Certificate", []string{"fail.doe.org"}, testKey) + testFailCR, err = x509util.CreateCertificateRequest("", []string{"fail.doe.org"}, testKey) if err != nil { panic(err) }