From 25eba1a96c77feeae18d81900b60d978c112bfc9 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 22 Jan 2019 19:54:12 -0800 Subject: [PATCH] WIP on the safely rotate of root and federated certificates. Fixes #23 --- ca/mutable_tls_config.go | 109 +++++++++++++++++++++++++++++++ ca/tls.go | 54 +++++++++++++-- ca/tls_options.go | 137 +++++++++++++++++++++++++++------------ ca/tls_options_test.go | 60 ++++++++++------- 4 files changed, 286 insertions(+), 74 deletions(-) create mode 100644 ca/mutable_tls_config.go diff --git a/ca/mutable_tls_config.go b/ca/mutable_tls_config.go new file mode 100644 index 00000000..7a564a3e --- /dev/null +++ b/ca/mutable_tls_config.go @@ -0,0 +1,109 @@ +package ca + +import ( + "crypto/tls" + "crypto/x509" + "sync" + + "github.com/smallstep/certificates/api" +) + +// mutableTLSConfig allows to use a tls.Config with mutable cert pools. +type mutableTLSConfig struct { + sync.RWMutex + config *tls.Config + clientCerts []*x509.Certificate + rootCerts []*x509.Certificate + mutClientCerts []*x509.Certificate + mutRootCerts []*x509.Certificate +} + +// newMutableTLSConfig creates a new mutableTLSConfig using the passed one as +// the base one. +func newMutableTLSConfig() *mutableTLSConfig { + return &mutableTLSConfig{ + clientCerts: []*x509.Certificate{}, + rootCerts: []*x509.Certificate{}, + mutClientCerts: []*x509.Certificate{}, + mutRootCerts: []*x509.Certificate{}, + } +} + +// Init initializes the mutable tls.Config with the given tls.Config. +func (c *mutableTLSConfig) Init(base *tls.Config) { + c.Lock() + c.config = base.Clone() + c.Unlock() +} + +// TLSConfig returns the updated tls.Config it it has changed. It's is used in +// the tls.Config GetConfigForClient. +func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) { + c.RLock() + config = c.config + c.RUnlock() + return +} + +// Reload reloads the tls.Config with the new CAs. +func (c *mutableTLSConfig) Reload() { + // Prepare new pools + c.RLock() + rootCAs := x509.NewCertPool() + clientCAs := x509.NewCertPool() + // Fixed certs + for _, cert := range c.rootCerts { + rootCAs.AddCert(cert) + } + for _, cert := range c.clientCerts { + clientCAs.AddCert(cert) + } + // Mutable certs + for _, cert := range c.mutRootCerts { + rootCAs.AddCert(cert) + } + for _, cert := range c.mutClientCerts { + clientCAs.AddCert(cert) + } + c.RUnlock() + + // Set new pool + c.Lock() + c.config.RootCAs = rootCAs + c.config.ClientCAs = clientCAs + c.mutRootCerts = []*x509.Certificate{} + c.mutClientCerts = []*x509.Certificate{} + c.Unlock() +} + +// AddFixedClientCACert add an in-mutable cert to ClientCAs. +func (c *mutableTLSConfig) AddInmutableClientCACert(cert *x509.Certificate) { + c.Lock() + c.clientCerts = append(c.clientCerts, cert) + c.Unlock() +} + +// AddInmutableRootCACert add an in-mutable cert to RootCas. +func (c *mutableTLSConfig) AddInmutableRootCACert(cert *x509.Certificate) { + c.Lock() + c.rootCerts = append(c.rootCerts, cert) + c.Unlock() +} + +// AddClientCAs add mutable certs to ClientCAs. +func (c *mutableTLSConfig) AddClientCAs(certs []api.Certificate) { + c.Lock() + for _, cert := range certs { + c.mutClientCerts = append(c.mutClientCerts, cert.Certificate) + } + c.Unlock() +} + +// AddRootCAs add mutable certs to RootCAs. +func (c *mutableTLSConfig) AddRootCAs(certs []api.Certificate) { + c.Lock() + for _, cert := range certs { + c.mutRootCerts = append(c.mutRootCerts, cert.Certificate) + } + c.Unlock() +} diff --git a/ca/tls.go b/ca/tls.go index 31d1632b..415707db 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -21,13 +21,21 @@ import ( // sign certificate, and a new certificate pool with the sign root certificate. // The client certificate will automatically rotate before expiring. func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) { - cert, err := TLSCertificate(sign, pk) + tlsConfig, _, err := c.getClientTLSConfig(ctx, sign, pk, options) if err != nil { return nil, err } + return tlsConfig, nil +} + +func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, *http.Transport, error) { + cert, err := TLSCertificate(sign, pk) + if err != nil { + return nil, nil, err + } renewer, err := NewTLSRenewer(cert, nil) if err != nil { - return nil, err + return nil, nil, err } tlsConfig := getDefaultTLSConfig(sign) @@ -43,14 +51,16 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, // Apply options if given tlsCtx := newTLSOptionCtx(c, tlsConfig) if err := tlsCtx.apply(options); err != nil { - return nil, err + return nil, nil, err } // Update renew function with transport tr, err := getDefaultTransport(tlsConfig) if err != nil { - return nil, err + return nil, nil, err } + // Use mutable tls.Config on renew + tr.DialTLS = c.buildDialTLS(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Update client transport @@ -58,7 +68,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, // Start renewer renewer.RunContext(ctx) - return tlsConfig, nil + return tlsConfig, tr, nil } // GetServerTLSConfig returns a tls.Config for server use configured with the @@ -96,11 +106,18 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, return nil, err } + // GetConfigForClient allows seamless root and federated roots rotation. + // If the return of the callback is not-nil, it will use the returned + // tls.Config instead of the default one. + tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx) + // Update renew function with transport tr, err := getDefaultTransport(tlsConfig) if err != nil { return nil, err } + // Use mutable tls.Config on renew + tr.DialTLS = c.buildDialTLS(tlsCtx) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Update client transport @@ -113,11 +130,34 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, // Transport returns an http.Transport configured to use the client certificate from the sign response. func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) { - tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk, options...) + _, tr, err := c.getClientTLSConfig(ctx, sign, pk, options) if err != nil { return nil, err } - return getDefaultTransport(tlsConfig) + return tr, nil +} + +// buildGetConfigForClient returns an implementation of GetConfigForClient +// callback in tls.Config. +// +// If the implementation returns a nil tls.Config, the original Config will be +// used, but if it's non-nil, the returned Config will be used to handle this +// connection. +func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHelloInfo) (*tls.Config, error) { + return func(*tls.ClientHelloInfo) (*tls.Config, error) { + return ctx.mutableConfig.TLSConfig(), nil + } +} + +// buildDialTLS returns an implementation of DialTLS callback in http.Transport. +func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) { + return func(network, addr string) (net.Conn, error) { + return tls.DialWithDialer(&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }, network, addr, ctx.mutableConfig.TLSConfig()) + } } // Certificate returns the server or client certificate from the sign response. diff --git a/ca/tls_options.go b/ca/tls_options.go index 47e2c627..17233b75 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -10,16 +10,18 @@ type TLSOption func(ctx *TLSOptionCtx) error // TLSOptionCtx is the context modified on TLSOption methods. type TLSOptionCtx struct { - Client *Client - Config *tls.Config - OnRenewFunc []TLSOption + Client *Client + Config *tls.Config + OnRenewFunc []TLSOption + mutableConfig *mutableTLSConfig } // newTLSOptionCtx creates the TLSOption context. func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx { return &TLSOptionCtx{ - Client: c, - Config: config, + Client: c, + Config: config, + mutableConfig: newMutableTLSConfig(), } } @@ -29,6 +31,23 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error { return err } } + + // Initialize mutable config with the fully configured tls.Config + ctx.mutableConfig.Init(ctx.Config) + // Update tls.Config with mutable data + if ctx.Config.RootCAs == nil && len(ctx.mutableConfig.mutRootCerts) > 0 { + ctx.Config.RootCAs = x509.NewCertPool() + } + if ctx.Config.ClientCAs == nil && len(ctx.mutableConfig.mutClientCerts) > 0 { + ctx.Config.ClientCAs = x509.NewCertPool() + } + for _, cert := range ctx.mutableConfig.mutRootCerts { + ctx.Config.RootCAs.AddCert(cert) + } + for _, cert := range ctx.mutableConfig.mutClientCerts { + ctx.Config.ClientCAs.AddCert(cert) + } + ctx.mutableConfig.Reload() return nil } @@ -38,6 +57,8 @@ func (ctx *TLSOptionCtx) applyRenew() error { return err } } + // Reload mutable config with the changes + ctx.mutableConfig.Reload() return nil } @@ -68,6 +89,7 @@ func AddRootCA(cert *x509.Certificate) TLSOption { ctx.Config.RootCAs = x509.NewCertPool() } ctx.Config.RootCAs.AddCert(cert) + ctx.mutableConfig.AddInmutableRootCACert(cert) return nil } } @@ -81,6 +103,7 @@ func AddClientCA(cert *x509.Certificate) TLSOption { ctx.Config.ClientCAs = x509.NewCertPool() } ctx.Config.ClientCAs.AddCert(cert) + ctx.mutableConfig.AddInmutableClientCACert(cert) return nil } } @@ -91,16 +114,21 @@ func AddClientCA(cert *x509.Certificate) TLSOption { // // BootstrapServer and BootstrapClient methods include this option by default. func AddRootsToRootCAs() TLSOption { + // var once sync.Once fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Roots() if err != nil { return err } - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.RootCAs.AddCert(cert.Certificate) + if ctx.mutableConfig == nil { + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.RootCAs.AddCert(cert.Certificate) + } + } else { + ctx.mutableConfig.AddRootCAs(certs.Certificates) } return nil } @@ -117,16 +145,21 @@ func AddRootsToRootCAs() TLSOption { // // BootstrapServer method includes this option by default. func AddRootsToClientCAs() TLSOption { + // var once sync.Once fn := func(ctx *TLSOptionCtx) error { certs, err := ctx.Client.Roots() if err != nil { return err } - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) + if ctx.mutableConfig == nil { + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.ClientCAs.AddCert(cert.Certificate) + } + } else { + ctx.mutableConfig.AddClientCAs(certs.Certificates) } return nil } @@ -145,11 +178,15 @@ func AddFederationToRootCAs() TLSOption { if err != nil { return err } - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.RootCAs.AddCert(cert.Certificate) + if ctx.mutableConfig == nil { + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.RootCAs.AddCert(cert.Certificate) + } + } else { + ctx.mutableConfig.AddRootCAs(certs.Certificates) } return nil } @@ -169,11 +206,15 @@ func AddFederationToClientCAs() TLSOption { if err != nil { return err } - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) + if ctx.mutableConfig == nil { + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.ClientCAs.AddCert(cert.Certificate) + } + } else { + ctx.mutableConfig.AddClientCAs(certs.Certificates) } return nil } @@ -192,15 +233,20 @@ func AddRootsToCAs() TLSOption { if err != nil { return err } - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) - ctx.Config.RootCAs.AddCert(cert.Certificate) + if ctx.mutableConfig == nil { + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.RootCAs.AddCert(cert.Certificate) + ctx.Config.ClientCAs.AddCert(cert.Certificate) + } + } else { + ctx.mutableConfig.AddRootCAs(certs.Certificates) + ctx.mutableConfig.AddClientCAs(certs.Certificates) } return nil } @@ -219,15 +265,20 @@ func AddFederationToCAs() TLSOption { if err != nil { return err } - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) - ctx.Config.RootCAs.AddCert(cert.Certificate) + if ctx.mutableConfig == nil { + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.RootCAs.AddCert(cert.Certificate) + ctx.Config.ClientCAs.AddCert(cert.Certificate) + } + } else { + ctx.mutableConfig.AddRootCAs(certs.Certificates) + ctx.mutableConfig.AddClientCAs(certs.Certificates) } return nil } diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index ceeea7dc..181ed682 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -26,7 +26,7 @@ func Test_newTLSOptionCtx(t *testing.T) { args args want *TLSOptionCtx }{ - {"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}}}, + {"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, mutableConfig: newMutableTLSConfig()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -63,7 +63,8 @@ func TestTLSOptionCtx_apply(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: tt.fields.Config, + Config: tt.fields.Config, + mutableConfig: newMutableTLSConfig(), } if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr { t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr) @@ -82,7 +83,8 @@ func TestRequireAndVerifyClientCert(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: &tls.Config{}, + Config: &tls.Config{}, + mutableConfig: newMutableTLSConfig(), } if err := RequireAndVerifyClientCert()(ctx); err != nil { t.Errorf("RequireAndVerifyClientCert() error = %v", err) @@ -105,7 +107,8 @@ func TestVerifyClientCertIfGiven(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: &tls.Config{}, + Config: &tls.Config{}, + mutableConfig: newMutableTLSConfig(), } if err := VerifyClientCertIfGiven()(ctx); err != nil { t.Errorf("VerifyClientCertIfGiven() error = %v", err) @@ -136,7 +139,8 @@ func TestAddRootCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: &tls.Config{}, + Config: &tls.Config{}, + mutableConfig: newMutableTLSConfig(), } if err := AddRootCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddRootCA() error = %v", err) @@ -167,7 +171,8 @@ func TestAddClientCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Config: &tls.Config{}, + Config: &tls.Config{}, + mutableConfig: newMutableTLSConfig(), } if err := AddClientCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddClientCA() error = %v", err) @@ -219,13 +224,15 @@ func TestAddRootsToRootCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddRootsToRootCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) return } + if !reflect.DeepEqual(ctx.Config, tt.want) { t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want) } @@ -272,10 +279,11 @@ func TestAddRootsToClientCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddRootsToClientCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) return } @@ -332,10 +340,11 @@ func TestAddFederationToRootCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddFederationToRootCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr) return } @@ -395,10 +404,11 @@ func TestAddFederationToClientCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddFederationToClientCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr) return } @@ -451,10 +461,11 @@ func TestAddRootsToCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddRootsToCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr) return } @@ -511,10 +522,11 @@ func TestAddFederationToCAs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: tt.args.client, - Config: tt.args.config, + Client: tt.args.client, + Config: tt.args.config, + mutableConfig: newMutableTLSConfig(), } - if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr { + if err := ctx.apply([]TLSOption{AddFederationToCAs()}); (err != nil) != tt.wantErr { t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr) return }