From e75a9409a5456be298fe38b74e35994d5f21b337 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 7 Apr 2021 18:57:48 -0700 Subject: [PATCH 1/9] Add experimental support for a TLS over TLS tunnel. --- ca/client.go | 16 ++++++++++------ ca/tls.go | 51 ++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/ca/client.go b/ca/client.go index b9593162..f268155a 100644 --- a/ca/client.go +++ b/ca/client.go @@ -99,12 +99,13 @@ type RetryFunc func(code int) bool type ClientOption func(o *clientOptions) error type clientOptions struct { - transport http.RoundTripper - rootSHA256 string - rootFilename string - rootBundle []byte - certificate tls.Certificate - retryFunc RetryFunc + transport http.RoundTripper + rootSHA256 string + rootFilename string + rootBundle []byte + certificate tls.Certificate + getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error) + retryFunc RetryFunc } func (o *clientOptions) apply(opts []ClientOption) (err error) { @@ -139,6 +140,7 @@ func (o *clientOptions) applyDefaultIdentity() error { return nil } o.certificate = crt + o.getClientCertificate = i.GetClientCertificateFunc() return nil } @@ -193,6 +195,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err } if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} + tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate } case *http2.Transport: if tr.TLSClientConfig == nil { @@ -200,6 +203,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err } if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} + tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate } default: return nil, errors.Errorf("unsupported transport type %T", tr) diff --git a/ca/tls.go b/ca/tls.go index 20a5e504..df0aa2d5 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -10,13 +10,33 @@ import ( "encoding/pem" "net" "net/http" + "os" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" - "golang.org/x/net/http2" ) +// mTLSDialContext will hold the dial context function to use in +// getDefaultTransport. +var mTLSDialContext func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) + +func init() { + // STEP_TLS_TUNNEL is an environment that can be set to do an TLS over mTLS + // tunnel to step-ca using identity credentials. The value must have the + // form "host:port", if the form is not correct, the default dialer will be + // used. This feature is EXPERIMENTAL and might change at any time. + if hostport := os.Getenv("STEP_TLS_TUNNEL"); hostport != "" { + if host, port, err := net.SplitHostPort(hostport); err == nil { + mTLSDialContext = func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) + } + } + } + } +} + // GetClientTLSConfig returns a tls.Config for client use configured with the // sign certificate, and a new certificate pool with the sign root certificate. // The client certificate will automatically rotate before expiring. @@ -242,23 +262,32 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { // http.DefaultTransport, but adds the given tls.Config and configures the // transport for HTTP/2. func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) { - tr := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ + var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error) + if mTLSDialContext == nil { + d := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, + } + dialContext = d.DialContext + } else { + dialContext = mTLSDialContext(&tls.Dialer{ + NetDialer: &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }, + Config: tlsConfig, + }) + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialContext, + ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: tlsConfig, - } - if err := http2.ConfigureTransport(tr); err != nil { - return nil, errors.Wrap(err, "error configuring transport") - } - return tr, nil + }, nil } func getPEM(i interface{}) ([]byte, error) { From 180b5c3e3c3c8a56815a1fb476db65b9f716b0d9 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 8 Apr 2021 11:25:52 -0700 Subject: [PATCH 2/9] Fix typo. --- kms/yubikey/yubikey.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kms/yubikey/yubikey.go b/kms/yubikey/yubikey.go index 19cef55e..2dde244a 100644 --- a/kms/yubikey/yubikey.go +++ b/kms/yubikey/yubikey.go @@ -313,7 +313,7 @@ func getSlotAndName(name string) (piv.Slot, string, error) { s, ok := slotMapping[slotID] if !ok { - return piv.Slot{}, "", errors.Errorf("usupported slot-id '%s'", name) + return piv.Slot{}, "", errors.Errorf("unsupported slot-id '%s'", name) } name = "yubikey:slot-id=" + url.QueryEscape(slotID) From c5234e9c61faf8eb29fdb87201d97f4f14315988 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 21 Apr 2021 16:05:27 -0700 Subject: [PATCH 3/9] Refactor tls tunnel connections. New method will use an identity-like file with the configuration used to create the (m)TLS connection to the tunnel. --- ca/client.go | 11 ++--- ca/identity/identity.go | 95 +++++++++++++++++++++++++++-------- ca/tls.go | 106 +++++++++++++++++++++++----------------- 3 files changed, 139 insertions(+), 73 deletions(-) diff --git a/ca/client.go b/ca/client.go index f268155a..19f758f1 100644 --- a/ca/client.go +++ b/ca/client.go @@ -56,10 +56,7 @@ func newClient(transport http.RoundTripper) *uaClient { func newInsecureClient() *uaClient { return &uaClient{ Client: &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, + Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}), }, } } @@ -292,7 +289,7 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) { MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, - }) + }), nil } func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { @@ -311,7 +308,7 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, - }) + }), nil } func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) { @@ -323,7 +320,7 @@ func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) { MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, - }) + }), nil } // parseEndpoint parses and validates the given endpoint. It supports general diff --git a/ca/identity/identity.go b/ca/identity/identity.go index fa9ebf71..aad139dc 100644 --- a/ca/identity/identity.go +++ b/ca/identity/identity.go @@ -26,9 +26,16 @@ type Type string // Disabled represents a disabled identity type const Disabled Type = "" -// MutualTLS represents the identity using mTLS +// MutualTLS represents the identity using mTLS. const MutualTLS Type = "mTLS" +// TunnelTLS represents an identity using a (m)TLS tunnel. +// +// TunnelTLS can be optionally configured with client certificates and a root +// file with the CAs to trust. By default it will use the system truststore +// instead of the CA truststore. +const TunnelTLS Type = "tTLS" + // DefaultLeeway is the duration for matching not before claims. const DefaultLeeway = 1 * time.Minute @@ -44,19 +51,30 @@ type Identity struct { Type string `json:"type"` Certificate string `json:"crt"` Key string `json:"key"` + + // Host is the tunnel host for a TunnelTLS (tTLS) identity. + Host string `json:"host,omitempty"` + // Root is the CA bundle of root CAs used in TunnelTLS to trust the + // certificate of the host. + Root string `json:"root,omitempty"` +} + +// LoadIdentity loads an identity present in the given filename. +func LoadIdentity(filename string) (*Identity, error) { + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, errors.Wrapf(err, "error reading %s", filename) + } + identity := new(Identity) + if err := json.Unmarshal(b, &identity); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling %s", filename) + } + return identity, nil } // LoadDefaultIdentity loads the default identity. func LoadDefaultIdentity() (*Identity, error) { - b, err := ioutil.ReadFile(IdentityFile) - if err != nil { - return nil, errors.Wrapf(err, "error reading %s", IdentityFile) - } - identity := new(Identity) - if err := json.Unmarshal(b, &identity); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling %s", IdentityFile) - } - return identity, nil + return LoadIdentity(IdentityFile) } // configDir and identityDir are used in WriteDefaultIdentity for testing @@ -81,7 +99,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er keyFilename := filepath.Join(identityDir, "identity_key") // Write certificate - if err := WriteIdentityCertificate(certChain); err != nil { + if err := writeCertificate(certFilename, certChain); err != nil { return err } @@ -116,22 +134,21 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er return nil } -// WriteIdentityCertificate writes the identity certificate in disk. -func WriteIdentityCertificate(certChain []api.Certificate) error { +// writeCertificate writes the given certificate on disk. +func writeCertificate(filename string, certChain []api.Certificate) error { buf := new(bytes.Buffer) - certFilename := filepath.Join(identityDir, "identity.crt") for _, crt := range certChain { block := &pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, } if err := pem.Encode(buf, block); err != nil { - return errors.Wrap(err, "error encoding identity certificate") + return errors.Wrap(err, "error encoding certificate") } } - if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { - return errors.Wrap(err, "error writing identity certificate") + if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil { + return errors.Wrap(err, "error writing certificate") } return nil @@ -144,6 +161,8 @@ func (i *Identity) Kind() Type { return Disabled case "mtls": return MutualTLS + case "ttls": + return TunnelTLS default: return Type(i.Type) } @@ -164,8 +183,26 @@ func (i *Identity) Validate() error { if err := fileExists(i.Certificate); err != nil { return err } - if err := fileExists(i.Key); err != nil { - return err + return fileExists(i.Key) + case TunnelTLS: + if i.Host == "" { + return errors.New("tunnel.crt cannot be empty") + } + if i.Certificate != "" { + if err := fileExists(i.Certificate); err != nil { + return err + } + if i.Key == "" { + return errors.New("tunnel.key cannot be empty") + } + if err := fileExists(i.Key); err != nil { + return err + } + } + if i.Root != "" { + if err := fileExists(i.Root); err != nil { + return err + } } return nil default: @@ -179,7 +216,7 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) { switch i.Kind() { case Disabled: return tls.Certificate{}, nil - case MutualTLS: + case MutualTLS, TunnelTLS: crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) if err != nil { return fail(errors.Wrap(err, "error creating identity certificate")) @@ -215,6 +252,22 @@ func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo) } } +// GetCertPool returns a x509.CertPool if the identity defines a custom root. +func (i *Identity) GetCertPool() (*x509.CertPool, error) { + if i.Root == "" { + return nil, nil + } + b, err := ioutil.ReadFile(i.Root) + if err != nil { + return nil, errors.Wrap(err, "error reading identity root") + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(b) { + return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root) + } + return pool, nil +} + // Renewer is that interface that a renew client must implement. type Renewer interface { GetRootCAs() *x509.CertPool @@ -227,7 +280,7 @@ func (i *Identity) Renew(client Renewer) error { switch i.Kind() { case Disabled: return nil - case MutualTLS: + case MutualTLS, TunnelTLS: cert, err := i.TLSCertificate() if err != nil { return err diff --git a/ca/tls.go b/ca/tls.go index df0aa2d5..22a9fff2 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -15,23 +15,55 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/ca/identity" ) // mTLSDialContext will hold the dial context function to use in // getDefaultTransport. -var mTLSDialContext func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) +var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error) func init() { - // STEP_TLS_TUNNEL is an environment that can be set to do an TLS over mTLS - // tunnel to step-ca using identity credentials. The value must have the - // form "host:port", if the form is not correct, the default dialer will be - // used. This feature is EXPERIMENTAL and might change at any time. - if hostport := os.Getenv("STEP_TLS_TUNNEL"); hostport != "" { - if host, port, err := net.SplitHostPort(hostport); err == nil { - mTLSDialContext = func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) { - return func(ctx context.Context, network, address string) (net.Conn, error) { - return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) - } + // STEP_TLS_TUNNEL is an environment that can be set to do an TLS over + // (m)TLS tunnel to step-ca using identity-like credentials. The value is a + // path to a json file with the tunnel host, certificate, key and root used + // to create the (m)TLS tunnel. + // + // The configuration should look like: + // { + // "type": "tTLS", + // "host": "tunnel.example.com:443" + // "crt": "/path/to/tunnel.crt", + // "key": "/path/to/tunnel.key", + // "root": "/path/to/tunnel-root.crt" + // } + // + // This feature is EXPERIMENTAL and might change at any time. + if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" { + id, err := identity.LoadIdentity(path) + if err != nil { + panic(err) + } + if err := id.Validate(); err != nil { + panic(err) + } + host, port, err := net.SplitHostPort(id.Host) + if err != nil { + panic(err) + } + pool, err := id.GetCertPool() + if err != nil { + panic(err) + } + mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) { + d := &tls.Dialer{ + NetDialer: getDefaultDialer(), + Config: &tls.Config{ + RootCAs: pool, + GetClientCertificate: id.GetClientCertificateFunc(), + }, + } + return func(ctx context.Context, network, address string) (net.Conn, error) { + return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) } } } @@ -71,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Update renew function with transport - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - return nil, nil, err - } + tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) @@ -123,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx) // Update renew function with transport - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - return nil, err - } + tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) @@ -164,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell // buildDialTLS returns an implementation of DialTLS callback in http.Transport. func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) { return func(network, addr string) (net.Conn, error) { - return tls.DialWithDialer(&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }, network, addr, ctx.mutableConfig.TLSConfig()) + return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig()) } } @@ -176,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net // nolint:unused func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) { + d := getDefaultDialer() // TLS dialers do not support context, but we can use the context // deadline if it is set. - var deadline time.Time if t, ok := ctx.Deadline(); ok { - deadline = t + d.Deadline = t } - return tls.DialWithDialer(&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - Deadline: deadline, - DualStack: true, - }, network, addr, tlsCtx.mutableConfig.TLSConfig()) + return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig()) } } @@ -258,25 +275,24 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { } } +// getDefaultDialer returns a new dialer with the default configuration. +func getDefaultDialer() *net.Dialer { + return &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } +} + // getDefaultTransport returns an http.Transport with the same parameters than // http.DefaultTransport, but adds the given tls.Config and configures the // transport for HTTP/2. -func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) { +func getDefaultTransport(tlsConfig *tls.Config) *http.Transport { var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error) if mTLSDialContext == nil { - d := &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - } + d := getDefaultDialer() dialContext = d.DialContext } else { - dialContext = mTLSDialContext(&tls.Dialer{ - NetDialer: &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }, - Config: tlsConfig, - }) + dialContext = mTLSDialContext() } return &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -287,7 +303,7 @@ func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: tlsConfig, - }, nil + } } func getPEM(i interface{}) ([]byte, error) { From e414d0c8eabe3528d43e54ee21ba5a54bf5e09c2 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 21 Apr 2021 16:20:23 -0700 Subject: [PATCH 4/9] Fix unit tests. --- ca/tls_test.go | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/ca/tls_test.go b/ca/tls_test.go index 5513e06d..ac1d84b6 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -181,13 +181,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { t.Errorf("Client.GetClientTLSConfig() error = %v", err) return nil } - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Errorf("getDefaultTransport() error = %v", err) - return nil - } return &http.Client{ - Transport: tr, + Transport: getDefaultTransport(tlsConfig), } }, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}}, {"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { @@ -199,14 +194,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { tlsConfig := getDefaultTLSConfig(sr) tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs.AddCert(root) - - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Errorf("getDefaultTransport() error = %v", err) - return nil - } return &http.Client{ - Transport: tr, + Transport: getDefaultTransport(tlsConfig), } }, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}}, {"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { @@ -288,10 +277,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { if err != nil { t.Fatalf("Client.GetClientTLSConfig() error = %v", err) } - tr2, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Fatalf("getDefaultTransport() error = %v", err) - } + tr2 := getDefaultTransport(tlsConfig) // No client cert root, err := RootCertificate(sr) if err != nil { @@ -300,10 +286,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { tlsConfig = getDefaultTLSConfig(sr) tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs.AddCert(root) - tr3, err := getDefaultTransport(tlsConfig) - if err != nil { - t.Fatalf("getDefaultTransport() error = %v", err) - } + tr3 := getDefaultTransport(tlsConfig) // Disable keep alives to force TLS handshake tr1.DisableKeepAlives = true From 50b9aaec57db7d42d4c5997024fa8090c6aba27c Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 21 Apr 2021 18:07:59 -0700 Subject: [PATCH 5/9] Add new identity tests. --- ca/identity/identity_test.go | 115 ++++++++++++++++++++++-- ca/identity/testdata/config/tunnel.json | 7 ++ 2 files changed, 115 insertions(+), 7 deletions(-) create mode 100644 ca/identity/testdata/config/tunnel.json diff --git a/ca/identity/identity_test.go b/ca/identity/identity_test.go index 7064cead..ce64768c 100644 --- a/ca/identity/identity_test.go +++ b/ca/identity/identity_test.go @@ -63,6 +63,7 @@ func TestIdentity_Kind(t *testing.T) { }{ {"disabled", fields{""}, Disabled}, {"mutualTLS", fields{"mTLS"}, MutualTLS}, + {"tunnelTLS", fields{"tTLS"}, TunnelTLS}, {"unknown", fields{"unknown"}, Type("unknown")}, } for _, tt := range tests { @@ -82,19 +83,27 @@ func TestIdentity_Validate(t *testing.T) { Type string Certificate string Key string + Host string + Root string } tests := []struct { name string fields fields wantErr bool }{ - {"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, false}, + {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false}, + {"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false}, {"ok disabled", fields{}, false}, - {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, true}, - {"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key"}, true}, - {"fail key", fields{"mTLS", "testdata/identity/identity.crt", ""}, true}, - {"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key"}, true}, - {"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key"}, true}, + {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true}, + {"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true}, + {"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true}, + {"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, + {"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true}, + {"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, + {"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true}, + {"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true}, + {"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true}, + {"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -102,6 +111,8 @@ func TestIdentity_Validate(t *testing.T) { Type: tt.fields.Type, Certificate: tt.fields.Certificate, Key: tt.fields.Key, + Host: tt.fields.Host, + Root: tt.fields.Root, } if err := i.Validate(); (err != nil) != tt.wantErr { t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr) @@ -127,7 +138,8 @@ func TestIdentity_TLSCertificate(t *testing.T) { want tls.Certificate wantErr bool }{ - {"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, + {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, + {"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, {"ok disabled", fields{}, tls.Certificate{}, false}, {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, @@ -255,6 +267,95 @@ func TestWriteDefaultIdentity(t *testing.T) { } } +func TestIdentity_GetClientCertificateFunc(t *testing.T) { + expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key") + if err != nil { + t.Fatal(err) + } + + type fields struct { + Type string + Certificate string + Key string + Host string + Root string + } + tests := []struct { + name string + fields fields + want *tls.Certificate + wantErr bool + }{ + {"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false}, + {"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false}, + {"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true}, + {"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &Identity{ + Type: tt.fields.Type, + Certificate: tt.fields.Certificate, + Key: tt.fields.Key, + Host: tt.fields.Host, + Root: tt.fields.Root, + } + fn := i.GetClientCertificateFunc() + got, err := fn(&tls.CertificateRequestInfo{}) + if (err != nil) != tt.wantErr { + t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIdentity_GetCertPool(t *testing.T) { + type fields struct { + Type string + Certificate string + Key string + Host string + Root string + } + tests := []struct { + name string + fields fields + wantSubjects [][]byte + wantErr bool + }{ + {"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false}, + {"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false}, + {"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true}, + {"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &Identity{ + Type: tt.fields.Type, + Certificate: tt.fields.Certificate, + Key: tt.fields.Key, + Host: tt.fields.Host, + Root: tt.fields.Root, + } + got, err := i.GetCertPool() + if (err != nil) != tt.wantErr { + t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil { + subjects := got.Subjects() + if !reflect.DeepEqual(subjects, tt.wantSubjects) { + t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects) + } + } + + }) + } +} + type renewer struct { pool *x509.CertPool sign *api.SignResponse diff --git a/ca/identity/testdata/config/tunnel.json b/ca/identity/testdata/config/tunnel.json new file mode 100644 index 00000000..49c76a55 --- /dev/null +++ b/ca/identity/testdata/config/tunnel.json @@ -0,0 +1,7 @@ +{ + "type": "mTLS", + "crt": "testdata/identity/identity.crt", + "key": "testdata/identity/identity_key", + "host": "tunnel:443", + "root": "testdata/certs/root_ca.crt" +} \ No newline at end of file From e6833ecee32cd0ecac01f6535bf1db91d90fcd30 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 26 Apr 2021 12:28:51 -0700 Subject: [PATCH 6/9] Add extension of db.AuthDB to store the fullchain. Add a temporary solution to allow an extension of an db.AuthDB interface that logs the fullchain of certificates instead of just the leaf. --- authority/tls.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/authority/tls.go b/authority/tls.go index c848d188..bc160ad0 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -156,14 +156,15 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error creating certificate", opts...) } - if err = a.db.StoreCertificate(resp.Certificate); err != nil { + fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) + if err = a.storeCertificate(fullchain); err != nil { if err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error storing certificate in db", opts...) } } - return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil + return fullchain, nil } // Renew creates a new Certificate identical to the old certificate, except @@ -261,13 +262,29 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...) } - if err = a.db.StoreCertificate(resp.Certificate); err != nil { + fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) + if err = a.storeCertificate(fullchain); err != nil { if err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...) } } - return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil + return fullchain, nil +} + +// storeCertificate allows to use an extension of the db.AuthDB interface that +// can log the full chain of certificates. +// +// TODO: at some point we should replace the db.AuthDB interface to implement +// `StoreCertificate(...*x509.Certificate) error` instead of just +// `StoreCertificate(*x509.Certificate) error`. +func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { + if s, ok := a.db.(interface { + StoreCertificateChain(...*x509.Certificate) error + }); ok { + return s.StoreCertificateChain(fullchain...) + } + return a.db.StoreCertificate(fullchain[0]) } // RevokeOptions are the options for the Revoke API. From 1328aa3e47976517d4ca47a881240efa350735d1 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 26 Apr 2021 18:45:46 -0700 Subject: [PATCH 7/9] Fix review comments. --- ca/identity/identity.go | 2 +- ca/tls.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ca/identity/identity.go b/ca/identity/identity.go index aad139dc..08a70c7f 100644 --- a/ca/identity/identity.go +++ b/ca/identity/identity.go @@ -186,7 +186,7 @@ func (i *Identity) Validate() error { return fileExists(i.Key) case TunnelTLS: if i.Host == "" { - return errors.New("tunnel.crt cannot be empty") + return errors.New("tunnel.host cannot be empty") } if i.Certificate != "" { if err := fileExists(i.Certificate); err != nil { diff --git a/ca/tls.go b/ca/tls.go index 22a9fff2..2d9b8f92 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -23,10 +23,10 @@ import ( var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error) func init() { - // STEP_TLS_TUNNEL is an environment that can be set to do an TLS over - // (m)TLS tunnel to step-ca using identity-like credentials. The value is a - // path to a json file with the tunnel host, certificate, key and root used - // to create the (m)TLS tunnel. + // STEP_TLS_TUNNEL is an environment variable that can be set to do an TLS + // over (m)TLS tunnel to step-ca using identity-like credentials. The value + // is a path to a json file with the tunnel host, certificate, key and root + // used to create the (m)TLS tunnel. // // The configuration should look like: // { From 2cbaee9c1dd67644b93f0316c48ebd294430cd0d Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 29 Apr 2021 15:55:22 -0700 Subject: [PATCH 8/9] Allow to use an alternative interface to store renewed certs. This can be useful to know if a certificate has been renewed and link one certificate with the 'parent'. --- authority/tls.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/authority/tls.go b/authority/tls.go index bc160ad0..b7b2f936 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -263,7 +263,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 } fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) - if err = a.storeCertificate(fullchain); err != nil { + if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil { if err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...) } @@ -287,6 +287,19 @@ func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { return a.db.StoreCertificate(fullchain[0]) } +// storeRenewedCertificate allows to use an extension of the db.AuthDB interface +// that can log if a certificate has been renewed or rekeyed. +// +// TODO: at some point we should implement this in the standard implementation. +func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error { + if s, ok := a.db.(interface { + StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error + }); ok { + return s.StoreRenewedCertificate(oldCert, fullchain...) + } + return a.db.StoreCertificate(fullchain[0]) +} + // RevokeOptions are the options for the Revoke API. type RevokeOptions struct { Serial string From 5846314f881561a7ecda4782f6621fc546d3d47d Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 29 Apr 2021 16:06:45 -0700 Subject: [PATCH 9/9] Add missing Rekey method to the ca.Client Fixes #315 --- ca/client.go | 30 +++++++++++++++++++++ ca/client_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/ca/client.go b/ca/client.go index 19f758f1..2292c41e 100644 --- a/ca/client.go +++ b/ca/client.go @@ -616,6 +616,36 @@ retry: return &sign, nil } +// Rekey performs the rekey request to the CA and returns the api.SignResponse +// struct. +func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { + var retried bool + body, err := json.Marshal(req) + if err != nil { + return nil, errors.Wrap(err, "error marshaling request") + } + + u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) + client := &http.Client{Transport: tr} +retry: + resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body)) + if err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; client POST %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readError(resp.Body) + } + var sign api.SignResponse + if err := readJSON(resp.Body, &sign); err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; error reading %s", u) + } + return &sign, nil +} + // Revoke performs the revoke request to the CA and returns the api.RevokeResponse // struct. func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) { diff --git a/ca/client_test.go b/ca/client_test.go index dbba4d4c..30669e6e 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -529,6 +529,75 @@ func TestClient_Renew(t *testing.T) { } } +func TestClient_Rekey(t *testing.T) { + ok := &api.SignResponse{ + ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + CertChainPEM: []api.Certificate{ + {Certificate: parseCertificate(certPEM)}, + {Certificate: parseCertificate(rootPEM)}, + }, + } + + request := &api.RekeyRequest{ + CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, + } + + tests := []struct { + name string + request *api.RekeyRequest + response interface{} + responseCode int + wantErr bool + err error + }{ + {"ok", request, ok, 200, false, nil}, + {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + api.JSONStatus(w, tt.response, tt.responseCode) + }) + + got, err := c.Rekey(tt.request, nil) + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Renew() = %v, want nil", got) + } + + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, tt.err.Error(), err.Error()) + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Renew() = %v, want %v", got, tt.response) + } + } + }) + } +} + func TestClient_Provisioners(t *testing.T) { ok := &api.ProvisionersResponse{ Provisioners: provisioner.List{},