implemented some requested changes

This commit is contained in:
Raal Goff 2022-04-05 11:19:13 +08:00
parent 45975b061c
commit 8520c861d5
6 changed files with 34 additions and 22 deletions

View file

@ -46,7 +46,6 @@ type Authority interface {
GetRoots() (federation []*x509.Certificate, err error) GetRoots() (federation []*x509.Certificate, err error)
GetFederation() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error)
Version() authority.Version Version() authority.Version
GenerateCertificateRevocationList() error
GetCertificateRevocationList() ([]byte, error) GetCertificateRevocationList() ([]byte, error)
} }

View file

@ -580,7 +580,7 @@ type mockAuthority struct {
version func() authority.Version version func() authority.Version
} }
func (m *mockAuthority) GenerateCertificateRevocationList(force bool) ([]byte, error) { func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) {
panic("implement me") panic("implement me")
} }

View file

@ -4,6 +4,7 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"net/http" "net/http"
) )
@ -14,17 +15,16 @@ func (h *caHandler) CRL(w http.ResponseWriter, r *http.Request) {
_, formatAsPEM := r.URL.Query()["pem"] _, formatAsPEM := r.URL.Query()["pem"]
if err != nil { if err != nil {
w.WriteHeader(500)
_, err = fmt.Fprintf(w, "%v\n", err) caErr, isCaErr := err.(*errs.Error)
if err != nil {
panic(errors.Wrap(err, "error writing http response")) if isCaErr {
} http.Error(w, caErr.Msg, caErr.Status)
return return
} }
if crlBytes == nil { w.WriteHeader(500)
w.WriteHeader(404) _, err = fmt.Fprintf(w, "%v\n", err)
_, err = fmt.Fprintln(w, "No CRL available")
if err != nil { if err != nil {
panic(errors.Wrap(err, "error writing http response")) panic(errors.Wrap(err, "error writing http response"))
} }

View file

@ -66,7 +66,7 @@ type Authority struct {
sshCAHostFederatedCerts []ssh.PublicKey sshCAHostFederatedCerts []ssh.PublicKey
// CRL vars // CRL vars
crlChannel chan int crlTicker *time.Ticker
// Do not re-initialize // Do not re-initialize
initOnce bool initOnce bool
@ -586,6 +586,10 @@ func (a *Authority) IsAdminAPIEnabled() bool {
// Shutdown safely shuts down any clients, databases, etc. held by the Authority. // Shutdown safely shuts down any clients, databases, etc. held by the Authority.
func (a *Authority) Shutdown() error { func (a *Authority) Shutdown() error {
if a.crlTicker != nil {
a.crlTicker.Stop()
}
if err := a.keyManager.Close(); err != nil { if err := a.keyManager.Close(); err != nil {
log.Printf("error closing the key manager: %v", err) log.Printf("error closing the key manager: %v", err)
} }
@ -594,6 +598,11 @@ func (a *Authority) Shutdown() error {
// CloseForReload closes internal services, to allow a safe reload. // CloseForReload closes internal services, to allow a safe reload.
func (a *Authority) CloseForReload() { func (a *Authority) CloseForReload() {
if a.crlTicker != nil {
a.crlTicker.Stop()
}
if err := a.keyManager.Close(); err != nil { if err := a.keyManager.Close(); err != nil {
log.Printf("error closing the key manager: %v", err) log.Printf("error closing the key manager: %v", err)
} }
@ -655,12 +664,12 @@ func (a *Authority) startCRLGenerator() error {
if tickerDuration <= 0 { if tickerDuration <= 0 {
panic(fmt.Sprintf("ERROR: Addition of jitter to CRL generation time %v creates a negative duration (%v). Use a CRL generation time of longer than 1 minute.", a.config.CRL.CacheDuration, tickerDuration)) panic(fmt.Sprintf("ERROR: Addition of jitter to CRL generation time %v creates a negative duration (%v). Use a CRL generation time of longer than 1 minute.", a.config.CRL.CacheDuration, tickerDuration))
} }
crlTicker := time.NewTicker(tickerDuration) a.crlTicker = time.NewTicker(tickerDuration)
go func() { go func() {
for { for {
select { select {
case <-crlTicker.C: case <-a.crlTicker.C:
log.Println("Regenerating CRL") log.Println("Regenerating CRL")
err := a.GenerateCertificateRevocationList() err := a.GenerateCertificateRevocationList()
if err != nil { if err != nil {

View file

@ -365,6 +365,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
err error err error
) )
if revokeOpts.Crt == nil {
// Attempt to get the certificate expiry using the serial number. // Attempt to get the certificate expiry using the serial number.
cert, err := a.db.GetCertificate(revokeOpts.Serial) cert, err := a.db.GetCertificate(revokeOpts.Serial)
@ -373,6 +374,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
if err == nil { if err == nil {
rci.ExpiresAt = cert.NotAfter rci.ExpiresAt = cert.NotAfter
} }
}
// If not mTLS then get the TokenID of the token. // If not mTLS then get the TokenID of the token.
if !revokeOpts.MTLS { if !revokeOpts.MTLS {

View file

@ -215,13 +215,15 @@ func (db *DB) GetRevokedCertificates() (*[]RevokedCertificateInfo, error) {
return nil, err return nil, err
} }
var revokedCerts []RevokedCertificateInfo var revokedCerts []RevokedCertificateInfo
now := time.Now().UTC()
for _, e := range entries { for _, e := range entries {
var data RevokedCertificateInfo var data RevokedCertificateInfo
if err := json.Unmarshal(e.Value, &data); err != nil { if err := json.Unmarshal(e.Value, &data); err != nil {
return nil, err return nil, err
} }
if !data.ExpiresAt.IsZero() && data.ExpiresAt.After(time.Now().UTC()) { if !data.ExpiresAt.IsZero() && data.ExpiresAt.After(now) {
revokedCerts = append(revokedCerts, data) revokedCerts = append(revokedCerts, data)
} else if data.ExpiresAt.IsZero() { } else if data.ExpiresAt.IsZero() {
cert, err := db.GetCertificate(data.Serial) cert, err := db.GetCertificate(data.Serial)
@ -232,7 +234,7 @@ func (db *DB) GetRevokedCertificates() (*[]RevokedCertificateInfo, error) {
continue continue
} }
if cert.NotAfter.After(time.Now().UTC()) { if cert.NotAfter.After(now) {
revokedCerts = append(revokedCerts, data) revokedCerts = append(revokedCerts, data)
} }
} }