Add experimental support for a TLS over TLS tunnel.

This commit is contained in:
Mariano Cano 2021-04-07 18:57:48 -07:00
parent 75f24a103a
commit e75a9409a5
2 changed files with 50 additions and 17 deletions

View file

@ -99,12 +99,13 @@ type RetryFunc func(code int) bool
type ClientOption func(o *clientOptions) error type ClientOption func(o *clientOptions) error
type clientOptions struct { type clientOptions struct {
transport http.RoundTripper transport http.RoundTripper
rootSHA256 string rootSHA256 string
rootFilename string rootFilename string
rootBundle []byte rootBundle []byte
certificate tls.Certificate certificate tls.Certificate
retryFunc RetryFunc getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
retryFunc RetryFunc
} }
func (o *clientOptions) apply(opts []ClientOption) (err error) { func (o *clientOptions) apply(opts []ClientOption) (err error) {
@ -139,6 +140,7 @@ func (o *clientOptions) applyDefaultIdentity() error {
return nil return nil
} }
o.certificate = crt o.certificate = crt
o.getClientCertificate = i.GetClientCertificateFunc()
return nil return nil
} }
@ -193,6 +195,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
} }
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
} }
case *http2.Transport: case *http2.Transport:
if tr.TLSClientConfig == nil { if tr.TLSClientConfig == nil {
@ -200,6 +203,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
} }
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
} }
default: default:
return nil, errors.Errorf("unsupported transport type %T", tr) return nil, errors.Errorf("unsupported transport type %T", tr)

View file

@ -10,13 +10,33 @@ import (
"encoding/pem" "encoding/pem"
"net" "net"
"net/http" "net/http"
"os"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"golang.org/x/net/http2"
) )
// mTLSDialContext will hold the dial context function to use in
// getDefaultTransport.
var mTLSDialContext func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error)
func init() {
// STEP_TLS_TUNNEL is an environment that can be set to do an TLS over mTLS
// tunnel to step-ca using identity credentials. The value must have the
// form "host:port", if the form is not correct, the default dialer will be
// used. This feature is EXPERIMENTAL and might change at any time.
if hostport := os.Getenv("STEP_TLS_TUNNEL"); hostport != "" {
if host, port, err := net.SplitHostPort(hostport); err == nil {
mTLSDialContext = func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
}
}
}
}
}
// GetClientTLSConfig returns a tls.Config for client use configured with the // GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate. // sign certificate, and a new certificate pool with the sign root certificate.
// The client certificate will automatically rotate before expiring. // The client certificate will automatically rotate before expiring.
@ -242,23 +262,32 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
// http.DefaultTransport, but adds the given tls.Config and configures the // http.DefaultTransport, but adds the given tls.Config and configures the
// transport for HTTP/2. // transport for HTTP/2.
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) { func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
tr := &http.Transport{ var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
Proxy: http.ProxyFromEnvironment, if mTLSDialContext == nil {
DialContext: (&net.Dialer{ d := &net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
DualStack: true, }
}).DialContext, dialContext = d.DialContext
} else {
dialContext = mTLSDialContext(&tls.Dialer{
NetDialer: &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
},
Config: tlsConfig,
})
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100, MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
} }, nil
if err := http2.ConfigureTransport(tr); err != nil {
return nil, errors.Wrap(err, "error configuring transport")
}
return tr, nil
} }
func getPEM(i interface{}) ([]byte, error) { func getPEM(i interface{}) ([]byte, error) {