Remove unused code; fix usage wrong word; add gap time for unit test

This commit is contained in:
max furman 2020-08-20 18:48:17 -07:00
parent 03d642e59c
commit 46fc922afd
3 changed files with 57 additions and 77 deletions

View file

@ -70,7 +70,7 @@ func (fn CertificateModifierFunc) Modify(cert *x509.Certificate, opts SignOption
// with a function. // with a function.
type CertificateEnforcerFunc func(cert *x509.Certificate) error type CertificateEnforcerFunc func(cert *x509.Certificate) error
// Modify implements CertificateEnforcer and just calls the defined function. // Enforce implements CertificateEnforcer and just calls the defined function.
func (fn CertificateEnforcerFunc) Enforce(cert *x509.Certificate) error { func (fn CertificateEnforcerFunc) Enforce(cert *x509.Certificate) error {
return fn(cert) return fn(cert)
} }
@ -248,24 +248,6 @@ func (v defaultSANsValidator) Valid(req *x509.CertificateRequest) (err error) {
return return
} }
// ExtraExtsEnforcer enforces only those extra extensions that are strictly
// managed by step-ca. All other "extra extensions" are dropped.
type ExtraExtsEnforcer struct{}
// Enforce removes all extensions except the step provisioner extension, if it
// exists. If the step provisioner extension is not present, then remove all
// extra extensions from the cert.
func (eee ExtraExtsEnforcer) Enforce(cert *x509.Certificate) error {
for _, ext := range cert.ExtraExtensions {
if ext.Id.Equal(stepOIDProvisioner) {
cert.ExtraExtensions = []pkix.Extension{ext}
return nil
}
}
cert.ExtraExtensions = nil
return nil
}
// profileDefaultDuration is a modifier that sets the certificate // profileDefaultDuration is a modifier that sets the certificate
// duration. // duration.
type profileDefaultDuration time.Duration type profileDefaultDuration time.Duration

View file

@ -351,51 +351,6 @@ func Test_defaultSANsValidator_Valid(t *testing.T) {
} }
} }
func Test_ExtraExtsEnforcer_Enforce(t *testing.T) {
e1 := pkix.Extension{Id: []int{1, 2, 3, 4, 5}, Critical: false, Value: []byte("foo")}
e2 := pkix.Extension{Id: []int{2, 2, 2}, Critical: false, Value: []byte("bar")}
stepExt := pkix.Extension{Id: stepOIDProvisioner, Critical: false, Value: []byte("baz")}
fakeStepExt := pkix.Extension{Id: stepOIDProvisioner, Critical: false, Value: []byte("zap")}
type test struct {
cert *x509.Certificate
check func(*x509.Certificate)
}
tests := map[string]func() test{
"ok/empty-exts": func() test {
return test{
cert: &x509.Certificate{},
check: func(cert *x509.Certificate) {
assert.Equals(t, len(cert.ExtraExtensions), 0)
},
}
},
"ok/no-step-provisioner-ext": func() test {
return test{
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{e1, e2}},
check: func(cert *x509.Certificate) {
assert.Equals(t, len(cert.ExtraExtensions), 0)
},
}
},
"ok/step-provisioner-ext": func() test {
return test{
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{e1, stepExt, fakeStepExt, e2}},
check: func(cert *x509.Certificate) {
assert.Equals(t, len(cert.ExtraExtensions), 1)
assert.Equals(t, cert.ExtraExtensions[0], stepExt)
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tt := run()
ExtraExtsEnforcer{}.Enforce(tt.cert)
tt.check(tt.cert)
})
}
}
func Test_validityValidator_Valid(t *testing.T) { func Test_validityValidator_Valid(t *testing.T) {
type test struct { type test struct {
cert *x509.Certificate cert *x509.Certificate
@ -589,10 +544,10 @@ func Test_profileDefaultDuration_Option(t *testing.T) {
cert: new(x509.Certificate), cert: new(x509.Certificate),
valid: func(cert *x509.Certificate) { valid: func(cert *x509.Certificate) {
n := now() n := now()
assert.True(t, n.After(cert.NotBefore)) assert.True(t, n.After(cert.NotBefore.Add(-time.Second)))
assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore)) assert.True(t, n.Add(-1*time.Minute).Before(cert.NotBefore))
assert.True(t, n.Add(24*time.Hour).After(cert.NotAfter)) assert.True(t, n.Add(24*time.Hour).After(cert.NotAfter.Add(-time.Second)))
assert.True(t, n.Add(24*time.Hour).Add(-1*time.Minute).Before(cert.NotAfter)) assert.True(t, n.Add(24*time.Hour).Add(-1*time.Minute).Before(cert.NotAfter))
}, },
} }

View file

@ -390,6 +390,34 @@ ZYtQ9Ot36qc=
notAfter: signOpts.NotAfter.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
} }
}, },
"ok/csr with no template critical SAN extension": func(t *testing.T) *signTest {
csr := getCSR(t, priv, func(csr *x509.CertificateRequest) {
csr.Subject = pkix.Name{}
}, func(csr *x509.CertificateRequest) {
csr.DNSNames = []string{"foo", "bar"}
})
now := time.Now().UTC()
enforcedExtraOptions := []provisioner.SignOption{&certificateDurationEnforcer{
NotBefore: now,
NotAfter: now.Add(365 * 24 * time.Hour),
}}
_a := testAuthority(t)
_a.config.AuthorityConfig.Template = &x509util.ASN1DN{}
_a.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject, pkix.Name{})
return nil
},
}
return &signTest{
auth: _a,
csr: csr,
extraOpts: enforcedExtraOptions,
signOpts: provisioner.SignOptions{},
notBefore: now.Truncate(time.Second),
notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second),
}
},
} }
for name, genTestCase := range tests { for name, genTestCase := range tests {
@ -417,6 +445,9 @@ ZYtQ9Ot36qc=
assert.Equals(t, leaf.NotBefore, tc.notBefore) assert.Equals(t, leaf.NotBefore, tc.notBefore)
assert.Equals(t, leaf.NotAfter, tc.notAfter) assert.Equals(t, leaf.NotAfter, tc.notAfter)
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
if tc.csr.Subject.CommonName == "" {
assert.Equals(t, leaf.Subject, pkix.Name{})
} else {
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, fmt.Sprintf("%v", leaf.Subject),
fmt.Sprintf("%v", &pkix.Name{ fmt.Sprintf("%v", &pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
@ -426,13 +457,14 @@ ZYtQ9Ot36qc=
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: "smallstep test", CommonName: "smallstep test",
})) }))
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"})
}
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA) assert.Equals(t, leaf.PublicKeyAlgorithm, x509.ECDSA)
assert.Equals(t, leaf.ExtKeyUsage, assert.Equals(t, leaf.ExtKeyUsage,
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth})
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"})
subjectKeyID, err := generateSubjectKeyID(pub) subjectKeyID, err := generateSubjectKeyID(pub)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -452,6 +484,7 @@ ZYtQ9Ot36qc=
assert.Equals(t, val.Type, provisionerTypeJWK) assert.Equals(t, val.Type, provisionerTypeJWK)
assert.Equals(t, val.Name, []byte(p.Name)) assert.Equals(t, val.Name, []byte(p.Name))
assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID)) assert.Equals(t, val.CredentialID, []byte(p.Key.KeyID))
// Basic Constraints // Basic Constraints
case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 19})): case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 19})):
val := basicConstraints{} val := basicConstraints{}
@ -459,10 +492,20 @@ ZYtQ9Ot36qc=
assert.FatalError(t, err) assert.FatalError(t, err)
assert.False(t, val.IsCA, false) assert.False(t, val.IsCA, false)
assert.Equals(t, val.MaxPathLen, 0) assert.Equals(t, val.MaxPathLen, 0)
// SAN extension
case ext.Id.Equal(asn1.ObjectIdentifier([]int{2, 5, 29, 17})):
if tc.csr.Subject.CommonName == "" {
// Empty CSR subject test does not use any provisioner extensions.
// So provisioner ID ext will be missing.
found = 1
assert.Len(t, 5, leaf.Extensions)
} else {
assert.Len(t, 6, leaf.Extensions)
}
} }
} }
assert.Equals(t, found, 1) assert.Equals(t, found, 1)
assert.Len(t, 6, leaf.Extensions)
realIntermediate, err := x509.ParseCertificate(a.x509Issuer.Raw) realIntermediate, err := x509.ParseCertificate(a.x509Issuer.Raw)
assert.FatalError(t, err) assert.FatalError(t, err)