certificates/ca/renew.go

187 lines
5.1 KiB
Go
Raw Normal View History

2018-10-05 21:48:36 +00:00
package ca
import (
"context"
"crypto/tls"
"math/rand"
"sync"
"time"
"github.com/pkg/errors"
)
// RenewFunc defines the type of the functions used to get a new tls
// certificate.
type RenewFunc func() (*tls.Certificate, error)
2018-11-28 00:29:14 +00:00
// TLSRenewer automatically renews a tls certificate using a RenewFunc.
2018-10-05 21:48:36 +00:00
type TLSRenewer struct {
sync.RWMutex
RenewCertificate RenewFunc
cert *tls.Certificate
timer *time.Timer
renewBefore time.Duration
renewJitter time.Duration
2018-11-27 23:57:13 +00:00
certNotAfter time.Time
2018-10-05 21:48:36 +00:00
}
type tlsRenewerOptions func(r *TLSRenewer) error
// WithRenewBefore modifies a tlsRenewer by setting the renewBefore attribute.
func WithRenewBefore(b time.Duration) func(r *TLSRenewer) error {
return func(r *TLSRenewer) error {
r.renewBefore = b
return nil
}
}
// WithRenewJitter modifies a tlsRenewer by setting the renewJitter attribute.
func WithRenewJitter(j time.Duration) func(r *TLSRenewer) error {
return func(r *TLSRenewer) error {
r.renewJitter = j
return nil
}
}
// NewTLSRenewer creates a TLSRenewer for the given cert. It will use the given
2018-11-28 00:25:01 +00:00
// RenewFunc to get a new certificate when required.
2018-10-05 21:48:36 +00:00
func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOptions) (*TLSRenewer, error) {
r := &TLSRenewer{
RenewCertificate: fn,
cert: cert,
}
for _, f := range opts {
if err := f(r); err != nil {
return nil, errors.Wrap(err, "error applying options")
}
}
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore)
if period < time.Minute {
return nil, errors.Errorf("period must be greater than or equal to 1 Minute, but got %v.", period)
}
// By default we will try to renew the cert before 2/3 of the validity
// period have expired.
if r.renewBefore == 0 {
r.renewBefore = period / 3
}
// By default we set the jitter to 1/20th of the validity period.
if r.renewJitter == 0 {
r.renewJitter = period / 20
}
return r, nil
}
// Run starts the certificate renewer for the given certificate.
func (r *TLSRenewer) Run() {
cert := r.getCertificate()
next := r.nextRenewDuration(cert.Leaf.NotAfter)
r.timer = time.AfterFunc(next, r.renewCertificate)
}
// RunContext starts the certificate renewer for the given certificate.
func (r *TLSRenewer) RunContext(ctx context.Context) {
r.Run()
go func() {
<-ctx.Done()
r.Stop()
}()
}
// Stop prevents the renew timer from firing.
func (r *TLSRenewer) Stop() bool {
2018-11-27 23:57:13 +00:00
if r.timer != nil {
return r.timer.Stop()
}
return true
2018-10-05 21:48:36 +00:00
}
// GetCertificate returns the current server certificate.
//
// This method is set in the tls.Config GetCertificate property.
func (r *TLSRenewer) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return r.getCertificate(), nil
}
2018-11-27 23:57:13 +00:00
// GetCertificateForCA returns the current server certificate. It can only be
// used if the renew function creates the new certificate and do not uses a TLS
// request. It's intended to be use by the certificate authority server.
//
// This method is set in the tls.Config GetCertificate property.
func (r *TLSRenewer) GetCertificateForCA(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return r.getCertificateForCA(), nil
}
2018-10-05 21:48:36 +00:00
// GetClientCertificate returns the current client certificate.
//
// This method is set in the tls.Config GetClientCertificate property.
func (r *TLSRenewer) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return r.getCertificate(), nil
}
// getCertificate returns the certificate using a read-only lock.
2018-11-27 23:57:13 +00:00
//
2018-11-28 00:25:01 +00:00
// Known issue: It cannot renew an expired certificate because the /renew
// endpoint requires a valid client certificate. The certificate can expire
// if the timer does not fire e.g. when the CA is run from a laptop that
// enters sleep mode.
2018-10-05 21:48:36 +00:00
func (r *TLSRenewer) getCertificate() *tls.Certificate {
r.RLock()
cert := r.cert
r.RUnlock()
return cert
}
2018-11-27 23:57:13 +00:00
// getCertificateForCA returns the certificate using a read-only lock. It will
2018-11-28 00:25:01 +00:00
// automatically renew the certificate if it has expired.
2018-11-27 23:57:13 +00:00
func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
r.RLock()
// Force certificate renewal if the timer didn't run.
// This is an special case that can happen after a computer sleep.
if time.Now().After(r.certNotAfter) {
r.RUnlock()
r.renewCertificate()
r.RLock()
}
cert := r.cert
r.RUnlock()
return cert
}
// setCertificate updates the certificate using a read-write lock. It also
2018-11-28 00:25:01 +00:00
// updates certNotAfter with 1m of delta; this will force the renewal of the
// certificate if it is about to expire.
2018-10-05 21:48:36 +00:00
func (r *TLSRenewer) setCertificate(cert *tls.Certificate) {
r.Lock()
r.cert = cert
2018-11-27 23:57:13 +00:00
r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute)
2018-10-05 21:48:36 +00:00
r.Unlock()
}
func (r *TLSRenewer) renewCertificate() {
var next time.Duration
cert, err := r.RenewCertificate()
if err != nil {
next = r.renewJitter / 2
next += time.Duration(rand.Int63n(int64(next)))
} else {
r.setCertificate(cert)
next = r.nextRenewDuration(cert.Leaf.NotAfter)
}
2018-11-27 23:57:13 +00:00
r.Lock()
r.timer.Reset(next)
r.Unlock()
2018-10-05 21:48:36 +00:00
}
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {
d := notAfter.Sub(time.Now()) - r.renewBefore
n := rand.Int63n(int64(r.renewJitter))
d -= time.Duration(n)
if d < 0 {
d = 0
}
return d
}