diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 42087985..0e0f0fe3 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -2,12 +2,14 @@ package ca import ( "context" + "crypto" "crypto/tls" "net" "net/http" "strings" "github.com/pkg/errors" + "github.com/smallstep/certificates/api" "go.step.sm/crypto/jose" ) @@ -58,25 +60,21 @@ func Bootstrap(token string) (*Client, error) { // } // resp, err := client.Get("https://internal.smallstep.com") func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) { - client, err := Bootstrap(token) + b, err := createBootstrap(token) if err != nil { return nil, err } - req, pk, err := CreateSignRequest(token) - if err != nil { - return nil, err + // Make sure the tlsConfig has all supported roots on RootCAs. + // + // The roots request is only supported if identity certificates are not + // required. In all cases the current root is also added after applying all + // options too. + if !b.RequireClientAuth { + options = append(options, AddRootsToRootCAs()) } - sign, err := client.Sign(req) - if err != nil { - return nil, err - } - - // Make sure the tlsConfig have all supported roots on RootCAs - options = append(options, AddRootsToRootCAs()) - - transport, err := client.Transport(ctx, sign, pk, options...) + transport, err := b.Client.Transport(ctx, b.SignResponse, b.PrivateKey, options...) if err != nil { return nil, err } @@ -120,25 +118,21 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio return nil, errors.New("server TLSConfig is already set") } - client, err := Bootstrap(token) + b, err := createBootstrap(token) if err != nil { return nil, err } - req, pk, err := CreateSignRequest(token) - if err != nil { - return nil, err + // Make sure the tlsConfig has all supported roots on RootCAs. + // + // The roots request is only supported if identity certificates are not + // required. In all cases the current root is also added after applying all + // options too. + if !b.RequireClientAuth { + options = append(options, AddRootsToCAs()) } - sign, err := client.Sign(req) - if err != nil { - return nil, err - } - - // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs - options = append(options, AddRootsToCAs()) - - tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) + tlsConfig, err := b.Client.GetServerTLSConfig(ctx, b.SignResponse, b.PrivateKey, options...) if err != nil { return nil, err } @@ -172,11 +166,46 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio // ... // register services // srv.Serve(lis) func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) { + b, err := createBootstrap(token) + if err != nil { + return nil, err + } + + // Make sure the tlsConfig has all supported roots on RootCAs. + // + // The roots request is only supported if identity certificates are not + // required. In all cases the current root is also added after applying all + // options too. + if !b.RequireClientAuth { + options = append(options, AddRootsToCAs()) + } + + tlsConfig, err := b.Client.GetServerTLSConfig(ctx, b.SignResponse, b.PrivateKey, options...) + if err != nil { + return nil, err + } + + return tls.NewListener(inner, tlsConfig), nil +} + +type bootstrap struct { + Client *Client + RequireClientAuth bool + SignResponse *api.SignResponse + PrivateKey crypto.PrivateKey +} + +func createBootstrap(token string) (*bootstrap, error) { client, err := Bootstrap(token) if err != nil { return nil, err } + version, err := client.Version() + if err != nil { + return nil, err + } + req, pk, err := CreateSignRequest(token) if err != nil { return nil, err @@ -187,13 +216,10 @@ func BootstrapListener(ctx context.Context, token string, inner net.Listener, op return nil, err } - // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs - options = append(options, AddRootsToCAs()) - - tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) - if err != nil { - return nil, err - } - - return tls.NewListener(inner, tlsConfig), nil + return &bootstrap{ + Client: client, + RequireClientAuth: version.RequireClientAuthentication, + SignResponse: sign, + PrivateKey: pk, + }, nil } diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 7c1bc908..9482d657 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strings" "sync" "testing" "time" @@ -15,6 +16,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" ) @@ -74,6 +76,30 @@ func startCAServer(configFile string) (*CA, string, error) { return ca, caURL, nil } +func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/version" { + api.JSON(w, api.VersionResponse{ + Version: "test", + RequireClientAuthentication: true, + }) + return + } + + for _, s := range nonAuthenticatedPaths { + if strings.HasPrefix(r.URL.Path, s) || strings.HasPrefix(r.URL.Path, "/1.0"+s) { + next.ServeHTTP(w, r) + } + } + isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 + if !isMTLS { + api.WriteError(w, errs.Unauthorized("missing peer certificate")) + } else { + next.ServeHTTP(w, r) + } + }) +} + func generateBootstrapToken(ca, subject, sha string) string { now := time.Now() jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) @@ -171,6 +197,15 @@ func TestBootstrapServerWithoutMTLS(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { ctx context.Context token string @@ -182,6 +217,7 @@ func TestBootstrapServerWithoutMTLS(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), token(), &http.Server{}}, false}, + {"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false}, {"fail", args{context.Background(), "bad-token", &http.Server{}}, true}, {"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true}, } @@ -217,6 +253,15 @@ func TestBootstrapServerWithMTLS(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { ctx context.Context token string @@ -228,6 +273,7 @@ func TestBootstrapServerWithMTLS(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), token(), &http.Server{}}, false}, + {"ok mtls", args{context.Background(), mtlsToken(), &http.Server{}}, false}, {"fail", args{context.Background(), "bad-token", &http.Server{}}, true}, {"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true}, } @@ -263,6 +309,15 @@ func TestBootstrapClient(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { ctx context.Context token string @@ -273,6 +328,7 @@ func TestBootstrapClient(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), token()}, false}, + {"ok mtls", args{context.Background(), mtlsToken()}, false}, {"fail", args{context.Background(), "bad-token"}, true}, } for _, tt := range tests { @@ -541,6 +597,15 @@ func TestBootstrapListener(t *testing.T) { token := func() string { return generateBootstrapToken(srv.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") } + + mtlsServer := startCABootstrapServer() + next := mtlsServer.Config.Handler + mtlsServer.Config.Handler = mTLSMiddleware(next, "/root/", "/sign") + defer mtlsServer.Close() + mtlsToken := func() string { + return generateBootstrapToken(mtlsServer.URL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + } + type args struct { token string } @@ -550,6 +615,7 @@ func TestBootstrapListener(t *testing.T) { wantErr bool }{ {"ok", args{token()}, false}, + {"ok mtls", args{mtlsToken()}, false}, {"fail", args{"bad-token"}, true}, } for _, tt := range tests { diff --git a/ca/tls_options.go b/ca/tls_options.go index b3b2d057..c77b70c3 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -115,6 +115,7 @@ func AddRootCA(cert *x509.Certificate) TLSOption { if ctx.Config.RootCAs == nil { ctx.Config.RootCAs = x509.NewCertPool() } + ctx.hasRootCA = true ctx.Config.RootCAs.AddCert(cert) ctx.mutableConfig.AddImmutableRootCACert(cert) return nil @@ -129,6 +130,7 @@ func AddClientCA(cert *x509.Certificate) TLSOption { if ctx.Config.ClientCAs == nil { ctx.Config.ClientCAs = x509.NewCertPool() } + ctx.hasClientCA = true ctx.Config.ClientCAs.AddCert(cert) ctx.mutableConfig.AddImmutableClientCACert(cert) return nil