Initiate default RootCAs/ClientCAs when no options are passed.

This commit is contained in:
Mariano Cano 2019-01-23 14:33:16 -08:00
parent 25eba1a96c
commit d394dd233a
3 changed files with 52 additions and 89 deletions

View file

@ -43,13 +43,9 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Build RootCAs with given root certificate
if pool := getCertPool(sign); pool != nil {
tlsConfig.RootCAs = pool
}
// Apply options if given
tlsCtx := newTLSOptionCtx(c, tlsConfig)
// Apply options and initialize mutable tls.Config
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
if err := tlsCtx.apply(options); err != nil {
return nil, nil, err
}
@ -92,16 +88,10 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig.GetCertificate = renewer.GetCertificate
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Build RootCAs with given root certificate
if pool := getCertPool(sign); pool != nil {
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
// Add RootCAs for refresh client
tlsConfig.RootCAs = pool
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
// Apply options if given
tlsCtx := newTLSOptionCtx(c, tlsConfig)
// Apply options and initialize mutable tls.Config
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
if err := tlsCtx.apply(options); err != nil {
return nil, err
}
@ -179,7 +169,7 @@ func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error)
// RootCertificate returns the root certificate from the sign response.
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
if sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
if sign == nil || sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
return nil, errors.New("ca: certificate does not exists")
}
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
@ -218,17 +208,6 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
return &cert, nil
}
// getCertPool returns the transport x509.CertPool or the one from the sign
// request.
func getCertPool(sign *api.SignResponse) *x509.CertPool {
if root, err := RootCertificate(sign); err == nil {
pool := x509.NewCertPool()
pool.AddCert(root)
return pool
}
return nil
}
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
if sign.TLSOptions != nil {
return sign.TLSOptions.TLSConfig()

View file

@ -3,6 +3,8 @@ package ca
import (
"crypto/tls"
"crypto/x509"
"github.com/smallstep/certificates/api"
)
// TLSOption defines the type of a function that modifies a tls.Config.
@ -12,15 +14,19 @@ type TLSOption func(ctx *TLSOptionCtx) error
type TLSOptionCtx struct {
Client *Client
Config *tls.Config
Sign *api.SignResponse
OnRenewFunc []TLSOption
mutableConfig *mutableTLSConfig
hasRootCA bool
hasClientCA bool
}
// newTLSOptionCtx creates the TLSOption context.
func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx {
func newTLSOptionCtx(c *Client, config *tls.Config, sign *api.SignResponse) *TLSOptionCtx {
return &TLSOptionCtx{
Client: c,
Config: config,
Sign: sign,
mutableConfig: newMutableTLSConfig(),
}
}
@ -34,6 +40,26 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
// Initialize mutable config with the fully configured tls.Config
ctx.mutableConfig.Init(ctx.Config)
// Build RootCAs and ClientCAs with given root certificate if necessary
if root, err := RootCertificate(ctx.Sign); err == nil {
if !ctx.hasRootCA {
if ctx.Config.RootCAs == nil {
ctx.Config.RootCAs = x509.NewCertPool()
}
ctx.Config.RootCAs.AddCert(root)
ctx.mutableConfig.AddInmutableRootCACert(root)
}
if !ctx.hasClientCA && ctx.Config.ClientAuth != tls.NoClientCert {
if ctx.Config.ClientCAs == nil {
ctx.Config.ClientCAs = x509.NewCertPool()
}
ctx.Config.ClientCAs.AddCert(root)
ctx.mutableConfig.AddInmutableClientCACert(root)
}
}
// Update tls.Config with mutable data
if ctx.Config.RootCAs == nil && len(ctx.mutableConfig.mutRootCerts) > 0 {
ctx.Config.RootCAs = x509.NewCertPool()
@ -41,6 +67,7 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
if ctx.Config.ClientCAs == nil && len(ctx.mutableConfig.mutClientCerts) > 0 {
ctx.Config.ClientCAs = x509.NewCertPool()
}
// Add mutable certificates
for _, cert := range ctx.mutableConfig.mutRootCerts {
ctx.Config.RootCAs.AddCert(cert)
}
@ -120,16 +147,8 @@ func AddRootsToRootCAs() TLSOption {
if err != nil {
return err
}
if ctx.mutableConfig == nil {
if ctx.Config.RootCAs == nil {
ctx.Config.RootCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.RootCAs.AddCert(cert.Certificate)
}
} else {
ctx.mutableConfig.AddRootCAs(certs.Certificates)
}
ctx.hasRootCA = true
ctx.mutableConfig.AddRootCAs(certs.Certificates)
return nil
}
return func(ctx *TLSOptionCtx) error {
@ -151,16 +170,8 @@ func AddRootsToClientCAs() TLSOption {
if err != nil {
return err
}
if ctx.mutableConfig == nil {
if ctx.Config.ClientCAs == nil {
ctx.Config.ClientCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.ClientCAs.AddCert(cert.Certificate)
}
} else {
ctx.mutableConfig.AddClientCAs(certs.Certificates)
}
ctx.hasClientCA = true
ctx.mutableConfig.AddClientCAs(certs.Certificates)
return nil
}
return func(ctx *TLSOptionCtx) error {
@ -178,16 +189,7 @@ func AddFederationToRootCAs() TLSOption {
if err != nil {
return err
}
if ctx.mutableConfig == nil {
if ctx.Config.RootCAs == nil {
ctx.Config.RootCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.RootCAs.AddCert(cert.Certificate)
}
} else {
ctx.mutableConfig.AddRootCAs(certs.Certificates)
}
ctx.mutableConfig.AddRootCAs(certs.Certificates)
return nil
}
return func(ctx *TLSOptionCtx) error {
@ -206,16 +208,7 @@ func AddFederationToClientCAs() TLSOption {
if err != nil {
return err
}
if ctx.mutableConfig == nil {
if ctx.Config.ClientCAs == nil {
ctx.Config.ClientCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.ClientCAs.AddCert(cert.Certificate)
}
} else {
ctx.mutableConfig.AddClientCAs(certs.Certificates)
}
ctx.mutableConfig.AddClientCAs(certs.Certificates)
return nil
}
return func(ctx *TLSOptionCtx) error {
@ -233,21 +226,10 @@ func AddRootsToCAs() TLSOption {
if err != nil {
return err
}
if ctx.mutableConfig == nil {
if ctx.Config.RootCAs == nil {
ctx.Config.RootCAs = x509.NewCertPool()
}
if ctx.Config.ClientCAs == nil {
ctx.Config.ClientCAs = x509.NewCertPool()
}
for _, cert := range certs.Certificates {
ctx.Config.RootCAs.AddCert(cert.Certificate)
ctx.Config.ClientCAs.AddCert(cert.Certificate)
}
} else {
ctx.mutableConfig.AddRootCAs(certs.Certificates)
ctx.mutableConfig.AddClientCAs(certs.Certificates)
}
ctx.hasRootCA = true
ctx.hasClientCA = true
ctx.mutableConfig.AddRootCAs(certs.Certificates)
ctx.mutableConfig.AddClientCAs(certs.Certificates)
return nil
}
return func(ctx *TLSOptionCtx) error {

View file

@ -9,6 +9,8 @@ import (
"reflect"
"sort"
"testing"
"github.com/smallstep/certificates/api"
)
func Test_newTLSOptionCtx(t *testing.T) {
@ -20,17 +22,18 @@ func Test_newTLSOptionCtx(t *testing.T) {
type args struct {
c *Client
config *tls.Config
sign *api.SignResponse
}
tests := []struct {
name string
args args
want *TLSOptionCtx
}{
{"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, mutableConfig: newMutableTLSConfig()}},
{"ok", args{client, &tls.Config{}, &api.SignResponse{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, Sign: &api.SignResponse{}, mutableConfig: newMutableTLSConfig()}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) {
if got := newTLSOptionCtx(tt.args.c, tt.args.config, tt.args.sign); !reflect.DeepEqual(got, tt.want) {
t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want)
}
})
@ -232,8 +235,7 @@ func TestAddRootsToRootCAs(t *testing.T) {
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(ctx.Config, tt.want) {
if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) {
t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want)
}
})
@ -287,7 +289,7 @@ func TestAddRootsToClientCAs(t *testing.T) {
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(ctx.Config, tt.want) {
if !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) {
t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want)
}
})
@ -469,7 +471,7 @@ func TestAddRootsToCAs(t *testing.T) {
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(ctx.Config, tt.want) {
if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) || !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) {
t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want)
}
})