Add a custom enforcer that can be used to modify a cert.

This commit is contained in:
Mariano Cano 2022-02-02 14:36:58 -08:00
parent 09a9b3e1c8
commit 300c19f8b9
4 changed files with 139 additions and 35 deletions

View file

@ -50,6 +50,7 @@ type Authority struct {
rootX509CertPool *x509.CertPool rootX509CertPool *x509.CertPool
federatedX509Certs []*x509.Certificate federatedX509Certs []*x509.Certificate
certificates *sync.Map certificates *sync.Map
x509Enforcers []provisioner.CertificateEnforcer
// SCEP CA // SCEP CA
scepService *scep.Service scepService *scep.Service

View file

@ -241,6 +241,15 @@ func WithLinkedCAToken(token string) Option {
} }
} }
// WithX509Enforcers is an option that allows to define custom certificate
// modifiers that will be processed just before the signing of the certificate.
func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option {
return func(a *Authority) error {
a.x509Enforcers = ces
return nil
}
}
func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) { func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
var block *pem.Block var block *pem.Block
var certs []*x509.Certificate var certs []*x509.Certificate

View file

@ -180,6 +180,17 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
} }
} }
// Process injected modifiers after validation
for _, m := range a.x509Enforcers {
if err := m.Enforce(leaf); err != nil {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"),
opts...,
)
}
}
// Sign certificate
lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate)) lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate))
resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{ resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{
Template: leaf, Template: leaf,

View file

@ -205,6 +205,17 @@ type basicConstraints struct {
MaxPathLen int `asn1:"optional,default:-1"` MaxPathLen int `asn1:"optional,default:-1"`
} }
type testEnforcer struct {
enforcer func(*x509.Certificate) error
}
func (e *testEnforcer) Enforce(cert *x509.Certificate) error {
if e.enforcer != nil {
return e.enforcer(cert)
}
return nil
}
func TestAuthority_Sign(t *testing.T) { func TestAuthority_Sign(t *testing.T) {
pub, priv, err := keyutil.GenerateDefaultKeyPair() pub, priv, err := keyutil.GenerateDefaultKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -238,14 +249,15 @@ func TestAuthority_Sign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
type signTest struct { type signTest struct {
auth *Authority auth *Authority
csr *x509.CertificateRequest csr *x509.CertificateRequest
signOpts provisioner.SignOptions signOpts provisioner.SignOptions
extraOpts []provisioner.SignOption extraOpts []provisioner.SignOption
notBefore time.Time notBefore time.Time
notAfter time.Time notAfter time.Time
err error extensionsCount int
code int err error
code int
} }
tests := map[string]func(*testing.T) *signTest{ tests := map[string]func(*testing.T) *signTest{
"fail invalid signature": func(t *testing.T) *signTest { "fail invalid signature": func(t *testing.T) *signTest {
@ -454,6 +466,49 @@ ZYtQ9Ot36qc=
code: http.StatusInternalServerError, code: http.StatusInternalServerError,
} }
}, },
"fail with provisioner enforcer": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
aa := testAuthority(t)
aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil
},
}
return &signTest{
auth: aa,
csr: csr,
extraOpts: append(extraOpts, &testEnforcer{
enforcer: func(crt *x509.Certificate) error { return fmt.Errorf("an error") },
}),
signOpts: signOpts,
err: errors.New("error creating certificate"),
code: http.StatusForbidden,
}
},
"fail with custom enforcer": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
aa := testAuthority(t, WithX509Enforcers(&testEnforcer{
enforcer: func(cert *x509.Certificate) error {
return fmt.Errorf("an error")
},
}))
aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
return nil
},
}
return &signTest{
auth: aa,
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: errors.New("error creating certificate"),
code: http.StatusForbidden,
}
},
"ok": func(t *testing.T) *signTest { "ok": func(t *testing.T) *signTest {
csr := getCSR(t, priv) csr := getCSR(t, priv)
_a := testAuthority(t) _a := testAuthority(t)
@ -464,12 +519,13 @@ ZYtQ9Ot36qc=
}, },
} }
return &signTest{ return &signTest{
auth: a, auth: a,
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
notAfter: signOpts.NotAfter.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
extensionsCount: 6,
} }
}, },
"ok with enforced modifier": func(t *testing.T) *signTest { "ok with enforced modifier": func(t *testing.T) *signTest {
@ -497,12 +553,13 @@ ZYtQ9Ot36qc=
}, },
} }
return &signTest{ return &signTest{
auth: a, auth: a,
csr: csr, csr: csr,
extraOpts: enforcedExtraOptions, extraOpts: enforcedExtraOptions,
signOpts: signOpts, signOpts: signOpts,
notBefore: now.Truncate(time.Second), notBefore: now.Truncate(time.Second),
notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second), notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second),
extensionsCount: 6,
} }
}, },
"ok with custom template": func(t *testing.T) *signTest { "ok with custom template": func(t *testing.T) *signTest {
@ -530,12 +587,13 @@ ZYtQ9Ot36qc=
}, },
} }
return &signTest{ return &signTest{
auth: testAuthority, auth: testAuthority,
csr: csr, csr: csr,
extraOpts: testExtraOpts, extraOpts: testExtraOpts,
signOpts: signOpts, signOpts: signOpts,
notBefore: signOpts.NotBefore.Time().Truncate(time.Second), notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
notAfter: signOpts.NotAfter.Time().Truncate(time.Second), notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
extensionsCount: 6,
} }
}, },
"ok/csr with no template critical SAN extension": func(t *testing.T) *signTest { "ok/csr with no template critical SAN extension": func(t *testing.T) *signTest {
@ -558,12 +616,39 @@ ZYtQ9Ot36qc=
}, },
} }
return &signTest{ return &signTest{
auth: _a, auth: _a,
csr: csr, csr: csr,
extraOpts: enforcedExtraOptions, extraOpts: enforcedExtraOptions,
signOpts: provisioner.SignOptions{}, signOpts: provisioner.SignOptions{},
notBefore: now.Truncate(time.Second), notBefore: now.Truncate(time.Second),
notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second), notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second),
extensionsCount: 5,
}
},
"ok with custom enforcer": func(t *testing.T) *signTest {
csr := getCSR(t, priv)
aa := testAuthority(t, WithX509Enforcers(&testEnforcer{
enforcer: func(cert *x509.Certificate) error {
cert.CRLDistributionPoints = []string{"http://ca.example.org/leaf.crl"}
return nil
},
}))
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
aa.db = &db.MockAuthDB{
MStoreCertificate: func(crt *x509.Certificate) error {
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
assert.Equals(t, crt.CRLDistributionPoints, []string{"http://ca.example.org/leaf.crl"})
return nil
},
}
return &signTest{
auth: aa,
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
extensionsCount: 7,
} }
}, },
} }
@ -645,9 +730,6 @@ ZYtQ9Ot36qc=
// Empty CSR subject test does not use any provisioner extensions. // Empty CSR subject test does not use any provisioner extensions.
// So provisioner ID ext will be missing. // So provisioner ID ext will be missing.
found = 1 found = 1
assert.Len(t, 5, leaf.Extensions)
} else {
assert.Len(t, 6, leaf.Extensions)
} }
} }
} }
@ -655,6 +737,7 @@ ZYtQ9Ot36qc=
realIntermediate, err := x509.ParseCertificate(issuer.Raw) realIntermediate, err := x509.ParseCertificate(issuer.Raw)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, intermediate, realIntermediate) assert.Equals(t, intermediate, realIntermediate)
assert.Len(t, tc.extensionsCount, leaf.Extensions)
} }
} }
}) })