forked from TrueCloudLab/certificates
Merge pull request #27 from smallstep/mariano/renew-pool
SDK should update certificate pools safely
This commit is contained in:
commit
262a9d0978
7 changed files with 415 additions and 126 deletions
|
@ -2,6 +2,8 @@ package ca
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -145,3 +147,54 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BootstrapListener is a helper function that using the given token returns a
|
||||||
|
// TLS listener which accepts connections from an inner listener and wraps each
|
||||||
|
// connection with Server.
|
||||||
|
//
|
||||||
|
// Without any extra option the server will be configured for mTLS, it will
|
||||||
|
// require and verify clients certificates, but options can be used to drop this
|
||||||
|
// requirement, the most common will be only verify the certs if given with
|
||||||
|
// ca.VerifyClientCertIfGiven(), or add extra CAs with
|
||||||
|
// ca.AddClientCA(*x509.Certificate).
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
// inner, err := net.Listen("tcp", ":443")
|
||||||
|
// if err != nil {
|
||||||
|
// return nil
|
||||||
|
// }
|
||||||
|
// ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
// defer cancel()
|
||||||
|
// lis, err := ca.BootstrapListener(ctx, token, inner)
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// srv := grpc.NewServer()
|
||||||
|
// ... // register services
|
||||||
|
// srv.Serve(lis)
|
||||||
|
func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) {
|
||||||
|
client, err := Bootstrap(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, pk, err := CreateSignRequest(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sign, err := client.Sign(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs
|
||||||
|
options = append(options, AddRootsToCAs())
|
||||||
|
|
||||||
|
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tls.NewListener(inner, tlsConfig), nil
|
||||||
|
}
|
||||||
|
|
|
@ -8,13 +8,13 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
|
|
||||||
"github.com/smallstep/cli/crypto/randutil"
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
stepJOSE "github.com/smallstep/cli/jose"
|
stepJOSE "github.com/smallstep/cli/jose"
|
||||||
jose "gopkg.in/square/go-jose.v2"
|
jose "gopkg.in/square/go-jose.v2"
|
||||||
|
@ -365,6 +365,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
|
||||||
|
|
||||||
// doTest does a request that requires mTLS
|
// doTest does a request that requires mTLS
|
||||||
doTest := func(client *http.Client) error {
|
doTest := func(client *http.Client) error {
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
// test with ca
|
// test with ca
|
||||||
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
|
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -532,3 +533,70 @@ func doReload(ca *CA) error {
|
||||||
newCA.srv.Addr = ca.srv.Addr
|
newCA.srv.Addr = ca.srv.Addr
|
||||||
return ca.srv.Reload(newCA.srv)
|
return ca.srv.Reload(newCA.srv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBootstrapListener(t *testing.T) {
|
||||||
|
srv := startCABootstrapServer()
|
||||||
|
defer srv.Close()
|
||||||
|
token := func() string {
|
||||||
|
return generateBootstrapToken(srv.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
token string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{token()}, false},
|
||||||
|
{"fail", args{"bad-token"}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
inner := newLocalListener()
|
||||||
|
defer inner.Close()
|
||||||
|
lis, err := BootstrapListener(context.Background(), tt.args.token, inner)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("BootstrapListener() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr {
|
||||||
|
if lis != nil {
|
||||||
|
t.Errorf("BootstrapListener() = %v, want nil", lis)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wg := new(sync.WaitGroup)
|
||||||
|
go func() {
|
||||||
|
wg.Add(1)
|
||||||
|
http.Serve(lis, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
defer wg.Wait()
|
||||||
|
defer lis.Close()
|
||||||
|
|
||||||
|
client, err := BootstrapClient(context.Background(), token())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("BootstrapClient() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := client.Get("https://" + lis.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("client.Get() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ioutil.ReadAll() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if string(b) != "ok" {
|
||||||
|
t.Errorf("client.Get() = %s, want ok", string(b))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
13
ca/client.go
13
ca/client.go
|
@ -12,7 +12,6 @@ import (
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -117,16 +116,10 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "error reading %s", filename)
|
return nil, errors.Wrapf(err, "error reading %s", filename)
|
||||||
}
|
}
|
||||||
block, _ := pem.Decode(data)
|
|
||||||
if block == nil {
|
|
||||||
return nil, errors.Errorf("error decoding %s", filename)
|
|
||||||
}
|
|
||||||
root, err := x509.ParseCertificate(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "error parsing %s", filename)
|
|
||||||
}
|
|
||||||
pool := x509.NewCertPool()
|
pool := x509.NewCertPool()
|
||||||
pool.AddCert(root)
|
if !pool.AppendCertsFromPEM(data) {
|
||||||
|
return nil, errors.Errorf("error parsing %s: no certificates found", filename)
|
||||||
|
}
|
||||||
return getDefaultTransport(&tls.Config{
|
return getDefaultTransport(&tls.Config{
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
PreferServerCipherSuites: true,
|
PreferServerCipherSuites: true,
|
||||||
|
|
109
ca/mutable_tls_config.go
Normal file
109
ca/mutable_tls_config.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package ca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mutableTLSConfig allows to use a tls.Config with mutable cert pools.
|
||||||
|
type mutableTLSConfig struct {
|
||||||
|
sync.RWMutex
|
||||||
|
config *tls.Config
|
||||||
|
clientCerts []*x509.Certificate
|
||||||
|
rootCerts []*x509.Certificate
|
||||||
|
mutClientCerts []*x509.Certificate
|
||||||
|
mutRootCerts []*x509.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
// newMutableTLSConfig creates a new mutableTLSConfig that will be later
|
||||||
|
// initialized with a tls.Config.
|
||||||
|
func newMutableTLSConfig() *mutableTLSConfig {
|
||||||
|
return &mutableTLSConfig{
|
||||||
|
clientCerts: []*x509.Certificate{},
|
||||||
|
rootCerts: []*x509.Certificate{},
|
||||||
|
mutClientCerts: []*x509.Certificate{},
|
||||||
|
mutRootCerts: []*x509.Certificate{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init initializes the mutable tls.Config with the given tls.Config.
|
||||||
|
func (c *mutableTLSConfig) Init(base *tls.Config) {
|
||||||
|
c.Lock()
|
||||||
|
c.config = base.Clone()
|
||||||
|
c.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSConfig returns the updated tls.Config it it has changed. It's used in the
|
||||||
|
// tls.Config GetConfigForClient.
|
||||||
|
func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) {
|
||||||
|
c.RLock()
|
||||||
|
config = c.config
|
||||||
|
c.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload reloads the tls.Config with the new CAs.
|
||||||
|
func (c *mutableTLSConfig) Reload() {
|
||||||
|
// Prepare new pools
|
||||||
|
c.RLock()
|
||||||
|
rootCAs := x509.NewCertPool()
|
||||||
|
clientCAs := x509.NewCertPool()
|
||||||
|
// Fixed certs
|
||||||
|
for _, cert := range c.rootCerts {
|
||||||
|
rootCAs.AddCert(cert)
|
||||||
|
}
|
||||||
|
for _, cert := range c.clientCerts {
|
||||||
|
clientCAs.AddCert(cert)
|
||||||
|
}
|
||||||
|
// Mutable certs
|
||||||
|
for _, cert := range c.mutRootCerts {
|
||||||
|
rootCAs.AddCert(cert)
|
||||||
|
}
|
||||||
|
for _, cert := range c.mutClientCerts {
|
||||||
|
clientCAs.AddCert(cert)
|
||||||
|
}
|
||||||
|
c.RUnlock()
|
||||||
|
|
||||||
|
// Set new pool
|
||||||
|
c.Lock()
|
||||||
|
c.config.RootCAs = rootCAs
|
||||||
|
c.config.ClientCAs = clientCAs
|
||||||
|
c.mutRootCerts = []*x509.Certificate{}
|
||||||
|
c.mutClientCerts = []*x509.Certificate{}
|
||||||
|
c.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddImmutableClientCACert add an immutable cert to ClientCAs.
|
||||||
|
func (c *mutableTLSConfig) AddImmutableClientCACert(cert *x509.Certificate) {
|
||||||
|
c.Lock()
|
||||||
|
c.clientCerts = append(c.clientCerts, cert)
|
||||||
|
c.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddImmutableRootCACert add an immutable cert to RootCas.
|
||||||
|
func (c *mutableTLSConfig) AddImmutableRootCACert(cert *x509.Certificate) {
|
||||||
|
c.Lock()
|
||||||
|
c.rootCerts = append(c.rootCerts, cert)
|
||||||
|
c.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddClientCAs add mutable certs to ClientCAs.
|
||||||
|
func (c *mutableTLSConfig) AddClientCAs(certs []api.Certificate) {
|
||||||
|
c.Lock()
|
||||||
|
for _, cert := range certs {
|
||||||
|
c.mutClientCerts = append(c.mutClientCerts, cert.Certificate)
|
||||||
|
}
|
||||||
|
c.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRootCAs add mutable certs to RootCAs.
|
||||||
|
func (c *mutableTLSConfig) AddRootCAs(certs []api.Certificate) {
|
||||||
|
c.Lock()
|
||||||
|
for _, cert := range certs {
|
||||||
|
c.mutRootCerts = append(c.mutRootCerts, cert.Certificate)
|
||||||
|
}
|
||||||
|
c.Unlock()
|
||||||
|
}
|
95
ca/tls.go
95
ca/tls.go
|
@ -21,13 +21,21 @@ import (
|
||||||
// sign certificate, and a new certificate pool with the sign root certificate.
|
// sign certificate, and a new certificate pool with the sign root certificate.
|
||||||
// The client certificate will automatically rotate before expiring.
|
// The client certificate will automatically rotate before expiring.
|
||||||
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
|
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
|
||||||
cert, err := TLSCertificate(sign, pk)
|
tlsConfig, _, err := c.getClientTLSConfig(ctx, sign, pk, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, *http.Transport, error) {
|
||||||
|
cert, err := TLSCertificate(sign, pk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
renewer, err := NewTLSRenewer(cert, nil)
|
renewer, err := NewTLSRenewer(cert, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig := getDefaultTLSConfig(sign)
|
tlsConfig := getDefaultTLSConfig(sign)
|
||||||
|
@ -35,22 +43,20 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
|
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
|
||||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
tlsConfig.PreferServerCipherSuites = true
|
||||||
// Build RootCAs with given root certificate
|
|
||||||
if pool := getCertPool(sign); pool != nil {
|
|
||||||
tlsConfig.RootCAs = pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options and initialize mutable tls.Config
|
||||||
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
|
||||||
if err := tlsCtx.apply(options); err != nil {
|
if err := tlsCtx.apply(options); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update renew function with transport
|
// Update renew function with transport
|
||||||
tr, err := getDefaultTransport(tlsConfig)
|
tr, err := getDefaultTransport(tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
// Use mutable tls.Config on renew
|
||||||
|
tr.DialTLS = c.buildDialTLS(tlsCtx)
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
// Update client transport
|
// Update client transport
|
||||||
|
@ -58,7 +64,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
|
|
||||||
// Start renewer
|
// Start renewer
|
||||||
renewer.RunContext(ctx)
|
renewer.RunContext(ctx)
|
||||||
return tlsConfig, nil
|
return tlsConfig, tr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServerTLSConfig returns a tls.Config for server use configured with the
|
// GetServerTLSConfig returns a tls.Config for server use configured with the
|
||||||
|
@ -82,25 +88,26 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
tlsConfig.GetCertificate = renewer.GetCertificate
|
tlsConfig.GetCertificate = renewer.GetCertificate
|
||||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
tlsConfig.PreferServerCipherSuites = true
|
||||||
// Build RootCAs with given root certificate
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
if pool := getCertPool(sign); pool != nil {
|
|
||||||
tlsConfig.ClientCAs = pool
|
|
||||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
|
||||||
// Add RootCAs for refresh client
|
|
||||||
tlsConfig.RootCAs = pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options and initialize mutable tls.Config
|
||||||
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
|
||||||
if err := tlsCtx.apply(options); err != nil {
|
if err := tlsCtx.apply(options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfigForClient allows seamless root and federated roots rotation.
|
||||||
|
// If the return of the callback is not-nil, it will use the returned
|
||||||
|
// tls.Config instead of the default one.
|
||||||
|
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
|
||||||
|
|
||||||
// Update renew function with transport
|
// Update renew function with transport
|
||||||
tr, err := getDefaultTransport(tlsConfig)
|
tr, err := getDefaultTransport(tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// Use mutable tls.Config on renew
|
||||||
|
tr.DialTLS = c.buildDialTLS(tlsCtx)
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
// Update client transport
|
// Update client transport
|
||||||
|
@ -113,17 +120,40 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
|
|
||||||
// Transport returns an http.Transport configured to use the client certificate from the sign response.
|
// Transport returns an http.Transport configured to use the client certificate from the sign response.
|
||||||
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
|
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
|
||||||
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk, options...)
|
_, tr, err := c.getClientTLSConfig(ctx, sign, pk, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return getDefaultTransport(tlsConfig)
|
return tr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildGetConfigForClient returns an implementation of GetConfigForClient
|
||||||
|
// callback in tls.Config.
|
||||||
|
//
|
||||||
|
// If the implementation returns a nil tls.Config, the original Config will be
|
||||||
|
// used, but if it's non-nil, the returned Config will be used to handle this
|
||||||
|
// connection.
|
||||||
|
func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return ctx.mutableConfig.TLSConfig(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
|
||||||
|
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
|
||||||
|
return func(network, addr string) (net.Conn, error) {
|
||||||
|
return tls.DialWithDialer(&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
DualStack: true,
|
||||||
|
}, network, addr, ctx.mutableConfig.TLSConfig())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Certificate returns the server or client certificate from the sign response.
|
// Certificate returns the server or client certificate from the sign response.
|
||||||
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||||||
if sign.ServerPEM.Certificate == nil {
|
if sign.ServerPEM.Certificate == nil {
|
||||||
return nil, errors.New("ca: certificate does not exists")
|
return nil, errors.New("ca: certificate does not exist")
|
||||||
}
|
}
|
||||||
return sign.ServerPEM.Certificate, nil
|
return sign.ServerPEM.Certificate, nil
|
||||||
}
|
}
|
||||||
|
@ -132,19 +162,19 @@ func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||||||
// response.
|
// response.
|
||||||
func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||||||
if sign.CaPEM.Certificate == nil {
|
if sign.CaPEM.Certificate == nil {
|
||||||
return nil, errors.New("ca: certificate does not exists")
|
return nil, errors.New("ca: certificate does not exist")
|
||||||
}
|
}
|
||||||
return sign.CaPEM.Certificate, nil
|
return sign.CaPEM.Certificate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RootCertificate returns the root certificate from the sign response.
|
// RootCertificate returns the root certificate from the sign response.
|
||||||
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||||||
if sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
|
if sign == nil || sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
|
||||||
return nil, errors.New("ca: certificate does not exists")
|
return nil, errors.New("ca: certificate does not exist")
|
||||||
}
|
}
|
||||||
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
|
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
|
||||||
if len(lastChain) == 0 {
|
if len(lastChain) == 0 {
|
||||||
return nil, errors.New("ca: certificate does not exists")
|
return nil, errors.New("ca: certificate does not exist")
|
||||||
}
|
}
|
||||||
return lastChain[len(lastChain)-1], nil
|
return lastChain[len(lastChain)-1], nil
|
||||||
}
|
}
|
||||||
|
@ -178,17 +208,6 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
|
||||||
return &cert, nil
|
return &cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCertPool returns the transport x509.CertPool or the one from the sign
|
|
||||||
// request.
|
|
||||||
func getCertPool(sign *api.SignResponse) *x509.CertPool {
|
|
||||||
if root, err := RootCertificate(sign); err == nil {
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
pool.AddCert(root)
|
|
||||||
return pool
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
||||||
if sign.TLSOptions != nil {
|
if sign.TLSOptions != nil {
|
||||||
return sign.TLSOptions.TLSConfig()
|
return sign.TLSOptions.TLSConfig()
|
||||||
|
|
|
@ -3,6 +3,8 @@ package ca
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||||
|
@ -10,16 +12,22 @@ type TLSOption func(ctx *TLSOptionCtx) error
|
||||||
|
|
||||||
// TLSOptionCtx is the context modified on TLSOption methods.
|
// TLSOptionCtx is the context modified on TLSOption methods.
|
||||||
type TLSOptionCtx struct {
|
type TLSOptionCtx struct {
|
||||||
Client *Client
|
Client *Client
|
||||||
Config *tls.Config
|
Config *tls.Config
|
||||||
OnRenewFunc []TLSOption
|
Sign *api.SignResponse
|
||||||
|
OnRenewFunc []TLSOption
|
||||||
|
mutableConfig *mutableTLSConfig
|
||||||
|
hasRootCA bool
|
||||||
|
hasClientCA bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTLSOptionCtx creates the TLSOption context.
|
// newTLSOptionCtx creates the TLSOption context.
|
||||||
func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx {
|
func newTLSOptionCtx(c *Client, config *tls.Config, sign *api.SignResponse) *TLSOptionCtx {
|
||||||
return &TLSOptionCtx{
|
return &TLSOptionCtx{
|
||||||
Client: c,
|
Client: c,
|
||||||
Config: config,
|
Config: config,
|
||||||
|
Sign: sign,
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,6 +37,44 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize mutable config with the fully configured tls.Config
|
||||||
|
ctx.mutableConfig.Init(ctx.Config)
|
||||||
|
|
||||||
|
// Build RootCAs and ClientCAs with given root certificate if necessary
|
||||||
|
if root, err := RootCertificate(ctx.Sign); err == nil {
|
||||||
|
if !ctx.hasRootCA {
|
||||||
|
if ctx.Config.RootCAs == nil {
|
||||||
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
ctx.Config.RootCAs.AddCert(root)
|
||||||
|
ctx.mutableConfig.AddImmutableRootCACert(root)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ctx.hasClientCA && ctx.Config.ClientAuth != tls.NoClientCert {
|
||||||
|
if ctx.Config.ClientCAs == nil {
|
||||||
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
ctx.Config.ClientCAs.AddCert(root)
|
||||||
|
ctx.mutableConfig.AddImmutableClientCACert(root)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update tls.Config with mutable data
|
||||||
|
if ctx.Config.RootCAs == nil && len(ctx.mutableConfig.mutRootCerts) > 0 {
|
||||||
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
if ctx.Config.ClientCAs == nil && len(ctx.mutableConfig.mutClientCerts) > 0 {
|
||||||
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
// Add mutable certificates
|
||||||
|
for _, cert := range ctx.mutableConfig.mutRootCerts {
|
||||||
|
ctx.Config.RootCAs.AddCert(cert)
|
||||||
|
}
|
||||||
|
for _, cert := range ctx.mutableConfig.mutClientCerts {
|
||||||
|
ctx.Config.ClientCAs.AddCert(cert)
|
||||||
|
}
|
||||||
|
ctx.mutableConfig.Reload()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,6 +84,8 @@ func (ctx *TLSOptionCtx) applyRenew() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Reload mutable config with the changes
|
||||||
|
ctx.mutableConfig.Reload()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,6 +116,7 @@ func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||||
ctx.Config.RootCAs = x509.NewCertPool()
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
ctx.Config.RootCAs.AddCert(cert)
|
ctx.Config.RootCAs.AddCert(cert)
|
||||||
|
ctx.mutableConfig.AddImmutableRootCACert(cert)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -81,6 +130,7 @@ func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
ctx.Config.ClientCAs.AddCert(cert)
|
ctx.Config.ClientCAs.AddCert(cert)
|
||||||
|
ctx.mutableConfig.AddImmutableClientCACert(cert)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -91,17 +141,14 @@ func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||||
//
|
//
|
||||||
// BootstrapServer and BootstrapClient methods include this option by default.
|
// BootstrapServer and BootstrapClient methods include this option by default.
|
||||||
func AddRootsToRootCAs() TLSOption {
|
func AddRootsToRootCAs() TLSOption {
|
||||||
|
// var once sync.Once
|
||||||
fn := func(ctx *TLSOptionCtx) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := ctx.Client.Roots()
|
certs, err := ctx.Client.Roots()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.Config.RootCAs == nil {
|
ctx.hasRootCA = true
|
||||||
ctx.Config.RootCAs = x509.NewCertPool()
|
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
@ -117,17 +164,14 @@ func AddRootsToRootCAs() TLSOption {
|
||||||
//
|
//
|
||||||
// BootstrapServer method includes this option by default.
|
// BootstrapServer method includes this option by default.
|
||||||
func AddRootsToClientCAs() TLSOption {
|
func AddRootsToClientCAs() TLSOption {
|
||||||
|
// var once sync.Once
|
||||||
fn := func(ctx *TLSOptionCtx) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := ctx.Client.Roots()
|
certs, err := ctx.Client.Roots()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.Config.ClientCAs == nil {
|
ctx.hasClientCA = true
|
||||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
@ -145,12 +189,7 @@ func AddFederationToRootCAs() TLSOption {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.Config.RootCAs == nil {
|
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||||
ctx.Config.RootCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
@ -169,12 +208,7 @@ func AddFederationToClientCAs() TLSOption {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.Config.ClientCAs == nil {
|
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
@ -192,16 +226,10 @@ func AddRootsToCAs() TLSOption {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.Config.ClientCAs == nil {
|
ctx.hasRootCA = true
|
||||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
ctx.hasClientCA = true
|
||||||
}
|
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||||
if ctx.Config.RootCAs == nil {
|
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||||
ctx.Config.RootCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
for _, cert := range certs.Certificates {
|
|
||||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
|
||||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return func(ctx *TLSOptionCtx) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
@ -219,15 +247,20 @@ func AddFederationToCAs() TLSOption {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if ctx.Config.ClientCAs == nil {
|
if ctx.mutableConfig == nil {
|
||||||
ctx.Config.ClientCAs = x509.NewCertPool()
|
if ctx.Config.RootCAs == nil {
|
||||||
}
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
if ctx.Config.RootCAs == nil {
|
}
|
||||||
ctx.Config.RootCAs = x509.NewCertPool()
|
if ctx.Config.ClientCAs == nil {
|
||||||
}
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
for _, cert := range certs.Certificates {
|
}
|
||||||
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
for _, cert := range certs.Certificates {
|
||||||
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||||
|
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.mutableConfig.AddRootCAs(certs.Certificates)
|
||||||
|
ctx.mutableConfig.AddClientCAs(certs.Certificates)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_newTLSOptionCtx(t *testing.T) {
|
func Test_newTLSOptionCtx(t *testing.T) {
|
||||||
|
@ -20,17 +22,18 @@ func Test_newTLSOptionCtx(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
c *Client
|
c *Client
|
||||||
config *tls.Config
|
config *tls.Config
|
||||||
|
sign *api.SignResponse
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
want *TLSOptionCtx
|
want *TLSOptionCtx
|
||||||
}{
|
}{
|
||||||
{"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}}},
|
{"ok", args{client, &tls.Config{}, &api.SignResponse{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}, Sign: &api.SignResponse{}, mutableConfig: newMutableTLSConfig()}},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) {
|
if got := newTLSOptionCtx(tt.args.c, tt.args.config, tt.args.sign); !reflect.DeepEqual(got, tt.want) {
|
||||||
t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want)
|
t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -63,7 +66,8 @@ func TestTLSOptionCtx_apply(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Config: tt.fields.Config,
|
Config: tt.fields.Config,
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr {
|
if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -82,7 +86,8 @@ func TestRequireAndVerifyClientCert(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Config: &tls.Config{},
|
Config: &tls.Config{},
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := RequireAndVerifyClientCert()(ctx); err != nil {
|
if err := RequireAndVerifyClientCert()(ctx); err != nil {
|
||||||
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
||||||
|
@ -105,7 +110,8 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Config: &tls.Config{},
|
Config: &tls.Config{},
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := VerifyClientCertIfGiven()(ctx); err != nil {
|
if err := VerifyClientCertIfGiven()(ctx); err != nil {
|
||||||
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
||||||
|
@ -136,7 +142,8 @@ func TestAddRootCA(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Config: &tls.Config{},
|
Config: &tls.Config{},
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := AddRootCA(tt.args.cert)(ctx); err != nil {
|
if err := AddRootCA(tt.args.cert)(ctx); err != nil {
|
||||||
t.Errorf("AddRootCA() error = %v", err)
|
t.Errorf("AddRootCA() error = %v", err)
|
||||||
|
@ -167,7 +174,8 @@ func TestAddClientCA(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Config: &tls.Config{},
|
Config: &tls.Config{},
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := AddClientCA(tt.args.cert)(ctx); err != nil {
|
if err := AddClientCA(tt.args.cert)(ctx); err != nil {
|
||||||
t.Errorf("AddClientCA() error = %v", err)
|
t.Errorf("AddClientCA() error = %v", err)
|
||||||
|
@ -219,14 +227,15 @@ func TestAddRootsToRootCAs(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: tt.args.client,
|
Client: tt.args.client,
|
||||||
Config: tt.args.config,
|
Config: tt.args.config,
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := ctx.apply([]TLSOption{AddRootsToRootCAs()}); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) {
|
||||||
t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want)
|
t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -272,14 +281,15 @@ func TestAddRootsToClientCAs(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: tt.args.client,
|
Client: tt.args.client,
|
||||||
Config: tt.args.config,
|
Config: tt.args.config,
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := ctx.apply([]TLSOption{AddRootsToClientCAs()}); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
if !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) {
|
||||||
t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want)
|
t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -332,10 +342,11 @@ func TestAddFederationToRootCAs(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: tt.args.client,
|
Client: tt.args.client,
|
||||||
Config: tt.args.config,
|
Config: tt.args.config,
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := ctx.apply([]TLSOption{AddFederationToRootCAs()}); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -395,10 +406,11 @@ func TestAddFederationToClientCAs(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: tt.args.client,
|
Client: tt.args.client,
|
||||||
Config: tt.args.config,
|
Config: tt.args.config,
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := ctx.apply([]TLSOption{AddFederationToClientCAs()}); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -451,14 +463,15 @@ func TestAddRootsToCAs(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: tt.args.client,
|
Client: tt.args.client,
|
||||||
Config: tt.args.config,
|
Config: tt.args.config,
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := ctx.apply([]TLSOption{AddRootsToCAs()}); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
if !reflect.DeepEqual(ctx.Config.RootCAs, tt.want.RootCAs) || !reflect.DeepEqual(ctx.Config.ClientCAs, tt.want.ClientCAs) {
|
||||||
t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want)
|
t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -511,10 +524,11 @@ func TestAddFederationToCAs(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: tt.args.client,
|
Client: tt.args.client,
|
||||||
Config: tt.args.config,
|
Config: tt.args.config,
|
||||||
|
mutableConfig: newMutableTLSConfig(),
|
||||||
}
|
}
|
||||||
if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := ctx.apply([]TLSOption{AddFederationToCAs()}); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue