Add helpers to add direct support for mTLS.

This commit is contained in:
Mariano Cano 2018-11-07 16:07:35 -08:00
parent 272bbc57dd
commit 9c64dbda9a
6 changed files with 495 additions and 26 deletions

View file

@ -90,6 +90,59 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server) (*htt
return base, nil
}
// BootstrapServerWithMTLS is a helper function that using the given token
// returns the given http.Server configured with a TLS certificate signed by the
// Certificate Authority, this server will always require and verify a client
// certificate. By default the server will kick off a routine that will renew
// the certificate after 2/3rd of the certificate's lifetime has expired.
//
// Usage:
// // Default example with certificate rotation.
// srv, err := ca.BootstrapServerWithMTLS(context.Background(), token, &http.Server{
// Addr: ":443",
// Handler: handler,
// })
//
// // Example canceling automatic certificate rotation.
// ctx, cancel := context.WithCancel(context.Background())
// defer cancel()
// srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
// Addr: ":443",
// Handler: handler,
// })
// if err != nil {
// return err
// }
// srv.ListenAndServeTLS("", "")
func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
if base.TLSConfig != nil {
return nil, errors.New("server TLSConfig is already set")
}
client, err := Bootstrap(token)
if err != nil {
return nil, err
}
req, pk, err := CreateSignRequest(token)
if err != nil {
return nil, err
}
sign, err := client.Sign(req)
if err != nil {
return nil, err
}
tlsConfig, err := client.GetServerMutualTLSConfig(ctx, sign, pk)
if err != nil {
return nil, err
}
base.TLSConfig = tlsConfig
return base, nil
}
// BootstrapClient is a helper function that using the given bootstrap token
// return an http.Client configured with a Transport prepared to do TLS
// connections using the client certificate returned by the certificate

View file

@ -170,6 +170,52 @@ func TestBootstrapServer(t *testing.T) {
}
}
func TestBootstrapServerWithMTLS(t *testing.T) {
srv := startCABootstrapServer()
defer srv.Close()
token := func() string {
return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
}
type args struct {
ctx context.Context
token string
base *http.Server
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{context.Background(), token(), &http.Server{}}, false},
{"fail", args{context.Background(), "bad-token", &http.Server{}}, true},
{"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := BootstrapServerWithMTLS(tt.args.ctx, tt.args.token, tt.args.base)
if (err != nil) != tt.wantErr {
t.Errorf("BootstrapServerWithMTLS() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
if got != nil {
t.Errorf("BootstrapServerWithMTLS() = %v, want nil", got)
}
} else {
expected := &http.Server{
TLSConfig: got.TLSConfig,
}
if !reflect.DeepEqual(got, expected) {
t.Errorf("BootstrapServerWithMTLS() = %v, want %v", got, expected)
}
if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
t.Errorf("BootstrapServerWithMTLS() invalid TLSConfig = %#v", got.TLSConfig)
}
}
})
}
}
func TestBootstrapClient(t *testing.T) {
srv := startCABootstrapServer()
defer srv.Close()

View file

@ -19,7 +19,7 @@ import (
// GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The certificate will automatically rotate before expiring.
// The client certificate will automatically rotate before expiring.
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
cert, err := TLSCertificate(sign, pk)
if err != nil {
@ -32,6 +32,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig := getDefaultTLSConfig(sign)
// Note that with GetClientCertificate tlsConfig.Certificates is not used.
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Build RootCAs with given root certificate
@ -39,9 +40,6 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig.RootCAs = pool
}
// Parse Certificates and build NameToCertificate
tlsConfig.BuildNameToCertificate()
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
@ -56,7 +54,8 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
// GetServerTLSConfig returns a tls.Config for server use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The certificate will automatically rotate before expiring.
// The returned tls.Config will only verify the client certificate if provided.
// The server certificate will automatically rotate before expiring.
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
cert, err := TLSCertificate(sign, pk)
if err != nil {
@ -70,6 +69,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig := getDefaultTLSConfig(sign)
// Note that GetCertificate will only be called if the client supplies SNI
// information or if tlsConfig.Certificates is empty.
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
tlsConfig.GetCertificate = renewer.GetCertificate
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
@ -93,6 +93,19 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
return tlsConfig, nil
}
// GetServerMutualTLSConfig returns a tls.Config for server use configured with
// the sign certificate, and a new certificate pool with the sign root certificate.
// The returned tls.Config will always require and verify a client certificate.
// The server certificate will automatically rotate before expiring.
func (c *Client) GetServerMutualTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
tlsConfig, err := c.GetServerTLSConfig(ctx, sign, pk)
if err != nil {
return nil, err
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
return tlsConfig, nil
}
// Transport returns an http.Transport configured to use the client certificate from the sign response.
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)

View file

@ -113,6 +113,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
clientDomain := "test.domain"
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != "/no-cert" {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
@ -128,15 +129,18 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
}
w.Write([]byte("ok"))
}))
defer srv.Close()
tests := []struct {
name string
path string
wantErr bool
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
}{
{"with transport", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tr, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.Transport() error = %v", err)
@ -146,7 +150,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
Transport: tr,
}
}},
{"with tlsConfig", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
@ -161,6 +165,28 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
Transport: tr,
}
}},
{"ok with no cert", "/no-cert", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
root, err := RootCertificate(sr)
if err != nil {
t.Errorf("RootCertificate() error = %v", err)
return nil
}
tlsConfig := getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
{"fail with default", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
return &http.Client{}
}},
}
for _, tt := range tests {
@ -168,9 +194,13 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
client, sr, pk := sign(clientDomain)
cli := tt.getClient(t, client, sr, pk)
if cli != nil {
resp, err := cli.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
resp, err := cli.Get(srv.URL + tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
@ -301,6 +331,230 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
}
}
func TestClient_GetServerMutualTLSConfig_http(t *testing.T) {
client, sr, pk := sign("127.0.0.1")
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.RequestURI != "/no-cert" {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
}
w.Write([]byte("ok"))
}))
defer srv.Close()
tests := []struct {
name string
path string
wantErr bool
getClient func(*testing.T, *Client, *api.SignResponse, crypto.PrivateKey) *http.Client
}{
{"with transport", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tr, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.Transport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
{"with tlsConfig", "", false, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
tlsConfig, err := client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
return nil
}
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
{"fail with no cert", "/no-cert", true, func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
root, err := RootCertificate(sr)
if err != nil {
t.Errorf("RootCertificate() error = %v", err)
return nil
}
tlsConfig := getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
}
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, sr, pk := sign(clientDomain)
cli := tt.getClient(t, client, sr, pk)
if cli != nil {
resp, err := cli.Get(srv.URL + tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("http.Client.Get() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
}
})
}
}
func TestClient_GetServerMutualTLSConfig_renew(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
// Start CA
ca := startCATestServer()
defer ca.Close()
client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute)
tlsConfig, err := client.GetServerMutualTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetServerTLSConfig() error = %v", err)
}
clientDomain := "test.domain"
fingerprints := make(map[string]struct{})
// Create server with given tls.Config
srv := startTestServer(tlsConfig, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.TLS == nil || len(req.TLS.PeerCertificates) == 0 {
w.Write([]byte("fail"))
t.Error("http.Request.TLS does not have peer certificates")
return
}
if req.TLS.PeerCertificates[0].Subject.CommonName != clientDomain {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].Subject.CommonName = %s, wants %s", req.TLS.PeerCertificates[0].Subject.CommonName, clientDomain)
return
}
if !reflect.DeepEqual(req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain}) {
w.Write([]byte("fail"))
t.Errorf("http.Request.TLS.PeerCertificates[0].DNSNames %v, wants %v", req.TLS.PeerCertificates[0].DNSNames, []string{clientDomain})
return
}
// Add serial number to check rotation
sum := sha256.Sum256(req.TLS.PeerCertificates[0].Raw)
fingerprints[hex.EncodeToString(sum[:])] = struct{}{}
w.Write([]byte("ok"))
}))
defer srv.Close()
// Clients: transport and tlsConfig
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tr1, err := client.Transport(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.Transport() error = %v", err)
}
client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute)
tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk)
if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true
tr2.DisableKeepAlives = true
tests := []struct {
name string
client *http.Client
}{
{"with transport", &http.Client{Transport: tr1}},
{"with tlsConfig", &http.Client{Transport: tr2}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
if l := len(fingerprints); l != 2 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
}
// Wait for renewal 40s == 1m-1m/3
log.Printf("Sleeping for %s ...\n", 40*time.Second)
time.Sleep(40 * time.Second)
for _, tt := range tests {
t.Run("renewed "+tt.name, func(t *testing.T) {
resp, err := tt.client.Get(srv.URL)
if err != nil {
t.Fatalf("http.Client.Get() error = %v", err)
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err)
}
if !bytes.Equal(b, []byte("ok")) {
t.Errorf("response body unexpected, got %s, want ok", b)
}
})
}
if l := len(fingerprints); l != 4 {
t.Errorf("number of fingerprints unexpected, got %d, want 4", l)
}
}
func TestCertificate(t *testing.T) {
cert := parseCertificate(certPEM)
ok := &api.SignResponse{

View file

@ -142,7 +142,8 @@ password `password` hardcoded, but you can create your own using `step ca init`.
These examples show the use of other helper methods, they are simple ways to
create TLS configured http.Server and http.Client objects. The methods are
`BootstrapServer` and `BootstrapClient` and they are used like:
`BootstrapServer`, `BootstrapServerWithMTLS` and `BootstrapClient` and they are
used like:
```go
// Get a cancelable context to stop the renewal goroutines and timers.
@ -159,6 +160,21 @@ if err != nil {
srv.ListenAndServeTLS("", "")
```
```go
// Get a cancelable context to stop the renewal goroutines and timers.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create an http.Server that requires a client certificate
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
Addr: ":8443",
Handler: handler,
})
if err != nil {
panic(err)
}
srv.ListenAndServeTLS("", "")
```
```go
// Get a cancelable context to stop the renewal goroutines and timers.
ctx, cancel := context.WithCancel(context.Background())
@ -171,6 +187,9 @@ if err != nil {
resp, err := client.Get("https://localhost:8443")
```
We will demonstrate the mTLS configuration if a different example, for this one
we will only verify it if provided.
To run the example first we will start the certificate authority:
```sh
certificates $ bin/step-ca examples/pki/config/ca.json
@ -229,6 +248,47 @@ Server responded: Hello Mike at 2018-11-03 01:52:54.682787 +0000 UTC!!!
...
```
## Bootstrap mTLS Client & Server
This example demonstrates a stricter configuration of the bootstrap server, this
one always requires a valid client certificate.
As always, to run this example will require the Certificate Authority running:
```sh
certificates $ bin/step-ca examples/pki/config/ca.json
2018/11/02 18:29:25 Serving HTTPS on :9000 ...
```
We will start the mTLS server and we will type `password` when step asks for the
provisioner password:
```sh
certificates $ export STEPPATH=examples/pki
certificates $ export STEP_CA_URL=https://localhost:9000
certificates $ go run examples/bootstrap-mtls-server/server.go $(step ca token localhost)
✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com)
Please enter the password to decrypt the provisioner key:
Listening on :8443 ...
```
For mTLS, curl and curl with the root certificate will fail:
```sh
certificates $ curl --cacert examples/pki/secrets/root_ca.crt https://localhost:8443
curl: (35) error:1401E412:SSL routines:CONNECT_CR_FINISHED:sslv3 alert bad certificate
```
But if we the client with the certificate name Mike we'll see:
```sh
certificates $ export STEPPATH=examples/pki
certificates $ export STEP_CA_URL=https://localhost:9000
certificates $ go run examples/bootstrap-client/client.go $(step ca token Mike)
✔ Key ID: DmAtZt2EhmZr_iTJJ387fr4Md2NbzMXGdXQNW1UWPXk (mariano@smallstep.com)
Please enter the password to decrypt the provisioner key:
Server responded: Hello Mike at 2018-11-07 21:54:00.140022 +0000 UTC!!!
Server responded: Hello Mike at 2018-11-07 21:54:01.140827 +0000 UTC!!!
Server responded: Hello Mike at 2018-11-07 21:54:02.141578 +0000 UTC!!!
...
```
## Certificate rotation
We can use the bootstrap-server to demonstrate the certificate rotation. We've
@ -240,7 +300,7 @@ rotates after approximately two thirds of the duration has passed.
```sh
certificates $ export STEPPATH=examples/pki
certificates $ export STEP_CA_URL=https://localhost:9000
certificates $ go run examples/bootstrap-server/server.go $(step ca token localhost))
certificates $ go run examples/bootstrap-server/server.go $(step ca token localhost)
✔ Key ID: YYNxZ0rq0WsT2MlqLCWvgme3jszkmt99KjoGEJJwAKs (mike@smallstep.com)
Please enter the password to decrypt the provisioner key:
Listening on :8443 ...

View file

@ -0,0 +1,43 @@
package main
import (
"context"
"fmt"
"net/http"
"os"
"time"
"github.com/smallstep/certificates/ca"
)
func main() {
if len(os.Args) != 2 {
fmt.Fprintf(os.Stderr, "Usage: %s <token>\n", os.Args[0])
os.Exit(1)
}
token := os.Args[1]
// make sure to cancel the renew goroutine
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
Addr: ":8443",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
name := "nobody"
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
name = r.TLS.PeerCertificates[0].Subject.CommonName
}
w.Write([]byte(fmt.Sprintf("Hello %s at %s!!!", name, time.Now().UTC())))
}),
})
if err != nil {
panic(err)
}
fmt.Println("Listening on :8443 ...")
if err := srv.ListenAndServeTLS("", ""); err != nil {
panic(err)
}
}