diff --git a/ca/client.go b/ca/client.go index f268155a..19f758f1 100644 --- a/ca/client.go +++ b/ca/client.go @@ -56,10 +56,7 @@ func newClient(transport http.RoundTripper) *uaClient { func newInsecureClient() *uaClient { return &uaClient{ Client: &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, + Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}), }, } } @@ -292,7 +289,7 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) { MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, - }) + }), nil } func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { @@ -311,7 +308,7 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, - }) + }), nil } func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) { @@ -323,7 +320,7 @@ func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) { MinVersion: tls.VersionTLS12, PreferServerCipherSuites: true, RootCAs: pool, - }) + }), nil } // parseEndpoint parses and validates the given endpoint. It supports general diff --git a/ca/identity/identity.go b/ca/identity/identity.go index fa9ebf71..aad139dc 100644 --- a/ca/identity/identity.go +++ b/ca/identity/identity.go @@ -26,9 +26,16 @@ type Type string // Disabled represents a disabled identity type const Disabled Type = "" -// MutualTLS represents the identity using mTLS +// MutualTLS represents the identity using mTLS. const MutualTLS Type = "mTLS" +// TunnelTLS represents an identity using a (m)TLS tunnel. +// +// TunnelTLS can be optionally configured with client certificates and a root +// file with the CAs to trust. By default it will use the system truststore +// instead of the CA truststore. +const TunnelTLS Type = "tTLS" + // DefaultLeeway is the duration for matching not before claims. const DefaultLeeway = 1 * time.Minute @@ -44,19 +51,30 @@ type Identity struct { Type string `json:"type"` Certificate string `json:"crt"` Key string `json:"key"` + + // Host is the tunnel host for a TunnelTLS (tTLS) identity. + Host string `json:"host,omitempty"` + // Root is the CA bundle of root CAs used in TunnelTLS to trust the + // certificate of the host. + Root string `json:"root,omitempty"` +} + +// LoadIdentity loads an identity present in the given filename. +func LoadIdentity(filename string) (*Identity, error) { + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, errors.Wrapf(err, "error reading %s", filename) + } + identity := new(Identity) + if err := json.Unmarshal(b, &identity); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling %s", filename) + } + return identity, nil } // LoadDefaultIdentity loads the default identity. func LoadDefaultIdentity() (*Identity, error) { - b, err := ioutil.ReadFile(IdentityFile) - if err != nil { - return nil, errors.Wrapf(err, "error reading %s", IdentityFile) - } - identity := new(Identity) - if err := json.Unmarshal(b, &identity); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling %s", IdentityFile) - } - return identity, nil + return LoadIdentity(IdentityFile) } // configDir and identityDir are used in WriteDefaultIdentity for testing @@ -81,7 +99,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er keyFilename := filepath.Join(identityDir, "identity_key") // Write certificate - if err := WriteIdentityCertificate(certChain); err != nil { + if err := writeCertificate(certFilename, certChain); err != nil { return err } @@ -116,22 +134,21 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er return nil } -// WriteIdentityCertificate writes the identity certificate in disk. -func WriteIdentityCertificate(certChain []api.Certificate) error { +// writeCertificate writes the given certificate on disk. +func writeCertificate(filename string, certChain []api.Certificate) error { buf := new(bytes.Buffer) - certFilename := filepath.Join(identityDir, "identity.crt") for _, crt := range certChain { block := &pem.Block{ Type: "CERTIFICATE", Bytes: crt.Raw, } if err := pem.Encode(buf, block); err != nil { - return errors.Wrap(err, "error encoding identity certificate") + return errors.Wrap(err, "error encoding certificate") } } - if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { - return errors.Wrap(err, "error writing identity certificate") + if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil { + return errors.Wrap(err, "error writing certificate") } return nil @@ -144,6 +161,8 @@ func (i *Identity) Kind() Type { return Disabled case "mtls": return MutualTLS + case "ttls": + return TunnelTLS default: return Type(i.Type) } @@ -164,8 +183,26 @@ func (i *Identity) Validate() error { if err := fileExists(i.Certificate); err != nil { return err } - if err := fileExists(i.Key); err != nil { - return err + return fileExists(i.Key) + case TunnelTLS: + if i.Host == "" { + return errors.New("tunnel.crt cannot be empty") + } + if i.Certificate != "" { + if err := fileExists(i.Certificate); err != nil { + return err + } + if i.Key == "" { + return errors.New("tunnel.key cannot be empty") + } + if err := fileExists(i.Key); err != nil { + return err + } + } + if i.Root != "" { + if err := fileExists(i.Root); err != nil { + return err + } } return nil default: @@ -179,7 +216,7 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) { switch i.Kind() { case Disabled: return tls.Certificate{}, nil - case MutualTLS: + case MutualTLS, TunnelTLS: crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) if err != nil { return fail(errors.Wrap(err, "error creating identity certificate")) @@ -215,6 +252,22 @@ func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo) } } +// GetCertPool returns a x509.CertPool if the identity defines a custom root. +func (i *Identity) GetCertPool() (*x509.CertPool, error) { + if i.Root == "" { + return nil, nil + } + b, err := ioutil.ReadFile(i.Root) + if err != nil { + return nil, errors.Wrap(err, "error reading identity root") + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(b) { + return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root) + } + return pool, nil +} + // Renewer is that interface that a renew client must implement. type Renewer interface { GetRootCAs() *x509.CertPool @@ -227,7 +280,7 @@ func (i *Identity) Renew(client Renewer) error { switch i.Kind() { case Disabled: return nil - case MutualTLS: + case MutualTLS, TunnelTLS: cert, err := i.TLSCertificate() if err != nil { return err diff --git a/ca/tls.go b/ca/tls.go index df0aa2d5..22a9fff2 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -15,23 +15,55 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/ca/identity" ) // mTLSDialContext will hold the dial context function to use in // getDefaultTransport. -var mTLSDialContext func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) +var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error) func init() { - // STEP_TLS_TUNNEL is an environment that can be set to do an TLS over mTLS - // tunnel to step-ca using identity credentials. The value must have the - // form "host:port", if the form is not correct, the default dialer will be - // used. This feature is EXPERIMENTAL and might change at any time. - if hostport := os.Getenv("STEP_TLS_TUNNEL"); hostport != "" { - if host, port, err := net.SplitHostPort(hostport); err == nil { - mTLSDialContext = func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) { - return func(ctx context.Context, network, address string) (net.Conn, error) { - return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) - } + // STEP_TLS_TUNNEL is an environment that can be set to do an TLS over + // (m)TLS tunnel to step-ca using identity-like credentials. The value is a + // path to a json file with the tunnel host, certificate, key and root used + // to create the (m)TLS tunnel. + // + // The configuration should look like: + // { + // "type": "tTLS", + // "host": "tunnel.example.com:443" + // "crt": "/path/to/tunnel.crt", + // "key": "/path/to/tunnel.key", + // "root": "/path/to/tunnel-root.crt" + // } + // + // This feature is EXPERIMENTAL and might change at any time. + if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" { + id, err := identity.LoadIdentity(path) + if err != nil { + panic(err) + } + if err := id.Validate(); err != nil { + panic(err) + } + host, port, err := net.SplitHostPort(id.Host) + if err != nil { + panic(err) + } + pool, err := id.GetCertPool() + if err != nil { + panic(err) + } + mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) { + d := &tls.Dialer{ + NetDialer: getDefaultDialer(), + Config: &tls.Config{ + RootCAs: pool, + GetClientCertificate: id.GetClientCertificateFunc(), + }, + } + return func(ctx context.Context, network, address string) (net.Conn, error) { + return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) } } } @@ -71,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Update renew function with transport - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - return nil, nil, err - } + tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) @@ -123,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx) // Update renew function with transport - tr, err := getDefaultTransport(tlsConfig) - if err != nil { - return nil, err - } + tr := getDefaultTransport(tlsConfig) // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) @@ -164,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell // buildDialTLS returns an implementation of DialTLS callback in http.Transport. func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) { return func(network, addr string) (net.Conn, error) { - return tls.DialWithDialer(&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }, network, addr, ctx.mutableConfig.TLSConfig()) + return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig()) } } @@ -176,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net // nolint:unused func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) { + d := getDefaultDialer() // TLS dialers do not support context, but we can use the context // deadline if it is set. - var deadline time.Time if t, ok := ctx.Deadline(); ok { - deadline = t + d.Deadline = t } - return tls.DialWithDialer(&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - Deadline: deadline, - DualStack: true, - }, network, addr, tlsCtx.mutableConfig.TLSConfig()) + return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig()) } } @@ -258,25 +275,24 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { } } +// getDefaultDialer returns a new dialer with the default configuration. +func getDefaultDialer() *net.Dialer { + return &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } +} + // getDefaultTransport returns an http.Transport with the same parameters than // http.DefaultTransport, but adds the given tls.Config and configures the // transport for HTTP/2. -func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) { +func getDefaultTransport(tlsConfig *tls.Config) *http.Transport { var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error) if mTLSDialContext == nil { - d := &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - } + d := getDefaultDialer() dialContext = d.DialContext } else { - dialContext = mTLSDialContext(&tls.Dialer{ - NetDialer: &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }, - Config: tlsConfig, - }) + dialContext = mTLSDialContext() } return &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -287,7 +303,7 @@ func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: tlsConfig, - }, nil + } } func getPEM(i interface{}) ([]byte, error) {