Merge pull request #11 from smallstep/ca-force-renew

Force the renew of the CA server.
This commit is contained in:
Mariano Cano 2018-11-27 16:30:02 -08:00 committed by GitHub
commit eaa9bc5faf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 6 deletions

View file

@ -121,6 +121,7 @@ func (ca *CA) Run() error {
// Stop stops the CA calling to the server Shutdown method. // Stop stops the CA calling to the server Shutdown method.
func (ca *CA) Stop() error { func (ca *CA) Stop() error {
ca.renewer.Stop()
return ca.srv.Shutdown() return ca.srv.Shutdown()
} }
@ -185,7 +186,7 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
// empty we are implicitly forcing GetCertificate to be the only mechanism // empty we are implicitly forcing GetCertificate to be the only mechanism
// by which the server can find it's own leaf Certificate. // by which the server can find it's own leaf Certificate.
tlsConfig.Certificates = []tls.Certificate{} tlsConfig.Certificates = []tls.Certificate{}
tlsConfig.GetCertificate = ca.renewer.GetCertificate tlsConfig.GetCertificate = ca.renewer.GetCertificateForCA
// Add support for mutual tls to renew certificates // Add support for mutual tls to renew certificates
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven

View file

@ -14,7 +14,7 @@ import (
// certificate. // certificate.
type RenewFunc func() (*tls.Certificate, error) type RenewFunc func() (*tls.Certificate, error)
// TLSRenewer renews automatically a tls certificate with a given function. // TLSRenewer automatically renews a tls certificate using a RenewFunc.
type TLSRenewer struct { type TLSRenewer struct {
sync.RWMutex sync.RWMutex
RenewCertificate RenewFunc RenewCertificate RenewFunc
@ -22,6 +22,7 @@ type TLSRenewer struct {
timer *time.Timer timer *time.Timer
renewBefore time.Duration renewBefore time.Duration
renewJitter time.Duration renewJitter time.Duration
certNotAfter time.Time
} }
type tlsRenewerOptions func(r *TLSRenewer) error type tlsRenewerOptions func(r *TLSRenewer) error
@ -43,7 +44,7 @@ func WithRenewJitter(j time.Duration) func(r *TLSRenewer) error {
} }
// NewTLSRenewer creates a TLSRenewer for the given cert. It will use the given // NewTLSRenewer creates a TLSRenewer for the given cert. It will use the given
// function to get a new certificate when required. // RenewFunc to get a new certificate when required.
func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOptions) (*TLSRenewer, error) { func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOptions) (*TLSRenewer, error) {
r := &TLSRenewer{ r := &TLSRenewer{
RenewCertificate: fn, RenewCertificate: fn,
@ -91,8 +92,11 @@ func (r *TLSRenewer) RunContext(ctx context.Context) {
// Stop prevents the renew timer from firing. // Stop prevents the renew timer from firing.
func (r *TLSRenewer) Stop() bool { func (r *TLSRenewer) Stop() bool {
if r.timer != nil {
return r.timer.Stop() return r.timer.Stop()
} }
return true
}
// GetCertificate returns the current server certificate. // GetCertificate returns the current server certificate.
// //
@ -101,6 +105,15 @@ func (r *TLSRenewer) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Cert
return r.getCertificate(), nil return r.getCertificate(), nil
} }
// GetCertificateForCA returns the current server certificate. It can only be
// used if the renew function creates the new certificate and do not uses a TLS
// request. It's intended to be use by the certificate authority server.
//
// This method is set in the tls.Config GetCertificate property.
func (r *TLSRenewer) GetCertificateForCA(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return r.getCertificateForCA(), nil
}
// GetClientCertificate returns the current client certificate. // GetClientCertificate returns the current client certificate.
// //
// This method is set in the tls.Config GetClientCertificate property. // This method is set in the tls.Config GetClientCertificate property.
@ -109,6 +122,11 @@ func (r *TLSRenewer) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Cer
} }
// getCertificate returns the certificate using a read-only lock. // getCertificate returns the certificate using a read-only lock.
//
// Known issue: It cannot renew an expired certificate because the /renew
// endpoint requires a valid client certificate. The certificate can expire
// if the timer does not fire e.g. when the CA is run from a laptop that
// enters sleep mode.
func (r *TLSRenewer) getCertificate() *tls.Certificate { func (r *TLSRenewer) getCertificate() *tls.Certificate {
r.RLock() r.RLock()
cert := r.cert cert := r.cert
@ -116,10 +134,29 @@ func (r *TLSRenewer) getCertificate() *tls.Certificate {
return cert return cert
} }
// setCertificate updates the certificate using a read-write lock. // getCertificateForCA returns the certificate using a read-only lock. It will
// automatically renew the certificate if it has expired.
func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
r.RLock()
// Force certificate renewal if the timer didn't run.
// This is an special case that can happen after a computer sleep.
if time.Now().After(r.certNotAfter) {
r.RUnlock()
r.renewCertificate()
r.RLock()
}
cert := r.cert
r.RUnlock()
return cert
}
// setCertificate updates the certificate using a read-write lock. It also
// updates certNotAfter with 1m of delta; this will force the renewal of the
// certificate if it is about to expire.
func (r *TLSRenewer) setCertificate(cert *tls.Certificate) { func (r *TLSRenewer) setCertificate(cert *tls.Certificate) {
r.Lock() r.Lock()
r.cert = cert r.cert = cert
r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute)
r.Unlock() r.Unlock()
} }
@ -133,7 +170,9 @@ func (r *TLSRenewer) renewCertificate() {
r.setCertificate(cert) r.setCertificate(cert)
next = r.nextRenewDuration(cert.Leaf.NotAfter) next = r.nextRenewDuration(cert.Leaf.NotAfter)
} }
r.timer = time.AfterFunc(next, r.renewCertificate) r.Lock()
r.timer.Reset(next)
r.Unlock()
} }
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration { func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {