certificates/ca/tls.go

245 lines
7.1 KiB
Go
Raw Normal View History

2018-10-05 21:48:36 +00:00
package ca
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"net"
"net/http"
"time"
"github.com/pkg/errors"
"github.com/smallstep/ca-component/api"
"golang.org/x/net/http2"
)
// GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The certificate will automatically rotate before expiring.
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
cert, err := TLSCertificate(sign, pk)
if err != nil {
return nil, err
}
renewer, err := NewTLSRenewer(cert, nil)
if err != nil {
return nil, err
}
tlsConfig := getDefaultTLSConfig(sign)
// Note that with GetClientCertificate tlsConfig.Certificates is not used.
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Build RootCAs with given root certificate
if pool := c.getCertPool(sign); pool != nil {
tlsConfig.RootCAs = pool
}
// Parse Certificates and build NameToCertificate
tlsConfig.BuildNameToCertificate()
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, err
}
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
// Start renewer
renewer.RunContext(ctx)
return tlsConfig, nil
}
// GetServerTLSConfig returns a tls.Config for server use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The certificate will automatically rotate before expiring.
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
cert, err := TLSCertificate(sign, pk)
if err != nil {
return nil, err
}
renewer, err := NewTLSRenewer(cert, nil)
if err != nil {
return nil, err
}
tlsConfig := getDefaultTLSConfig(sign)
// Note that GetCertificate will only be called if the client supplies SNI
// information or if tlsConfig.Certificates is empty.
tlsConfig.GetCertificate = renewer.GetCertificate
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Build RootCAs with given root certificate
if pool := c.getCertPool(sign); pool != nil {
tlsConfig.ClientCAs = pool
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
// Add RootCAs for refresh client
tlsConfig.RootCAs = pool
}
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, err
}
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
// Start renewer
renewer.RunContext(ctx)
return tlsConfig, nil
}
// Transport returns an http.Transport configured to use the client certificate from the sign response.
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)
if err != nil {
return nil, err
}
return getDefaultTransport(tlsConfig)
}
// getCertPool returns the transport x509.CertPool or the one from the sign
// request.
func (c *Client) getCertPool(sign *api.SignResponse) *x509.CertPool {
// Return the transport certPool
if c.certPool != nil {
return c.certPool
}
// Return certificate used in sign request.
if root, err := RootCertificate(sign); err == nil {
pool := x509.NewCertPool()
pool.AddCert(root)
return pool
}
return nil
}
// Certificate returns the server or client certificate from the sign response.
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
if sign.ServerPEM.Certificate == nil {
return nil, errors.New("ca: certificate does not exists")
}
return sign.ServerPEM.Certificate, nil
}
// IntermediateCertificate returns the CA intermediate certificate from the sign
// response.
func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
if sign.CaPEM.Certificate == nil {
return nil, errors.New("ca: certificate does not exists")
}
return sign.CaPEM.Certificate, nil
}
// RootCertificate returns the root certificate from the sign response.
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
if sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
return nil, errors.New("ca: certificate does not exists")
}
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
if len(lastChain) == 0 {
return nil, errors.New("ca: certificate does not exists")
}
return lastChain[len(lastChain)-1], nil
}
// TLSCertificate creates a new TLS certificate from the sign response and the
// private key used.
func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certificate, error) {
certPEM, err := getPEM(sign.ServerPEM)
if err != nil {
return nil, err
}
caPEM, err := getPEM(sign.CaPEM)
if err != nil {
return nil, err
}
keyPEM, err := getPEM(pk)
if err != nil {
return nil, err
}
chain := append(certPEM, caPEM...)
cert, err := tls.X509KeyPair(chain, keyPEM)
if err != nil {
return nil, errors.Wrap(err, "error creating tls certificate")
}
leaf, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "error parsing tls certificate")
}
cert.Leaf = leaf
return &cert, nil
}
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
if sign.TLSOptions != nil {
return sign.TLSOptions.TLSConfig()
}
return &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
// getDefaultTransport returns an http.Transport with the same parameters than
// http.DefaultTransport, but adds the given tls.Config and configures the
// transport for HTTP/2.
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig,
}
if err := http2.ConfigureTransport(tr); err != nil {
return nil, errors.Wrap(err, "error configuring transport")
}
return tr, nil
}
func getPEM(i interface{}) ([]byte, error) {
block := new(pem.Block)
switch i := i.(type) {
case api.Certificate:
block.Type = "CERTIFICATE"
block.Bytes = i.Raw
case *x509.Certificate:
block.Type = "CERTIFICATE"
block.Bytes = i.Raw
case *rsa.PrivateKey:
block.Type = "RSA PRIVATE KEY"
block.Bytes = x509.MarshalPKCS1PrivateKey(i)
case *ecdsa.PrivateKey:
var err error
block.Type = "EC PRIVATE KEY"
block.Bytes, err = x509.MarshalECPrivateKey(i)
if err != nil {
return nil, errors.Wrap(err, "error marshaling private key")
}
default:
return nil, errors.Errorf("unsupported key type %T", i)
}
return pem.EncodeToMemory(block), nil
}
func getRenewFunc(client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
return func() (*tls.Certificate, error) {
sign, err := client.Renew(tr)
if err != nil {
return nil, err
}
return TLSCertificate(sign, pk)
}
}