Merge branch 'master' into hs/scep

This commit is contained in:
Herman Slatman 2021-04-29 22:18:00 +02:00
commit 68d5f6d0d2
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
11 changed files with 312 additions and 105 deletions

View file

@ -91,8 +91,8 @@ func (h *Handler) Route(r api.Router) {
// Standard ACME API
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))

View file

@ -156,14 +156,15 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error creating certificate", opts...)
}
if err = a.db.StoreCertificate(resp.Certificate); err != nil {
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeCertificate(fullchain); err != nil {
if err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Sign; error storing certificate in db", opts...)
}
}
return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil
return fullchain, nil
}
// Renew creates a new Certificate identical to the old certificate, except
@ -261,13 +262,29 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
}
if err = a.db.StoreCertificate(resp.Certificate); err != nil {
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeCertificate(fullchain); err != nil {
if err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
}
}
return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil
return fullchain, nil
}
// storeCertificate allows to use an extension of the db.AuthDB interface that
// can log the full chain of certificates.
//
// TODO: at some point we should replace the db.AuthDB interface to implement
// `StoreCertificate(...*x509.Certificate) error` instead of just
// `StoreCertificate(*x509.Certificate) error`.
func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
if s, ok := a.db.(interface {
StoreCertificateChain(...*x509.Certificate) error
}); ok {
return s.StoreCertificateChain(fullchain...)
}
return a.db.StoreCertificate(fullchain[0])
}
// RevokeOptions are the options for the Revoke API.

View file

@ -56,10 +56,7 @@ func newClient(transport http.RoundTripper) *uaClient {
func newInsecureClient() *uaClient {
return &uaClient{
Client: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}),
},
}
}
@ -99,12 +96,13 @@ type RetryFunc func(code int) bool
type ClientOption func(o *clientOptions) error
type clientOptions struct {
transport http.RoundTripper
rootSHA256 string
rootFilename string
rootBundle []byte
certificate tls.Certificate
retryFunc RetryFunc
transport http.RoundTripper
rootSHA256 string
rootFilename string
rootBundle []byte
certificate tls.Certificate
getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
retryFunc RetryFunc
}
func (o *clientOptions) apply(opts []ClientOption) (err error) {
@ -139,6 +137,7 @@ func (o *clientOptions) applyDefaultIdentity() error {
return nil
}
o.certificate = crt
o.getClientCertificate = i.GetClientCertificateFunc()
return nil
}
@ -193,6 +192,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
}
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
}
case *http2.Transport:
if tr.TLSClientConfig == nil {
@ -200,6 +200,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
}
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
}
default:
return nil, errors.Errorf("unsupported transport type %T", tr)
@ -288,7 +289,7 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}), nil
}
func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
@ -307,7 +308,7 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}), nil
}
func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
@ -319,7 +320,7 @@ func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}), nil
}
// parseEndpoint parses and validates the given endpoint. It supports general

View file

@ -26,9 +26,16 @@ type Type string
// Disabled represents a disabled identity type
const Disabled Type = ""
// MutualTLS represents the identity using mTLS
// MutualTLS represents the identity using mTLS.
const MutualTLS Type = "mTLS"
// TunnelTLS represents an identity using a (m)TLS tunnel.
//
// TunnelTLS can be optionally configured with client certificates and a root
// file with the CAs to trust. By default it will use the system truststore
// instead of the CA truststore.
const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute
@ -44,19 +51,30 @@ type Identity struct {
Type string `json:"type"`
Certificate string `json:"crt"`
Key string `json:"key"`
// Host is the tunnel host for a TunnelTLS (tTLS) identity.
Host string `json:"host,omitempty"`
// Root is the CA bundle of root CAs used in TunnelTLS to trust the
// certificate of the host.
Root string `json:"root,omitempty"`
}
// LoadIdentity loads an identity present in the given filename.
func LoadIdentity(filename string) (*Identity, error) {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename)
}
identity := new(Identity)
if err := json.Unmarshal(b, &identity); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", filename)
}
return identity, nil
}
// LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) {
b, err := ioutil.ReadFile(IdentityFile)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", IdentityFile)
}
identity := new(Identity)
if err := json.Unmarshal(b, &identity); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", IdentityFile)
}
return identity, nil
return LoadIdentity(IdentityFile)
}
// configDir and identityDir are used in WriteDefaultIdentity for testing
@ -81,7 +99,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
keyFilename := filepath.Join(identityDir, "identity_key")
// Write certificate
if err := WriteIdentityCertificate(certChain); err != nil {
if err := writeCertificate(certFilename, certChain); err != nil {
return err
}
@ -116,22 +134,21 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
return nil
}
// WriteIdentityCertificate writes the identity certificate in disk.
func WriteIdentityCertificate(certChain []api.Certificate) error {
// writeCertificate writes the given certificate on disk.
func writeCertificate(filename string, certChain []api.Certificate) error {
buf := new(bytes.Buffer)
certFilename := filepath.Join(identityDir, "identity.crt")
for _, crt := range certChain {
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: crt.Raw,
}
if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity certificate")
return errors.Wrap(err, "error encoding certificate")
}
}
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing certificate")
}
return nil
@ -144,6 +161,8 @@ func (i *Identity) Kind() Type {
return Disabled
case "mtls":
return MutualTLS
case "ttls":
return TunnelTLS
default:
return Type(i.Type)
}
@ -164,8 +183,26 @@ func (i *Identity) Validate() error {
if err := fileExists(i.Certificate); err != nil {
return err
}
if err := fileExists(i.Key); err != nil {
return err
return fileExists(i.Key)
case TunnelTLS:
if i.Host == "" {
return errors.New("tunnel.host cannot be empty")
}
if i.Certificate != "" {
if err := fileExists(i.Certificate); err != nil {
return err
}
if i.Key == "" {
return errors.New("tunnel.key cannot be empty")
}
if err := fileExists(i.Key); err != nil {
return err
}
}
if i.Root != "" {
if err := fileExists(i.Root); err != nil {
return err
}
}
return nil
default:
@ -179,7 +216,7 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) {
switch i.Kind() {
case Disabled:
return tls.Certificate{}, nil
case MutualTLS:
case MutualTLS, TunnelTLS:
crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key)
if err != nil {
return fail(errors.Wrap(err, "error creating identity certificate"))
@ -215,6 +252,22 @@ func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo)
}
}
// GetCertPool returns a x509.CertPool if the identity defines a custom root.
func (i *Identity) GetCertPool() (*x509.CertPool, error) {
if i.Root == "" {
return nil, nil
}
b, err := ioutil.ReadFile(i.Root)
if err != nil {
return nil, errors.Wrap(err, "error reading identity root")
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(b) {
return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root)
}
return pool, nil
}
// Renewer is that interface that a renew client must implement.
type Renewer interface {
GetRootCAs() *x509.CertPool
@ -227,7 +280,7 @@ func (i *Identity) Renew(client Renewer) error {
switch i.Kind() {
case Disabled:
return nil
case MutualTLS:
case MutualTLS, TunnelTLS:
cert, err := i.TLSCertificate()
if err != nil {
return err

View file

@ -63,6 +63,7 @@ func TestIdentity_Kind(t *testing.T) {
}{
{"disabled", fields{""}, Disabled},
{"mutualTLS", fields{"mTLS"}, MutualTLS},
{"tunnelTLS", fields{"tTLS"}, TunnelTLS},
{"unknown", fields{"unknown"}, Type("unknown")},
}
for _, tt := range tests {
@ -82,19 +83,27 @@ func TestIdentity_Validate(t *testing.T) {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, false},
{"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false},
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false},
{"ok disabled", fields{}, false},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, true},
{"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key"}, true},
{"fail key", fields{"mTLS", "testdata/identity/identity.crt", ""}, true},
{"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key"}, true},
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key"}, true},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true},
{"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true},
{"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true},
{"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true},
{"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true},
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true},
{"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -102,6 +111,8 @@ func TestIdentity_Validate(t *testing.T) {
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
if err := i.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr)
@ -127,7 +138,8 @@ func TestIdentity_TLSCertificate(t *testing.T) {
want tls.Certificate
wantErr bool
}{
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok disabled", fields{}, tls.Certificate{}, false},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
{"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
@ -255,6 +267,95 @@ func TestWriteDefaultIdentity(t *testing.T) {
}
}
func TestIdentity_GetClientCertificateFunc(t *testing.T) {
expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key")
if err != nil {
t.Fatal(err)
}
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
want *tls.Certificate
wantErr bool
}{
{"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false},
{"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false},
{"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true},
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := &Identity{
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
fn := i.GetClientCertificateFunc()
got, err := fn(&tls.CertificateRequestInfo{})
if (err != nil) != tt.wantErr {
t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want)
}
})
}
}
func TestIdentity_GetCertPool(t *testing.T) {
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
wantSubjects [][]byte
wantErr bool
}{
{"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false},
{"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false},
{"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true},
{"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := &Identity{
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
got, err := i.GetCertPool()
if (err != nil) != tt.wantErr {
t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
subjects := got.Subjects()
if !reflect.DeepEqual(subjects, tt.wantSubjects) {
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)
}
}
})
}
}
type renewer struct {
pool *x509.CertPool
sign *api.SignResponse

View file

@ -0,0 +1,7 @@
{
"type": "mTLS",
"crt": "testdata/identity/identity.crt",
"key": "testdata/identity/identity_key",
"host": "tunnel:443",
"root": "testdata/certs/root_ca.crt"
}

113
ca/tls.go
View file

@ -10,13 +10,65 @@ import (
"encoding/pem"
"net"
"net/http"
"os"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"golang.org/x/net/http2"
"github.com/smallstep/certificates/ca/identity"
)
// mTLSDialContext will hold the dial context function to use in
// getDefaultTransport.
var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error)
func init() {
// STEP_TLS_TUNNEL is an environment variable that can be set to do an TLS
// over (m)TLS tunnel to step-ca using identity-like credentials. The value
// is a path to a json file with the tunnel host, certificate, key and root
// used to create the (m)TLS tunnel.
//
// The configuration should look like:
// {
// "type": "tTLS",
// "host": "tunnel.example.com:443"
// "crt": "/path/to/tunnel.crt",
// "key": "/path/to/tunnel.key",
// "root": "/path/to/tunnel-root.crt"
// }
//
// This feature is EXPERIMENTAL and might change at any time.
if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" {
id, err := identity.LoadIdentity(path)
if err != nil {
panic(err)
}
if err := id.Validate(); err != nil {
panic(err)
}
host, port, err := net.SplitHostPort(id.Host)
if err != nil {
panic(err)
}
pool, err := id.GetCertPool()
if err != nil {
panic(err)
}
mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) {
d := &tls.Dialer{
NetDialer: getDefaultDialer(),
Config: &tls.Config{
RootCAs: pool,
GetClientCertificate: id.GetClientCertificateFunc(),
},
}
return func(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
}
}
}
}
// GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate.
// The client certificate will automatically rotate before expiring.
@ -51,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
}
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, nil, err
}
tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
@ -103,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
// Update renew function with transport
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
return nil, err
}
tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
@ -144,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
return func(network, addr string) (net.Conn, error) {
return tls.DialWithDialer(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}, network, addr, ctx.mutableConfig.TLSConfig())
return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig())
}
}
@ -156,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
// nolint:unused
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
d := getDefaultDialer()
// TLS dialers do not support context, but we can use the context
// deadline if it is set.
var deadline time.Time
if t, ok := ctx.Deadline(); ok {
deadline = t
d.Deadline = t
}
return tls.DialWithDialer(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Deadline: deadline,
DualStack: true,
}, network, addr, tlsCtx.mutableConfig.TLSConfig())
return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig())
}
}
@ -238,27 +275,35 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
}
}
// getDefaultDialer returns a new dialer with the default configuration.
func getDefaultDialer() *net.Dialer {
return &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
}
// getDefaultTransport returns an http.Transport with the same parameters than
// http.DefaultTransport, but adds the given tls.Config and configures the
// transport for HTTP/2.
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
func getDefaultTransport(tlsConfig *tls.Config) *http.Transport {
var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
if mTLSDialContext == nil {
d := getDefaultDialer()
dialContext = d.DialContext
} else {
dialContext = mTLSDialContext()
}
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig,
}
if err := http2.ConfigureTransport(tr); err != nil {
return nil, errors.Wrap(err, "error configuring transport")
}
return tr, nil
}
func getPEM(i interface{}) ([]byte, error) {

View file

@ -181,13 +181,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
t.Errorf("Client.GetClientTLSConfig() error = %v", err)
return nil
}
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
Transport: getDefaultTransport(tlsConfig),
}
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
@ -199,14 +194,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
tlsConfig := getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{
Transport: tr,
Transport: getDefaultTransport(tlsConfig),
}
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
@ -288,10 +277,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
}
tr2, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
tr2 := getDefaultTransport(tlsConfig)
// No client cert
root, err := RootCertificate(sr)
if err != nil {
@ -300,10 +286,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
tlsConfig = getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root)
tr3, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
tr3 := getDefaultTransport(tlsConfig)
// Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true

View file

@ -191,7 +191,7 @@ In the ca.json configuration file, a complete JWK provisioner example looks like
### OIDC
An OIDC provisioner allows a user to get a certificate after authenticating
himself with an OAuth OpenID Connect identity provider. The ID token provided
with an OAuth OpenID Connect identity provider. The ID token provided
will be used on the CA authentication, and by default, the certificate will only
have the user's email as a Subject Alternative Name (SAN) Extension.

View file

@ -313,7 +313,7 @@ func getSlotAndName(name string) (piv.Slot, string, error) {
s, ok := slotMapping[slotID]
if !ok {
return piv.Slot{}, "", errors.Errorf("usupported slot-id '%s'", name)
return piv.Slot{}, "", errors.Errorf("unsupported slot-id '%s'", name)
}
name = "yubikey:slot-id=" + url.QueryEscape(slotID)

View file

@ -26,7 +26,7 @@ ExecStart=/usr/bin/step ca renew --force $CERT_LOCATION $KEY_LOCATION
; Try to reload or restart the systemd service that relies on this cert-renewer
; If the relying service doesn't exist, forge ahead.
ExecStartPost=/usr/bin/env bash -c "if ! systemctl --quiet is-enabled %i.service ; then exit 0; fi; systemctl try-reload-or-restart %i"
ExecStartPost=-systemctl try-reload-or-restart %i
[Install]
WantedBy=multi-user.target