Initiate default RootCAs/ClientCAs when no options are passed.
This commit is contained in:
parent
25eba1a96c
commit
d394dd233a
3 changed files with 52 additions and 89 deletions
33
ca/tls.go
33
ca/tls.go
|
@ -43,13 +43,9 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
|
||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||
tlsConfig.PreferServerCipherSuites = true
|
||||
// Build RootCAs with given root certificate
|
||||
if pool := getCertPool(sign); pool != nil {
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
|
||||
// Apply options if given
|
||||
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
||||
// Apply options and initialize mutable tls.Config
|
||||
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
|
||||
if err := tlsCtx.apply(options); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -92,16 +88,10 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
tlsConfig.GetCertificate = renewer.GetCertificate
|
||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||
tlsConfig.PreferServerCipherSuites = true
|
||||
// Build RootCAs with given root certificate
|
||||
if pool := getCertPool(sign); pool != nil {
|
||||
tlsConfig.ClientCAs = pool
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
// Add RootCAs for refresh client
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
|
||||
// Apply options if given
|
||||
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
||||
// Apply options and initialize mutable tls.Config
|
||||
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
|
||||
if err := tlsCtx.apply(options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -179,7 +169,7 @@ func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error)
|
|||
|
||||
// RootCertificate returns the root certificate from the sign response.
|
||||
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||||
if sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
|
||||
if sign == nil || sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
|
||||
return nil, errors.New("ca: certificate does not exists")
|
||||
}
|
||||
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
|
||||
|
@ -218,17 +208,6 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
|
|||
return &cert, nil
|
||||
}
|
||||
|
||||
// getCertPool returns the transport x509.CertPool or the one from the sign
|
||||
// request.
|
||||
func getCertPool(sign *api.SignResponse) *x509.CertPool {
|
||||
if root, err := RootCertificate(sign); err == nil {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(root)
|
||||
return pool
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
||||
if sign.TLSOptions != nil {
|
||||
return sign.TLSOptions.TLSConfig()
|
||||
|
|
|
@ -3,6 +3,8 @@ package ca
|
|||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
)
|
||||
|
||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||
|
@ -12,15 +14,19 @@ type TLSOption func(ctx *TLSOptionCtx) error
|
|||
type TLSOptionCtx struct {
|
||||
Client *Client
|
||||
Config *tls.Config
|
||||
Sign *api.SignResponse
|
||||
OnRenewFunc []TLSOption
|
||||
mutableConfig *mutableTLSConfig
|
||||
hasRootCA bool
|
||||
hasClientCA bool
|
||||
}
|
||||
|
||||
// newTLSOptionCtx creates the TLSOption context.
|
||||
func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx {
|
||||
func newTLSOptionCtx(c *Client, config *tls.Config, sign *api.SignResponse) *TLSOptionCtx {
|
||||
return &TLSOptionCtx{
|
||||
Client: c,
|
||||
Config: config,
|
||||
Sign: sign,
|
||||
mutableConfig: newMutableTLSConfig(),
|
||||
}
|
||||
}
|
||||
|
@ -34,6 +40,26 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
|||
|
||||
// Initialize mutable config with the fully configured tls.Config
|
||||
ctx.mutableConfig.Init(ctx.Config)
|
||||
|
||||
// Build RootCAs and ClientCAs with given root certificate if necessary
|
||||
if root, err := RootCertificate(ctx.Sign); err == nil {
|
||||
if !ctx.hasRootCA {
|
||||
if ctx.Config.RootCAs == nil {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
ctx.Config.RootCAs.AddCert(root)
|
||||
ctx.mutableConfig.AddInmutableRootCACert(root)
|
||||
}
|
||||
|
||||
if !ctx.hasClientCA && ctx.Config.ClientAuth != tls.NoClientCert {
|
||||
if ctx.Config.ClientCAs == nil {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
ctx.Config.ClientCAs.AddCert(root)
|
||||
ctx.mutableConfig.AddInmutableClientCACert(root)
|
||||
}
|
||||
}
|
||||
|
||||
// Update tls.Config with mutable data
|
||||
if ctx.Config.RootCAs == nil && len(ctx.mutableConfig.mutRootCerts) > 0 {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
|
@ -41,6 +67,7 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
|||
if ctx.Config.ClientCAs == nil && len(ctx.mutableConfig.mutClientCerts) > 0 {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
// Add mutable certificates
|
||||
for _, cert := range ctx.mutableConfig.mutRootCerts {
|
||||
ctx.Config.RootCAs.AddCert(cert)
|
||||
}
|
||||
|
@ -120,16 +147,8 @@ func AddRootsToRootCAs() TLSOption {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.mutableConfig == nil {
|
||||
if ctx.Config.RootCAs == nil {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
} else {
|
||||
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||
}
|
||||
ctx.hasRootCA = true
|
||||
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
|
@ -151,16 +170,8 @@ func AddRootsToClientCAs() TLSOption {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.mutableConfig == nil {
|
||||
if ctx.Config.ClientCAs == nil {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
} else {
|
||||
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||
}
|
||||
ctx.hasClientCA = true
|
||||
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
|
@ -178,16 +189,7 @@ func AddFederationToRootCAs() TLSOption {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.mutableConfig == nil {
|
||||
if ctx.Config.RootCAs == nil {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
} else {
|
||||
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||
}
|
||||
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
|
@ -206,16 +208,7 @@ func AddFederationToClientCAs() TLSOption {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.mutableConfig == nil {
|
||||
if ctx.Config.ClientCAs == nil {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
} else {
|
||||
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||
}
|
||||
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
|
@ -233,21 +226,10 @@ func AddRootsToCAs() TLSOption {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.mutableConfig == nil {
|
||||
if ctx.Config.RootCAs == nil {
|
||||
ctx.Config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
if ctx.Config.ClientCAs == nil {
|
||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
} else {
|
||||
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||
}
|
||||
ctx.hasRootCA = true
|
||||
ctx.hasClientCA = true
|
||||
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||
return nil
|
||||
}
|
||||
return func(ctx *TLSOptionCtx) error {
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
)
|
||||
|
||||
func Test_newTLSOptionCtx(t *testing.T) {
|
||||
|
@ -20,17 +22,18 @@ func Test_newTLSOptionCtx(t *testing.T) {
|
|||
type args struct {
|
||||
c *Client
|
||||
config *tls.Config
|
||||
sign *api.SignResponse
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *TLSOptionCtx
|
||||
}{
|
||||
{"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, mutableConfig: newMutableTLSConfig()}},
|
||||
{"ok", args{client, &tls.Config{}, &api.SignResponse{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, Sign: &api.SignResponse{}, mutableConfig: newMutableTLSConfig()}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) {
|
||||
if got := newTLSOptionCtx(tt.args.c, tt.args.config, tt.args.sign); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
|
@ -232,8 +235,7 @@ func TestAddRootsToRootCAs(t *testing.T) {
|
|||
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) {
|
||||
t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
|
@ -287,7 +289,7 @@ func TestAddRootsToClientCAs(t *testing.T) {
|
|||
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
if !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) {
|
||||
t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
|
@ -469,7 +471,7 @@ func TestAddRootsToCAs(t *testing.T) {
|
|||
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||
if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) || !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) {
|
||||
t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want)
|
||||
}
|
||||
})
|
||||
|
|
Loading…
Reference in a new issue