diff --git a/ca/tls.go b/ca/tls.go index 415707db..494be574 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -43,13 +43,9 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.PreferServerCipherSuites = true - // Build RootCAs with given root certificate - if pool := getCertPool(sign); pool != nil { - tlsConfig.RootCAs = pool - } - // Apply options if given - tlsCtx := newTLSOptionCtx(c, tlsConfig) + // Apply options and initialize mutable tls.Config + tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) if err := tlsCtx.apply(options); err != nil { return nil, nil, err } @@ -92,16 +88,10 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, tlsConfig.GetCertificate = renewer.GetCertificate tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.PreferServerCipherSuites = true - // Build RootCAs with given root certificate - if pool := getCertPool(sign); pool != nil { - tlsConfig.ClientCAs = pool - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - // Add RootCAs for refresh client - tlsConfig.RootCAs = pool - } + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - // Apply options if given - tlsCtx := newTLSOptionCtx(c, tlsConfig) + // Apply options and initialize mutable tls.Config + tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) if err := tlsCtx.apply(options); err != nil { return nil, err } @@ -179,7 +169,7 @@ func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) // RootCertificate returns the root certificate from the sign response. func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) { - if sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 { + if sign == nil || sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 { return nil, errors.New("ca: certificate does not exists") } lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1] @@ -218,17 +208,6 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific return &cert, nil } -// getCertPool returns the transport x509.CertPool or the one from the sign -// request. -func getCertPool(sign *api.SignResponse) *x509.CertPool { - if root, err := RootCertificate(sign); err == nil { - pool := x509.NewCertPool() - pool.AddCert(root) - return pool - } - return nil -} - func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { if sign.TLSOptions != nil { return sign.TLSOptions.TLSConfig() diff --git a/ca/tls_options.go b/ca/tls_options.go index 17233b75..80c57d0e 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -3,6 +3,8 @@ package ca import ( "crypto/tls" "crypto/x509" + + "github.com/smallstep/certificates/api" ) // TLSOption defines the type of a function that modifies a tls.Config. @@ -12,15 +14,19 @@ type TLSOption func(ctx *TLSOptionCtx) error type TLSOptionCtx struct { Client *Client Config *tls.Config + Sign *api.SignResponse OnRenewFunc []TLSOption mutableConfig *mutableTLSConfig + hasRootCA bool + hasClientCA bool } // newTLSOptionCtx creates the TLSOption context. -func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx { +func newTLSOptionCtx(c *Client, config *tls.Config, sign *api.SignResponse) *TLSOptionCtx { return &TLSOptionCtx{ Client: c, Config: config, + Sign: sign, mutableConfig: newMutableTLSConfig(), } } @@ -34,6 +40,26 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error { // Initialize mutable config with the fully configured tls.Config ctx.mutableConfig.Init(ctx.Config) + + // Build RootCAs and ClientCAs with given root certificate if necessary + if root, err := RootCertificate(ctx.Sign); err == nil { + if !ctx.hasRootCA { + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + ctx.Config.RootCAs.AddCert(root) + ctx.mutableConfig.AddInmutableRootCACert(root) + } + + if !ctx.hasClientCA && ctx.Config.ClientAuth != tls.NoClientCert { + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + ctx.Config.ClientCAs.AddCert(root) + ctx.mutableConfig.AddInmutableClientCACert(root) + } + } + // Update tls.Config with mutable data if ctx.Config.RootCAs == nil && len(ctx.mutableConfig.mutRootCerts) > 0 { ctx.Config.RootCAs = x509.NewCertPool() @@ -41,6 +67,7 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error { if ctx.Config.ClientCAs == nil && len(ctx.mutableConfig.mutClientCerts) > 0 { ctx.Config.ClientCAs = x509.NewCertPool() } + // Add mutable certificates for _, cert := range ctx.mutableConfig.mutRootCerts { ctx.Config.RootCAs.AddCert(cert) } @@ -120,16 +147,8 @@ func AddRootsToRootCAs() TLSOption { if err != nil { return err } - if ctx.mutableConfig == nil { - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.RootCAs.AddCert(cert.Certificate) - } - } else { - ctx.mutableConfig.AddRootCAs(certs.Certificates) - } + ctx.hasRootCA = true + ctx.mutableConfig.AddRootCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { @@ -151,16 +170,8 @@ func AddRootsToClientCAs() TLSOption { if err != nil { return err } - if ctx.mutableConfig == nil { - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) - } - } else { - ctx.mutableConfig.AddClientCAs(certs.Certificates) - } + ctx.hasClientCA = true + ctx.mutableConfig.AddClientCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { @@ -178,16 +189,7 @@ func AddFederationToRootCAs() TLSOption { if err != nil { return err } - if ctx.mutableConfig == nil { - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.RootCAs.AddCert(cert.Certificate) - } - } else { - ctx.mutableConfig.AddRootCAs(certs.Certificates) - } + ctx.mutableConfig.AddRootCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { @@ -206,16 +208,7 @@ func AddFederationToClientCAs() TLSOption { if err != nil { return err } - if ctx.mutableConfig == nil { - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.ClientCAs.AddCert(cert.Certificate) - } - } else { - ctx.mutableConfig.AddClientCAs(certs.Certificates) - } + ctx.mutableConfig.AddClientCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { @@ -233,21 +226,10 @@ func AddRootsToCAs() TLSOption { if err != nil { return err } - if ctx.mutableConfig == nil { - if ctx.Config.RootCAs == nil { - ctx.Config.RootCAs = x509.NewCertPool() - } - if ctx.Config.ClientCAs == nil { - ctx.Config.ClientCAs = x509.NewCertPool() - } - for _, cert := range certs.Certificates { - ctx.Config.RootCAs.AddCert(cert.Certificate) - ctx.Config.ClientCAs.AddCert(cert.Certificate) - } - } else { - ctx.mutableConfig.AddRootCAs(certs.Certificates) - ctx.mutableConfig.AddClientCAs(certs.Certificates) - } + ctx.hasRootCA = true + ctx.hasClientCA = true + ctx.mutableConfig.AddRootCAs(certs.Certificates) + ctx.mutableConfig.AddClientCAs(certs.Certificates) return nil } return func(ctx *TLSOptionCtx) error { diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 181ed682..a422799e 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -9,6 +9,8 @@ import ( "reflect" "sort" "testing" + + "github.com/smallstep/certificates/api" ) func Test_newTLSOptionCtx(t *testing.T) { @@ -20,17 +22,18 @@ func Test_newTLSOptionCtx(t *testing.T) { type args struct { c *Client config *tls.Config + sign *api.SignResponse } tests := []struct { name string args args want *TLSOptionCtx }{ - {"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, mutableConfig: newMutableTLSConfig()}}, + {"ok", args{client, &tls.Config{}, &api.SignResponse{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, Sign: &api.SignResponse{}, mutableConfig: newMutableTLSConfig()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) { + if got := newTLSOptionCtx(tt.args.c, tt.args.config, tt.args.sign); !reflect.DeepEqual(got, tt.want) { t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want) } }) @@ -232,8 +235,7 @@ func TestAddRootsToRootCAs(t *testing.T) { t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) return } - - if !reflect.DeepEqual(ctx.Config, tt.want) { + if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) { t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want) } }) @@ -287,7 +289,7 @@ func TestAddRootsToClientCAs(t *testing.T) { t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(ctx.Config, tt.want) { + if !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) { t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want) } }) @@ -469,7 +471,7 @@ func TestAddRootsToCAs(t *testing.T) { t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(ctx.Config, tt.want) { + if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) || !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) { t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want) } })