Merge branch 'master' into hs/scep

This commit is contained in:
Herman Slatman 2021-04-29 22:18:00 +02:00
commit 68d5f6d0d2
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
11 changed files with 312 additions and 105 deletions

View file

@ -91,8 +91,8 @@ func (h *Handler) Route(r api.Router) {
// Standard ACME API // Standard ACME API
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
extractPayloadByJWK := func(next nextHTTP) nextHTTP { extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))

View file

@ -156,14 +156,15 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error creating certificate", opts...) return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error creating certificate", opts...)
} }
if err = a.db.StoreCertificate(resp.Certificate); err != nil { fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeCertificate(fullchain); err != nil {
if err != db.ErrNotImplemented { if err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Sign; error storing certificate in db", opts...) "authority.Sign; error storing certificate in db", opts...)
} }
} }
return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil return fullchain, nil
} }
// Renew creates a new Certificate identical to the old certificate, except // Renew creates a new Certificate identical to the old certificate, except
@ -261,13 +262,29 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...) return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey", opts...)
} }
if err = a.db.StoreCertificate(resp.Certificate); err != nil { fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeCertificate(fullchain); err != nil {
if err != db.ErrNotImplemented { if err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...) return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
} }
} }
return append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...), nil return fullchain, nil
}
// storeCertificate allows to use an extension of the db.AuthDB interface that
// can log the full chain of certificates.
//
// TODO: at some point we should replace the db.AuthDB interface to implement
// `StoreCertificate(...*x509.Certificate) error` instead of just
// `StoreCertificate(*x509.Certificate) error`.
func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
if s, ok := a.db.(interface {
StoreCertificateChain(...*x509.Certificate) error
}); ok {
return s.StoreCertificateChain(fullchain...)
}
return a.db.StoreCertificate(fullchain[0])
} }
// RevokeOptions are the options for the Revoke API. // RevokeOptions are the options for the Revoke API.

View file

@ -56,10 +56,7 @@ func newClient(transport http.RoundTripper) *uaClient {
func newInsecureClient() *uaClient { func newInsecureClient() *uaClient {
return &uaClient{ return &uaClient{
Client: &http.Client{ Client: &http.Client{
Transport: &http.Transport{ Transport: getDefaultTransport(&tls.Config{InsecureSkipVerify: true}),
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}, },
} }
} }
@ -99,12 +96,13 @@ type RetryFunc func(code int) bool
type ClientOption func(o *clientOptions) error type ClientOption func(o *clientOptions) error
type clientOptions struct { type clientOptions struct {
transport http.RoundTripper transport http.RoundTripper
rootSHA256 string rootSHA256 string
rootFilename string rootFilename string
rootBundle []byte rootBundle []byte
certificate tls.Certificate certificate tls.Certificate
retryFunc RetryFunc getClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
retryFunc RetryFunc
} }
func (o *clientOptions) apply(opts []ClientOption) (err error) { func (o *clientOptions) apply(opts []ClientOption) (err error) {
@ -139,6 +137,7 @@ func (o *clientOptions) applyDefaultIdentity() error {
return nil return nil
} }
o.certificate = crt o.certificate = crt
o.getClientCertificate = i.GetClientCertificateFunc()
return nil return nil
} }
@ -193,6 +192,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
} }
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
} }
case *http2.Transport: case *http2.Transport:
if tr.TLSClientConfig == nil { if tr.TLSClientConfig == nil {
@ -200,6 +200,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
} }
if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil { if len(tr.TLSClientConfig.Certificates) == 0 && tr.TLSClientConfig.GetClientCertificate == nil {
tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate} tr.TLSClientConfig.Certificates = []tls.Certificate{o.certificate}
tr.TLSClientConfig.GetClientCertificate = o.getClientCertificate
} }
default: default:
return nil, errors.Errorf("unsupported transport type %T", tr) return nil, errors.Errorf("unsupported transport type %T", tr)
@ -288,7 +289,7 @@ func getTransportFromFile(filename string) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true, PreferServerCipherSuites: true,
RootCAs: pool, RootCAs: pool,
}) }), nil
} }
func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
@ -307,7 +308,7 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true, PreferServerCipherSuites: true,
RootCAs: pool, RootCAs: pool,
}) }), nil
} }
func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) { func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
@ -319,7 +320,7 @@ func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true, PreferServerCipherSuites: true,
RootCAs: pool, RootCAs: pool,
}) }), nil
} }
// parseEndpoint parses and validates the given endpoint. It supports general // parseEndpoint parses and validates the given endpoint. It supports general

View file

@ -26,9 +26,16 @@ type Type string
// Disabled represents a disabled identity type // Disabled represents a disabled identity type
const Disabled Type = "" const Disabled Type = ""
// MutualTLS represents the identity using mTLS // MutualTLS represents the identity using mTLS.
const MutualTLS Type = "mTLS" const MutualTLS Type = "mTLS"
// TunnelTLS represents an identity using a (m)TLS tunnel.
//
// TunnelTLS can be optionally configured with client certificates and a root
// file with the CAs to trust. By default it will use the system truststore
// instead of the CA truststore.
const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims. // DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute const DefaultLeeway = 1 * time.Minute
@ -44,19 +51,30 @@ type Identity struct {
Type string `json:"type"` Type string `json:"type"`
Certificate string `json:"crt"` Certificate string `json:"crt"`
Key string `json:"key"` Key string `json:"key"`
// Host is the tunnel host for a TunnelTLS (tTLS) identity.
Host string `json:"host,omitempty"`
// Root is the CA bundle of root CAs used in TunnelTLS to trust the
// certificate of the host.
Root string `json:"root,omitempty"`
}
// LoadIdentity loads an identity present in the given filename.
func LoadIdentity(filename string) (*Identity, error) {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename)
}
identity := new(Identity)
if err := json.Unmarshal(b, &identity); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", filename)
}
return identity, nil
} }
// LoadDefaultIdentity loads the default identity. // LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) { func LoadDefaultIdentity() (*Identity, error) {
b, err := ioutil.ReadFile(IdentityFile) return LoadIdentity(IdentityFile)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", IdentityFile)
}
identity := new(Identity)
if err := json.Unmarshal(b, &identity); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", IdentityFile)
}
return identity, nil
} }
// configDir and identityDir are used in WriteDefaultIdentity for testing // configDir and identityDir are used in WriteDefaultIdentity for testing
@ -81,7 +99,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
keyFilename := filepath.Join(identityDir, "identity_key") keyFilename := filepath.Join(identityDir, "identity_key")
// Write certificate // Write certificate
if err := WriteIdentityCertificate(certChain); err != nil { if err := writeCertificate(certFilename, certChain); err != nil {
return err return err
} }
@ -116,22 +134,21 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
return nil return nil
} }
// WriteIdentityCertificate writes the identity certificate in disk. // writeCertificate writes the given certificate on disk.
func WriteIdentityCertificate(certChain []api.Certificate) error { func writeCertificate(filename string, certChain []api.Certificate) error {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
certFilename := filepath.Join(identityDir, "identity.crt")
for _, crt := range certChain { for _, crt := range certChain {
block := &pem.Block{ block := &pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: crt.Raw, Bytes: crt.Raw,
} }
if err := pem.Encode(buf, block); err != nil { if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity certificate") return errors.Wrap(err, "error encoding certificate")
} }
} }
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing certificate")
} }
return nil return nil
@ -144,6 +161,8 @@ func (i *Identity) Kind() Type {
return Disabled return Disabled
case "mtls": case "mtls":
return MutualTLS return MutualTLS
case "ttls":
return TunnelTLS
default: default:
return Type(i.Type) return Type(i.Type)
} }
@ -164,8 +183,26 @@ func (i *Identity) Validate() error {
if err := fileExists(i.Certificate); err != nil { if err := fileExists(i.Certificate); err != nil {
return err return err
} }
if err := fileExists(i.Key); err != nil { return fileExists(i.Key)
return err case TunnelTLS:
if i.Host == "" {
return errors.New("tunnel.host cannot be empty")
}
if i.Certificate != "" {
if err := fileExists(i.Certificate); err != nil {
return err
}
if i.Key == "" {
return errors.New("tunnel.key cannot be empty")
}
if err := fileExists(i.Key); err != nil {
return err
}
}
if i.Root != "" {
if err := fileExists(i.Root); err != nil {
return err
}
} }
return nil return nil
default: default:
@ -179,7 +216,7 @@ func (i *Identity) TLSCertificate() (tls.Certificate, error) {
switch i.Kind() { switch i.Kind() {
case Disabled: case Disabled:
return tls.Certificate{}, nil return tls.Certificate{}, nil
case MutualTLS: case MutualTLS, TunnelTLS:
crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key)
if err != nil { if err != nil {
return fail(errors.Wrap(err, "error creating identity certificate")) return fail(errors.Wrap(err, "error creating identity certificate"))
@ -215,6 +252,22 @@ func (i *Identity) GetClientCertificateFunc() func(*tls.CertificateRequestInfo)
} }
} }
// GetCertPool returns a x509.CertPool if the identity defines a custom root.
func (i *Identity) GetCertPool() (*x509.CertPool, error) {
if i.Root == "" {
return nil, nil
}
b, err := ioutil.ReadFile(i.Root)
if err != nil {
return nil, errors.Wrap(err, "error reading identity root")
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(b) {
return nil, errors.Errorf("error pasing identity root: %s does not contain any certificate", i.Root)
}
return pool, nil
}
// Renewer is that interface that a renew client must implement. // Renewer is that interface that a renew client must implement.
type Renewer interface { type Renewer interface {
GetRootCAs() *x509.CertPool GetRootCAs() *x509.CertPool
@ -227,7 +280,7 @@ func (i *Identity) Renew(client Renewer) error {
switch i.Kind() { switch i.Kind() {
case Disabled: case Disabled:
return nil return nil
case MutualTLS: case MutualTLS, TunnelTLS:
cert, err := i.TLSCertificate() cert, err := i.TLSCertificate()
if err != nil { if err != nil {
return err return err

View file

@ -63,6 +63,7 @@ func TestIdentity_Kind(t *testing.T) {
}{ }{
{"disabled", fields{""}, Disabled}, {"disabled", fields{""}, Disabled},
{"mutualTLS", fields{"mTLS"}, MutualTLS}, {"mutualTLS", fields{"mTLS"}, MutualTLS},
{"tunnelTLS", fields{"tTLS"}, TunnelTLS},
{"unknown", fields{"unknown"}, Type("unknown")}, {"unknown", fields{"unknown"}, Type("unknown")},
} }
for _, tt := range tests { for _, tt := range tests {
@ -82,19 +83,27 @@ func TestIdentity_Validate(t *testing.T) {
Type string Type string
Certificate string Certificate string
Key string Key string
Host string
Root string
} }
tests := []struct { tests := []struct {
name string name string
fields fields fields fields
wantErr bool wantErr bool
}{ }{
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, false}, {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false},
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false},
{"ok disabled", fields{}, false}, {"ok disabled", fields{}, false},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, true}, {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true},
{"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key"}, true}, {"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true},
{"fail key", fields{"mTLS", "testdata/identity/identity.crt", ""}, true}, {"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true},
{"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key"}, true}, {"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key"}, true}, {"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true},
{"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true},
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
{"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true},
{"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -102,6 +111,8 @@ func TestIdentity_Validate(t *testing.T) {
Type: tt.fields.Type, Type: tt.fields.Type,
Certificate: tt.fields.Certificate, Certificate: tt.fields.Certificate,
Key: tt.fields.Key, Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
} }
if err := i.Validate(); (err != nil) != tt.wantErr { if err := i.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr)
@ -127,7 +138,8 @@ func TestIdentity_TLSCertificate(t *testing.T) {
want tls.Certificate want tls.Certificate
wantErr bool wantErr bool
}{ }{
{"ok", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false}, {"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
{"ok disabled", fields{}, tls.Certificate{}, false}, {"ok disabled", fields{}, tls.Certificate{}, false},
{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
{"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true}, {"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
@ -255,6 +267,95 @@ func TestWriteDefaultIdentity(t *testing.T) {
} }
} }
func TestIdentity_GetClientCertificateFunc(t *testing.T) {
expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key")
if err != nil {
t.Fatal(err)
}
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
want *tls.Certificate
wantErr bool
}{
{"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false},
{"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false},
{"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true},
{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := &Identity{
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
fn := i.GetClientCertificateFunc()
got, err := fn(&tls.CertificateRequestInfo{})
if (err != nil) != tt.wantErr {
t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want)
}
})
}
}
func TestIdentity_GetCertPool(t *testing.T) {
type fields struct {
Type string
Certificate string
Key string
Host string
Root string
}
tests := []struct {
name string
fields fields
wantSubjects [][]byte
wantErr bool
}{
{"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false},
{"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false},
{"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true},
{"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := &Identity{
Type: tt.fields.Type,
Certificate: tt.fields.Certificate,
Key: tt.fields.Key,
Host: tt.fields.Host,
Root: tt.fields.Root,
}
got, err := i.GetCertPool()
if (err != nil) != tt.wantErr {
t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
subjects := got.Subjects()
if !reflect.DeepEqual(subjects, tt.wantSubjects) {
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)
}
}
})
}
}
type renewer struct { type renewer struct {
pool *x509.CertPool pool *x509.CertPool
sign *api.SignResponse sign *api.SignResponse

View file

@ -0,0 +1,7 @@
{
"type": "mTLS",
"crt": "testdata/identity/identity.crt",
"key": "testdata/identity/identity_key",
"host": "tunnel:443",
"root": "testdata/certs/root_ca.crt"
}

113
ca/tls.go
View file

@ -10,13 +10,65 @@ import (
"encoding/pem" "encoding/pem"
"net" "net"
"net/http" "net/http"
"os"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"golang.org/x/net/http2" "github.com/smallstep/certificates/ca/identity"
) )
// mTLSDialContext will hold the dial context function to use in
// getDefaultTransport.
var mTLSDialContext func() func(ctx context.Context, network, address string) (net.Conn, error)
func init() {
// STEP_TLS_TUNNEL is an environment variable that can be set to do an TLS
// over (m)TLS tunnel to step-ca using identity-like credentials. The value
// is a path to a json file with the tunnel host, certificate, key and root
// used to create the (m)TLS tunnel.
//
// The configuration should look like:
// {
// "type": "tTLS",
// "host": "tunnel.example.com:443"
// "crt": "/path/to/tunnel.crt",
// "key": "/path/to/tunnel.key",
// "root": "/path/to/tunnel-root.crt"
// }
//
// This feature is EXPERIMENTAL and might change at any time.
if path := os.Getenv("STEP_TLS_TUNNEL"); path != "" {
id, err := identity.LoadIdentity(path)
if err != nil {
panic(err)
}
if err := id.Validate(); err != nil {
panic(err)
}
host, port, err := net.SplitHostPort(id.Host)
if err != nil {
panic(err)
}
pool, err := id.GetCertPool()
if err != nil {
panic(err)
}
mTLSDialContext = func() func(ctx context.Context, network, address string) (net.Conn, error) {
d := &tls.Dialer{
NetDialer: getDefaultDialer(),
Config: &tls.Config{
RootCAs: pool,
GetClientCertificate: id.GetClientCertificateFunc(),
},
}
return func(ctx context.Context, network, address string) (net.Conn, error) {
return d.DialContext(ctx, "tcp", net.JoinHostPort(host, port))
}
}
}
}
// GetClientTLSConfig returns a tls.Config for client use configured with the // GetClientTLSConfig returns a tls.Config for client use configured with the
// sign certificate, and a new certificate pool with the sign root certificate. // sign certificate, and a new certificate pool with the sign root certificate.
// The client certificate will automatically rotate before expiring. // The client certificate will automatically rotate before expiring.
@ -51,10 +103,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
} }
// Update renew function with transport // Update renew function with transport
tr, err := getDefaultTransport(tlsConfig) tr := getDefaultTransport(tlsConfig)
if err != nil {
return nil, nil, err
}
// Use mutable tls.Config on renew // Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
@ -103,10 +152,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx) tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
// Update renew function with transport // Update renew function with transport
tr, err := getDefaultTransport(tlsConfig) tr := getDefaultTransport(tlsConfig)
if err != nil {
return nil, err
}
// Use mutable tls.Config on renew // Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
@ -144,11 +190,7 @@ func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHell
// buildDialTLS returns an implementation of DialTLS callback in http.Transport. // buildDialTLS returns an implementation of DialTLS callback in http.Transport.
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) { func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
return func(network, addr string) (net.Conn, error) { return func(network, addr string) (net.Conn, error) {
return tls.DialWithDialer(&net.Dialer{ return tls.DialWithDialer(getDefaultDialer(), network, addr, ctx.mutableConfig.TLSConfig())
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}, network, addr, ctx.mutableConfig.TLSConfig())
} }
} }
@ -156,18 +198,13 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
// nolint:unused // nolint:unused
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) { func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) {
d := getDefaultDialer()
// TLS dialers do not support context, but we can use the context // TLS dialers do not support context, but we can use the context
// deadline if it is set. // deadline if it is set.
var deadline time.Time
if t, ok := ctx.Deadline(); ok { if t, ok := ctx.Deadline(); ok {
deadline = t d.Deadline = t
} }
return tls.DialWithDialer(&net.Dialer{ return tls.DialWithDialer(d, network, addr, tlsCtx.mutableConfig.TLSConfig())
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Deadline: deadline,
DualStack: true,
}, network, addr, tlsCtx.mutableConfig.TLSConfig())
} }
} }
@ -238,27 +275,35 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
} }
} }
// getDefaultDialer returns a new dialer with the default configuration.
func getDefaultDialer() *net.Dialer {
return &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
}
// getDefaultTransport returns an http.Transport with the same parameters than // getDefaultTransport returns an http.Transport with the same parameters than
// http.DefaultTransport, but adds the given tls.Config and configures the // http.DefaultTransport, but adds the given tls.Config and configures the
// transport for HTTP/2. // transport for HTTP/2.
func getDefaultTransport(tlsConfig *tls.Config) (*http.Transport, error) { func getDefaultTransport(tlsConfig *tls.Config) *http.Transport {
tr := &http.Transport{ var dialContext func(ctx context.Context, network string, addr string) (net.Conn, error)
Proxy: http.ProxyFromEnvironment, if mTLSDialContext == nil {
DialContext: (&net.Dialer{ d := getDefaultDialer()
Timeout: 30 * time.Second, dialContext = d.DialContext
KeepAlive: 30 * time.Second, } else {
DualStack: true, dialContext = mTLSDialContext()
}).DialContext, }
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100, MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
} }
if err := http2.ConfigureTransport(tr); err != nil {
return nil, errors.Wrap(err, "error configuring transport")
}
return tr, nil
} }
func getPEM(i interface{}) ([]byte, error) { func getPEM(i interface{}) ([]byte, error) {

View file

@ -181,13 +181,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
t.Errorf("Client.GetClientTLSConfig() error = %v", err) t.Errorf("Client.GetClientTLSConfig() error = %v", err)
return nil return nil
} }
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{ return &http.Client{
Transport: tr, Transport: getDefaultTransport(tlsConfig),
} }
}, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}}, }, map[string]bool{srvTLS.URL: false, srvMTLS.URL: false}},
{"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { {"with no ClientCert", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
@ -199,14 +194,8 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
tlsConfig := getDefaultTLSConfig(sr) tlsConfig := getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root) tlsConfig.RootCAs.AddCert(root)
tr, err := getDefaultTransport(tlsConfig)
if err != nil {
t.Errorf("getDefaultTransport() error = %v", err)
return nil
}
return &http.Client{ return &http.Client{
Transport: tr, Transport: getDefaultTransport(tlsConfig),
} }
}, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}}, }, map[string]bool{srvTLS.URL + "/no-cert": false, srvMTLS.URL + "/no-cert": true}},
{"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client { {"fail with default", func(t *testing.T, client *Client, sr *api.SignResponse, pk crypto.PrivateKey) *http.Client {
@ -288,10 +277,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Client.GetClientTLSConfig() error = %v", err) t.Fatalf("Client.GetClientTLSConfig() error = %v", err)
} }
tr2, err := getDefaultTransport(tlsConfig) tr2 := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// No client cert // No client cert
root, err := RootCertificate(sr) root, err := RootCertificate(sr)
if err != nil { if err != nil {
@ -300,10 +286,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
tlsConfig = getDefaultTLSConfig(sr) tlsConfig = getDefaultTLSConfig(sr)
tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs = x509.NewCertPool()
tlsConfig.RootCAs.AddCert(root) tlsConfig.RootCAs.AddCert(root)
tr3, err := getDefaultTransport(tlsConfig) tr3 := getDefaultTransport(tlsConfig)
if err != nil {
t.Fatalf("getDefaultTransport() error = %v", err)
}
// Disable keep alives to force TLS handshake // Disable keep alives to force TLS handshake
tr1.DisableKeepAlives = true tr1.DisableKeepAlives = true

View file

@ -191,7 +191,7 @@ In the ca.json configuration file, a complete JWK provisioner example looks like
### OIDC ### OIDC
An OIDC provisioner allows a user to get a certificate after authenticating An OIDC provisioner allows a user to get a certificate after authenticating
himself with an OAuth OpenID Connect identity provider. The ID token provided with an OAuth OpenID Connect identity provider. The ID token provided
will be used on the CA authentication, and by default, the certificate will only will be used on the CA authentication, and by default, the certificate will only
have the user's email as a Subject Alternative Name (SAN) Extension. have the user's email as a Subject Alternative Name (SAN) Extension.

View file

@ -313,7 +313,7 @@ func getSlotAndName(name string) (piv.Slot, string, error) {
s, ok := slotMapping[slotID] s, ok := slotMapping[slotID]
if !ok { if !ok {
return piv.Slot{}, "", errors.Errorf("usupported slot-id '%s'", name) return piv.Slot{}, "", errors.Errorf("unsupported slot-id '%s'", name)
} }
name = "yubikey:slot-id=" + url.QueryEscape(slotID) name = "yubikey:slot-id=" + url.QueryEscape(slotID)

View file

@ -26,7 +26,7 @@ ExecStart=/usr/bin/step ca renew --force $CERT_LOCATION $KEY_LOCATION
; Try to reload or restart the systemd service that relies on this cert-renewer ; Try to reload or restart the systemd service that relies on this cert-renewer
; If the relying service doesn't exist, forge ahead. ; If the relying service doesn't exist, forge ahead.
ExecStartPost=/usr/bin/env bash -c "if ! systemctl --quiet is-enabled %i.service ; then exit 0; fi; systemctl try-reload-or-restart %i" ExecStartPost=-systemctl try-reload-or-restart %i
[Install] [Install]
WantedBy=multi-user.target WantedBy=multi-user.target