refactor crl config, add some tests

This commit is contained in:
Raal Goff 2022-10-07 10:30:00 +08:00
parent d0e81af524
commit f7df865687
6 changed files with 232 additions and 37 deletions

View file

@ -7,7 +7,6 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt"
"log" "log"
"strings" "strings"
"sync" "sync"
@ -664,12 +663,22 @@ func (a *Authority) init() error {
a.initOnce = true a.initOnce = true
// Start the CRL generator // Start the CRL generator
if a.config.CRL != nil && a.config.CRL.Generate { if a.config.CRL != nil && a.config.CRL.Enabled {
if v := a.config.CRL.CacheDuration; v != nil && v.Duration > 0 { if v := a.config.CRL.CacheDuration; v != nil && v.Duration < 0 {
err := a.startCRLGenerator() return errors.New("crl cacheDuration must be >= 0")
if err != nil { }
return err
} if v := a.config.CRL.CacheDuration; v != nil && v.Duration == 0 {
a.config.CRL.CacheDuration.Duration, _ = time.ParseDuration("24h")
}
if a.config.CRL.CacheDuration == nil {
a.config.CRL.CacheDuration, _ = provisioner.NewDuration("24h")
}
err = a.startCRLGenerator()
if err != nil {
return err
} }
} }
@ -797,6 +806,12 @@ func (a *Authority) startCRLGenerator() error {
return nil return nil
} }
// Make the renewal ticker run ~2/3 of cacheDuration by default, or use renewPeriod if available
tickerDuration := (a.config.CRL.CacheDuration.Duration / 3) * 2
if v := a.config.CRL.RenewPeriod; v != nil && v.Duration > 0 {
tickerDuration = v.Duration
}
// Check that there is a valid CRL in the DB right now. If it doesn't exist // Check that there is a valid CRL in the DB right now. If it doesn't exist
// or is expired, generate one now // or is expired, generate one now
_, ok := a.db.(db.CertificateRevocationListDB) _, ok := a.db.(db.CertificateRevocationListDB)
@ -811,11 +826,6 @@ func (a *Authority) startCRLGenerator() error {
return errors.Wrap(err, "could not generate a CRL") return errors.Wrap(err, "could not generate a CRL")
} }
log.Printf("CRL will be auto-generated every %v", a.config.CRL.CacheDuration)
tickerDuration := a.config.CRL.CacheDuration.Duration - time.Minute // generate the new CRL 1 minute before it expires
if tickerDuration <= 0 {
panic(fmt.Sprintf("ERROR: Addition of jitter to CRL generation time %v creates a negative duration (%v). Use a CRL generation time of longer than 1 minute.", a.config.CRL.CacheDuration, tickerDuration))
}
a.crlTicker = time.NewTicker(tickerDuration) a.crlTicker = time.NewTicker(tickerDuration)
go func() { go func() {
@ -832,3 +842,14 @@ func (a *Authority) startCRLGenerator() error {
return nil return nil
} }
func (a *Authority) resetCRLGeneratorTimer() {
if a.crlTicker != nil {
tickerDuration := (a.config.CRL.CacheDuration.Duration / 3) * 2
if v := a.config.CRL.RenewPeriod; v != nil && v.Duration > 0 {
tickerDuration = v.Duration
}
a.crlTicker.Reset(tickerDuration)
}
}

View file

@ -80,6 +80,12 @@ func testAuthority(t *testing.T, opts ...Option) *Authority {
AuthorityConfig: &AuthConfig{ AuthorityConfig: &AuthConfig{
Provisioners: p, Provisioners: p,
}, },
CRL: &config.CRLConfig{
Enabled: true,
CacheDuration: nil,
GenerateOnRevoke: true,
RenewPeriod: nil,
},
} }
a, err := New(c, opts...) a, err := New(c, opts...)
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -111,8 +111,10 @@ type AuthConfig struct {
// CRLConfig represents config options for CRL generation // CRLConfig represents config options for CRL generation
type CRLConfig struct { type CRLConfig struct {
Generate bool `json:"generate,omitempty"` Enabled bool `json:"enabled"`
CacheDuration *provisioner.Duration `json:"cacheDuration,omitempty"` CacheDuration *provisioner.Duration `json:"cacheDuration,omitempty"`
GenerateOnRevoke bool `json:"generateOnRevoke,omitempty"`
RenewPeriod *provisioner.Duration `json:"renewPeriod,omitempty"`
} }
// init initializes the required fields in the AuthConfig if they are not // init initializes the required fields in the AuthConfig if they are not

View file

@ -572,10 +572,17 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
} }
// Generate a new CRL so CRL requesters will always get an up-to-date CRL whenever they request it if a.config.CRL != nil && a.config.CRL.GenerateOnRevoke {
err = a.GenerateCertificateRevocationList() // Generate a new CRL so CRL requesters will always get an up-to-date CRL whenever they request it
if err != nil { err = a.GenerateCertificateRevocationList()
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
}
// the timer only gets reset if CRL is enabled
if a.config.CRL.Enabled {
a.resetCRLGeneratorTimer()
}
} }
} }
switch { switch {
@ -693,6 +700,13 @@ func (a *Authority) GenerateCertificateRevocationList() error {
}) })
} }
var updateDuration time.Duration
if a.config.CRL.CacheDuration != nil {
updateDuration = a.config.CRL.CacheDuration.Duration
} else if crlInfo != nil {
updateDuration = crlInfo.Duration
}
// Create a RevocationList representation ready for the CAS to sign // Create a RevocationList representation ready for the CAS to sign
// TODO: allow SignatureAlgorithm to be specified? // TODO: allow SignatureAlgorithm to be specified?
revocationList := x509.RevocationList{ revocationList := x509.RevocationList{
@ -700,7 +714,7 @@ func (a *Authority) GenerateCertificateRevocationList() error {
RevokedCertificates: revokedCertificates, RevokedCertificates: revokedCertificates,
Number: &bn, Number: &bn,
ThisUpdate: time.Now().UTC(), ThisUpdate: time.Now().UTC(),
NextUpdate: time.Now().UTC().Add(a.config.CRL.CacheDuration.Duration), NextUpdate: time.Now().UTC().Add(updateDuration),
ExtraExtensions: nil, ExtraExtensions: nil,
} }
@ -710,11 +724,12 @@ func (a *Authority) GenerateCertificateRevocationList() error {
} }
// Create a new db.CertificateRevocationListInfo, which stores the new Number we just generated, the // Create a new db.CertificateRevocationListInfo, which stores the new Number we just generated, the
// expiry time, and the DER-encoded CRL // expiry time, duration, and the DER-encoded CRL
newCRLInfo := db.CertificateRevocationListInfo{ newCRLInfo := db.CertificateRevocationListInfo{
Number: n, Number: n,
ExpiresAt: revocationList.NextUpdate, ExpiresAt: revocationList.NextUpdate,
DER: certificateRevocationList.CRL, DER: certificateRevocationList.CRL,
Duration: updateDuration,
} }
// Store the CRL in the database ready for retrieval by api endpoints // Store the CRL in the database ready for retrieval by api endpoints

View file

@ -1673,3 +1673,141 @@ func TestAuthority_constraints(t *testing.T) {
}) })
} }
} }
func TestAuthority_CRL(t *testing.T) {
reasonCode := 2
reason := "bob was let go"
validIssuer := "step-cli"
validAudience := testAudiences.Revoke
now := time.Now().UTC()
//
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
//
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err)
crlCtx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
var crlStore db.CertificateRevocationListInfo
var revokedList []db.RevokedCertificateInfo
type test struct {
auth *Authority
ctx context.Context
expected []string
err error
code int
}
tests := map[string]func() test{
"ok/empty-crl": func() test {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificate: func(sn string) (*x509.Certificate, error) {
return nil, errors.New("not found")
},
MStoreCRL: func(i *db.CertificateRevocationListInfo) error {
crlStore = *i
return nil
},
MGetCRL: func() (*db.CertificateRevocationListInfo, error) {
return &crlStore, nil
},
MGetRevokedCertificates: func() (*[]db.RevokedCertificateInfo, error) {
return &revokedList, nil
},
MRevoke: func(rci *db.RevokedCertificateInfo) error {
revokedList = append(revokedList, *rci)
return nil
},
}))
return test{
auth: _a,
ctx: crlCtx,
expected: nil,
}
},
"ok/crl-full": func() test {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificate: func(sn string) (*x509.Certificate, error) {
return nil, errors.New("not found")
},
MStoreCRL: func(i *db.CertificateRevocationListInfo) error {
crlStore = *i
return nil
},
MGetCRL: func() (*db.CertificateRevocationListInfo, error) {
return &crlStore, nil
},
MGetRevokedCertificates: func() (*[]db.RevokedCertificateInfo, error) {
return &revokedList, nil
},
MRevoke: func(rci *db.RevokedCertificateInfo) error {
revokedList = append(revokedList, *rci)
return nil
},
}))
var ex []string
for i := 0; i < 100; i++ {
sn := fmt.Sprintf("%v", i)
cl := jwt.Claims{
Subject: fmt.Sprintf("sn-%v", i),
Issuer: validIssuer,
NotBefore: jwt.NewNumericDate(now),
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: sn,
}
raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
err = _a.Revoke(crlCtx, &RevokeOptions{
Serial: sn,
ReasonCode: reasonCode,
Reason: reason,
OTT: raw,
})
assert.FatalError(t, err)
ex = append(ex, sn)
}
return test{
auth: _a,
ctx: crlCtx,
expected: ex,
}
},
}
for name, f := range tests {
tc := f()
t.Run(name, func(t *testing.T) {
if crlBytes, _err := tc.auth.GetCertificateRevocationList(); _err == nil {
crl, parseErr := x509.ParseCRL(crlBytes)
if parseErr != nil {
t.Errorf("x509.ParseCertificateRequest() error = %v, wantErr %v", parseErr, nil)
}
var cmpList []string
for _, c := range crl.TBSCertList.RevokedCertificates {
cmpList = append(cmpList, c.SerialNumber.String())
}
assert.Equals(t, cmpList, tc.expected)
} else {
assert.NotNil(t, tc.err)
}
})
}
}

View file

@ -155,6 +155,7 @@ type RevokedCertificateInfo struct {
type CertificateRevocationListInfo struct { type CertificateRevocationListInfo struct {
Number int64 Number int64
ExpiresAt time.Time ExpiresAt time.Time
Duration time.Duration
DER []byte DER []byte
} }
@ -471,32 +472,44 @@ func (db *DB) Shutdown() error {
// MockAuthDB mocks the AuthDB interface. // // MockAuthDB mocks the AuthDB interface. //
type MockAuthDB struct { type MockAuthDB struct {
Err error Err error
Ret1 interface{} Ret1 interface{}
MIsRevoked func(string) (bool, error) MIsRevoked func(string) (bool, error)
MIsSSHRevoked func(string) (bool, error) MIsSSHRevoked func(string) (bool, error)
MRevoke func(rci *RevokedCertificateInfo) error MRevoke func(rci *RevokedCertificateInfo) error
MRevokeSSH func(rci *RevokedCertificateInfo) error MRevokeSSH func(rci *RevokedCertificateInfo) error
MGetCertificate func(serialNumber string) (*x509.Certificate, error) MGetCertificate func(serialNumber string) (*x509.Certificate, error)
MGetCertificateData func(serialNumber string) (*CertificateData, error) MGetCertificateData func(serialNumber string) (*CertificateData, error)
MStoreCertificate func(crt *x509.Certificate) error MStoreCertificate func(crt *x509.Certificate) error
MUseToken func(id, tok string) (bool, error) MUseToken func(id, tok string) (bool, error)
MIsSSHHost func(principal string) (bool, error) MIsSSHHost func(principal string) (bool, error)
MStoreSSHCertificate func(crt *ssh.Certificate) error MStoreSSHCertificate func(crt *ssh.Certificate) error
MGetSSHHostPrincipals func() ([]string, error) MGetSSHHostPrincipals func() ([]string, error)
MShutdown func() error MShutdown func() error
MGetRevokedCertificates func() (*[]RevokedCertificateInfo, error)
MGetCRL func() (*CertificateRevocationListInfo, error)
MStoreCRL func(*CertificateRevocationListInfo) error
} }
func (m *MockAuthDB) GetRevokedCertificates() (*[]RevokedCertificateInfo, error) { func (m *MockAuthDB) GetRevokedCertificates() (*[]RevokedCertificateInfo, error) {
panic("implement me") if m.MGetRevokedCertificates != nil {
return m.MGetRevokedCertificates()
}
return m.Ret1.(*[]RevokedCertificateInfo), m.Err
} }
func (m *MockAuthDB) GetCRL() (*CertificateRevocationListInfo, error) { func (m *MockAuthDB) GetCRL() (*CertificateRevocationListInfo, error) {
panic("implement me") if m.MGetCRL != nil {
return m.MGetCRL()
}
return m.Ret1.(*CertificateRevocationListInfo), m.Err
} }
func (m *MockAuthDB) StoreCRL(info *CertificateRevocationListInfo) error { func (m *MockAuthDB) StoreCRL(info *CertificateRevocationListInfo) error {
panic("implement me") if m.MStoreCRL != nil {
return m.MStoreCRL(info)
}
return m.Err
} }
// IsRevoked mock. // IsRevoked mock.