This commit is contained in:
Carl Tashian 2021-05-03 16:19:47 -07:00
commit 0295280c20
9 changed files with 420 additions and 101 deletions

View file

@ -156,14 +156,15 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error creating certificate", opts...)
}
if err = a.db.StoreCertificate(resp.Certificate); err != nil {
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeCertificate(fullchain); err != nil {
if err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Sign; error storing certificate in db", opts...)
}
}
return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil
return fullchain, nil
}
// Renew creates a new Certificate identical to the old certificate, except
@ -261,13 +262,42 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
}
if err = a.db.StoreCertificate(resp.Certificate); err != nil {
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil {
if err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
}
}
return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil
return fullchain, nil
}
// storeCertificate allows to use an extension of the db.AuthDB interface that
// can log the full chain of certificates.
//
// TODO: at some point we should replace the db.AuthDB interface to implement
// `StoreCertificate(...*x509.Certificate) error` instead of just
// `StoreCertificate(*x509.Certificate) error`.
func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
if s, ok := a.db.(interface {
StoreCertificateChain(...*x509.Certificate) error
}); ok {
return s.StoreCertificateChain(fullchain...)
}
return a.db.StoreCertificate(fullchain[0])
}
// storeRenewedCertificate allows to use an extension of the db.AuthDB interface
// that can log if a certificate has been renewed or rekeyed.
//
// TODO: at some point we should implement this in the standard implementation.
func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error {
if s, ok := a.db.(interface {
StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error
}); ok {
return s.StoreRenewedCertificate(oldCert, fullchain...)
}
return a.db.StoreCertificate(fullchain[0])
}
// RevokeOptions are the options for the Revoke API.

View file

@ -56,10 +56,7 @@ func newClient(transport http.RoundTripper) *uaClient {
func newInsecureClient() *uaClient {
return &uaClient{
Client: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}),
},
}
}
@ -104,6 +101,7 @@ type clientOptions struct {
rootFilename string
rootBundle []byte
certificate tls.Certificate
getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
retryFunc RetryFunc
}
@ -139,6 +137,7 @@ func (o *clientOptions) applyDefaultIdentity() error {
return nil
}
o.certificate = crt
o.getClientCertificate = i.GetClientCertificateFunc()
return nil
}
@ -193,6 +192,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
}
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
}
case *http2.Transport:
if tr.TLSClientConfig == nil {
@ -200,6 +200,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
}
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
}
default:
return nil, errors.Errorf("unsupported transport type %T", tr)
@ -288,7 +289,7 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}), nil
}
func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
@ -307,7 +308,7 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}), nil
}
func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
@ -319,7 +320,7 @@ func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}), nil
}
// parseEndpoint parses and validates the given endpoint. It supports general
@ -615,6 +616,36 @@ retry:
return &sign, nil
}
// Rekey performs the rekey request to the CA and returns the api.SignResponse
// struct.
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"})
client := &http.Client{Transport: tr}
retry:
resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; error reading %s", u)
}
return &sign, nil
}
// Revoke performs the revoke request to the CA and returns the api.RevokeResponse
// struct.
func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {

View file

@ -529,6 +529,75 @@ func TestClient_Renew(t *testing.T) {
}
}
func TestClient_Rekey(t *testing.T) {
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)},
{Certificate: parseCertificate(rootPEM)},
},
}
request := &api.RekeyRequest{
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)},
}
tests := []struct {
name string
request *api.RekeyRequest
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Rekey(tt.request, nil)
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.Renew() = %v, want nil", got)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Provisioners(t *testing.T) {
ok := &api.ProvisionersResponse{
Provisioners: provisioner.List{},

View file

@ -26,9 +26,16 @@ type Type string
// Disabled represents a disabled identity type
const Disabled Type = ""
// MutualTLS represents the identity using mTLS
// MutualTLS represents the identity using mTLS.
const MutualTLS Type = "mTLS"
// TunnelTLS represents an identity using a (m)TLS tunnel.
//
// TunnelTLS can be optionally configured with client certificates and a root
// file with the CAs to trust. By default it will use the system truststore
// instead of the CA truststore.
const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute
@ -44,19 +51,30 @@ type Identity struct {
Type string `json:"type"`
Certificate string `json:"crt"`
Key string `json:"key"`
// Host is the tunnel host for a TunnelTLS (tTLS) identity.
Host string `json:"host,omitempty"`
// Root is the CA bundle of root CAs used in TunnelTLS to trust the
// certificate of the host.
Root string `json:"root,omitempty"`
}
// LoadIdentity loads an identity present in the given filename.
func LoadIdentity(filename string) (*Identity, error) {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename)
}
identity := new(Identity)
if err := json.Unmarshal(b, &identity); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", filename)
}
return identity, nil
}
// LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) {
b, err := ioutil.ReadFile(IdentityFile)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", IdentityFile)
}
identity := new(Identity)
if err := json.Unmarshal(b, &identity); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", IdentityFile)
}
return identity, nil
return LoadIdentity(IdentityFile)
}
// configDir and identityDir are used in WriteDefaultIdentity for testing
@ -81,7 +99,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
keyFilename := filepath.Join(identityDir, "identity_key")
// Write certificate
if err := WriteIdentityCertificate(certChain); err != nil {
if err := writeCertificate(certFilename, certChain); err != nil {
return err
}
@ -116,22 +134,21 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
return nil
}
// WriteIdentityCertificate writes the identity certificate in disk.
func WriteIdentityCertificate(certChain []api.Certificate) error {
// writeCertificate writes the given certificate on disk.
func writeCertificate(filename string, certChain []api.Certificate) error {
buf := new(bytes.Buffer)
certFilename := filepath.Join(identityDir, "identity.crt")
for _, crt := range certChain {
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: crt.Raw,
}
if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity certificate")
return errors.Wrap(err, "error encoding certificate")
}
}
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing certificate")
}
return nil
@ -144,6 +161,8 @@ func (i *Identity) Kind() Type {
return Disabled
case "mtls":
return MutualTLS
case "ttls":
return TunnelTLS
default:
return Type(i.Type)
}
@ -164,9 +183,27 @@ func (i *Identity) Validate() error {
if err := fileExists(i.Certificate); err != nil {
return err
}
return fileExists(i.Key)
case TunnelTLS:
if i.Host == "" {
return errors.New("tunnel.host cannot be empty")
}
if i.Certificate != "" {
if err := fileExists(i.Certificate); err != nil {
return err
}
if i.Key == "" {
return errors.New("tunnel.key cannot be empty")
}
if err := fileExists(i.Key); err != nil {
return err
}
}
if i.Root != "" {
if err := fileExists(i.Root); err != nil {
return err
}
}
return nil
default:
return errors.Errorf("unsupported identity type %s", i.Type)
@ -179,7 +216,7 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) {
switch i.Kind() {
case Disabled:
return tls.Certificate{}, nil
case MutualTLS:
case MutualTLS, TunnelTLS:
crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key)
if err != nil {
return fail(errors.Wrap(err, "error creating identity certificate"))
@ -215,6 +252,22 @@ func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo)
}
}
// GetCertPool returns a x509.CertPool if the identity defines a custom root.
func (i *Identity) GetCertPool() (*x509.CertPool, error) {
if i.Root == "" {
return nil, nil
}
b, err := ioutil.ReadFile(i.Root)
if err != nil {
return nil, errors.Wrap(err, "error reading identity root")
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(b) {
return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root)
}
return pool, nil
}
// Renewer is that interface that a renew client must implement.
type Renewer interface {
GetRootCAs() *x509.CertPool
@ -227,7 +280,7 @@ func (i *Identity) Renew(client Renewer) error {
switch i.Kind() {
case Disabled:
return nil
case MutualTLS:
case MutualTLS, TunnelTLS:
cert, err := i.TLSCertificate()
if err != nil {
return err

View file

@ -63,6 +63,7 @@ func TestIdentity_Kind(t *testing.T) {
}{
{"disabled", fields{""}, Disabled},
{"mutualTLS", fields{"mTLS"}, MutualTLS},
{"tunnelTLS", fields{"tTLS"}, TunnelTLS},
{"unknown", fields{"unknown"}, Type("unknown")},
}
for _, tt := range tests {
@ -82,19 +83,27 @@ func TestIdentity_Validate(t *testing.T) {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, false},
{"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false},
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false},
{"ok disabled", fields{}, false},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, true},
{"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key"}, true},
{"fail key", fields{"mTLS", "testdata/identity/identity.crt", ""}, true},
{"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key"}, true},
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key"}, true},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true},
{"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true},
{"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true},
{"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true},
{"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true},
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true},
{"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -102,6 +111,8 @@ func TestIdentity_Validate(t *testing.T) {
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
if err := i.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr)
@ -127,7 +138,8 @@ func TestIdentity_TLSCertificate(t *testing.T) {
want tls.Certificate
wantErr bool
}{
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok disabled", fields{}, tls.Certificate{}, false},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
{"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
@ -255,6 +267,95 @@ func TestWriteDefaultIdentity(t *testing.T) {
}
}
func TestIdentity_GetClientCertificateFunc(t *testing.T) {
expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key")
if err != nil {
t.Fatal(err)
}
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
want *tls.Certificate
wantErr bool
}{
{"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false},
{"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false},
{"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true},
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := &Identity{
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
fn := i.GetClientCertificateFunc()
got, err := fn(&tls.CertificateRequestInfo{})
if (err != nil) != tt.wantErr {
t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want)
}
})
}
}
func TestIdentity_GetCertPool(t *testing.T) {
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
wantSubjects [][]byte
wantErr bool
}{
{"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false},
{"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false},
{"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true},
{"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := &Identity{
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
got, err := i.GetCertPool()
if (err != nil) != tt.wantErr {
t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
subjects := got.Subjects()
if !reflect.DeepEqual(subjects, tt.wantSubjects) {
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)
}
}
})
}
}
type renewer struct {
pool *x509.CertPool
sign *api.SignResponse

View file

@ -0,0 +1,7 @@
{
"type": "mTLS",
"crt": "testdata/identity/identity.crt",
"key": "testdata/identity/identity_key",
"host": "tunnel:443",
"root": "testdata/certs/root_ca.crt"
}

111
ca/tls.go
View file

@ -10,13 +10,65 @@ import (
"encoding/pem"
"net"
"net/http"
"os"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"golang.org/x/net/http2"
"github.com/smallstep/certificates/ca/identity"
)
// mTLSDialContext will hold the dial context function to use in
// getDefaultTransport.
var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error)
func init() {
// STEP_TLS_TUNNEL is an environment variable that can be set to do an TLS
// over (m)TLS tunnel to step-ca using identity-like credentials. The value
// is a path to a json file with the tunnel host, certificate, key and root
// used to create the (m)TLS tunnel.
//
// The configuration should look like:
// {
// "type": "tTLS",
// "host": "tunnel.example.com:443"
// "crt": "/path/to/tunnel.crt",
// "key": "/path/to/tunnel.key",
// "root": "/path/to/tunnel-root.crt"
// }
//
// This feature is EXPERIMENTAL and might change at any time.
if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" {
id, err := identity.LoadIdentity(path)
if err != nil {
panic(err)
}
if err := id.Validate(); err != nil {
panic(err)
}
host, port, err := net.SplitHostPort(id.Host)
if err != nil {
panic(err)
}
pool, err := id.GetCertPool()
if err != nil {
panic(err)
}
mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) {
d := &tls.Dialer{
NetDialer: getDefaultDialer(),
Config: &tls.Config{
RootCAs: pool,
GetClientCertificate: id.GetClientCertificateFunc(),
},
}
return func(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
}
}
}
}
// GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The client certificate will automatically rotate before expiring.
@ -51,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
}
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, nil, err
}
tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
@ -103,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, err
}
tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
@ -144,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
return func(network, addr string) (net.Conn, error) {
return tls.DialWithDialer(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}, network, addr, ctx.mutableConfig.TLSConfig())
return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig())
}
}
@ -156,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
// nolint:unused
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
d := getDefaultDialer()
// TLS dialers do not support context, but we can use the context
// deadline if it is set.
var deadline time.Time
if t, ok := ctx.Deadline(); ok {
deadline = t
d.Deadline = t
}
return tls.DialWithDialer(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Deadline: deadline,
DualStack: true,
}, network, addr, tlsCtx.mutableConfig.TLSConfig())
return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig())
}
}
@ -238,27 +275,35 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
}
}
// getDefaultDialer returns a new dialer with the default configuration.
func getDefaultDialer() *net.Dialer {
return &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
}
// getDefaultTransport returns an http.Transport with the same parameters than
// http.DefaultTransport, but adds the given tls.Config and configures the
// transport for HTTP/2.
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
tr := &http.Transport{
func getDefaultTransport(tlsConfig *tls.Config) *http.Transport {
var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
if mTLSDialContext == nil {
d := getDefaultDialer()
dialContext = d.DialContext
} else {
dialContext = mTLSDialContext()
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
DialContext: dialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig,
}
if err := http2.ConfigureTransport(tr); err != nil {
return nil, errors.Wrap(err, "error configuring transport")
}
return tr, nil
}
func getPEM(i interface{}) ([]byte, error) {

View file

@ -181,13 +181,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
return nil
}
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
Transport: getDefaultTransport(tlsConfig),
}
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
@ -199,14 +194,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
tlsConfig := getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
Transport: getDefaultTransport(tlsConfig),
}
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
@ -288,10 +277,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
tr2 := getDefaultTransport(tlsConfig)
// No client cert
root, err := RootCertificate(sr)
if err != nil {
@ -300,10 +286,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
tlsConfig = getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr3, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
tr3 := getDefaultTransport(tlsConfig)
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true

View file

@ -313,7 +313,7 @@ func getSlotAndName(name string) (piv.Slot, string, error) {
s, ok := slotMapping[slotID]
if !ok {
return piv.Slot{}, "", errors.Errorf("usupported slot-id '%s'", name)
return piv.Slot{}, "", errors.Errorf("unsupported slot-id '%s'", name)
}
name = "yubikey:slot-id=" + url.QueryEscape(slotID)