forked from TrueCloudLab/certificates
Merge pull request #531 from smallstep/tls-tunnel
Add experimental support for a TLS over TLS tunnel.
This commit is contained in:
commit
582d6b161d
7 changed files with 287 additions and 97 deletions
27
ca/client.go
27
ca/client.go
|
@ -56,10 +56,7 @@ func newClient(transport http.RoundTripper) *uaClient {
|
|||
func newInsecureClient() *uaClient {
|
||||
return &uaClient{
|
||||
Client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -99,12 +96,13 @@ type RetryFunc func(code int) bool
|
|||
type ClientOption func(o *clientOptions) error
|
||||
|
||||
type clientOptions struct {
|
||||
transport http.RoundTripper
|
||||
rootSHA256 string
|
||||
rootFilename string
|
||||
rootBundle []byte
|
||||
certificate tls.Certificate
|
||||
retryFunc RetryFunc
|
||||
transport http.RoundTripper
|
||||
rootSHA256 string
|
||||
rootFilename string
|
||||
rootBundle []byte
|
||||
certificate tls.Certificate
|
||||
getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
|
||||
retryFunc RetryFunc
|
||||
}
|
||||
|
||||
func (o *clientOptions) apply(opts []ClientOption) (err error) {
|
||||
|
@ -139,6 +137,7 @@ func (o *clientOptions) applyDefaultIdentity() error {
|
|||
return nil
|
||||
}
|
||||
o.certificate = crt
|
||||
o.getClientCertificate = i.GetClientCertificateFunc()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -193,6 +192,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
|
|||
}
|
||||
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
|
||||
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
|
||||
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
|
||||
}
|
||||
case *http2.Transport:
|
||||
if tr.TLSClientConfig == nil {
|
||||
|
@ -200,6 +200,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
|
|||
}
|
||||
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
|
||||
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
|
||||
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
|
||||
}
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported transport type %T", tr)
|
||||
|
@ -288,7 +289,7 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) {
|
|||
MinVersion: tls.VersionTLS12,
|
||||
PreferServerCipherSuites: true,
|
||||
RootCAs: pool,
|
||||
})
|
||||
}), nil
|
||||
}
|
||||
|
||||
func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
|
||||
|
@ -307,7 +308,7 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
|
|||
MinVersion: tls.VersionTLS12,
|
||||
PreferServerCipherSuites: true,
|
||||
RootCAs: pool,
|
||||
})
|
||||
}), nil
|
||||
}
|
||||
|
||||
func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
|
||||
|
@ -319,7 +320,7 @@ func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
|
|||
MinVersion: tls.VersionTLS12,
|
||||
PreferServerCipherSuites: true,
|
||||
RootCAs: pool,
|
||||
})
|
||||
}), nil
|
||||
}
|
||||
|
||||
// parseEndpoint parses and validates the given endpoint. It supports general
|
||||
|
|
|
@ -26,9 +26,16 @@ type Type string
|
|||
// Disabled represents a disabled identity type
|
||||
const Disabled Type = ""
|
||||
|
||||
// MutualTLS represents the identity using mTLS
|
||||
// MutualTLS represents the identity using mTLS.
|
||||
const MutualTLS Type = "mTLS"
|
||||
|
||||
// TunnelTLS represents an identity using a (m)TLS tunnel.
|
||||
//
|
||||
// TunnelTLS can be optionally configured with client certificates and a root
|
||||
// file with the CAs to trust. By default it will use the system truststore
|
||||
// instead of the CA truststore.
|
||||
const TunnelTLS Type = "tTLS"
|
||||
|
||||
// DefaultLeeway is the duration for matching not before claims.
|
||||
const DefaultLeeway = 1 * time.Minute
|
||||
|
||||
|
@ -44,19 +51,30 @@ type Identity struct {
|
|||
Type string `json:"type"`
|
||||
Certificate string `json:"crt"`
|
||||
Key string `json:"key"`
|
||||
|
||||
// Host is the tunnel host for a TunnelTLS (tTLS) identity.
|
||||
Host string `json:"host,omitempty"`
|
||||
// Root is the CA bundle of root CAs used in TunnelTLS to trust the
|
||||
// certificate of the host.
|
||||
Root string `json:"root,omitempty"`
|
||||
}
|
||||
|
||||
// LoadIdentity loads an identity present in the given filename.
|
||||
func LoadIdentity(filename string) (*Identity, error) {
|
||||
b, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", filename)
|
||||
}
|
||||
identity := new(Identity)
|
||||
if err := json.Unmarshal(b, &identity); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling %s", filename)
|
||||
}
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
// LoadDefaultIdentity loads the default identity.
|
||||
func LoadDefaultIdentity() (*Identity, error) {
|
||||
b, err := ioutil.ReadFile(IdentityFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", IdentityFile)
|
||||
}
|
||||
identity := new(Identity)
|
||||
if err := json.Unmarshal(b, &identity); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling %s", IdentityFile)
|
||||
}
|
||||
return identity, nil
|
||||
return LoadIdentity(IdentityFile)
|
||||
}
|
||||
|
||||
// configDir and identityDir are used in WriteDefaultIdentity for testing
|
||||
|
@ -81,7 +99,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
|
|||
keyFilename := filepath.Join(identityDir, "identity_key")
|
||||
|
||||
// Write certificate
|
||||
if err := WriteIdentityCertificate(certChain); err != nil {
|
||||
if err := writeCertificate(certFilename, certChain); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -116,22 +134,21 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
|
|||
return nil
|
||||
}
|
||||
|
||||
// WriteIdentityCertificate writes the identity certificate in disk.
|
||||
func WriteIdentityCertificate(certChain []api.Certificate) error {
|
||||
// writeCertificate writes the given certificate on disk.
|
||||
func writeCertificate(filename string, certChain []api.Certificate) error {
|
||||
buf := new(bytes.Buffer)
|
||||
certFilename := filepath.Join(identityDir, "identity.crt")
|
||||
for _, crt := range certChain {
|
||||
block := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: crt.Raw,
|
||||
}
|
||||
if err := pem.Encode(buf, block); err != nil {
|
||||
return errors.Wrap(err, "error encoding identity certificate")
|
||||
return errors.Wrap(err, "error encoding certificate")
|
||||
}
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
|
||||
return errors.Wrap(err, "error writing identity certificate")
|
||||
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil {
|
||||
return errors.Wrap(err, "error writing certificate")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -144,6 +161,8 @@ func (i *Identity) Kind() Type {
|
|||
return Disabled
|
||||
case "mtls":
|
||||
return MutualTLS
|
||||
case "ttls":
|
||||
return TunnelTLS
|
||||
default:
|
||||
return Type(i.Type)
|
||||
}
|
||||
|
@ -164,8 +183,26 @@ func (i *Identity) Validate() error {
|
|||
if err := fileExists(i.Certificate); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileExists(i.Key); err != nil {
|
||||
return err
|
||||
return fileExists(i.Key)
|
||||
case TunnelTLS:
|
||||
if i.Host == "" {
|
||||
return errors.New("tunnel.host cannot be empty")
|
||||
}
|
||||
if i.Certificate != "" {
|
||||
if err := fileExists(i.Certificate); err != nil {
|
||||
return err
|
||||
}
|
||||
if i.Key == "" {
|
||||
return errors.New("tunnel.key cannot be empty")
|
||||
}
|
||||
if err := fileExists(i.Key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if i.Root != "" {
|
||||
if err := fileExists(i.Root); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
|
@ -179,7 +216,7 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) {
|
|||
switch i.Kind() {
|
||||
case Disabled:
|
||||
return tls.Certificate{}, nil
|
||||
case MutualTLS:
|
||||
case MutualTLS, TunnelTLS:
|
||||
crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key)
|
||||
if err != nil {
|
||||
return fail(errors.Wrap(err, "error creating identity certificate"))
|
||||
|
@ -215,6 +252,22 @@ func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo)
|
|||
}
|
||||
}
|
||||
|
||||
// GetCertPool returns a x509.CertPool if the identity defines a custom root.
|
||||
func (i *Identity) GetCertPool() (*x509.CertPool, error) {
|
||||
if i.Root == "" {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := ioutil.ReadFile(i.Root)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error reading identity root")
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(b) {
|
||||
return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Renewer is that interface that a renew client must implement.
|
||||
type Renewer interface {
|
||||
GetRootCAs() *x509.CertPool
|
||||
|
@ -227,7 +280,7 @@ func (i *Identity) Renew(client Renewer) error {
|
|||
switch i.Kind() {
|
||||
case Disabled:
|
||||
return nil
|
||||
case MutualTLS:
|
||||
case MutualTLS, TunnelTLS:
|
||||
cert, err := i.TLSCertificate()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -63,6 +63,7 @@ func TestIdentity_Kind(t *testing.T) {
|
|||
}{
|
||||
{"disabled", fields{""}, Disabled},
|
||||
{"mutualTLS", fields{"mTLS"}, MutualTLS},
|
||||
{"tunnelTLS", fields{"tTLS"}, TunnelTLS},
|
||||
{"unknown", fields{"unknown"}, Type("unknown")},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -82,19 +83,27 @@ func TestIdentity_Validate(t *testing.T) {
|
|||
Type string
|
||||
Certificate string
|
||||
Key string
|
||||
Host string
|
||||
Root string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, false},
|
||||
{"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false},
|
||||
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false},
|
||||
{"ok disabled", fields{}, false},
|
||||
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, true},
|
||||
{"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key"}, true},
|
||||
{"fail key", fields{"mTLS", "testdata/identity/identity.crt", ""}, true},
|
||||
{"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key"}, true},
|
||||
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key"}, true},
|
||||
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true},
|
||||
{"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true},
|
||||
{"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true},
|
||||
{"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
|
||||
{"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true},
|
||||
{"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
|
||||
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true},
|
||||
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
|
||||
{"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true},
|
||||
{"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -102,6 +111,8 @@ func TestIdentity_Validate(t *testing.T) {
|
|||
Type: tt.fields.Type,
|
||||
Certificate: tt.fields.Certificate,
|
||||
Key: tt.fields.Key,
|
||||
Host: tt.fields.Host,
|
||||
Root: tt.fields.Root,
|
||||
}
|
||||
if err := i.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
@ -127,7 +138,8 @@ func TestIdentity_TLSCertificate(t *testing.T) {
|
|||
want tls.Certificate
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
|
||||
{"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
|
||||
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
|
||||
{"ok disabled", fields{}, tls.Certificate{}, false},
|
||||
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
|
||||
{"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
|
||||
|
@ -255,6 +267,95 @@ func TestWriteDefaultIdentity(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestIdentity_GetClientCertificateFunc(t *testing.T) {
|
||||
expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
Type string
|
||||
Certificate string
|
||||
Key string
|
||||
Host string
|
||||
Root string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want *tls.Certificate
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false},
|
||||
{"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false},
|
||||
{"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true},
|
||||
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
i := &Identity{
|
||||
Type: tt.fields.Type,
|
||||
Certificate: tt.fields.Certificate,
|
||||
Key: tt.fields.Key,
|
||||
Host: tt.fields.Host,
|
||||
Root: tt.fields.Root,
|
||||
}
|
||||
fn := i.GetClientCertificateFunc()
|
||||
got, err := fn(&tls.CertificateRequestInfo{})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentity_GetCertPool(t *testing.T) {
|
||||
type fields struct {
|
||||
Type string
|
||||
Certificate string
|
||||
Key string
|
||||
Host string
|
||||
Root string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantSubjects [][]byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false},
|
||||
{"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false},
|
||||
{"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true},
|
||||
{"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
i := &Identity{
|
||||
Type: tt.fields.Type,
|
||||
Certificate: tt.fields.Certificate,
|
||||
Key: tt.fields.Key,
|
||||
Host: tt.fields.Host,
|
||||
Root: tt.fields.Root,
|
||||
}
|
||||
got, err := i.GetCertPool()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != nil {
|
||||
subjects := got.Subjects()
|
||||
if !reflect.DeepEqual(subjects, tt.wantSubjects) {
|
||||
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type renewer struct {
|
||||
pool *x509.CertPool
|
||||
sign *api.SignResponse
|
||||
|
|
7
ca/identity/testdata/config/tunnel.json
vendored
Normal file
7
ca/identity/testdata/config/tunnel.json
vendored
Normal file
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"type": "mTLS",
|
||||
"crt": "testdata/identity/identity.crt",
|
||||
"key": "testdata/identity/identity_key",
|
||||
"host": "tunnel:443",
|
||||
"root": "testdata/certs/root_ca.crt"
|
||||
}
|
113
ca/tls.go
113
ca/tls.go
|
@ -10,13 +10,65 @@ import (
|
|||
"encoding/pem"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"golang.org/x/net/http2"
|
||||
"github.com/smallstep/certificates/ca/identity"
|
||||
)
|
||||
|
||||
// mTLSDialContext will hold the dial context function to use in
|
||||
// getDefaultTransport.
|
||||
var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
func init() {
|
||||
// STEP_TLS_TUNNEL is an environment variable that can be set to do an TLS
|
||||
// over (m)TLS tunnel to step-ca using identity-like credentials. The value
|
||||
// is a path to a json file with the tunnel host, certificate, key and root
|
||||
// used to create the (m)TLS tunnel.
|
||||
//
|
||||
// The configuration should look like:
|
||||
// {
|
||||
// "type": "tTLS",
|
||||
// "host": "tunnel.example.com:443"
|
||||
// "crt": "/path/to/tunnel.crt",
|
||||
// "key": "/path/to/tunnel.key",
|
||||
// "root": "/path/to/tunnel-root.crt"
|
||||
// }
|
||||
//
|
||||
// This feature is EXPERIMENTAL and might change at any time.
|
||||
if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" {
|
||||
id, err := identity.LoadIdentity(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := id.Validate(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
host, port, err := net.SplitHostPort(id.Host)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pool, err := id.GetCertPool()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := &tls.Dialer{
|
||||
NetDialer: getDefaultDialer(),
|
||||
Config: &tls.Config{
|
||||
RootCAs: pool,
|
||||
GetClientCertificate: id.GetClientCertificateFunc(),
|
||||
},
|
||||
}
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientTLSConfig returns a tls.Config for client use configured with the
|
||||
// sign certificate, and a new certificate pool with the sign root certificate.
|
||||
// The client certificate will automatically rotate before expiring.
|
||||
|
@ -51,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
}
|
||||
|
||||
// Update renew function with transport
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tr := getDefaultTransport(tlsConfig)
|
||||
// Use mutable tls.Config on renew
|
||||
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
||||
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
||||
|
@ -103,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
|
||||
|
||||
// Update renew function with transport
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr := getDefaultTransport(tlsConfig)
|
||||
// Use mutable tls.Config on renew
|
||||
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
||||
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
||||
|
@ -144,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell
|
|||
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
|
||||
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
|
||||
return func(network, addr string) (net.Conn, error) {
|
||||
return tls.DialWithDialer(&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}, network, addr, ctx.mutableConfig.TLSConfig())
|
||||
return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
|
|||
// nolint:unused
|
||||
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
d := getDefaultDialer()
|
||||
// TLS dialers do not support context, but we can use the context
|
||||
// deadline if it is set.
|
||||
var deadline time.Time
|
||||
if t, ok := ctx.Deadline(); ok {
|
||||
deadline = t
|
||||
d.Deadline = t
|
||||
}
|
||||
return tls.DialWithDialer(&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Deadline: deadline,
|
||||
DualStack: true,
|
||||
}, network, addr, tlsCtx.mutableConfig.TLSConfig())
|
||||
return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -238,27 +275,35 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
|||
}
|
||||
}
|
||||
|
||||
// getDefaultDialer returns a new dialer with the default configuration.
|
||||
func getDefaultDialer() *net.Dialer {
|
||||
return &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultTransport returns an http.Transport with the same parameters than
|
||||
// http.DefaultTransport, but adds the given tls.Config and configures the
|
||||
// transport for HTTP/2.
|
||||
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
|
||||
tr := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
func getDefaultTransport(tlsConfig *tls.Config) *http.Transport {
|
||||
var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||
if mTLSDialContext == nil {
|
||||
d := getDefaultDialer()
|
||||
dialContext = d.DialContext
|
||||
} else {
|
||||
dialContext = mTLSDialContext()
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
if err := http2.ConfigureTransport(tr); err != nil {
|
||||
return nil, errors.Wrap(err, "error configuring transport")
|
||||
}
|
||||
return tr, nil
|
||||
}
|
||||
|
||||
func getPEM(i interface{}) ([]byte, error) {
|
||||
|
|
|
@ -181,13 +181,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|||
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
|
||||
return nil
|
||||
}
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
t.Errorf("getDefaultTransport() error = %v", err)
|
||||
return nil
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: tr,
|
||||
Transport: getDefaultTransport(tlsConfig),
|
||||
}
|
||||
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
|
||||
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
|
@ -199,14 +194,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
|
|||
tlsConfig := getDefaultTLSConfig(sr)
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
tlsConfig.RootCAs.AddCert(root)
|
||||
|
||||
tr, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
t.Errorf("getDefaultTransport() error = %v", err)
|
||||
return nil
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: tr,
|
||||
Transport: getDefaultTransport(tlsConfig),
|
||||
}
|
||||
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
|
||||
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
|
||||
|
@ -288,10 +277,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
|
||||
}
|
||||
tr2, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("getDefaultTransport() error = %v", err)
|
||||
}
|
||||
tr2 := getDefaultTransport(tlsConfig)
|
||||
// No client cert
|
||||
root, err := RootCertificate(sr)
|
||||
if err != nil {
|
||||
|
@ -300,10 +286,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
|
|||
tlsConfig = getDefaultTLSConfig(sr)
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
tlsConfig.RootCAs.AddCert(root)
|
||||
tr3, err := getDefaultTransport(tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("getDefaultTransport() error = %v", err)
|
||||
}
|
||||
tr3 := getDefaultTransport(tlsConfig)
|
||||
|
||||
// Disable keep alives to force TLS handshake
|
||||
tr1.DisableKeepAlives = true
|
||||
|
|
|
@ -313,7 +313,7 @@ func getSlotAndName(name string) (piv.Slot, string, error) {
|
|||
|
||||
s, ok := slotMapping[slotID]
|
||||
if !ok {
|
||||
return piv.Slot{}, "", errors.Errorf("usupported slot-id '%s'", name)
|
||||
return piv.Slot{}, "", errors.Errorf("unsupported slot-id '%s'", name)
|
||||
}
|
||||
|
||||
name = "yubikey:slot-id=" + url.QueryEscape(slotID)
|
||||
|
|
Loading…
Reference in a new issue