forked from TrueCloudLab/certificates
Add helpers to add direct support for mTLS.
This commit is contained in:
parent
272bbc57dd
commit
9c64dbda9a
6 changed files with 495 additions and 26 deletions
|
@ -90,6 +90,59 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server) (*htt
|
||||||
return base, nil
|
return base, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BootstrapServerWithMTLS is a helper function that using the given token
|
||||||
|
// returns the given http.Server configured with a TLS certificate signed by the
|
||||||
|
// Certificate Authority, this server will always require and verify a client
|
||||||
|
// certificate. By default the server will kick off a routine that will renew
|
||||||
|
// the certificate after 2/3rd of the certificate's lifetime has expired.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
// // Default example with certificate rotation.
|
||||||
|
// srv, err := ca.BootstrapServerWithMTLS(context.Background(), token, &http.Server{
|
||||||
|
// Addr: ":443",
|
||||||
|
// Handler: handler,
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// // Example canceling automatic certificate rotation.
|
||||||
|
// ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
// defer cancel()
|
||||||
|
// srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
||||||
|
// Addr: ":443",
|
||||||
|
// Handler: handler,
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
// srv.ListenAndServeTLS("", "")
|
||||||
|
func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
|
||||||
|
if base.TLSConfig != nil {
|
||||||
|
return nil, errors.New("server TLSConfig is already set")
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := Bootstrap(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, pk, err := CreateSignRequest(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sign, err := client.Sign(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig, err := client.GetServerMutualTLSConfig(ctx, sign, pk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
base.TLSConfig = tlsConfig
|
||||||
|
return base, nil
|
||||||
|
}
|
||||||
|
|
||||||
// BootstrapClient is a helper function that using the given bootstrap token
|
// BootstrapClient is a helper function that using the given bootstrap token
|
||||||
// return an http.Client configured with a Transport prepared to do TLS
|
// return an http.Client configured with a Transport prepared to do TLS
|
||||||
// connections using the client certificate returned by the certificate
|
// connections using the client certificate returned by the certificate
|
||||||
|
|
|
@ -170,6 +170,52 @@ func TestBootstrapServer(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBootstrapServerWithMTLS(t *testing.T) {
|
||||||
|
srv := startCABootstrapServer()
|
||||||
|
defer srv.Close()
|
||||||
|
token := func() string {
|
||||||
|
return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
token string
|
||||||
|
base *http.Server
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{context.Background(), token(), &http.Server{}}, false},
|
||||||
|
{"fail", args{context.Background(), "bad-token", &http.Server{}}, true},
|
||||||
|
{"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := BootstrapServerWithMTLS(tt.args.ctx, tt.args.token, tt.args.base)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("BootstrapServerWithMTLS() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr {
|
||||||
|
if got != nil {
|
||||||
|
t.Errorf("BootstrapServerWithMTLS() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
expected := &http.Server{
|
||||||
|
TLSConfig: got.TLSConfig,
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, expected) {
|
||||||
|
t.Errorf("BootstrapServerWithMTLS() = %v, want %v", got, expected)
|
||||||
|
}
|
||||||
|
if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
|
||||||
|
t.Errorf("BootstrapServerWithMTLS() invalid TLSConfig = %#v", got.TLSConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBootstrapClient(t *testing.T) {
|
func TestBootstrapClient(t *testing.T) {
|
||||||
srv := startCABootstrapServer()
|
srv := startCABootstrapServer()
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
23
ca/tls.go
23
ca/tls.go
|
@ -19,7 +19,7 @@ import (
|
||||||
|
|
||||||
// GetClientTLSConfig returns a tls.Config for client use configured with the
|
// GetClientTLSConfig returns a tls.Config for client use configured with the
|
||||||
// sign certificate, and a new certificate pool with the sign root certificate.
|
// sign certificate, and a new certificate pool with the sign root certificate.
|
||||||
// The certificate will automatically rotate before expiring.
|
// The client certificate will automatically rotate before expiring.
|
||||||
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
||||||
cert, err := TLSCertificate(sign, pk)
|
cert, err := TLSCertificate(sign, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -32,6 +32,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
|
|
||||||
tlsConfig := getDefaultTLSConfig(sign)
|
tlsConfig := getDefaultTLSConfig(sign)
|
||||||
// Note that with GetClientCertificate tlsConfig.Certificates is not used.
|
// Note that with GetClientCertificate tlsConfig.Certificates is not used.
|
||||||
|
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
|
||||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
tlsConfig.PreferServerCipherSuites = true
|
||||||
// Build RootCAs with given root certificate
|
// Build RootCAs with given root certificate
|
||||||
|
@ -39,9 +40,6 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
tlsConfig.RootCAs = pool
|
tlsConfig.RootCAs = pool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse Certificates and build NameToCertificate
|
|
||||||
tlsConfig.BuildNameToCertificate()
|
|
||||||
|
|
||||||
// Update renew function with transport
|
// Update renew function with transport
|
||||||
tr, err := getDefaultTransport(tlsConfig)
|
tr, err := getDefaultTransport(tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,7 +54,8 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
|
|
||||||
// GetServerTLSConfig returns a tls.Config for server use configured with the
|
// GetServerTLSConfig returns a tls.Config for server use configured with the
|
||||||
// sign certificate, and a new certificate pool with the sign root certificate.
|
// sign certificate, and a new certificate pool with the sign root certificate.
|
||||||
// The certificate will automatically rotate before expiring.
|
// The returned tls.Config will only verify the client certificate if provided.
|
||||||
|
// The server certificate will automatically rotate before expiring.
|
||||||
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
||||||
cert, err := TLSCertificate(sign, pk)
|
cert, err := TLSCertificate(sign, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -70,6 +69,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
tlsConfig := getDefaultTLSConfig(sign)
|
tlsConfig := getDefaultTLSConfig(sign)
|
||||||
// Note that GetCertificate will only be called if the client supplies SNI
|
// Note that GetCertificate will only be called if the client supplies SNI
|
||||||
// information or if tlsConfig.Certificates is empty.
|
// information or if tlsConfig.Certificates is empty.
|
||||||
|
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
|
||||||
tlsConfig.GetCertificate = renewer.GetCertificate
|
tlsConfig.GetCertificate = renewer.GetCertificate
|
||||||
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
tlsConfig.PreferServerCipherSuites = true
|
||||||
|
@ -93,6 +93,19 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
return tlsConfig, nil
|
return tlsConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetServerMutualTLSConfig returns a tls.Config for server use configured with
|
||||||
|
// the sign certificate, and a new certificate pool with the sign root certificate.
|
||||||
|
// The returned tls.Config will always require and verify a client certificate.
|
||||||
|
// The server certificate will automatically rotate before expiring.
|
||||||
|
func (c *Client) GetServerMutualTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
||||||
|
tlsConfig, err := c.GetServerTLSConfig(ctx, sign, pk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Transport returns an http.Transport configured to use the client certificate from the sign response.
|
// Transport returns an http.Transport configured to use the client certificate from the sign response.
|
||||||
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
|
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
|
||||||
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)
|
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)
|
||||||
|
|
292
ca/tls_test.go
292
ca/tls_test.go
|
@ -113,20 +113,22 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
clientDomain := "test.domain"
|
clientDomain := "test.domain"
|
||||||
// Create server with given tls.Config
|
// Create server with given tls.Config
|
||||||
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
if req.RequestURI != "/no-cert" {
|
||||||
w.Write([]byte("fail"))
|
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
||||||
t.Error("http.Request.TLS does not have peer certificates")
|
w.Write([]byte("fail"))
|
||||||
return
|
t.Error("http.Request.TLS does not have peer certificates")
|
||||||
}
|
return
|
||||||
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
}
|
||||||
w.Write([]byte("fail"))
|
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
||||||
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
w.Write([]byte("fail"))
|
||||||
return
|
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
||||||
}
|
return
|
||||||
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
}
|
||||||
w.Write([]byte("fail"))
|
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
||||||
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
w.Write([]byte("fail"))
|
||||||
return
|
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
w.Write([]byte("ok"))
|
w.Write([]byte("ok"))
|
||||||
}))
|
}))
|
||||||
|
@ -134,9 +136,11 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
path string
|
||||||
|
wantErr bool
|
||||||
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
||||||
}{
|
}{
|
||||||
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
tr, err := client.Transport(context.Background(), sr, pk)
|
tr, err := client.Transport(context.Background(), sr, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Client.Transport() error = %v", err)
|
t.Errorf("Client.Transport() error = %v", err)
|
||||||
|
@ -146,7 +150,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
Transport: tr,
|
Transport: tr,
|
||||||
}
|
}
|
||||||
}},
|
}},
|
||||||
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
||||||
|
@ -161,6 +165,28 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
Transport: tr,
|
Transport: tr,
|
||||||
}
|
}
|
||||||
}},
|
}},
|
||||||
|
{"ok with no cert", "/no-cert", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
|
root, err := RootCertificate(sr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("RootCertificate() error = %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tlsConfig := getDefaultTLSConfig(sr)
|
||||||
|
tlsConfig.RootCAs = x509.NewCertPool()
|
||||||
|
tlsConfig.RootCAs.AddCert(root)
|
||||||
|
|
||||||
|
tr, err := getDefaultTransport(tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("getDefaultTransport() error = %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"fail with default", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
|
return &http.Client{}
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -168,9 +194,13 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
||||||
client, sr, pk := sign(clientDomain)
|
client, sr, pk := sign(clientDomain)
|
||||||
cli := tt.getClient(t, client, sr, pk)
|
cli := tt.getClient(t, client, sr, pk)
|
||||||
if cli != nil {
|
if cli != nil {
|
||||||
resp, err := cli.Get(srv.URL)
|
resp, err := cli.Get(srv.URL + tt.path)
|
||||||
if err != nil {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Fatalf("http.Client.Get() error = %v", err)
|
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
@ -301,6 +331,230 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
|
||||||
|
client, sr, pk := sign("127.0.0.1")
|
||||||
|
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
clientDomain := "test.domain"
|
||||||
|
// Create server with given tls.Config
|
||||||
|
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.RequestURI != "/no-cert" {
|
||||||
|
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
||||||
|
w.Write([]byte("fail"))
|
||||||
|
t.Error("http.Request.TLS does not have peer certificates")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
||||||
|
w.Write([]byte("fail"))
|
||||||
|
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
||||||
|
w.Write([]byte("fail"))
|
||||||
|
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
wantErr bool
|
||||||
|
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
|
||||||
|
}{
|
||||||
|
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
|
tr, err := client.Transport(context.Background(), sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Client.Transport() error = %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
|
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tr, err := getDefaultTransport(tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("getDefaultTransport() error = %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
{"fail with no cert", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||||
|
root, err := RootCertificate(sr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("RootCertificate() error = %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tlsConfig := getDefaultTLSConfig(sr)
|
||||||
|
tlsConfig.RootCAs = x509.NewCertPool()
|
||||||
|
tlsConfig.RootCAs.AddCert(root)
|
||||||
|
|
||||||
|
tr, err := getDefaultTransport(tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("getDefaultTransport() error = %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
}
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
client, sr, pk := sign(clientDomain)
|
||||||
|
cli := tt.getClient(t, client, sr, pk)
|
||||||
|
if cli != nil {
|
||||||
|
resp, err := cli.Get(srv.URL + tt.path)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b, []byte("ok")) {
|
||||||
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping test in short mode.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start CA
|
||||||
|
ca := startCATestServer()
|
||||||
|
defer ca.Close()
|
||||||
|
|
||||||
|
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
|
||||||
|
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
clientDomain := "test.domain"
|
||||||
|
fingerprints := make(map[string]struct{})
|
||||||
|
|
||||||
|
// Create server with given tls.Config
|
||||||
|
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
|
||||||
|
w.Write([]byte("fail"))
|
||||||
|
t.Error("http.Request.TLS does not have peer certificates")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
|
||||||
|
w.Write([]byte("fail"))
|
||||||
|
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
|
||||||
|
w.Write([]byte("fail"))
|
||||||
|
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Add serial number to check rotation
|
||||||
|
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
|
||||||
|
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
|
||||||
|
w.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
// Clients: transport and tlsConfig
|
||||||
|
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||||
|
tr1, err := client.Transport(context.Background(), sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Client.Transport() error = %v", err)
|
||||||
|
}
|
||||||
|
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
|
||||||
|
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
tr2, err := getDefaultTransport(tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getDefaultTransport() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable keep alives to force TLS handshake
|
||||||
|
tr1.DisableKeepAlives = true
|
||||||
|
tr2.DisableKeepAlives = true
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
client *http.Client
|
||||||
|
}{
|
||||||
|
{"with transport", &http.Client{Transport: tr1}},
|
||||||
|
{"with tlsConfig", &http.Client{Transport: tr2}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resp, err := tt.client.Get(srv.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("http.Client.Get() error = %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b, []byte("ok")) {
|
||||||
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if l := len(fingerprints); l != 2 {
|
||||||
|
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for renewal 40s == 1m-1m/3
|
||||||
|
log.Printf("Sleeping for %s ...\n", 40*time.Second)
|
||||||
|
time.Sleep(40 * time.Second)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run("renewed "+tt.name, func(t *testing.T) {
|
||||||
|
resp, err := tt.client.Get(srv.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("http.Client.Get() error = %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ioutil.RealAdd() error = %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b, []byte("ok")) {
|
||||||
|
t.Errorf("response body unexpected, got %s, want ok", b)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if l := len(fingerprints); l != 4 {
|
||||||
|
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCertificate(t *testing.T) {
|
func TestCertificate(t *testing.T) {
|
||||||
cert := parseCertificate(certPEM)
|
cert := parseCertificate(certPEM)
|
||||||
ok := &api.SignResponse{
|
ok := &api.SignResponse{
|
||||||
|
|
|
@ -142,7 +142,8 @@ password `password` hardcoded, but you can create your own using `step ca init`.
|
||||||
|
|
||||||
These examples show the use of other helper methods, they are simple ways to
|
These examples show the use of other helper methods, they are simple ways to
|
||||||
create TLS configured http.Server and http.Client objects. The methods are
|
create TLS configured http.Server and http.Client objects. The methods are
|
||||||
`BootstrapServer` and `BootstrapClient` and they are used like:
|
`BootstrapServer`, `BootstrapServerWithMTLS` and `BootstrapClient` and they are
|
||||||
|
used like:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// Get a cancelable context to stop the renewal goroutines and timers.
|
// Get a cancelable context to stop the renewal goroutines and timers.
|
||||||
|
@ -159,6 +160,21 @@ if err != nil {
|
||||||
srv.ListenAndServeTLS("", "")
|
srv.ListenAndServeTLS("", "")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Get a cancelable context to stop the renewal goroutines and timers.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
// Create an http.Server that requires a client certificate
|
||||||
|
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
||||||
|
Addr: ":8443",
|
||||||
|
Handler: handler,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
srv.ListenAndServeTLS("", "")
|
||||||
|
```
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// Get a cancelable context to stop the renewal goroutines and timers.
|
// Get a cancelable context to stop the renewal goroutines and timers.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -171,6 +187,9 @@ if err != nil {
|
||||||
resp, err := client.Get("https://localhost:8443")
|
resp, err := client.Get("https://localhost:8443")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We will demonstrate the mTLS configuration if a different example, for this one
|
||||||
|
we will only verify it if provided.
|
||||||
|
|
||||||
To run the example first we will start the certificate authority:
|
To run the example first we will start the certificate authority:
|
||||||
```sh
|
```sh
|
||||||
certificates $ bin/step-ca examples/pki/config/ca.json
|
certificates $ bin/step-ca examples/pki/config/ca.json
|
||||||
|
@ -229,6 +248,47 @@ Server responded: Hello Mike at 2018-11-03 01:52:54.682787 +0000 UTC!!!
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Bootstrap mTLS Client & Server
|
||||||
|
|
||||||
|
This example demonstrates a stricter configuration of the bootstrap server, this
|
||||||
|
one always requires a valid client certificate.
|
||||||
|
|
||||||
|
As always, to run this example will require the Certificate Authority running:
|
||||||
|
```sh
|
||||||
|
certificates $ bin/step-ca examples/pki/config/ca.json
|
||||||
|
2018/11/02 18:29:25 Serving HTTPS on :9000 ...
|
||||||
|
```
|
||||||
|
|
||||||
|
We will start the mTLS server and we will type `password` when step asks for the
|
||||||
|
provisioner password:
|
||||||
|
```sh
|
||||||
|
certificates $ export STEPPATH=examples/pki
|
||||||
|
certificates $ export STEP_CA_URL=https://localhost:9000
|
||||||
|
certificates $ go run examples/bootstrap-mtls-server/server.go $(step ca token localhost)
|
||||||
|
✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com)
|
||||||
|
Please enter the password to decrypt the provisioner key:
|
||||||
|
Listening on :8443 ...
|
||||||
|
```
|
||||||
|
|
||||||
|
For mTLS, curl and curl with the root certificate will fail:
|
||||||
|
```sh
|
||||||
|
certificates $ curl --cacert examples/pki/secrets/root_ca.crt https://localhost:8443
|
||||||
|
curl: (35) error:1401E412:SSL routines:CONNECT_CR_FINISHED:sslv3 alert bad certificate
|
||||||
|
```
|
||||||
|
|
||||||
|
But if we the client with the certificate name Mike we'll see:
|
||||||
|
```sh
|
||||||
|
certificates $ export STEPPATH=examples/pki
|
||||||
|
certificates $ export STEP_CA_URL=https://localhost:9000
|
||||||
|
certificates $ go run examples/bootstrap-client/client.go $(step ca token Mike)
|
||||||
|
✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com)
|
||||||
|
Please enter the password to decrypt the provisioner key:
|
||||||
|
Server responded: Hello Mike at 2018-11-07 21:54:00.140022 +0000 UTC!!!
|
||||||
|
Server responded: Hello Mike at 2018-11-07 21:54:01.140827 +0000 UTC!!!
|
||||||
|
Server responded: Hello Mike at 2018-11-07 21:54:02.141578 +0000 UTC!!!
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
## Certificate rotation
|
## Certificate rotation
|
||||||
|
|
||||||
We can use the bootstrap-server to demonstrate the certificate rotation. We've
|
We can use the bootstrap-server to demonstrate the certificate rotation. We've
|
||||||
|
@ -240,7 +300,7 @@ rotates after approximately two thirds of the duration has passed.
|
||||||
```sh
|
```sh
|
||||||
certificates $ export STEPPATH=examples/pki
|
certificates $ export STEPPATH=examples/pki
|
||||||
certificates $ export STEP_CA_URL=https://localhost:9000
|
certificates $ export STEP_CA_URL=https://localhost:9000
|
||||||
certificates $ go run examples/bootstrap-server/server.go $(step ca token localhost))
|
certificates $ go run examples/bootstrap-server/server.go $(step ca token localhost)
|
||||||
✔ Key ID: YYNxZ0rq0WsT2MlqLCWvgme3jszkmt99KjoGEJJwAKs (mike@smallstep.com)
|
✔ Key ID: YYNxZ0rq0WsT2MlqLCWvgme3jszkmt99KjoGEJJwAKs (mike@smallstep.com)
|
||||||
Please enter the password to decrypt the provisioner key:
|
Please enter the password to decrypt the provisioner key:
|
||||||
Listening on :8443 ...
|
Listening on :8443 ...
|
||||||
|
|
43
examples/bootstrap-mtls-server/server.go
Normal file
43
examples/bootstrap-mtls-server/server.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/ca"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if len(os.Args) != 2 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: %s <token>\n", os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := os.Args[1]
|
||||||
|
|
||||||
|
// make sure to cancel the renew goroutine
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
|
||||||
|
Addr: ":8443",
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
name := "nobody"
|
||||||
|
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||||
|
name = r.TLS.PeerCertificates[0].Subject.CommonName
|
||||||
|
}
|
||||||
|
w.Write([]byte(fmt.Sprintf("Hello %s at %s!!!", name, time.Now().UTC())))
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Listening on :8443 ...")
|
||||||
|
if err := srv.ListenAndServeTLS("", ""); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue