Refactor tls tunnel connections.

New method will use an identity-like file with the configuration
used to create the (m)TLS connection to the tunnel.
This commit is contained in:
Mariano Cano 2021-04-21 16:05:27 -07:00
parent 180b5c3e3c
commit c5234e9c61
3 changed files with 139 additions and 73 deletions

View file

@ -56,10 +56,7 @@ func newClient(transport http.RoundTripper) *uaClient {
func newInsecureClient() *uaClient {
return &uaClient{
Client: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}),
},
}
}
@ -292,7 +289,7 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}), nil
}
func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
@ -311,7 +308,7 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}), nil
}
func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
@ -323,7 +320,7 @@ func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}), nil
}
// parseEndpoint parses and validates the given endpoint. It supports general

View file

@ -26,9 +26,16 @@ type Type string
// Disabled represents a disabled identity type
const Disabled Type = ""
// MutualTLS represents the identity using mTLS
// MutualTLS represents the identity using mTLS.
const MutualTLS Type = "mTLS"
// TunnelTLS represents an identity using a (m)TLS tunnel.
//
// TunnelTLS can be optionally configured with client certificates and a root
// file with the CAs to trust. By default it will use the system truststore
// instead of the CA truststore.
const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute
@ -44,19 +51,30 @@ type Identity struct {
Type string `json:"type"`
Certificate string `json:"crt"`
Key string `json:"key"`
// Host is the tunnel host for a TunnelTLS (tTLS) identity.
Host string `json:"host,omitempty"`
// Root is the CA bundle of root CAs used in TunnelTLS to trust the
// certificate of the host.
Root string `json:"root,omitempty"`
}
// LoadIdentity loads an identity present in the given filename.
func LoadIdentity(filename string) (*Identity, error) {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename)
}
identity := new(Identity)
if err := json.Unmarshal(b, &identity); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", filename)
}
return identity, nil
}
// LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) {
b, err := ioutil.ReadFile(IdentityFile)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", IdentityFile)
}
identity := new(Identity)
if err := json.Unmarshal(b, &identity); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", IdentityFile)
}
return identity, nil
return LoadIdentity(IdentityFile)
}
// configDir and identityDir are used in WriteDefaultIdentity for testing
@ -81,7 +99,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
keyFilename := filepath.Join(identityDir, "identity_key")
// Write certificate
if err := WriteIdentityCertificate(certChain); err != nil {
if err := writeCertificate(certFilename, certChain); err != nil {
return err
}
@ -116,22 +134,21 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
return nil
}
// WriteIdentityCertificate writes the identity certificate in disk.
func WriteIdentityCertificate(certChain []api.Certificate) error {
// writeCertificate writes the given certificate on disk.
func writeCertificate(filename string, certChain []api.Certificate) error {
buf := new(bytes.Buffer)
certFilename := filepath.Join(identityDir, "identity.crt")
for _, crt := range certChain {
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: crt.Raw,
}
if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity certificate")
return errors.Wrap(err, "error encoding certificate")
}
}
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing certificate")
}
return nil
@ -144,6 +161,8 @@ func (i *Identity) Kind() Type {
return Disabled
case "mtls":
return MutualTLS
case "ttls":
return TunnelTLS
default:
return Type(i.Type)
}
@ -164,9 +183,27 @@ func (i *Identity) Validate() error {
if err := fileExists(i.Certificate); err != nil {
return err
}
return fileExists(i.Key)
case TunnelTLS:
if i.Host == "" {
return errors.New("tunnel.crt cannot be empty")
}
if i.Certificate != "" {
if err := fileExists(i.Certificate); err != nil {
return err
}
if i.Key == "" {
return errors.New("tunnel.key cannot be empty")
}
if err := fileExists(i.Key); err != nil {
return err
}
}
if i.Root != "" {
if err := fileExists(i.Root); err != nil {
return err
}
}
return nil
default:
return errors.Errorf("unsupported identity type %s", i.Type)
@ -179,7 +216,7 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) {
switch i.Kind() {
case Disabled:
return tls.Certificate{}, nil
case MutualTLS:
case MutualTLS, TunnelTLS:
crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key)
if err != nil {
return fail(errors.Wrap(err, "error creating identity certificate"))
@ -215,6 +252,22 @@ func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo)
}
}
// GetCertPool returns a x509.CertPool if the identity defines a custom root.
func (i *Identity) GetCertPool() (*x509.CertPool, error) {
if i.Root == "" {
return nil, nil
}
b, err := ioutil.ReadFile(i.Root)
if err != nil {
return nil, errors.Wrap(err, "error reading identity root")
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(b) {
return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root)
}
return pool, nil
}
// Renewer is that interface that a renew client must implement.
type Renewer interface {
GetRootCAs() *x509.CertPool
@ -227,7 +280,7 @@ func (i *Identity) Renew(client Renewer) error {
switch i.Kind() {
case Disabled:
return nil
case MutualTLS:
case MutualTLS, TunnelTLS:
cert, err := i.TLSCertificate()
if err != nil {
return err

106
ca/tls.go
View file

@ -15,27 +15,59 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/ca/identity"
)
// mTLSDialContext will hold the dial context function to use in
// getDefaultTransport.
var mTLSDialContext func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error)
var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error)
func init() {
// STEP_TLS_TUNNEL is an environment that can be set to do an TLS over mTLS
// tunnel to step-ca using identity credentials. The value must have the
// form "host:port", if the form is not correct, the default dialer will be
// used. This feature is EXPERIMENTAL and might change at any time.
if hostport := os.Getenv("STEP_TLS_TUNNEL"); hostport != "" {
if host, port, err := net.SplitHostPort(hostport); err == nil {
mTLSDialContext = func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) {
// STEP_TLS_TUNNEL is an environment that can be set to do an TLS over
// (m)TLS tunnel to step-ca using identity-like credentials. The value is a
// path to a json file with the tunnel host, certificate, key and root used
// to create the (m)TLS tunnel.
//
// The configuration should look like:
// {
// "type": "tTLS",
// "host": "tunnel.example.com:443"
// "crt": "/path/to/tunnel.crt",
// "key": "/path/to/tunnel.key",
// "root": "/path/to/tunnel-root.crt"
// }
//
// This feature is EXPERIMENTAL and might change at any time.
if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" {
id, err := identity.LoadIdentity(path)
if err != nil {
panic(err)
}
if err := id.Validate(); err != nil {
panic(err)
}
host, port, err := net.SplitHostPort(id.Host)
if err != nil {
panic(err)
}
pool, err := id.GetCertPool()
if err != nil {
panic(err)
}
mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) {
d := &tls.Dialer{
NetDialer: getDefaultDialer(),
Config: &tls.Config{
RootCAs: pool,
GetClientCertificate: id.GetClientCertificateFunc(),
},
}
return func(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
}
}
}
}
}
// GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
@ -71,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
}
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, nil, err
}
tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
@ -123,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, err
}
tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
@ -164,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
return func(network, addr string) (net.Conn, error) {
return tls.DialWithDialer(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}, network, addr, ctx.mutableConfig.TLSConfig())
return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig())
}
}
@ -176,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
// nolint:unused
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
d := getDefaultDialer()
// TLS dialers do not support context, but we can use the context
// deadline if it is set.
var deadline time.Time
if t, ok := ctx.Deadline(); ok {
deadline = t
d.Deadline = t
}
return tls.DialWithDialer(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Deadline: deadline,
DualStack: true,
}, network, addr, tlsCtx.mutableConfig.TLSConfig())
return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig())
}
}
@ -258,25 +275,24 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
}
}
// getDefaultTransport returns an http.Transport with the same parameters than
// http.DefaultTransport, but adds the given tls.Config and configures the
// transport for HTTP/2.
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
if mTLSDialContext == nil {
d := &net.Dialer{
// getDefaultDialer returns a new dialer with the default configuration.
func getDefaultDialer() *net.Dialer {
return &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
}
// getDefaultTransport returns an http.Transport with the same parameters than
// http.DefaultTransport, but adds the given tls.Config and configures the
// transport for HTTP/2.
func getDefaultTransport(tlsConfig *tls.Config) *http.Transport {
var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
if mTLSDialContext == nil {
d := getDefaultDialer()
dialContext = d.DialContext
} else {
dialContext = mTLSDialContext(&tls.Dialer{
NetDialer: &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
},
Config: tlsConfig,
})
dialContext = mTLSDialContext()
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
@ -287,7 +303,7 @@ func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig,
}, nil
}
}
func getPEM(i interface{}) ([]byte, error) {