347 lines
11 KiB
Go
347 lines
11 KiB
Go
package ca
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/smallstep/certificates/api"
|
|
"github.com/smallstep/certificates/ca/identity"
|
|
)
|
|
|
|
// mTLSDialContext will hold the dial context function to use in
|
|
// getDefaultTransport.
|
|
var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error)
|
|
|
|
func init() {
|
|
// STEP_TLS_TUNNEL is an environment variable that can be set to do an TLS
|
|
// over (m)TLS tunnel to step-ca using identity-like credentials. The value
|
|
// is a path to a json file with the tunnel host, certificate, key and root
|
|
// used to create the (m)TLS tunnel.
|
|
//
|
|
// The configuration should look like:
|
|
// {
|
|
// "type": "tTLS",
|
|
// "host": "tunnel.example.com:443"
|
|
// "crt": "/path/to/tunnel.crt",
|
|
// "key": "/path/to/tunnel.key",
|
|
// "root": "/path/to/tunnel-root.crt"
|
|
// }
|
|
//
|
|
// This feature is EXPERIMENTAL and might change at any time.
|
|
if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" {
|
|
id, err := identity.LoadIdentity(path)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if err := id.Validate(); err != nil {
|
|
panic(err)
|
|
}
|
|
host, port, err := net.SplitHostPort(id.Host)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
pool, err := id.GetCertPool()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
d := &tls.Dialer{
|
|
NetDialer: getDefaultDialer(),
|
|
Config: &tls.Config{
|
|
RootCAs: pool,
|
|
GetClientCertificate: id.GetClientCertificateFunc(),
|
|
},
|
|
}
|
|
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetClientTLSConfig returns a tls.Config for client use configured with the
|
|
// sign certificate, and a new certificate pool with the sign root certificate.
|
|
// The client certificate will automatically rotate before expiring.
|
|
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
|
|
tlsConfig, _, err := c.getClientTLSConfig(ctx, sign, pk, options)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, *http.Transport, error) {
|
|
cert, err := TLSCertificate(sign, pk)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
renewer, err := NewTLSRenewer(cert, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
tlsConfig := getDefaultTLSConfig(sign)
|
|
// Note that with GetClientCertificate tlsConfig.Certificates is not used.
|
|
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
|
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
|
tlsConfig.PreferServerCipherSuites = true
|
|
|
|
// Apply options and initialize mutable tls.Config
|
|
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
|
|
if err := tlsCtx.apply(options); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Update renew function with transport
|
|
tr := getDefaultTransport(tlsConfig)
|
|
// Use mutable tls.Config on renew
|
|
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
|
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
|
|
|
// Update client transport
|
|
c.SetTransport(tr)
|
|
|
|
// Start renewer
|
|
renewer.RunContext(ctx)
|
|
return tlsConfig, tr, nil
|
|
}
|
|
|
|
// GetServerTLSConfig returns a tls.Config for server use configured with the
|
|
// sign certificate, and a new certificate pool with the sign root certificate.
|
|
// The returned tls.Config will only verify the client certificate if provided.
|
|
// The server certificate will automatically rotate before expiring.
|
|
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
|
|
cert, err := TLSCertificate(sign, pk)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
renewer, err := NewTLSRenewer(cert, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tlsConfig := getDefaultTLSConfig(sign)
|
|
// Note that GetCertificate will only be called if the client supplies SNI
|
|
// information or if tlsConfig.Certificates is empty.
|
|
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
|
|
tlsConfig.GetCertificate = renewer.GetCertificate
|
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
|
tlsConfig.PreferServerCipherSuites = true
|
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
|
|
|
// Apply options and initialize mutable tls.Config
|
|
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
|
|
if err := tlsCtx.apply(options); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// GetConfigForClient allows seamless root and federated roots rotation.
|
|
// If the return of the callback is not-nil, it will use the returned
|
|
// tls.Config instead of the default one.
|
|
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
|
|
|
|
// Update renew function with transport
|
|
tr := getDefaultTransport(tlsConfig)
|
|
// Use mutable tls.Config on renew
|
|
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
|
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
|
|
|
// Update client transport
|
|
c.SetTransport(tr)
|
|
|
|
// Start renewer
|
|
renewer.RunContext(ctx)
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
// Transport returns an http.Transport configured to use the client certificate from the sign response.
|
|
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
|
|
_, tr, err := c.getClientTLSConfig(ctx, sign, pk, options)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return tr, nil
|
|
}
|
|
|
|
// buildGetConfigForClient returns an implementation of GetConfigForClient
|
|
// callback in tls.Config.
|
|
//
|
|
// If the implementation returns a nil tls.Config, the original Config will be
|
|
// used, but if it's non-nil, the returned Config will be used to handle this
|
|
// connection.
|
|
func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
|
return ctx.mutableConfig.TLSConfig(), nil
|
|
}
|
|
}
|
|
|
|
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
|
|
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
|
|
return func(network, addr string) (net.Conn, error) {
|
|
return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig())
|
|
}
|
|
}
|
|
|
|
// buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport.
|
|
// nolint:unused
|
|
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
d := getDefaultDialer()
|
|
// TLS dialers do not support context, but we can use the context
|
|
// deadline if it is set.
|
|
if t, ok := ctx.Deadline(); ok {
|
|
d.Deadline = t
|
|
}
|
|
return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig())
|
|
}
|
|
}
|
|
|
|
// Certificate returns the server or client certificate from the sign response.
|
|
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
|
if sign.ServerPEM.Certificate == nil {
|
|
return nil, errors.New("ca: certificate does not exist")
|
|
}
|
|
return sign.ServerPEM.Certificate, nil
|
|
}
|
|
|
|
// IntermediateCertificate returns the CA intermediate certificate from the sign
|
|
// response.
|
|
func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
|
if sign.CaPEM.Certificate == nil {
|
|
return nil, errors.New("ca: certificate does not exist")
|
|
}
|
|
return sign.CaPEM.Certificate, nil
|
|
}
|
|
|
|
// RootCertificate returns the root certificate from the sign response.
|
|
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
|
if sign == nil || sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
|
|
return nil, errors.New("ca: certificate does not exist")
|
|
}
|
|
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
|
|
if len(lastChain) == 0 {
|
|
return nil, errors.New("ca: certificate does not exist")
|
|
}
|
|
return lastChain[len(lastChain)-1], nil
|
|
}
|
|
|
|
// TLSCertificate creates a new TLS certificate from the sign response and the
|
|
// private key used.
|
|
func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certificate, error) {
|
|
certPEM, err := getPEM(sign.ServerPEM)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
caPEM, err := getPEM(sign.CaPEM)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
keyPEM, err := getPEM(pk)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
chain := append(certPEM, caPEM...)
|
|
cert, err := tls.X509KeyPair(chain, keyPEM)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "error creating tls certificate")
|
|
}
|
|
leaf, err := x509.ParseCertificate(cert.Certificate[0])
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "error parsing tls certificate")
|
|
}
|
|
cert.Leaf = leaf
|
|
return &cert, nil
|
|
}
|
|
|
|
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
|
if sign.TLSOptions != nil {
|
|
return sign.TLSOptions.TLSConfig()
|
|
}
|
|
return &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
}
|
|
|
|
// getDefaultDialer returns a new dialer with the default configuration.
|
|
func getDefaultDialer() *net.Dialer {
|
|
return &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
}
|
|
|
|
// getDefaultTransport returns an http.Transport with the same parameters than
|
|
// http.DefaultTransport, but adds the given tls.Config and configures the
|
|
// transport for HTTP/2.
|
|
func getDefaultTransport(tlsConfig *tls.Config) *http.Transport {
|
|
var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
|
|
if mTLSDialContext == nil {
|
|
d := getDefaultDialer()
|
|
dialContext = d.DialContext
|
|
} else {
|
|
dialContext = mTLSDialContext()
|
|
}
|
|
return &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
DialContext: dialContext,
|
|
ForceAttemptHTTP2: true,
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
TLSClientConfig: tlsConfig,
|
|
}
|
|
}
|
|
|
|
func getPEM(i interface{}) ([]byte, error) {
|
|
block := new(pem.Block)
|
|
switch i := i.(type) {
|
|
case api.Certificate:
|
|
block.Type = "CERTIFICATE"
|
|
block.Bytes = i.Raw
|
|
case *x509.Certificate:
|
|
block.Type = "CERTIFICATE"
|
|
block.Bytes = i.Raw
|
|
case *rsa.PrivateKey:
|
|
block.Type = "RSA PRIVATE KEY"
|
|
block.Bytes = x509.MarshalPKCS1PrivateKey(i)
|
|
case *ecdsa.PrivateKey:
|
|
var err error
|
|
block.Type = "EC PRIVATE KEY"
|
|
block.Bytes, err = x509.MarshalECPrivateKey(i)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "error marshaling private key")
|
|
}
|
|
default:
|
|
return nil, errors.Errorf("unsupported key type %T", i)
|
|
}
|
|
return pem.EncodeToMemory(block), nil
|
|
}
|
|
|
|
func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
|
|
return func() (*tls.Certificate, error) {
|
|
// Get updated list of roots
|
|
if err := ctx.applyRenew(); err != nil {
|
|
return nil, err
|
|
}
|
|
// Get new certificate
|
|
sign, err := client.Renew(tr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return TLSCertificate(sign, pk)
|
|
}
|
|
}
|