Use mTLS by default on SDK methods.
Add options to modify the tls.Config for different configurations. Fixes #7
This commit is contained in:
parent
bb03aadddf
commit
d872f09910
9 changed files with 417 additions and 419 deletions
|
@ -43,6 +43,12 @@ func Bootstrap(token string) (*Client, error) {
|
||||||
// Authority. By default the server will kick off a routine that will renew the
|
// Authority. By default the server will kick off a routine that will renew the
|
||||||
// certificate after 2/3rd of the certificate's lifetime has expired.
|
// certificate after 2/3rd of the certificate's lifetime has expired.
|
||||||
//
|
//
|
||||||
|
// Without any extra option the server will be configured for mTLS, it will
|
||||||
|
// require and verify clients certificates, but options can be used to drop this
|
||||||
|
// requirement, the most common will be only verify the certs if given with
|
||||||
|
// ca.VerifyClientCertIfGiven(), or add extra CAs with
|
||||||
|
// ca.AddClientCA(*x509.Certificate).
|
||||||
|
//
|
||||||
// Usage:
|
// Usage:
|
||||||
// // Default example with certificate rotation.
|
// // Default example with certificate rotation.
|
||||||
// srv, err := ca.BootstrapServer(context.Background(), token, &http.Server{
|
// srv, err := ca.BootstrapServer(context.Background(), token, &http.Server{
|
||||||
|
@ -61,7 +67,7 @@ func Bootstrap(token string) (*Client, error) {
|
||||||
// return err
|
// return err
|
||||||
// }
|
// }
|
||||||
// srv.ListenAndServeTLS("", "")
|
// srv.ListenAndServeTLS("", "")
|
||||||
func BootstrapServer(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
|
func BootstrapServer(ctx context.Context, token string, base *http.Server, options ...TLSOption) (*http.Server, error) {
|
||||||
if base.TLSConfig != nil {
|
if base.TLSConfig != nil {
|
||||||
return nil, errors.New("server TLSConfig is already set")
|
return nil, errors.New("server TLSConfig is already set")
|
||||||
}
|
}
|
||||||
|
@ -81,60 +87,7 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server) (*htt
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk)
|
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
base.TLSConfig = tlsConfig
|
|
||||||
return base, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BootstrapServerWithMTLS is a helper function that using the given token
|
|
||||||
// returns the given http.Server configured with a TLS certificate signed by the
|
|
||||||
// Certificate Authority, this server will always require and verify a client
|
|
||||||
// certificate. By default the server will kick off a routine that will renew
|
|
||||||
// the certificate after 2/3rd of the certificate's lifetime has expired.
|
|
||||||
//
|
|
||||||
// Usage:
|
|
||||||
// // Default example with certificate rotation.
|
|
||||||
// srv, err := ca.BootstrapServerWithMTLS(context.Background(), token, &http.Server{
|
|
||||||
// Addr: ":443",
|
|
||||||
// Handler: handler,
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// // Example canceling automatic certificate rotation.
|
|
||||||
// ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
// defer cancel()
|
|
||||||
// srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
|
||||||
// Addr: ":443",
|
|
||||||
// Handler: handler,
|
|
||||||
// })
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// srv.ListenAndServeTLS("", "")
|
|
||||||
func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
|
|
||||||
if base.TLSConfig != nil {
|
|
||||||
return nil, errors.New("server TLSConfig is already set")
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := Bootstrap(token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, pk, err := CreateSignRequest(token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sign, err := client.Sign(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig, err := client.GetServerMutualTLSConfig(ctx, sign, pk)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -161,7 +114,7 @@ func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Serve
|
||||||
// return err
|
// return err
|
||||||
// }
|
// }
|
||||||
// resp, err := client.Get("https://internal.smallstep.com")
|
// resp, err := client.Get("https://internal.smallstep.com")
|
||||||
func BootstrapClient(ctx context.Context, token string) (*http.Client, error) {
|
func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) {
|
||||||
client, err := Bootstrap(token)
|
client, err := Bootstrap(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -177,7 +130,7 @@ func BootstrapClient(ctx context.Context, token string) (*http.Client, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
transport, err := client.Transport(ctx, sign, pk)
|
transport, err := client.Transport(ctx, sign, pk, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -124,7 +124,7 @@ func TestBootstrap(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBootstrapServer(t *testing.T) {
|
func TestBootstrapServerWithoutMTLS(t *testing.T) {
|
||||||
srv := startCABootstrapServer()
|
srv := startCABootstrapServer()
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
token := func() string {
|
token := func() string {
|
||||||
|
@ -146,7 +146,7 @@ func TestBootstrapServer(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base)
|
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base, VerifyClientCertIfGiven())
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -192,24 +192,24 @@ func TestBootstrapServerWithMTLS(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := BootstrapServerWithMTLS(tt.args.ctx, tt.args.token, tt.args.base)
|
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("BootstrapServerWithMTLS() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
if got != nil {
|
if got != nil {
|
||||||
t.Errorf("BootstrapServerWithMTLS() = %v, want nil", got)
|
t.Errorf("BootstrapServer() = %v, want nil", got)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
expected := &http.Server{
|
expected := &http.Server{
|
||||||
TLSConfig: got.TLSConfig,
|
TLSConfig: got.TLSConfig,
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, expected) {
|
if !reflect.DeepEqual(got, expected) {
|
||||||
t.Errorf("BootstrapServerWithMTLS() = %v, want %v", got, expected)
|
t.Errorf("BootstrapServer() = %v, want %v", got, expected)
|
||||||
}
|
}
|
||||||
if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
|
if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
|
||||||
t.Errorf("BootstrapServerWithMTLS() invalid TLSConfig = %#v", got.TLSConfig)
|
t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
64
ca/tls.go
64
ca/tls.go
|
@ -20,7 +20,7 @@ import (
|
||||||
// GetClientTLSConfig returns a tls.Config for client use configured with the
|
// GetClientTLSConfig returns a tls.Config for client use configured with the
|
||||||
// sign certificate, and a new certificate pool with the sign root certificate.
|
// sign certificate, and a new certificate pool with the sign root certificate.
|
||||||
// The client certificate will automatically rotate before expiring.
|
// The client certificate will automatically rotate before expiring.
|
||||||
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
|
||||||
cert, err := TLSCertificate(sign, pk)
|
cert, err := TLSCertificate(sign, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -36,10 +36,15 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
tlsConfig.PreferServerCipherSuites = true
|
||||||
// Build RootCAs with given root certificate
|
// Build RootCAs with given root certificate
|
||||||
if pool := c.getCertPool(sign); pool != nil {
|
if pool := getCertPool(sign); pool != nil {
|
||||||
tlsConfig.RootCAs = pool
|
tlsConfig.RootCAs = pool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply options if given
|
||||||
|
if err := setTLSOptions(tlsConfig, options); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Update renew function with transport
|
// Update renew function with transport
|
||||||
tr, err := getDefaultTransport(tlsConfig)
|
tr, err := getDefaultTransport(tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,7 +61,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
// sign certificate, and a new certificate pool with the sign root certificate.
|
// sign certificate, and a new certificate pool with the sign root certificate.
|
||||||
// The returned tls.Config will only verify the client certificate if provided.
|
// The returned tls.Config will only verify the client certificate if provided.
|
||||||
// The server certificate will automatically rotate before expiring.
|
// The server certificate will automatically rotate before expiring.
|
||||||
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
|
||||||
cert, err := TLSCertificate(sign, pk)
|
cert, err := TLSCertificate(sign, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -74,13 +79,18 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
tlsConfig.PreferServerCipherSuites = true
|
||||||
// Build RootCAs with given root certificate
|
// Build RootCAs with given root certificate
|
||||||
if pool := c.getCertPool(sign); pool != nil {
|
if pool := getCertPool(sign); pool != nil {
|
||||||
tlsConfig.ClientCAs = pool
|
tlsConfig.ClientCAs = pool
|
||||||
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
// Add RootCAs for refresh client
|
// Add RootCAs for refresh client
|
||||||
tlsConfig.RootCAs = pool
|
tlsConfig.RootCAs = pool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply options if given
|
||||||
|
if err := setTLSOptions(tlsConfig, options); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Update renew function with transport
|
// Update renew function with transport
|
||||||
tr, err := getDefaultTransport(tlsConfig)
|
tr, err := getDefaultTransport(tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -93,44 +103,15 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
return tlsConfig, nil
|
return tlsConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServerMutualTLSConfig returns a tls.Config for server use configured with
|
|
||||||
// the sign certificate, and a new certificate pool with the sign root certificate.
|
|
||||||
// The returned tls.Config will always require and verify a client certificate.
|
|
||||||
// The server certificate will automatically rotate before expiring.
|
|
||||||
func (c *Client) GetServerMutualTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
|
||||||
tlsConfig, err := c.GetServerTLSConfig(ctx, sign, pk)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
|
||||||
return tlsConfig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transport returns an http.Transport configured to use the client certificate from the sign response.
|
// Transport returns an http.Transport configured to use the client certificate from the sign response.
|
||||||
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
|
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
|
||||||
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)
|
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return getDefaultTransport(tlsConfig)
|
return getDefaultTransport(tlsConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCertPool returns the transport x509.CertPool or the one from the sign
|
|
||||||
// request.
|
|
||||||
func (c *Client) getCertPool(sign *api.SignResponse) *x509.CertPool {
|
|
||||||
// Return the transport certPool
|
|
||||||
if c.certPool != nil {
|
|
||||||
return c.certPool
|
|
||||||
}
|
|
||||||
// Return certificate used in sign request.
|
|
||||||
if root, err := RootCertificate(sign); err == nil {
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
pool.AddCert(root)
|
|
||||||
return pool
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Certificate returns the server or client certificate from the sign response.
|
// Certificate returns the server or client certificate from the sign response.
|
||||||
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||||||
if sign.ServerPEM.Certificate == nil {
|
if sign.ServerPEM.Certificate == nil {
|
||||||
|
@ -189,6 +170,17 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
|
||||||
return &cert, nil
|
return &cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getCertPool returns the transport x509.CertPool or the one from the sign
|
||||||
|
// request.
|
||||||
|
func getCertPool(sign *api.SignResponse) *x509.CertPool {
|
||||||
|
if root, err := RootCertificate(sign); err == nil {
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(root)
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
||||||
if sign.TLSOptions != nil {
|
if sign.TLSOptions != nil {
|
||||||
return sign.TLSOptions.TLSConfig()
|
return sign.TLSOptions.TLSConfig()
|
||||||
|
|
64
ca/tls_options.go
Normal file
64
ca/tls_options.go
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
package ca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||||
|
type TLSOption func(c *tls.Config) error
|
||||||
|
|
||||||
|
// setTLSOptions takes one or more option function and applies them in order to
|
||||||
|
// a tls.Config.
|
||||||
|
func setTLSOptions(c *tls.Config, options []TLSOption) error {
|
||||||
|
for _, opt := range options {
|
||||||
|
if err := opt(c); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
||||||
|
// a valid TLS client certificate. This is the default option for mTLS servers.
|
||||||
|
func RequireAndVerifyClientCert() TLSOption {
|
||||||
|
return func(c *tls.Config) error {
|
||||||
|
c.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
||||||
|
// a TLS client certificate if it is provided. It does not requires a certificate.
|
||||||
|
func VerifyClientCertIfGiven() TLSOption {
|
||||||
|
return func(c *tls.Config) error {
|
||||||
|
c.ClientAuth = tls.VerifyClientCertIfGiven
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRootCA adds to the tls.Config RootCAs the given certificate. RootCAs
|
||||||
|
// defines the set of root certificate authorities that clients use when
|
||||||
|
// verifying server certificates.
|
||||||
|
func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||||
|
return func(c *tls.Config) error {
|
||||||
|
if c.RootCAs == nil {
|
||||||
|
c.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
c.RootCAs.AddCert(cert)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddClientCA adds to the tls.Config ClientCAs the given certificate. ClientCAs
|
||||||
|
// defines the set of root certificate authorities that servers use if required
|
||||||
|
// to verify a client certificate by the policy in ClientAuth.
|
||||||
|
func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||||
|
return func(c *tls.Config) error {
|
||||||
|
if c.ClientCAs == nil {
|
||||||
|
c.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
c.ClientCAs.AddCert(cert)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
137
ca/tls_options_test.go
Normal file
137
ca/tls_options_test.go
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
package ca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_setTLSOptions(t *testing.T) {
|
||||||
|
fail := func() TLSOption {
|
||||||
|
return func(c *tls.Config) error {
|
||||||
|
return fmt.Errorf("an error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
c *tls.Config
|
||||||
|
options []TLSOption
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{&tls.Config{}, []TLSOption{RequireAndVerifyClientCert()}}, false},
|
||||||
|
{"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false},
|
||||||
|
{"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAndVerifyClientCert(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want *tls.Config
|
||||||
|
}{
|
||||||
|
{"ok", &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := &tls.Config{}
|
||||||
|
if err := RequireAndVerifyClientCert()(got); err != nil {
|
||||||
|
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("RequireAndVerifyClientCert() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyClientCertIfGiven(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want *tls.Config
|
||||||
|
}{
|
||||||
|
{"ok", &tls.Config{ClientAuth: tls.VerifyClientCertIfGiven}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := &tls.Config{}
|
||||||
|
if err := VerifyClientCertIfGiven()(got); err != nil {
|
||||||
|
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("VerifyClientCertIfGiven() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddRootCA(t *testing.T) {
|
||||||
|
cert := parseCertificate(rootPEM)
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(cert)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *tls.Config
|
||||||
|
}{
|
||||||
|
{"ok", args{cert}, &tls.Config{RootCAs: pool}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := &tls.Config{}
|
||||||
|
if err := AddRootCA(tt.args.cert)(got); err != nil {
|
||||||
|
t.Errorf("AddRootCA() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("AddRootCA() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddClientCA(t *testing.T) {
|
||||||
|
cert := parseCertificate(rootPEM)
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(cert)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *tls.Config
|
||||||
|
}{
|
||||||
|
{"ok", args{cert}, &tls.Config{ClientCAs: pool}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := &tls.Config{}
|
||||||
|
if err := AddClientCA(tt.args.cert)(got); err != nil {
|
||||||
|
t.Errorf("AddClientCA() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("AddClientCA() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
410
ca/tls_test.go
410
ca/tls_test.go
|
@ -104,15 +104,8 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) (
|
||||||
return client, sr, pk
|
return client, sr, pk
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
func serverHandler(t *testing.T, clientDomain string) http.Handler {
|
||||||
client, sr, pk := sign("127.0.0.1")
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
|
||||||
}
|
|
||||||
clientDomain := "test.domain"
|
|
||||||
// Create server with given tls.Config
|
|
||||||
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
if req.RequestURI != "/no-cert" {
|
if req.RequestURI != "/no-cert" {
|
||||||
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
||||||
w.Write([]byte("fail"))
|
w.Write([]byte("fail"))
|
||||||
|
@ -129,18 +122,46 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add serial number to check rotation
|
||||||
|
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
|
||||||
|
w.Header().Set("x-fingerprint", hex.EncodeToString(sum[:]))
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Write([]byte("ok"))
|
w.Write([]byte("ok"))
|
||||||
}))
|
})
|
||||||
defer srv.Close()
|
}
|
||||||
|
|
||||||
|
func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
|
clientDomain := "test.domain"
|
||||||
|
client, sr, pk := sign("127.0.0.1")
|
||||||
|
|
||||||
|
// Create mTLS server
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
||||||
|
defer srvMTLS.Close()
|
||||||
|
|
||||||
|
// Create TLS server
|
||||||
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
||||||
|
defer srvTLS.Close()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
path string
|
|
||||||
wantErr bool
|
|
||||||
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
||||||
|
wantErr map[string]bool
|
||||||
}{
|
}{
|
||||||
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
tr, err := client.Transport(context.Background(), sr, pk)
|
tr, err := client.Transport(context.Background(), sr, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Client.Transport() error = %v", err)
|
t.Errorf("Client.Transport() error = %v", err)
|
||||||
|
@ -149,8 +170,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: tr,
|
Transport: tr,
|
||||||
}
|
}
|
||||||
}},
|
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
|
||||||
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
||||||
|
@ -164,8 +185,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: tr,
|
Transport: tr,
|
||||||
}
|
}
|
||||||
}},
|
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
|
||||||
{"ok with no cert", "/no-cert", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
root, err := RootCertificate(sr)
|
root, err := RootCertificate(sr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("RootCertificate() error = %v", err)
|
t.Errorf("RootCertificate() error = %v", err)
|
||||||
|
@ -183,23 +204,27 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: tr,
|
Transport: tr,
|
||||||
}
|
}
|
||||||
}},
|
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
|
||||||
{"fail with default", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
return &http.Client{}
|
return &http.Client{}
|
||||||
}},
|
}, map[string]bool{srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
client, sr, pk := sign(clientDomain)
|
client, sr, pk := sign(clientDomain)
|
||||||
cli := tt.getClient(t, client, sr, pk)
|
cli := tt.getClient(t, client, sr, pk)
|
||||||
if cli != nil {
|
if cli == nil {
|
||||||
resp, err := cli.Get(srv.URL + tt.path)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if tt.wantErr {
|
for path, wantErr := range tt.wantErr {
|
||||||
|
t.Run(path, func(t *testing.T) {
|
||||||
|
resp, err := cli.Get(path)
|
||||||
|
if (err != nil) != wantErr {
|
||||||
|
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if wantErr {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
@ -210,6 +235,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
if !bytes.Equal(b, []byte("ok")) {
|
if !bytes.Equal(b, []byte("ok")) {
|
||||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -224,44 +250,36 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
||||||
ca := startCATestServer()
|
ca := startCATestServer()
|
||||||
defer ca.Close()
|
defer ca.Close()
|
||||||
|
|
||||||
|
clientDomain := "test.domain"
|
||||||
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
|
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
|
||||||
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
|
|
||||||
|
// Start mTLS server
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||||
}
|
}
|
||||||
clientDomain := "test.domain"
|
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
||||||
fingerprints := make(map[string]struct{})
|
defer srvMTLS.Close()
|
||||||
|
|
||||||
// Create server with given tls.Config
|
// Start TLS server
|
||||||
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
defer cancel()
|
||||||
w.Write([]byte("fail"))
|
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
|
||||||
t.Error("http.Request.TLS does not have peer certificates")
|
if err != nil {
|
||||||
return
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||||
}
|
}
|
||||||
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
||||||
w.Write([]byte("fail"))
|
defer srvTLS.Close()
|
||||||
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
|
||||||
w.Write([]byte("fail"))
|
|
||||||
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Add serial number to check rotation
|
|
||||||
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
|
|
||||||
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
|
|
||||||
w.Write([]byte("ok"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
// Clients: transport and tlsConfig
|
// Transport
|
||||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||||
tr1, err := client.Transport(context.Background(), sr, pk)
|
tr1, err := client.Transport(context.Background(), sr, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Client.Transport() error = %v", err)
|
t.Fatalf("Client.Transport() error = %v", err)
|
||||||
}
|
}
|
||||||
|
// Transport with tlsConfig
|
||||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||||
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
|
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -271,227 +289,15 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getDefaultTransport() error = %v", err)
|
t.Fatalf("getDefaultTransport() error = %v", err)
|
||||||
}
|
}
|
||||||
|
// No client cert
|
||||||
// Disable keep alives to force TLS handshake
|
|
||||||
tr1.DisableKeepAlives = true
|
|
||||||
tr2.DisableKeepAlives = true
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
client *http.Client
|
|
||||||
}{
|
|
||||||
{"with transport", &http.Client{Transport: tr1}},
|
|
||||||
{"with tlsConfig", &http.Client{Transport: tr2}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
resp, err := tt.client.Get(srv.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("http.Client.Get() error = %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b, []byte("ok")) {
|
|
||||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if l := len(fingerprints); l != 2 {
|
|
||||||
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for renewal 40s == 1m-1m/3
|
|
||||||
log.Printf("Sleeping for %s ...\n", 40*time.Second)
|
|
||||||
time.Sleep(40 * time.Second)
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run("renewed "+tt.name, func(t *testing.T) {
|
|
||||||
resp, err := tt.client.Get(srv.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("http.Client.Get() error = %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b, []byte("ok")) {
|
|
||||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if l := len(fingerprints); l != 4 {
|
|
||||||
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
|
|
||||||
client, sr, pk := sign("127.0.0.1")
|
|
||||||
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
|
||||||
}
|
|
||||||
clientDomain := "test.domain"
|
|
||||||
// Create server with given tls.Config
|
|
||||||
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
if req.RequestURI != "/no-cert" {
|
|
||||||
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
|
||||||
w.Write([]byte("fail"))
|
|
||||||
t.Error("http.Request.TLS does not have peer certificates")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
|
||||||
w.Write([]byte("fail"))
|
|
||||||
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
|
||||||
w.Write([]byte("fail"))
|
|
||||||
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.Write([]byte("ok"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
path string
|
|
||||||
wantErr bool
|
|
||||||
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
|
||||||
}{
|
|
||||||
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
||||||
tr, err := client.Transport(context.Background(), sr, pk)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Client.Transport() error = %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &http.Client{
|
|
||||||
Transport: tr,
|
|
||||||
}
|
|
||||||
}},
|
|
||||||
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
||||||
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
tr, err := getDefaultTransport(tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("getDefaultTransport() error = %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &http.Client{
|
|
||||||
Transport: tr,
|
|
||||||
}
|
|
||||||
}},
|
|
||||||
{"fail with no cert", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
|
||||||
root, err := RootCertificate(sr)
|
root, err := RootCertificate(sr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("RootCertificate() error = %v", err)
|
t.Fatalf("RootCertificate() error = %v", err)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
tlsConfig := getDefaultTLSConfig(sr)
|
tlsConfig = getDefaultTLSConfig(sr)
|
||||||
tlsConfig.RootCAs = x509.NewCertPool()
|
tlsConfig.RootCAs = x509.NewCertPool()
|
||||||
tlsConfig.RootCAs.AddCert(root)
|
tlsConfig.RootCAs.AddCert(root)
|
||||||
|
tr3, err := getDefaultTransport(tlsConfig)
|
||||||
tr, err := getDefaultTransport(tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("getDefaultTransport() error = %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &http.Client{
|
|
||||||
Transport: tr,
|
|
||||||
}
|
|
||||||
}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
client, sr, pk := sign(clientDomain)
|
|
||||||
cli := tt.getClient(t, client, sr, pk)
|
|
||||||
if cli != nil {
|
|
||||||
resp, err := cli.Get(srv.URL + tt.path)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.wantErr {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(b, []byte("ok")) {
|
|
||||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skipping test in short mode.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start CA
|
|
||||||
ca := startCATestServer()
|
|
||||||
defer ca.Close()
|
|
||||||
|
|
||||||
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
|
|
||||||
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
|
||||||
}
|
|
||||||
clientDomain := "test.domain"
|
|
||||||
fingerprints := make(map[string]struct{})
|
|
||||||
|
|
||||||
// Create server with given tls.Config
|
|
||||||
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
||||||
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
|
||||||
w.Write([]byte("fail"))
|
|
||||||
t.Error("http.Request.TLS does not have peer certificates")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
|
||||||
w.Write([]byte("fail"))
|
|
||||||
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
|
||||||
w.Write([]byte("fail"))
|
|
||||||
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Add serial number to check rotation
|
|
||||||
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
|
|
||||||
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
|
|
||||||
w.Write([]byte("ok"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
// Clients: transport and tlsConfig
|
|
||||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
|
||||||
tr1, err := client.Transport(context.Background(), sr, pk)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Client.Transport() error = %v", err)
|
|
||||||
}
|
|
||||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
|
||||||
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
|
|
||||||
}
|
|
||||||
tr2, err := getDefaultTransport(tlsConfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("getDefaultTransport() error = %v", err)
|
t.Fatalf("getDefaultTransport() error = %v", err)
|
||||||
}
|
}
|
||||||
|
@ -499,34 +305,66 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
|
||||||
// Disable keep alives to force TLS handshake
|
// Disable keep alives to force TLS handshake
|
||||||
tr1.DisableKeepAlives = true
|
tr1.DisableKeepAlives = true
|
||||||
tr2.DisableKeepAlives = true
|
tr2.DisableKeepAlives = true
|
||||||
|
tr3.DisableKeepAlives = true
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
wantErr map[string]bool
|
||||||
}{
|
}{
|
||||||
{"with transport", &http.Client{Transport: tr1}},
|
{"with transport", &http.Client{Transport: tr1}, map[string]bool{
|
||||||
{"with tlsConfig", &http.Client{Transport: tr2}},
|
srvTLS.URL: false,
|
||||||
|
srvMTLS.URL: false,
|
||||||
|
}},
|
||||||
|
{"with tlsConfig", &http.Client{Transport: tr2}, map[string]bool{
|
||||||
|
srvTLS.URL: false,
|
||||||
|
srvMTLS.URL: false,
|
||||||
|
}},
|
||||||
|
{"with no ClientCert", &http.Client{Transport: tr3}, map[string]bool{
|
||||||
|
srvTLS.URL + "/no-cert": false,
|
||||||
|
srvMTLS.URL + "/no-cert": true,
|
||||||
|
}},
|
||||||
|
{"fail with default", &http.Client{}, map[string]bool{
|
||||||
|
srvTLS.URL + "/no-cert": true,
|
||||||
|
srvMTLS.URL + "/no-cert": true,
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// To count different cert fingerprints
|
||||||
|
fingerprints := map[string]struct{}{}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
resp, err := tt.client.Get(srv.URL)
|
for path, wantErr := range tt.wantErr {
|
||||||
if err != nil {
|
t.Run(path, func(t *testing.T) {
|
||||||
t.Fatalf("http.Client.Get() error = %v", err)
|
resp, err := tt.client.Get(path)
|
||||||
|
if (err != nil) != wantErr {
|
||||||
|
t.Errorf("http.Client.Get() error = %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
if wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
|
||||||
|
fingerprints[fp] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
t.Errorf("ioutil.RealAdd() error = %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if !bytes.Equal(b, []byte("ok")) {
|
if !bytes.Equal(b, []byte("ok")) {
|
||||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if l := len(fingerprints); l != 2 {
|
if l := len(fingerprints); l != 2 {
|
||||||
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
t.Errorf("number of fingerprints unexpected, got %d, want 2", l)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for renewal 40s == 1m-1m/3
|
// Wait for renewal 40s == 1m-1m/3
|
||||||
|
@ -535,17 +373,31 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run("renewed "+tt.name, func(t *testing.T) {
|
t.Run("renewed "+tt.name, func(t *testing.T) {
|
||||||
resp, err := tt.client.Get(srv.URL)
|
for path, wantErr := range tt.wantErr {
|
||||||
if err != nil {
|
t.Run(path, func(t *testing.T) {
|
||||||
t.Fatalf("http.Client.Get() error = %v", err)
|
resp, err := tt.client.Get(path)
|
||||||
|
if (err != nil) != wantErr {
|
||||||
|
t.Errorf("http.Client.Get() error = %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
if wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
|
||||||
|
fingerprints[fp] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
t.Errorf("ioutil.RealAdd() error = %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if !bytes.Equal(b, []byte("ok")) {
|
if !bytes.Equal(b, []byte("ok")) {
|
||||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,13 +139,13 @@ password `password` hardcoded, but you can create your own using `step ca init`.
|
||||||
|
|
||||||
These examples show the use of some other helper methods - simple ways to
|
These examples show the use of some other helper methods - simple ways to
|
||||||
create TLS configured http.Server and http.Client objects. The methods are
|
create TLS configured http.Server and http.Client objects. The methods are
|
||||||
`BootstrapServer`, `BootstrapServerWithMTLS` and `BootstrapClient`.
|
`BootstrapServer` and `BootstrapClient`.
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// Get a cancelable context to stop the renewal goroutines and timers.
|
// Get a cancelable context to stop the renewal goroutines and timers.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Create an http.Server
|
// Create an http.Server that requires a client certificate
|
||||||
srv, err := ca.BootstrapServer(ctx, token, &http.Server{
|
srv, err := ca.BootstrapServer(ctx, token, &http.Server{
|
||||||
Addr: ":8443",
|
Addr: ":8443",
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
|
@ -160,11 +160,11 @@ srv.ListenAndServeTLS("", "")
|
||||||
// Get a cancelable context to stop the renewal goroutines and timers.
|
// Get a cancelable context to stop the renewal goroutines and timers.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Create an http.Server that requires a client certificate
|
// Create an http.Server that does not require a client certificate
|
||||||
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
||||||
Addr: ":8443",
|
Addr: ":8443",
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
})
|
}, ca.VerifyClientCertIfGiven())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -194,13 +194,13 @@ certificates $ bin/step-ca examples/pki/config/ca.json
|
||||||
2018/11/02 18:29:25 Serving HTTPS on :9000 ...
|
2018/11/02 18:29:25 Serving HTTPS on :9000 ...
|
||||||
```
|
```
|
||||||
|
|
||||||
Next we will start the bootstrap-server and enter `password` prompted for the
|
Next we will start the bootstrap-tls-server and enter `password` prompted for the
|
||||||
provisioner password:
|
provisioner password:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
certificates $ export STEPPATH=examples/pki
|
certificates $ export STEPPATH=examples/pki
|
||||||
certificates $ export STEP_CA_URL=https://localhost:9000
|
certificates $ export STEP_CA_URL=https://localhost:9000
|
||||||
certificates $ go run examples/bootstrap-server/server.go $(step ca token localhost)
|
certificates $ go run examples/bootstrap-tls-server/server.go $(step ca token localhost)
|
||||||
✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com)
|
✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com)
|
||||||
Please enter the password to decrypt the provisioner key:
|
Please enter the password to decrypt the provisioner key:
|
||||||
Listening on :8443 ...
|
Listening on :8443 ...
|
||||||
|
|
|
@ -22,7 +22,7 @@ func main() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
srv, err := ca.BootstrapServer(ctx, token, &http.Server{
|
||||||
Addr: ":8443",
|
Addr: ":8443",
|
||||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
name := "nobody"
|
name := "nobody"
|
||||||
|
|
|
@ -31,7 +31,7 @@ func main() {
|
||||||
}
|
}
|
||||||
w.Write([]byte(fmt.Sprintf("Hello %s at %s!!!", name, time.Now().UTC())))
|
w.Write([]byte(fmt.Sprintf("Hello %s at %s!!!", name, time.Now().UTC())))
|
||||||
}),
|
}),
|
||||||
})
|
}, ca.VerifyClientCertIfGiven())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
Loading…
Reference in a new issue