diff --git a/ca/client.go b/ca/client.go index b9593162..f268155a 100644 --- a/ca/client.go +++ b/ca/client.go @@ -99,12 +99,13 @@ type RetryFunc func(code int) bool type ClientOption func(o *clientOptions) error type clientOptions struct { - transport http.RoundTripper - rootSHA256 string - rootFilename string - rootBundle []byte - certificate tls.Certificate - retryFunc RetryFunc + transport http.RoundTripper + rootSHA256 string + rootFilename string + rootBundle []byte + certificate tls.Certificate + getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error) + retryFunc RetryFunc } func (o *clientOptions) apply(opts []ClientOption) (err error) { @@ -139,6 +140,7 @@ func (o *clientOptions) applyDefaultIdentity() error { return nil } o.certificate = crt + o.getClientCertificate = i.GetClientCertificateFunc() return nil } @@ -193,6 +195,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err } if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} + tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate } case *http2.Transport: if tr.TLSClientConfig == nil { @@ -200,6 +203,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err } if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} + tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate } default: return nil, errors.Errorf("unsupported transport type %T", tr) diff --git a/ca/tls.go b/ca/tls.go index 20a5e504..df0aa2d5 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -10,13 +10,33 @@ import ( "encoding/pem" "net" "net/http" + "os" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/api" - "golang.org/x/net/http2" ) +// mTLSDialContext will hold the dial context function to use in +// getDefaultTransport. +var mTLSDialContext func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) + +func init() { + // STEP_TLS_TUNNEL is an environment that can be set to do an TLS over mTLS + // tunnel to step-ca using identity credentials. The value must have the + // form "host:port", if the form is not correct, the default dialer will be + // used. This feature is EXPERIMENTAL and might change at any time. + if hostport := os.Getenv("STEP_TLS_TUNNEL"); hostport != "" { + if host, port, err := net.SplitHostPort(hostport); err == nil { + mTLSDialContext = func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port)) + } + } + } + } +} + // GetClientTLSConfig returns a tls.Config for client use configured with the // sign certificate, and a new certificate pool with the sign root certificate. // The client certificate will automatically rotate before expiring. @@ -242,23 +262,32 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config { // http.DefaultTransport, but adds the given tls.Config and configures the // transport for HTTP/2. func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) { - tr := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ + var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error) + if mTLSDialContext == nil { + d := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, + } + dialContext = d.DialContext + } else { + dialContext = mTLSDialContext(&tls.Dialer{ + NetDialer: &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }, + Config: tlsConfig, + }) + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialContext, + ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: tlsConfig, - } - if err := http2.ConfigureTransport(tr); err != nil { - return nil, errors.Wrap(err, "error configuring transport") - } - return tr, nil + }, nil } func getPEM(i interface{}) ([]byte, error) {