Initiate default RootCAs/ClientCAs when no options are passed.
This commit is contained in:
parent
25eba1a96c
commit
d394dd233a
3 changed files with 52 additions and 89 deletions
31
ca/tls.go
31
ca/tls.go
|
@ -43,13 +43,9 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
|
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
|
||||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
tlsConfig.PreferServerCipherSuites = true
|
||||||
// Build RootCAs with given root certificate
|
|
||||||
if pool := getCertPool(sign); pool != nil {
|
|
||||||
tlsConfig.RootCAs = pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options and initialize mutable tls.Config
|
||||||
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
|
||||||
if err := tlsCtx.apply(options); err != nil {
|
if err := tlsCtx.apply(options); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -92,16 +88,10 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
tlsConfig.GetCertificate = renewer.GetCertificate
|
tlsConfig.GetCertificate = renewer.GetCertificate
|
||||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
tlsConfig.PreferServerCipherSuites = true
|
||||||
// Build RootCAs with given root certificate
|
|
||||||
if pool := getCertPool(sign); pool != nil {
|
|
||||||
tlsConfig.ClientCAs = pool
|
|
||||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
// Add RootCAs for refresh client
|
|
||||||
tlsConfig.RootCAs = pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options and initialize mutable tls.Config
|
||||||
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
|
||||||
if err := tlsCtx.apply(options); err != nil {
|
if err := tlsCtx.apply(options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -179,7 +169,7 @@ func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error)
|
||||||
|
|
||||||
// RootCertificate returns the root certificate from the sign response.
|
// RootCertificate returns the root certificate from the sign response.
|
||||||
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||||||
if sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
|
if sign == nil || sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
|
||||||
return nil, errors.New("ca: certificate does not exists")
|
return nil, errors.New("ca: certificate does not exists")
|
||||||
}
|
}
|
||||||
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
|
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
|
||||||
|
@ -218,17 +208,6 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
|
||||||
return &cert, nil
|
return &cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCertPool returns the transport x509.CertPool or the one from the sign
|
|
||||||
// request.
|
|
||||||
func getCertPool(sign *api.SignResponse) *x509.CertPool {
|
|
||||||
if root, err := RootCertificate(sign); err == nil {
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
pool.AddCert(root)
|
|
||||||
return pool
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
||||||
if sign.TLSOptions != nil {
|
if sign.TLSOptions != nil {
|
||||||
return sign.TLSOptions.TLSConfig()
|
return sign.TLSOptions.TLSConfig()
|
||||||
|
|
|
@ -3,6 +3,8 @@ package ca
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||||
|
@ -12,15 +14,19 @@ type TLSOption func(ctx *TLSOptionCtx) error
|
||||||
type TLSOptionCtx struct {
|
type TLSOptionCtx struct {
|
||||||
Client *Client
|
Client *Client
|
||||||
Config *tls.Config
|
Config *tls.Config
|
||||||
|
Sign *api.SignResponse
|
||||||
OnRenewFunc []TLSOption
|
OnRenewFunc []TLSOption
|
||||||
mutableConfig *mutableTLSConfig
|
mutableConfig *mutableTLSConfig
|
||||||
|
hasRootCA bool
|
||||||
|
hasClientCA bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTLSOptionCtx creates the TLSOption context.
|
// newTLSOptionCtx creates the TLSOption context.
|
||||||
func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx {
|
func newTLSOptionCtx(c *Client, config *tls.Config, sign *api.SignResponse) *TLSOptionCtx {
|
||||||
return &TLSOptionCtx{
|
return &TLSOptionCtx{
|
||||||
Client: c,
|
Client: c,
|
||||||
Config: config,
|
Config: config,
|
||||||
|
Sign: sign,
|
||||||
mutableConfig: newMutableTLSConfig(),
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,6 +40,26 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
||||||
|
|
||||||
// Initialize mutable config with the fully configured tls.Config
|
// Initialize mutable config with the fully configured tls.Config
|
||||||
ctx.mutableConfig.Init(ctx.Config)
|
ctx.mutableConfig.Init(ctx.Config)
|
||||||
|
|
||||||
|
// Build RootCAs and ClientCAs with given root certificate if necessary
|
||||||
|
if root, err := RootCertificate(ctx.Sign); err == nil {
|
||||||
|
if !ctx.hasRootCA {
|
||||||
|
if ctx.Config.RootCAs == nil {
|
||||||
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
ctx.Config.RootCAs.AddCert(root)
|
||||||
|
ctx.mutableConfig.AddInmutableRootCACert(root)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ctx.hasClientCA && ctx.Config.ClientAuth != tls.NoClientCert {
|
||||||
|
if ctx.Config.ClientCAs == nil {
|
||||||
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
ctx.Config.ClientCAs.AddCert(root)
|
||||||
|
ctx.mutableConfig.AddInmutableClientCACert(root)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Update tls.Config with mutable data
|
// Update tls.Config with mutable data
|
||||||
if ctx.Config.RootCAs == nil && len(ctx.mutableConfig.mutRootCerts) > 0 {
|
if ctx.Config.RootCAs == nil && len(ctx.mutableConfig.mutRootCerts) > 0 {
|
||||||
ctx.Config.RootCAs = x509.NewCertPool()
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
|
@ -41,6 +67,7 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
||||||
if ctx.Config.ClientCAs == nil && len(ctx.mutableConfig.mutClientCerts) > 0 {
|
if ctx.Config.ClientCAs == nil && len(ctx.mutableConfig.mutClientCerts) > 0 {
|
||||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
|
// Add mutable certificates
|
||||||
for _, cert := range ctx.mutableConfig.mutRootCerts {
|
for _, cert := range ctx.mutableConfig.mutRootCerts {
|
||||||
ctx.Config.RootCAs.AddCert(cert)
|
ctx.Config.RootCAs.AddCert(cert)
|
||||||
}
|
}
|
||||||
|
@ -120,16 +147,8 @@ func AddRootsToRootCAs() TLSOption {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.mutableConfig == nil {
|
ctx.hasRootCA = true
|
||||||
if ctx.Config.RootCAs == nil {
|
|
||||||
ctx.Config.RootCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
@ -151,16 +170,8 @@ func AddRootsToClientCAs() TLSOption {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.mutableConfig == nil {
|
ctx.hasClientCA = true
|
||||||
if ctx.Config.ClientCAs == nil {
|
|
||||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
@ -178,16 +189,7 @@ func AddFederationToRootCAs() TLSOption {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.mutableConfig == nil {
|
|
||||||
if ctx.Config.RootCAs == nil {
|
|
||||||
ctx.Config.RootCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
@ -206,16 +208,7 @@ func AddFederationToClientCAs() TLSOption {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.mutableConfig == nil {
|
|
||||||
if ctx.Config.ClientCAs == nil {
|
|
||||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
@ -233,21 +226,10 @@ func AddRootsToCAs() TLSOption {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.mutableConfig == nil {
|
ctx.hasRootCA = true
|
||||||
if ctx.Config.RootCAs == nil {
|
ctx.hasClientCA = true
|
||||||
ctx.Config.RootCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
if ctx.Config.ClientCAs == nil {
|
|
||||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
|
||||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||||
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_newTLSOptionCtx(t *testing.T) {
|
func Test_newTLSOptionCtx(t *testing.T) {
|
||||||
|
@ -20,17 +22,18 @@ func Test_newTLSOptionCtx(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
c *Client
|
c *Client
|
||||||
config *tls.Config
|
config *tls.Config
|
||||||
|
sign *api.SignResponse
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
want *TLSOptionCtx
|
want *TLSOptionCtx
|
||||||
}{
|
}{
|
||||||
{"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, mutableConfig: newMutableTLSConfig()}},
|
{"ok", args{client, &tls.Config{}, &api.SignResponse{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, Sign: &api.SignResponse{}, mutableConfig: newMutableTLSConfig()}},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) {
|
if got := newTLSOptionCtx(tt.args.c, tt.args.config, tt.args.sign); !reflect.DeepEqual(got, tt.want) {
|
||||||
t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want)
|
t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -232,8 +235,7 @@ func TestAddRootsToRootCAs(t *testing.T) {
|
||||||
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) {
|
||||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
|
||||||
t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want)
|
t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -287,7 +289,7 @@ func TestAddRootsToClientCAs(t *testing.T) {
|
||||||
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
if !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) {
|
||||||
t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want)
|
t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -469,7 +471,7 @@ func TestAddRootsToCAs(t *testing.T) {
|
||||||
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) || !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) {
|
||||||
t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want)
|
t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in a new issue