forked from TrueCloudLab/certificates
245 lines
7.1 KiB
Go
245 lines
7.1 KiB
Go
|
package ca
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto"
|
||
|
"crypto/ecdsa"
|
||
|
"crypto/rsa"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"encoding/pem"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pkg/errors"
|
||
|
"github.com/smallstep/ca-component/api"
|
||
|
"golang.org/x/net/http2"
|
||
|
)
|
||
|
|
||
|
// GetClientTLSConfig returns a tls.Config for client use configured with the
|
||
|
// sign certificate, and a new certificate pool with the sign root certificate.
|
||
|
// The certificate will automatically rotate before expiring.
|
||
|
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
||
|
cert, err := TLSCertificate(sign, pk)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
renewer, err := NewTLSRenewer(cert, nil)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
tlsConfig := getDefaultTLSConfig(sign)
|
||
|
// Note that with GetClientCertificate tlsConfig.Certificates is not used.
|
||
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||
|
tlsConfig.PreferServerCipherSuites = true
|
||
|
// Build RootCAs with given root certificate
|
||
|
if pool := c.getCertPool(sign); pool != nil {
|
||
|
tlsConfig.RootCAs = pool
|
||
|
}
|
||
|
|
||
|
// Parse Certificates and build NameToCertificate
|
||
|
tlsConfig.BuildNameToCertificate()
|
||
|
|
||
|
// Update renew function with transport
|
||
|
tr, err := getDefaultTransport(tlsConfig)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
|
||
|
|
||
|
// Start renewer
|
||
|
renewer.RunContext(ctx)
|
||
|
return tlsConfig, nil
|
||
|
}
|
||
|
|
||
|
// GetServerTLSConfig returns a tls.Config for server use configured with the
|
||
|
// sign certificate, and a new certificate pool with the sign root certificate.
|
||
|
// The certificate will automatically rotate before expiring.
|
||
|
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
|
||
|
cert, err := TLSCertificate(sign, pk)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
renewer, err := NewTLSRenewer(cert, nil)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
tlsConfig := getDefaultTLSConfig(sign)
|
||
|
// Note that GetCertificate will only be called if the client supplies SNI
|
||
|
// information or if tlsConfig.Certificates is empty.
|
||
|
tlsConfig.GetCertificate = renewer.GetCertificate
|
||
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||
|
tlsConfig.PreferServerCipherSuites = true
|
||
|
// Build RootCAs with given root certificate
|
||
|
if pool := c.getCertPool(sign); pool != nil {
|
||
|
tlsConfig.ClientCAs = pool
|
||
|
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
||
|
// Add RootCAs for refresh client
|
||
|
tlsConfig.RootCAs = pool
|
||
|
}
|
||
|
|
||
|
// Update renew function with transport
|
||
|
tr, err := getDefaultTransport(tlsConfig)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
|
||
|
|
||
|
// Start renewer
|
||
|
renewer.RunContext(ctx)
|
||
|
return tlsConfig, nil
|
||
|
}
|
||
|
|
||
|
// Transport returns an http.Transport configured to use the client certificate from the sign response.
|
||
|
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
|
||
|
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return getDefaultTransport(tlsConfig)
|
||
|
}
|
||
|
|
||
|
// getCertPool returns the transport x509.CertPool or the one from the sign
|
||
|
// request.
|
||
|
func (c *Client) getCertPool(sign *api.SignResponse) *x509.CertPool {
|
||
|
// Return the transport certPool
|
||
|
if c.certPool != nil {
|
||
|
return c.certPool
|
||
|
}
|
||
|
// Return certificate used in sign request.
|
||
|
if root, err := RootCertificate(sign); err == nil {
|
||
|
pool := x509.NewCertPool()
|
||
|
pool.AddCert(root)
|
||
|
return pool
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Certificate returns the server or client certificate from the sign response.
|
||
|
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||
|
if sign.ServerPEM.Certificate == nil {
|
||
|
return nil, errors.New("ca: certificate does not exists")
|
||
|
}
|
||
|
return sign.ServerPEM.Certificate, nil
|
||
|
}
|
||
|
|
||
|
// IntermediateCertificate returns the CA intermediate certificate from the sign
|
||
|
// response.
|
||
|
func IntermediateCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||
|
if sign.CaPEM.Certificate == nil {
|
||
|
return nil, errors.New("ca: certificate does not exists")
|
||
|
}
|
||
|
return sign.CaPEM.Certificate, nil
|
||
|
}
|
||
|
|
||
|
// RootCertificate returns the root certificate from the sign response.
|
||
|
func RootCertificate(sign *api.SignResponse) (*x509.Certificate, error) {
|
||
|
if sign.TLS == nil || len(sign.TLS.VerifiedChains) == 0 {
|
||
|
return nil, errors.New("ca: certificate does not exists")
|
||
|
}
|
||
|
lastChain := sign.TLS.VerifiedChains[len(sign.TLS.VerifiedChains)-1]
|
||
|
if len(lastChain) == 0 {
|
||
|
return nil, errors.New("ca: certificate does not exists")
|
||
|
}
|
||
|
return lastChain[len(lastChain)-1], nil
|
||
|
}
|
||
|
|
||
|
// TLSCertificate creates a new TLS certificate from the sign response and the
|
||
|
// private key used.
|
||
|
func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certificate, error) {
|
||
|
certPEM, err := getPEM(sign.ServerPEM)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
caPEM, err := getPEM(sign.CaPEM)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
keyPEM, err := getPEM(pk)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
chain := append(certPEM, caPEM...)
|
||
|
cert, err := tls.X509KeyPair(chain, keyPEM)
|
||
|
if err != nil {
|
||
|
return nil, errors.Wrap(err, "error creating tls certificate")
|
||
|
}
|
||
|
leaf, err := x509.ParseCertificate(cert.Certificate[0])
|
||
|
if err != nil {
|
||
|
return nil, errors.Wrap(err, "error parsing tls certificate")
|
||
|
}
|
||
|
cert.Leaf = leaf
|
||
|
return &cert, nil
|
||
|
}
|
||
|
|
||
|
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
||
|
if sign.TLSOptions != nil {
|
||
|
return sign.TLSOptions.TLSConfig()
|
||
|
}
|
||
|
return &tls.Config{
|
||
|
MinVersion: tls.VersionTLS12,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// getDefaultTransport returns an http.Transport with the same parameters than
|
||
|
// http.DefaultTransport, but adds the given tls.Config and configures the
|
||
|
// transport for HTTP/2.
|
||
|
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
|
||
|
tr := &http.Transport{
|
||
|
Proxy: http.ProxyFromEnvironment,
|
||
|
DialContext: (&net.Dialer{
|
||
|
Timeout: 30 * time.Second,
|
||
|
KeepAlive: 30 * time.Second,
|
||
|
DualStack: true,
|
||
|
}).DialContext,
|
||
|
MaxIdleConns: 100,
|
||
|
IdleConnTimeout: 90 * time.Second,
|
||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||
|
ExpectContinueTimeout: 1 * time.Second,
|
||
|
TLSClientConfig: tlsConfig,
|
||
|
}
|
||
|
if err := http2.ConfigureTransport(tr); err != nil {
|
||
|
return nil, errors.Wrap(err, "error configuring transport")
|
||
|
}
|
||
|
return tr, nil
|
||
|
}
|
||
|
|
||
|
func getPEM(i interface{}) ([]byte, error) {
|
||
|
block := new(pem.Block)
|
||
|
switch i := i.(type) {
|
||
|
case api.Certificate:
|
||
|
block.Type = "CERTIFICATE"
|
||
|
block.Bytes = i.Raw
|
||
|
case *x509.Certificate:
|
||
|
block.Type = "CERTIFICATE"
|
||
|
block.Bytes = i.Raw
|
||
|
case *rsa.PrivateKey:
|
||
|
block.Type = "RSA PRIVATE KEY"
|
||
|
block.Bytes = x509.MarshalPKCS1PrivateKey(i)
|
||
|
case *ecdsa.PrivateKey:
|
||
|
var err error
|
||
|
block.Type = "EC PRIVATE KEY"
|
||
|
block.Bytes, err = x509.MarshalECPrivateKey(i)
|
||
|
if err != nil {
|
||
|
return nil, errors.Wrap(err, "error marshaling private key")
|
||
|
}
|
||
|
default:
|
||
|
return nil, errors.Errorf("unsupported key type %T", i)
|
||
|
}
|
||
|
return pem.EncodeToMemory(block), nil
|
||
|
}
|
||
|
|
||
|
func getRenewFunc(client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
|
||
|
return func() (*tls.Certificate, error) {
|
||
|
sign, err := client.Renew(tr)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return TLSCertificate(sign, pk)
|
||
|
}
|
||
|
}
|