Refactor tls tunnel connections.
New method will use an identity-like file with the configuration used to create the (m)TLS connection to the tunnel.
This commit is contained in:
parent
180b5c3e3c
commit
c5234e9c61
3 changed files with 139 additions and 73 deletions
11
ca/client.go
11
ca/client.go
|
@ -56,10 +56,7 @@ func newClient(transport http.RoundTripper) *uaClient {
|
|||
func newInsecureClient() *uaClient {
|
||||
return &uaClient{
|
||||
Client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -292,7 +289,7 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) {
|
|||
MinVersion: tls.VersionTLS12,
|
||||
PreferServerCipherSuites: true,
|
||||
RootCAs: pool,
|
||||
})
|
||||
}), nil
|
||||
}
|
||||
|
||||
func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
|
||||
|
@ -311,7 +308,7 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
|
|||
MinVersion: tls.VersionTLS12,
|
||||
PreferServerCipherSuites: true,
|
||||
RootCAs: pool,
|
||||
})
|
||||
}), nil
|
||||
}
|
||||
|
||||
func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
|
||||
|
@ -323,7 +320,7 @@ func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
|
|||
MinVersion: tls.VersionTLS12,
|
||||
PreferServerCipherSuites: true,
|
||||
RootCAs: pool,
|
||||
})
|
||||
}), nil
|
||||
}
|
||||
|
||||
// parseEndpoint parses and validates the given endpoint. It supports general
|
||||
|
|
|
@ -26,9 +26,16 @@ type Type string
|
|||
// Disabled represents a disabled identity type
|
||||
const Disabled Type = ""
|
||||
|
||||
// MutualTLS represents the identity using mTLS
|
||||
// MutualTLS represents the identity using mTLS.
|
||||
const MutualTLS Type = "mTLS"
|
||||
|
||||
// TunnelTLS represents an identity using a (m)TLS tunnel.
|
||||
//
|
||||
// TunnelTLS can be optionally configured with client certificates and a root
|
||||
// file with the CAs to trust. By default it will use the system truststore
|
||||
// instead of the CA truststore.
|
||||
const TunnelTLS Type = "tTLS"
|
||||
|
||||
// DefaultLeeway is the duration for matching not before claims.
|
||||
const DefaultLeeway = 1 * time.Minute
|
||||
|
||||
|
@ -44,19 +51,30 @@ type Identity struct {
|
|||
Type string `json:"type"`
|
||||
Certificate string `json:"crt"`
|
||||
Key string `json:"key"`
|
||||
|
||||
// Host is the tunnel host for a TunnelTLS (tTLS) identity.
|
||||
Host string `json:"host,omitempty"`
|
||||
// Root is the CA bundle of root CAs used in TunnelTLS to trust the
|
||||
// certificate of the host.
|
||||
Root string `json:"root,omitempty"`
|
||||
}
|
||||
|
||||
// LoadIdentity loads an identity present in the given filename.
|
||||
func LoadIdentity(filename string) (*Identity, error) {
|
||||
b, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", filename)
|
||||
}
|
||||
identity := new(Identity)
|
||||
if err := json.Unmarshal(b, &identity); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling %s", filename)
|
||||
}
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
// LoadDefaultIdentity loads the default identity.
|
||||
func LoadDefaultIdentity() (*Identity, error) {
|
||||
b, err := ioutil.ReadFile(IdentityFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", IdentityFile)
|
||||
}
|
||||
identity := new(Identity)
|
||||
if err := json.Unmarshal(b, &identity); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling %s", IdentityFile)
|
||||
}
|
||||
return identity, nil
|
||||
return LoadIdentity(IdentityFile)
|
||||
}
|
||||
|
||||
// configDir and identityDir are used in WriteDefaultIdentity for testing
|
||||
|
@ -81,7 +99,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
|
|||
keyFilename := filepath.Join(identityDir, "identity_key")
|
||||
|
||||
// Write certificate
|
||||
if err := WriteIdentityCertificate(certChain); err != nil {
|
||||
if err := writeCertificate(certFilename, certChain); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -116,22 +134,21 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
|
|||
return nil
|
||||
}
|
||||
|
||||
// WriteIdentityCertificate writes the identity certificate in disk.
|
||||
func WriteIdentityCertificate(certChain []api.Certificate) error {
|
||||
// writeCertificate writes the given certificate on disk.
|
||||
func writeCertificate(filename string, certChain []api.Certificate) error {
|
||||
buf := new(bytes.Buffer)
|
||||
certFilename := filepath.Join(identityDir, "identity.crt")
|
||||
for _, crt := range certChain {
|
||||
block := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: crt.Raw,
|
||||
}
|
||||
if err := pem.Encode(buf, block); err != nil {
|
||||
return errors.Wrap(err, "error encoding identity certificate")
|
||||
return errors.Wrap(err, "error encoding certificate")
|
||||
}
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
|
||||
return errors.Wrap(err, "error writing identity certificate")
|
||||
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil {
|
||||
return errors.Wrap(err, "error writing certificate")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -144,6 +161,8 @@ func (i *Identity) Kind() Type {
|
|||
return Disabled
|
||||
case "mtls":
|
||||
return MutualTLS
|
||||
case "ttls":
|
||||
return TunnelTLS
|
||||
default:
|
||||
return Type(i.Type)
|
||||
}
|
||||
|
@ -164,8 +183,26 @@ func (i *Identity) Validate() error {
|
|||
if err := fileExists(i.Certificate); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileExists(i.Key); err != nil {
|
||||
return err
|
||||
return fileExists(i.Key)
|
||||
case TunnelTLS:
|
||||
if i.Host == "" {
|
||||
return errors.New("tunnel.crt cannot be empty")
|
||||
}
|
||||
if i.Certificate != "" {
|
||||
if err := fileExists(i.Certificate); err != nil {
|
||||
return err
|
||||
}
|
||||
if i.Key == "" {
|
||||
return errors.New("tunnel.key cannot be empty")
|
||||
}
|
||||
if err := fileExists(i.Key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if i.Root != "" {
|
||||
if err := fileExists(i.Root); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
|
@ -179,7 +216,7 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) {
|
|||
switch i.Kind() {
|
||||
case Disabled:
|
||||
return tls.Certificate{}, nil
|
||||
case MutualTLS:
|
||||
case MutualTLS, TunnelTLS:
|
||||
crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key)
|
||||
if err != nil {
|
||||
return fail(errors.Wrap(err, "error creating identity certificate"))
|
||||
|
@ -215,6 +252,22 @@ func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo)
|
|||
}
|
||||
}
|
||||
|
||||
// GetCertPool returns a x509.CertPool if the identity defines a custom root.
|
||||
func (i *Identity) GetCertPool() (*x509.CertPool, error) {
|
||||
if i.Root == "" {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := ioutil.ReadFile(i.Root)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error reading identity root")
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(b) {
|
||||
return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Renewer is that interface that a renew client must implement.
|
||||
type Renewer interface {
|
||||
GetRootCAs() *x509.CertPool
|
||||
|
@ -227,7 +280,7 @@ func (i *Identity) Renew(client Renewer) error {
|
|||
switch i.Kind() {
|
||||
case Disabled:
|
||||
return nil
|
||||
case MutualTLS:
|
||||
case MutualTLS, TunnelTLS:
|
||||
cert, err := i.TLSCertificate()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
106
ca/tls.go
106
ca/tls.go
|
@ -15,23 +15,55 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/ca/identity"
|
||||
)
|
||||
|
||||
// mTLSDialContext will hold the dial context function to use in
|
||||
// getDefaultTransport.
|
||||
var mTLSDialContext func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
func init() {
|
||||
// STEP_TLS_TUNNEL is an environment that can be set to do an TLS over mTLS
|
||||
// tunnel to step-ca using identity credentials. The value must have the
|
||||
// form "host:port", if the form is not correct, the default dialer will be
|
||||
// used. This feature is EXPERIMENTAL and might change at any time.
|
||||
if hostport := os.Getenv("STEP_TLS_TUNNEL"); hostport != "" {
|
||||
if host, port, err := net.SplitHostPort(hostport); err == nil {
|
||||
mTLSDialContext = func(d *tls.Dialer) func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
|
||||
}
|
||||
// STEP_TLS_TUNNEL is an environment that can be set to do an TLS over
|
||||
// (m)TLS tunnel to step-ca using identity-like credentials. The value is a
|
||||
// path to a json file with the tunnel host, certificate, key and root used
|
||||
// to create the (m)TLS tunnel.
|
||||
//
|
||||
// The configuration should look like:
|
||||
// {
|
||||
// "type": "tTLS",
|
||||
// "host": "tunnel.example.com:443"
|
||||
// "crt": "/path/to/tunnel.crt",
|
||||
// "key": "/path/to/tunnel.key",
|
||||
// "root": "/path/to/tunnel-root.crt"
|
||||
// }
|
||||
//
|
||||
// This feature is EXPERIMENTAL and might change at any time.
|
||||
if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" {
|
||||
id, err := identity.LoadIdentity(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := id.Validate(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
host, port, err := net.SplitHostPort(id.Host)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pool, err := id.GetCertPool()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := &tls.Dialer{
|
||||
NetDialer: getDefaultDialer(),
|
||||
Config: &tls.Config{
|
||||
RootCAs: pool,
|
||||
GetClientCertificate: id.GetClientCertificateFunc(),
|
||||
},
|
||||
}
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -71,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
}
|
||||
|
||||
// Update renew function with transport
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tr := getDefaultTransport(tlsConfig)
|
||||
// Use mutable tls.Config on renew
|
||||
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
||||
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
||||
|
@ -123,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
|
||||
|
||||
// Update renew function with transport
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr := getDefaultTransport(tlsConfig)
|
||||
// Use mutable tls.Config on renew
|
||||
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
||||
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
||||
|
@ -164,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell
|
|||
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
|
||||
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
|
||||
return func(network, addr string) (net.Conn, error) {
|
||||
return tls.DialWithDialer(&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}, network, addr, ctx.mutableConfig.TLSConfig())
|
||||
return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
|
|||
// nolint:unused
|
||||
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
d := getDefaultDialer()
|
||||
// TLS dialers do not support context, but we can use the context
|
||||
// deadline if it is set.
|
||||
var deadline time.Time
|
||||
if t, ok := ctx.Deadline(); ok {
|
||||
deadline = t
|
||||
d.Deadline = t
|
||||
}
|
||||
return tls.DialWithDialer(&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Deadline: deadline,
|
||||
DualStack: true,
|
||||
}, network, addr, tlsCtx.mutableConfig.TLSConfig())
|
||||
return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -258,25 +275,24 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
|||
}
|
||||
}
|
||||
|
||||
// getDefaultDialer returns a new dialer with the default configuration.
|
||||
func getDefaultDialer() *net.Dialer {
|
||||
return &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultTransport returns an http.Transport with the same parameters than
|
||||
// http.DefaultTransport, but adds the given tls.Config and configures the
|
||||
// transport for HTTP/2.
|
||||
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
|
||||
func getDefaultTransport(tlsConfig *tls.Config) *http.Transport {
|
||||
var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||
if mTLSDialContext == nil {
|
||||
d := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
d := getDefaultDialer()
|
||||
dialContext = d.DialContext
|
||||
} else {
|
||||
dialContext = mTLSDialContext(&tls.Dialer{
|
||||
NetDialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
},
|
||||
Config: tlsConfig,
|
||||
})
|
||||
dialContext = mTLSDialContext()
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
|
@ -287,7 +303,7 @@ func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
|
|||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getPEM(i interface{}) ([]byte, error) {
|
||||
|
|
Loading…
Reference in a new issue