Add experimental support for a TLS over TLS tunnel.

This commit is contained in:
Mariano Cano 2021-04-07 18:57:48 -07:00
parent 75f24a103a
commit e75a9409a5
2 changed files with 50 additions and 17 deletions

View file

@ -99,12 +99,13 @@ type RetryFunc func(code int) bool
type ClientOption func(o *clientOptions) error
type clientOptions struct {
transport http.RoundTripper
rootSHA256 string
rootFilename string
rootBundle []byte
certificate tls.Certificate
retryFunc RetryFunc
transport http.RoundTripper
rootSHA256 string
rootFilename string
rootBundle []byte
certificate tls.Certificate
getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
retryFunc RetryFunc
}
func (o *clientOptions) apply(opts []ClientOption) (err error) {
@ -139,6 +140,7 @@ func (o *clientOptions) applyDefaultIdentity() error {
return nil
}
o.certificate = crt
o.getClientCertificate = i.GetClientCertificateFunc()
return nil
}
@ -193,6 +195,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
}
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
}
case *http2.Transport:
if tr.TLSClientConfig == nil {
@ -200,6 +203,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
}
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
}
default:
return nil, errors.Errorf("unsupported transport type %T", tr)

View file

@ -10,13 +10,33 @@ import (
"encoding/pem"
"net"
"net/http"
"os"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"golang.org/x/net/http2"
)
// mTLSDialContext will hold the dial context function to use in
// getDefaultTransport.
var mTLSDialContext func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error)
func init() {
// STEP_TLS_TUNNEL is an environment that can be set to do an TLS over mTLS
// tunnel to step-ca using identity credentials. The value must have the
// form "host:port", if the form is not correct, the default dialer will be
// used. This feature is EXPERIMENTAL and might change at any time.
if hostport := os.Getenv("STEP_TLS_TUNNEL"); hostport != "" {
if host, port, err := net.SplitHostPort(hostport); err == nil {
mTLSDialContext = func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
}
}
}
}
}
// GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The client certificate will automatically rotate before expiring.
@ -242,23 +262,32 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
// http.DefaultTransport, but adds the given tls.Config and configures the
// transport for HTTP/2.
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
if mTLSDialContext == nil {
d := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
}
dialContext = d.DialContext
} else {
dialContext = mTLSDialContext(&tls.Dialer{
NetDialer: &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
},
Config: tlsConfig,
})
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig,
}
if err := http2.ConfigureTransport(tr); err != nil {
return nil, errors.Wrap(err, "error configuring transport")
}
return tr, nil
}, nil
}
func getPEM(i interface{}) ([]byte, error) {