forked from TrueCloudLab/certificates
Use mTLS by default on SDK methods.
Add options to modify the tls.Config for different configurations. Fixes #7
This commit is contained in:
parent
bb03aadddf
commit
d872f09910
9 changed files with 417 additions and 419 deletions
|
@ -43,6 +43,12 @@ func Bootstrap(token string) (*Client, error) {
|
|||
// Authority. By default the server will kick off a routine that will renew the
|
||||
// certificate after 2/3rd of the certificate's lifetime has expired.
|
||||
//
|
||||
// Without any extra option the server will be configured for mTLS, it will
|
||||
// require and verify clients certificates, but options can be used to drop this
|
||||
// requirement, the most common will be only verify the certs if given with
|
||||
// ca.VerifyClientCertIfGiven(), or add extra CAs with
|
||||
// ca.AddClientCA(*x509.Certificate).
|
||||
//
|
||||
// Usage:
|
||||
// // Default example with certificate rotation.
|
||||
// srv, err := ca.BootstrapServer(context.Background(), token, &http.Server{
|
||||
|
@ -61,7 +67,7 @@ func Bootstrap(token string) (*Client, error) {
|
|||
// return err
|
||||
// }
|
||||
// srv.ListenAndServeTLS("", "")
|
||||
func BootstrapServer(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
|
||||
func BootstrapServer(ctx context.Context, token string, base *http.Server, options ...TLSOption) (*http.Server, error) {
|
||||
if base.TLSConfig != nil {
|
||||
return nil, errors.New("server TLSConfig is already set")
|
||||
}
|
||||
|
@ -81,60 +87,7 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server) (*htt
|
|||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base.TLSConfig = tlsConfig
|
||||
return base, nil
|
||||
}
|
||||
|
||||
// BootstrapServerWithMTLS is a helper function that using the given token
|
||||
// returns the given http.Server configured with a TLS certificate signed by the
|
||||
// Certificate Authority, this server will always require and verify a client
|
||||
// certificate. By default the server will kick off a routine that will renew
|
||||
// the certificate after 2/3rd of the certificate's lifetime has expired.
|
||||
//
|
||||
// Usage:
|
||||
// // Default example with certificate rotation.
|
||||
// srv, err := ca.BootstrapServerWithMTLS(context.Background(), token, &http.Server{
|
||||
// Addr: ":443",
|
||||
// Handler: handler,
|
||||
// })
|
||||
//
|
||||
// // Example canceling automatic certificate rotation.
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
// defer cancel()
|
||||
// srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
||||
// Addr: ":443",
|
||||
// Handler: handler,
|
||||
// })
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// srv.ListenAndServeTLS("", "")
|
||||
func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
|
||||
if base.TLSConfig != nil {
|
||||
return nil, errors.New("server TLSConfig is already set")
|
||||
}
|
||||
|
||||
client, err := Bootstrap(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, pk, err := CreateSignRequest(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sign, err := client.Sign(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig, err := client.GetServerMutualTLSConfig(ctx, sign, pk)
|
||||
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -161,7 +114,7 @@ func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Serve
|
|||
// return err
|
||||
// }
|
||||
// resp, err := client.Get("https://internal.smallstep.com")
|
||||
func BootstrapClient(ctx context.Context, token string) (*http.Client, error) {
|
||||
func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) {
|
||||
client, err := Bootstrap(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -177,7 +130,7 @@ func BootstrapClient(ctx context.Context, token string) (*http.Client, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
transport, err := client.Transport(ctx, sign, pk)
|
||||
transport, err := client.Transport(ctx, sign, pk, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ func TestBootstrap(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBootstrapServer(t *testing.T) {
|
||||
func TestBootstrapServerWithoutMTLS(t *testing.T) {
|
||||
srv := startCABootstrapServer()
|
||||
defer srv.Close()
|
||||
token := func() string {
|
||||
|
@ -146,7 +146,7 @@ func TestBootstrapServer(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base)
|
||||
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base, VerifyClientCertIfGiven())
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -192,24 +192,24 @@ func TestBootstrapServerWithMTLS(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := BootstrapServerWithMTLS(tt.args.ctx, tt.args.token, tt.args.base)
|
||||
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("BootstrapServerWithMTLS() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
if got != nil {
|
||||
t.Errorf("BootstrapServerWithMTLS() = %v, want nil", got)
|
||||
t.Errorf("BootstrapServer() = %v, want nil", got)
|
||||
}
|
||||
} else {
|
||||
expected := &http.Server{
|
||||
TLSConfig: got.TLSConfig,
|
||||
}
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("BootstrapServerWithMTLS() = %v, want %v", got, expected)
|
||||
t.Errorf("BootstrapServer() = %v, want %v", got, expected)
|
||||
}
|
||||
if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
|
||||
t.Errorf("BootstrapServerWithMTLS() invalid TLSConfig = %#v", got.TLSConfig)
|
||||
t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
64
ca/tls.go
64
ca/tls.go
|
@ -20,7 +20,7 @@ import (
|
|||
// GetClientTLSConfig returns a tls.Config for client use configured with the
|
||||
// sign certificate, and a new certificate pool with the sign root certificate.
|
||||
// The client certificate will automatically rotate before expiring.
|
||||
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
||||
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
|
||||
cert, err := TLSCertificate(sign, pk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -36,10 +36,15 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||
tlsConfig.PreferServerCipherSuites = true
|
||||
// Build RootCAs with given root certificate
|
||||
if pool := c.getCertPool(sign); pool != nil {
|
||||
if pool := getCertPool(sign); pool != nil {
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
|
||||
// Apply options if given
|
||||
if err := setTLSOptions(tlsConfig, options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update renew function with transport
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
|
@ -56,7 +61,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
// sign certificate, and a new certificate pool with the sign root certificate.
|
||||
// The returned tls.Config will only verify the client certificate if provided.
|
||||
// The server certificate will automatically rotate before expiring.
|
||||
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
||||
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
|
||||
cert, err := TLSCertificate(sign, pk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -74,13 +79,18 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||
tlsConfig.PreferServerCipherSuites = true
|
||||
// Build RootCAs with given root certificate
|
||||
if pool := c.getCertPool(sign); pool != nil {
|
||||
if pool := getCertPool(sign); pool != nil {
|
||||
tlsConfig.ClientCAs = pool
|
||||
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
// Add RootCAs for refresh client
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
|
||||
// Apply options if given
|
||||
if err := setTLSOptions(tlsConfig, options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update renew function with transport
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
|
@ -93,44 +103,15 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// GetServerMutualTLSConfig returns a tls.Config for server use configured with
|
||||
// the sign certificate, and a new certificate pool with the sign root certificate.
|
||||
// The returned tls.Config will always require and verify a client certificate.
|
||||
// The server certificate will automatically rotate before expiring.
|
||||
func (c *Client) GetServerMutualTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
||||
tlsConfig, err := c.GetServerTLSConfig(ctx, sign, pk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// Transport returns an http.Transport configured to use the client certificate from the sign response.
|
||||
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
|
||||
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)
|
||||
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
|
||||
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getDefaultTransport(tlsConfig)
|
||||
}
|
||||
|
||||
// getCertPool returns the transport x509.CertPool or the one from the sign
|
||||
// request.
|
||||
func (c *Client) getCertPool(sign *api.SignResponse) *x509.CertPool {
|
||||
// Return the transport certPool
|
||||
if c.certPool != nil {
|
||||
return c.certPool
|
||||
}
|
||||
// Return certificate used in sign request.
|
||||
if root, err := RootCertificate(sign); err == nil {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(root)
|
||||
return pool
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Certificate returns the server or client certificate from the sign response.
|
||||
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||||
if sign.ServerPEM.Certificate == nil {
|
||||
|
@ -189,6 +170,17 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
|
|||
return &cert, nil
|
||||
}
|
||||
|
||||
// getCertPool returns the transport x509.CertPool or the one from the sign
|
||||
// request.
|
||||
func getCertPool(sign *api.SignResponse) *x509.CertPool {
|
||||
if root, err := RootCertificate(sign); err == nil {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(root)
|
||||
return pool
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
||||
if sign.TLSOptions != nil {
|
||||
return sign.TLSOptions.TLSConfig()
|
||||
|
|
64
ca/tls_options.go
Normal file
64
ca/tls_options.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||
type TLSOption func(c *tls.Config) error
|
||||
|
||||
// setTLSOptions takes one or more option function and applies them in order to
|
||||
// a tls.Config.
|
||||
func setTLSOptions(c *tls.Config, options []TLSOption) error {
|
||||
for _, opt := range options {
|
||||
if err := opt(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
||||
// a valid TLS client certificate. This is the default option for mTLS servers.
|
||||
func RequireAndVerifyClientCert() TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
c.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
||||
// a TLS client certificate if it is provided. It does not requires a certificate.
|
||||
func VerifyClientCertIfGiven() TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
c.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddRootCA adds to the tls.Config RootCAs the given certificate. RootCAs
|
||||
// defines the set of root certificate authorities that clients use when
|
||||
// verifying server certificates.
|
||||
func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
if c.RootCAs == nil {
|
||||
c.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
c.RootCAs.AddCert(cert)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddClientCA adds to the tls.Config ClientCAs the given certificate. ClientCAs
|
||||
// defines the set of root certificate authorities that servers use if required
|
||||
// to verify a client certificate by the policy in ClientAuth.
|
||||
func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
if c.ClientCAs == nil {
|
||||
c.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
c.ClientCAs.AddCert(cert)
|
||||
return nil
|
||||
}
|
||||
}
|
137
ca/tls_options_test.go
Normal file
137
ca/tls_options_test.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package ca
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_setTLSOptions(t *testing.T) {
|
||||
fail := func() TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
return fmt.Errorf("an error")
|
||||
}
|
||||
}
|
||||
type args struct {
|
||||
c *tls.Config
|
||||
options []TLSOption
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{&tls.Config{}, []TLSOption{RequireAndVerifyClientCert()}}, false},
|
||||
{"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false},
|
||||
{"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
||||
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAndVerifyClientCert(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want *tls.Config
|
||||
}{
|
||||
{"ok", &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := RequireAndVerifyClientCert()(got); err != nil {
|
||||
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("RequireAndVerifyClientCert() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyClientCertIfGiven(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want *tls.Config
|
||||
}{
|
||||
{"ok", &tls.Config{ClientAuth: tls.VerifyClientCertIfGiven}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := VerifyClientCertIfGiven()(got); err != nil {
|
||||
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("VerifyClientCertIfGiven() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddRootCA(t *testing.T) {
|
||||
cert := parseCertificate(rootPEM)
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tls.Config
|
||||
}{
|
||||
{"ok", args{cert}, &tls.Config{RootCAs: pool}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := AddRootCA(tt.args.cert)(got); err != nil {
|
||||
t.Errorf("AddRootCA() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("AddRootCA() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddClientCA(t *testing.T) {
|
||||
cert := parseCertificate(rootPEM)
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert)
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *tls.Config
|
||||
}{
|
||||
{"ok", args{cert}, &tls.Config{ClientCAs: pool}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := AddClientCA(tt.args.cert)(got); err != nil {
|
||||
t.Errorf("AddClientCA() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("AddClientCA() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
474
ca/tls_test.go
474
ca/tls_test.go
|
@ -104,15 +104,8 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) (
|
|||
return client, sr, pk
|
||||
}
|
||||
|
||||
func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||
client, sr, pk := sign("127.0.0.1")
|
||||
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||
}
|
||||
clientDomain := "test.domain"
|
||||
// Create server with given tls.Config
|
||||
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
func serverHandler(t *testing.T, clientDomain string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.RequestURI != "/no-cert" {
|
||||
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
||||
w.Write([]byte("fail"))
|
||||
|
@ -129,18 +122,46 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|||
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
||||
return
|
||||
}
|
||||
|
||||
// Add serial number to check rotation
|
||||
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
|
||||
w.Header().Set("x-fingerprint", hex.EncodeToString(sum[:]))
|
||||
}
|
||||
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||
clientDomain := "test.domain"
|
||||
client, sr, pk := sign("127.0.0.1")
|
||||
|
||||
// Create mTLS server
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||
}
|
||||
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
||||
defer srvMTLS.Close()
|
||||
|
||||
// Create TLS server
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
|
||||
if err != nil {
|
||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||
}
|
||||
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
||||
defer srvTLS.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantErr bool
|
||||
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
||||
wantErr map[string]bool
|
||||
}{
|
||||
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
tr, err := client.Transport(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Errorf("Client.Transport() error = %v", err)
|
||||
|
@ -149,8 +170,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|||
return &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
}},
|
||||
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
|
||||
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
||||
|
@ -164,8 +185,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|||
return &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
}},
|
||||
{"ok with no cert", "/no-cert", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
|
||||
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
root, err := RootCertificate(sr)
|
||||
if err != nil {
|
||||
t.Errorf("RootCertificate() error = %v", err)
|
||||
|
@ -183,33 +204,38 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|||
return &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
}},
|
||||
{"fail with default", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
|
||||
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
return &http.Client{}
|
||||
}},
|
||||
}, map[string]bool{srvTLS.URL + "/no-cert": true, srvMTLS.URL + "/no-cert": true}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, sr, pk := sign(clientDomain)
|
||||
cli := tt.getClient(t, client, sr, pk)
|
||||
if cli != nil {
|
||||
resp, err := cli.Get(srv.URL + tt.path)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(b, []byte("ok")) {
|
||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||
}
|
||||
if cli == nil {
|
||||
return
|
||||
}
|
||||
for path, wantErr := range tt.wantErr {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
resp, err := cli.Get(path)
|
||||
if (err != nil) != wantErr {
|
||||
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, wantErr)
|
||||
return
|
||||
}
|
||||
if wantErr {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(b, []byte("ok")) {
|
||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -224,44 +250,36 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
|||
ca := startCATestServer()
|
||||
defer ca.Close()
|
||||
|
||||
clientDomain := "test.domain"
|
||||
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
|
||||
tlsConfig, err := client.GetServerTLSConfig(context.Background(), sr, pk)
|
||||
|
||||
// Start mTLS server
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||
}
|
||||
clientDomain := "test.domain"
|
||||
fingerprints := make(map[string]struct{})
|
||||
srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
||||
defer srvMTLS.Close()
|
||||
|
||||
// Create server with given tls.Config
|
||||
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
||||
w.Write([]byte("fail"))
|
||||
t.Error("http.Request.TLS does not have peer certificates")
|
||||
return
|
||||
}
|
||||
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
||||
w.Write([]byte("fail"))
|
||||
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
||||
w.Write([]byte("fail"))
|
||||
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
||||
return
|
||||
}
|
||||
// Add serial number to check rotation
|
||||
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
|
||||
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
// Start TLS server
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven())
|
||||
if err != nil {
|
||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||
}
|
||||
srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain))
|
||||
defer srvTLS.Close()
|
||||
|
||||
// Clients: transport and tlsConfig
|
||||
// Transport
|
||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||
tr1, err := client.Transport(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.Transport() error = %v", err)
|
||||
}
|
||||
// Transport with tlsConfig
|
||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
|
@ -271,227 +289,15 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("getDefaultTransport() error = %v", err)
|
||||
}
|
||||
|
||||
// Disable keep alives to force TLS handshake
|
||||
tr1.DisableKeepAlives = true
|
||||
tr2.DisableKeepAlives = true
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
client *http.Client
|
||||
}{
|
||||
{"with transport", &http.Client{Transport: tr1}},
|
||||
{"with tlsConfig", &http.Client{Transport: tr2}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, err := tt.client.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Client.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(b, []byte("ok")) {
|
||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if l := len(fingerprints); l != 2 {
|
||||
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
||||
}
|
||||
|
||||
// Wait for renewal 40s == 1m-1m/3
|
||||
log.Printf("Sleeping for %s ...\n", 40*time.Second)
|
||||
time.Sleep(40 * time.Second)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("renewed "+tt.name, func(t *testing.T) {
|
||||
resp, err := tt.client.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Client.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(b, []byte("ok")) {
|
||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if l := len(fingerprints); l != 4 {
|
||||
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
|
||||
client, sr, pk := sign("127.0.0.1")
|
||||
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
|
||||
// No client cert
|
||||
root, err := RootCertificate(sr)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||
t.Fatalf("RootCertificate() error = %v", err)
|
||||
}
|
||||
clientDomain := "test.domain"
|
||||
// Create server with given tls.Config
|
||||
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.RequestURI != "/no-cert" {
|
||||
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
||||
w.Write([]byte("fail"))
|
||||
t.Error("http.Request.TLS does not have peer certificates")
|
||||
return
|
||||
}
|
||||
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
||||
w.Write([]byte("fail"))
|
||||
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
||||
w.Write([]byte("fail"))
|
||||
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantErr bool
|
||||
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
||||
}{
|
||||
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
tr, err := client.Transport(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Errorf("Client.Transport() error = %v", err)
|
||||
return nil
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
}},
|
||||
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
||||
return nil
|
||||
}
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
t.Errorf("getDefaultTransport() error = %v", err)
|
||||
return nil
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
}},
|
||||
{"fail with no cert", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
root, err := RootCertificate(sr)
|
||||
if err != nil {
|
||||
t.Errorf("RootCertificate() error = %v", err)
|
||||
return nil
|
||||
}
|
||||
tlsConfig := getDefaultTLSConfig(sr)
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
tlsConfig.RootCAs.AddCert(root)
|
||||
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
t.Errorf("getDefaultTransport() error = %v", err)
|
||||
return nil
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, sr, pk := sign(clientDomain)
|
||||
cli := tt.getClient(t, client, sr, pk)
|
||||
if cli != nil {
|
||||
resp, err := cli.Get(srv.URL + tt.path)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(b, []byte("ok")) {
|
||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode.")
|
||||
}
|
||||
|
||||
// Start CA
|
||||
ca := startCATestServer()
|
||||
defer ca.Close()
|
||||
|
||||
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
|
||||
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||
}
|
||||
clientDomain := "test.domain"
|
||||
fingerprints := make(map[string]struct{})
|
||||
|
||||
// Create server with given tls.Config
|
||||
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
||||
w.Write([]byte("fail"))
|
||||
t.Error("http.Request.TLS does not have peer certificates")
|
||||
return
|
||||
}
|
||||
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
||||
w.Write([]byte("fail"))
|
||||
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
||||
w.Write([]byte("fail"))
|
||||
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
||||
return
|
||||
}
|
||||
// Add serial number to check rotation
|
||||
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
|
||||
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
|
||||
w.Write([]byte("ok"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Clients: transport and tlsConfig
|
||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||
tr1, err := client.Transport(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.Transport() error = %v", err)
|
||||
}
|
||||
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||
if err != nil {
|
||||
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
|
||||
}
|
||||
tr2, err := getDefaultTransport(tlsConfig)
|
||||
tlsConfig = getDefaultTLSConfig(sr)
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
tlsConfig.RootCAs.AddCert(root)
|
||||
tr3, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("getDefaultTransport() error = %v", err)
|
||||
}
|
||||
|
@ -499,34 +305,66 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
|
|||
// Disable keep alives to force TLS handshake
|
||||
tr1.DisableKeepAlives = true
|
||||
tr2.DisableKeepAlives = true
|
||||
tr3.DisableKeepAlives = true
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
client *http.Client
|
||||
name string
|
||||
client *http.Client
|
||||
wantErr map[string]bool
|
||||
}{
|
||||
{"with transport", &http.Client{Transport: tr1}},
|
||||
{"with tlsConfig", &http.Client{Transport: tr2}},
|
||||
{"with transport", &http.Client{Transport: tr1}, map[string]bool{
|
||||
srvTLS.URL: false,
|
||||
srvMTLS.URL: false,
|
||||
}},
|
||||
{"with tlsConfig", &http.Client{Transport: tr2}, map[string]bool{
|
||||
srvTLS.URL: false,
|
||||
srvMTLS.URL: false,
|
||||
}},
|
||||
{"with no ClientCert", &http.Client{Transport: tr3}, map[string]bool{
|
||||
srvTLS.URL + "/no-cert": false,
|
||||
srvMTLS.URL + "/no-cert": true,
|
||||
}},
|
||||
{"fail with default", &http.Client{}, map[string]bool{
|
||||
srvTLS.URL + "/no-cert": true,
|
||||
srvMTLS.URL + "/no-cert": true,
|
||||
}},
|
||||
}
|
||||
|
||||
// To count different cert fingerprints
|
||||
fingerprints := map[string]struct{}{}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, err := tt.client.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Client.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(b, []byte("ok")) {
|
||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||
for path, wantErr := range tt.wantErr {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
resp, err := tt.client.Get(path)
|
||||
if (err != nil) != wantErr {
|
||||
t.Errorf("http.Client.Get() error = %v", err)
|
||||
return
|
||||
}
|
||||
if wantErr {
|
||||
return
|
||||
}
|
||||
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
|
||||
fingerprints[fp] = struct{}{}
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("ioutil.RealAdd() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(b, []byte("ok")) {
|
||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if l := len(fingerprints); l != 2 {
|
||||
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
||||
t.Errorf("number of fingerprints unexpected, got %d, want 2", l)
|
||||
}
|
||||
|
||||
// Wait for renewal 40s == 1m-1m/3
|
||||
|
@ -535,17 +373,31 @@ func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run("renewed "+tt.name, func(t *testing.T) {
|
||||
resp, err := tt.client.Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Client.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||
}
|
||||
if !bytes.Equal(b, []byte("ok")) {
|
||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||
for path, wantErr := range tt.wantErr {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
resp, err := tt.client.Get(path)
|
||||
if (err != nil) != wantErr {
|
||||
t.Errorf("http.Client.Get() error = %v", err)
|
||||
return
|
||||
}
|
||||
if wantErr {
|
||||
return
|
||||
}
|
||||
if fp := resp.Header.Get("x-fingerprint"); fp != "" {
|
||||
fingerprints[fp] = struct{}{}
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("ioutil.RealAdd() error = %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(b, []byte("ok")) {
|
||||
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -139,13 +139,13 @@ password `password` hardcoded, but you can create your own using `step ca init`.
|
|||
|
||||
These examples show the use of some other helper methods - simple ways to
|
||||
create TLS configured http.Server and http.Client objects. The methods are
|
||||
`BootstrapServer`, `BootstrapServerWithMTLS` and `BootstrapClient`.
|
||||
`BootstrapServer` and `BootstrapClient`.
|
||||
|
||||
```go
|
||||
// Get a cancelable context to stop the renewal goroutines and timers.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// Create an http.Server
|
||||
// Create an http.Server that requires a client certificate
|
||||
srv, err := ca.BootstrapServer(ctx, token, &http.Server{
|
||||
Addr: ":8443",
|
||||
Handler: handler,
|
||||
|
@ -160,11 +160,11 @@ srv.ListenAndServeTLS("", "")
|
|||
// Get a cancelable context to stop the renewal goroutines and timers.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// Create an http.Server that requires a client certificate
|
||||
// Create an http.Server that does not require a client certificate
|
||||
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
||||
Addr: ":8443",
|
||||
Handler: handler,
|
||||
})
|
||||
}, ca.VerifyClientCertIfGiven())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -194,13 +194,13 @@ certificates $ bin/step-ca examples/pki/config/ca.json
|
|||
2018/11/02 18:29:25 Serving HTTPS on :9000 ...
|
||||
```
|
||||
|
||||
Next we will start the bootstrap-server and enter `password` prompted for the
|
||||
Next we will start the bootstrap-tls-server and enter `password` prompted for the
|
||||
provisioner password:
|
||||
|
||||
```sh
|
||||
certificates $ export STEPPATH=examples/pki
|
||||
certificates $ export STEP_CA_URL=https://localhost:9000
|
||||
certificates $ go run examples/bootstrap-server/server.go $(step ca token localhost)
|
||||
certificates $ go run examples/bootstrap-tls-server/server.go $(step ca token localhost)
|
||||
✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com)
|
||||
Please enter the password to decrypt the provisioner key:
|
||||
Listening on :8443 ...
|
||||
|
|
|
@ -22,7 +22,7 @@ func main() {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
||||
srv, err := ca.BootstrapServer(ctx, token, &http.Server{
|
||||
Addr: ":8443",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
name := "nobody"
|
||||
|
|
|
@ -31,7 +31,7 @@ func main() {
|
|||
}
|
||||
w.Write([]byte(fmt.Sprintf("Hello %s at %s!!!", name, time.Now().UTC())))
|
||||
}),
|
||||
})
|
||||
}, ca.VerifyClientCertIfGiven())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
Loading…
Reference in a new issue