Rename interface to CertificateEnforcer and add tests.

This commit is contained in:
Mariano Cano 2020-03-31 11:41:36 -07:00
parent 64f26c0f40
commit bfe1f4952d
4 changed files with 53 additions and 15 deletions

View file

@ -489,7 +489,7 @@ type identityModifier struct {
NotAfter time.Time NotAfter time.Time
} }
func (m *identityModifier) Constrain(cert *x509.Certificate) error { func (m *identityModifier) Enforce(cert *x509.Certificate) error {
cert.NotBefore = m.NotBefore cert.NotBefore = m.NotBefore
cert.NotAfter = m.NotAfter cert.NotAfter = m.NotAfter
return nil return nil

View file

@ -47,11 +47,11 @@ type ProfileModifier interface {
Option(o Options) x509util.WithOption Option(o Options) x509util.WithOption
} }
// CertificateConstrainModifier is the interface used to modify a certificate // CertificateEnforcer is the interface used to modify a certificate after
// after validation. // validation.
type CertificateConstrainModifier interface { type CertificateEnforcer interface {
SignOption SignOption
Constrain(cert *x509.Certificate) error Enforce(cert *x509.Certificate) error
} }
// profileWithOption is a wrapper against x509util.WithOption to conform the // profileWithOption is a wrapper against x509util.WithOption to conform the

View file

@ -61,10 +61,10 @@ func withDefaultASN1DN(def *x509util.ASN1DN) x509util.WithOption {
// Sign creates a signed certificate from a certificate signing request. // Sign creates a signed certificate from a certificate signing request.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
var ( var (
opts = []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} opts = []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)} mods = []x509util.WithOption{withDefaultASN1DN(a.config.AuthorityConfig.Template)}
certValidators = []provisioner.CertificateValidator{} certValidators = []provisioner.CertificateValidator{}
constrainModifiers = []provisioner.CertificateConstrainModifier{} forcedModifiers = []provisioner.CertificateEnforcer{}
) )
// Set backdate with the configured value // Set backdate with the configured value
@ -80,8 +80,8 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
} }
case provisioner.ProfileModifier: case provisioner.ProfileModifier:
mods = append(mods, k.Option(signOpts)) mods = append(mods, k.Option(signOpts))
case provisioner.CertificateConstrainModifier: case provisioner.CertificateEnforcer:
constrainModifiers = append(constrainModifiers, k) forcedModifiers = append(forcedModifiers, k)
default: default:
return nil, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]interface{}{k}, opts...)...) return nil, errs.InternalServer("authority.Sign; invalid extra option type %T", append([]interface{}{k}, opts...)...)
} }
@ -104,8 +104,8 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Opti
} }
// Certificate modifier after validation // Certificate modifier after validation
for _, m := range constrainModifiers { for _, m := range forcedModifiers {
if err := m.Constrain(leaf.Subject()); err != nil { if err := m.Enforce(leaf.Subject()); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...)
} }
} }

View file

@ -41,6 +41,17 @@ type stepProvisionerASN1 struct {
CredentialID []byte CredentialID []byte
} }
type certificateDurationEnforcer struct {
NotBefore time.Time
NotAfter time.Time
}
func (m *certificateDurationEnforcer) Enforce(cert *x509.Certificate) error {
cert.NotBefore = m.NotBefore
cert.NotAfter = m.NotAfter
return nil
}
func withProvisionerOID(name, kid string) x509util.WithOption { func withProvisionerOID(name, kid string) x509util.WithOption {
return func(p x509util.Profile) error { return func(p x509util.Profile) error {
crt := p.Subject() crt := p.Subject()
@ -114,6 +125,8 @@ func TestAuthority_Sign(t *testing.T) {
csr *x509.CertificateRequest csr *x509.CertificateRequest
signOpts provisioner.Options signOpts provisioner.Options
extraOpts []provisioner.SignOption extraOpts []provisioner.SignOption
notBefore time.Time
notAfter time.Time
err error err error
code int code int
} }
@ -253,6 +266,31 @@ ZYtQ9Ot36qc=
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
}
},
"ok with enforced modifier": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
now := time.Now().UTC()
enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{
NotBefore: now,
NotAfter: now.Add(365 * 24 * time.Hour),
})
_a := testAuthority(t)
_a.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil
},
}
return &signTest{
auth: a,
csr: csr,
extraOpts: enforcedExtraOptions,
signOpts: signOpts,
notBefore: now.Truncate(time.Second),
notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second),
} }
}, },
} }
@ -279,8 +317,8 @@ ZYtQ9Ot36qc=
leaf := certChain[0] leaf := certChain[0]
intermediate := certChain[1] intermediate := certChain[1]
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, leaf.NotBefore, signOpts.NotBefore.Time().Truncate(time.Second)) assert.Equals(t, leaf.NotBefore, tc.notBefore)
assert.Equals(t, leaf.NotAfter, signOpts.NotAfter.Time().Truncate(time.Second)) assert.Equals(t, leaf.NotAfter, tc.notAfter)
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, fmt.Sprintf("%v", leaf.Subject),
fmt.Sprintf("%v", &pkix.Name{ fmt.Sprintf("%v", &pkix.Name{