diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 29d8320e..9bfe8529 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -70,7 +70,7 @@ func (fn CertificateModifierFunc) Modify(cert *x509.Certificate, opts SignOption // with a function. type CertificateEnforcerFunc func(cert *x509.Certificate) error -// Modify implements CertificateEnforcer and just calls the defined function. +// Enforce implements CertificateEnforcer and just calls the defined function. func (fn CertificateEnforcerFunc) Enforce(cert *x509.Certificate) error { return fn(cert) } @@ -248,24 +248,6 @@ func (v defaultSANsValidator) Valid(req *x509.CertificateRequest) (err error) { return } -// ExtraExtsEnforcer enforces only those extra extensions that are strictly -// managed by step-ca. All other "extra extensions" are dropped. -type ExtraExtsEnforcer struct{} - -// Enforce removes all extensions except the step provisioner extension, if it -// exists. If the step provisioner extension is not present, then remove all -// extra extensions from the cert. -func (eee ExtraExtsEnforcer) Enforce(cert *x509.Certificate) error { - for _, ext := range cert.ExtraExtensions { - if ext.Id.Equal(stepOIDProvisioner) { - cert.ExtraExtensions = []pkix.Extension{ext} - return nil - } - } - cert.ExtraExtensions = nil - return nil -} - // profileDefaultDuration is a modifier that sets the certificate // duration. type profileDefaultDuration time.Duration diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index 459455bc..a4a2935c 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -351,51 +351,6 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { } } -func Test_ExtraExtsEnforcer_Enforce(t *testing.T) { - e1 := pkix.Extension{Id: []int{1, 2, 3, 4, 5}, Critical: false, Value: []byte("foo")} - e2 := pkix.Extension{Id: []int{2, 2, 2}, Critical: false, Value: []byte("bar")} - stepExt := pkix.Extension{Id: stepOIDProvisioner, Critical: false, Value: []byte("baz")} - fakeStepExt := pkix.Extension{Id: stepOIDProvisioner, Critical: false, Value: []byte("zap")} - type test struct { - cert *x509.Certificate - check func(*x509.Certificate) - } - tests := map[string]func() test{ - "ok/empty-exts": func() test { - return test{ - cert: &x509.Certificate{}, - check: func(cert *x509.Certificate) { - assert.Equals(t, len(cert.ExtraExtensions), 0) - }, - } - }, - "ok/no-step-provisioner-ext": func() test { - return test{ - cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{e1, e2}}, - check: func(cert *x509.Certificate) { - assert.Equals(t, len(cert.ExtraExtensions), 0) - }, - } - }, - "ok/step-provisioner-ext": func() test { - return test{ - cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{e1, stepExt, fakeStepExt, e2}}, - check: func(cert *x509.Certificate) { - assert.Equals(t, len(cert.ExtraExtensions), 1) - assert.Equals(t, cert.ExtraExtensions[0], stepExt) - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tt := run() - ExtraExtsEnforcer{}.Enforce(tt.cert) - tt.check(tt.cert) - }) - } -} - func Test_validityValidator_Valid(t *testing.T) { type test struct { cert *x509.Certificate @@ -589,10 +544,10 @@ func Test_profileDefaultDuration_Option(t *testing.T) { cert: new(x509.Certificate), valid: func(cert *x509.Certificate) { n := now() - assert.True(t, n.After(cert.NotBefore)) + assert.True(t, n.After(cert.NotBefore.Add(-time.Second))) assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore)) - assert.True(t, n.Add(24*time.Hour).After(cert.NotAfter)) + assert.True(t, n.Add(24*time.Hour).After(cert.NotAfter.Add(-time.Second))) assert.True(t, n.Add(24*time.Hour).Add(-1*time.Minute).Before(cert.NotAfter)) }, } diff --git a/authority/tls_test.go b/authority/tls_test.go index 50e618e8..a27f9c15 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -390,6 +390,34 @@ ZYtQ9Ot36qc= notAfter: signOpts.NotAfter.Time().Truncate(time.Second), } }, + "ok/csr with no template critical SAN extension": func(t *testing.T) *signTest { + csr := getCSR(t, priv, func(csr *x509.CertificateRequest) { + csr.Subject = pkix.Name{} + }, func(csr *x509.CertificateRequest) { + csr.DNSNames = []string{"foo", "bar"} + }) + now := time.Now().UTC() + enforcedExtraOptions := []provisioner.SignOption{&certificateDurationEnforcer{ + NotBefore: now, + NotAfter: now.Add(365 * 24 * time.Hour), + }} + _a := testAuthority(t) + _a.config.AuthorityConfig.Template = &x509util.ASN1DN{} + _a.db = &db.MockAuthDB{ + MStoreCertificate: func(crt *x509.Certificate) error { + assert.Equals(t, crt.Subject, pkix.Name{}) + return nil + }, + } + return &signTest{ + auth: _a, + csr: csr, + extraOpts: enforcedExtraOptions, + signOpts: provisioner.SignOptions{}, + notBefore: now.Truncate(time.Second), + notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second), + } + }, } for name, genTestCase := range tests { @@ -417,22 +445,26 @@ ZYtQ9Ot36qc= assert.Equals(t, leaf.NotBefore, tc.notBefore) assert.Equals(t, leaf.NotAfter, tc.notAfter) tmplt := a.config.AuthorityConfig.Template - assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), - fmt.Sprintf("%v", &pkix.Name{ - Country: []string{tmplt.Country}, - Organization: []string{tmplt.Organization}, - Locality: []string{tmplt.Locality}, - StreetAddress: []string{tmplt.StreetAddress}, - Province: []string{tmplt.Province}, - CommonName: "smallstep test", - })) + if tc.csr.Subject.CommonName == "" { + assert.Equals(t, leaf.Subject, pkix.Name{}) + } else { + assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), + fmt.Sprintf("%v", &pkix.Name{ + Country: []string{tmplt.Country}, + Organization: []string{tmplt.Organization}, + Locality: []string{tmplt.Locality}, + StreetAddress: []string{tmplt.StreetAddress}, + Province: []string{tmplt.Province}, + CommonName: "smallstep test", + })) + assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"}) + } assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) assert.Equals(t, leaf.ExtKeyUsage, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) - assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"}) subjectKeyID, err := generateSubjectKeyID(pub) assert.FatalError(t, err) @@ -452,6 +484,7 @@ ZYtQ9Ot36qc= assert.Equals(t, val.Type, provisionerTypeJWK) assert.Equals(t, val.Name, []byte(p.Name)) assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID)) + // Basic Constraints case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 19})): val := basicConstraints{} @@ -459,10 +492,20 @@ ZYtQ9Ot36qc= assert.FatalError(t, err) assert.False(t, val.IsCA, false) assert.Equals(t, val.MaxPathLen, 0) + + // SAN extension + case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 17})): + if tc.csr.Subject.CommonName == "" { + // Empty CSR subject test does not use any provisioner extensions. + // So provisioner ID ext will be missing. + found = 1 + assert.Len(t, 5, leaf.Extensions) + } else { + assert.Len(t, 6, leaf.Extensions) + } } } assert.Equals(t, found, 1) - assert.Len(t, 6, leaf.Extensions) realIntermediate, err := x509.ParseCertificate(a.x509Issuer.Raw) assert.FatalError(t, err)