Add mTLS request to get all the root CAs, not the federated ones.
This commit is contained in:
parent
98cc243a37
commit
d296cf95a9
10 changed files with 488 additions and 43 deletions
39
api/api.go
39
api/api.go
|
@ -22,10 +22,11 @@ type Authority interface {
|
||||||
GetTLSOptions() *tlsutil.TLSOptions
|
GetTLSOptions() *tlsutil.TLSOptions
|
||||||
Root(shasum string) (*x509.Certificate, error)
|
Root(shasum string) (*x509.Certificate, error)
|
||||||
Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
|
Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error)
|
||||||
Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||||
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
|
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
|
||||||
GetEncryptedKey(kid string) (string, error)
|
GetEncryptedKey(kid string) (string, error)
|
||||||
GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error)
|
GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error)
|
||||||
|
GetFederation(peer *x509.Certificate) ([]*x509.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Certificate wraps a *x509.Certificate and adds the json.Marshaler interface.
|
// Certificate wraps a *x509.Certificate and adds the json.Marshaler interface.
|
||||||
|
@ -187,6 +188,11 @@ type SignResponse struct {
|
||||||
TLS *tls.ConnectionState `json:"-"`
|
TLS *tls.ConnectionState `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RootsResponse is the response object of the roots request.
|
||||||
|
type RootsResponse struct {
|
||||||
|
Certificates []Certificate `json:"crts"`
|
||||||
|
}
|
||||||
|
|
||||||
// FederationResponse is the response object of the federation request.
|
// FederationResponse is the response object of the federation request.
|
||||||
type FederationResponse struct {
|
type FederationResponse struct {
|
||||||
Certificates []Certificate `json:"crts"`
|
Certificates []Certificate `json:"crts"`
|
||||||
|
@ -211,6 +217,7 @@ func (h *caHandler) Route(r Router) {
|
||||||
r.MethodFunc("POST", "/renew", h.Renew)
|
r.MethodFunc("POST", "/renew", h.Renew)
|
||||||
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
||||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
||||||
|
r.MethodFunc("GET", "/roots", h.Roots)
|
||||||
r.MethodFunc("GET", "/federation", h.Federation)
|
r.MethodFunc("GET", "/federation", h.Federation)
|
||||||
// For compatibility with old code:
|
// For compatibility with old code:
|
||||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||||
|
@ -327,7 +334,33 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||||
JSON(w, &ProvisionerKeyResponse{key})
|
JSON(w, &ProvisionerKeyResponse{key})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Federation returns all the public certificates in the federation.
|
// Roots returns all the root certificates for the CA. It requires a valid TLS
|
||||||
|
// client.
|
||||||
|
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||||
|
WriteError(w, BadRequest(errors.New("missing peer certificate")))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
roots, err := h.Authority.GetRoots(r.TLS.PeerCertificates[0])
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, Forbidden(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
certs := make([]Certificate, len(roots))
|
||||||
|
for i := range roots {
|
||||||
|
certs[i] = Certificate{roots[i]}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
JSON(w, &RootsResponse{
|
||||||
|
Certificates: certs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Federation returns all the public certificates in the federation. It requires
|
||||||
|
// a valid TLS client.
|
||||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||||
WriteError(w, BadRequest(errors.New("missing peer certificate")))
|
WriteError(w, BadRequest(errors.New("missing peer certificate")))
|
||||||
|
|
|
@ -392,6 +392,7 @@ type mockAuthority struct {
|
||||||
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||||
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
|
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
|
||||||
getEncryptedKey func(kid string) (string, error)
|
getEncryptedKey func(kid string) (string, error)
|
||||||
|
getRoots func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||||
getFederation func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
getFederation func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -444,6 +445,13 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
|
||||||
return m.ret1.(string), m.err
|
return m.ret1.(string), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthority) GetRoots(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||||
|
if m.getFederation != nil {
|
||||||
|
return m.getRoots(cert)
|
||||||
|
}
|
||||||
|
return m.ret1.([]*x509.Certificate), m.err
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
func (m *mockAuthority) GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||||
if m.getFederation != nil {
|
if m.getFederation != nil {
|
||||||
return m.getFederation(cert)
|
return m.getFederation(cert)
|
||||||
|
@ -821,6 +829,53 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_caHandler_Roots(t *testing.T) {
|
||||||
|
cs := &tls.ConnectionState{
|
||||||
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tls *tls.ConnectionState
|
||||||
|
cert *x509.Certificate
|
||||||
|
root *x509.Certificate
|
||||||
|
err error
|
||||||
|
statusCode int
|
||||||
|
}{
|
||||||
|
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||||
|
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
||||||
|
req.TLS = tt.tls
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.Roots(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("caHandler.Roots unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||||||
|
t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_caHandler_Federation(t *testing.T) {
|
func Test_caHandler_Federation(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
|
@ -851,17 +906,17 @@ func Test_caHandler_Federation(t *testing.T) {
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(res.Body)
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("caHandler.Root unexpected error = %v", err)
|
t.Errorf("caHandler.Federation unexpected error = %v", err)
|
||||||
}
|
}
|
||||||
if tt.statusCode < http.StatusBadRequest {
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||||||
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
|
t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -33,6 +33,15 @@ func (a *Authority) GetRootCertificates() []*x509.Certificate {
|
||||||
return a.rootX509Certs
|
return a.rootX509Certs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRoots returns all the root certificates for this CA.
|
||||||
|
func (a *Authority) GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error) {
|
||||||
|
// Check step provisioner extensions
|
||||||
|
if err := a.authorizeRenewal(peer); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return a.rootX509Certs, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetFederation returns all the root certificates in the federation.
|
// GetFederation returns all the root certificates in the federation.
|
||||||
func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) {
|
func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) {
|
||||||
// Check step provisioner extensions
|
// Check step provisioner extensions
|
||||||
|
|
26
ca/client.go
26
ca/client.go
|
@ -237,9 +237,10 @@ func WithProvisionerLimit(limit int) ProvisionerOption {
|
||||||
|
|
||||||
// Client implements an HTTP client for the CA server.
|
// Client implements an HTTP client for the CA server.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
endpoint *url.URL
|
endpoint *url.URL
|
||||||
certPool *x509.CertPool
|
certPool *x509.CertPool
|
||||||
|
cachedSign *api.SignResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new Client with the given endpoint and options.
|
// NewClient creates a new Client with the given endpoint and options.
|
||||||
|
@ -413,6 +414,25 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error)
|
||||||
return &key, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Roots performs the get roots request to the CA and returns the
|
||||||
|
// api.RootsResponse struct.
|
||||||
|
func (c *Client) Roots(tr http.RoundTripper) (*api.RootsResponse, error) {
|
||||||
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
resp, err := client.Get(u.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readError(resp.Body)
|
||||||
|
}
|
||||||
|
var federation api.RootsResponse
|
||||||
|
if err := readJSON(resp.Body, &federation); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||||
|
}
|
||||||
|
return &federation, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Federation performs the get federation request to the CA and returns the
|
// Federation performs the get federation request to the CA and returns the
|
||||||
// api.FederationResponse struct.
|
// api.FederationResponse struct.
|
||||||
func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) {
|
func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) {
|
||||||
|
|
|
@ -512,6 +512,67 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_Roots(t *testing.T) {
|
||||||
|
ok := &api.RootsResponse{
|
||||||
|
Certificates: []api.Certificate{
|
||||||
|
{Certificate: parseCertificate(rootPEM)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||||
|
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
response interface{}
|
||||||
|
responseCode int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", ok, 200, false},
|
||||||
|
{"unauthorized", unauthorized, 401, true},
|
||||||
|
{"empty request", badRequest, 403, true},
|
||||||
|
{"nil request", badRequest, 403, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(nil)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewClient() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
w.WriteHeader(tt.responseCode)
|
||||||
|
api.JSON(w, tt.response)
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := c.Roots(nil)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
fmt.Printf("%+v", err)
|
||||||
|
t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
if got != nil {
|
||||||
|
t.Errorf("Client.Roots() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(err, tt.response) {
|
||||||
|
t.Errorf("Client.Roots() error = %v, want %v", err, tt.response)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if !reflect.DeepEqual(got, tt.response) {
|
||||||
|
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClient_Federation(t *testing.T) {
|
func TestClient_Federation(t *testing.T) {
|
||||||
ok := &api.FederationResponse{
|
ok := &api.FederationResponse{
|
||||||
Certificates: []api.Certificate{
|
Certificates: []api.Certificate{
|
||||||
|
|
1
ca/testdata/ca.json
vendored
1
ca/testdata/ca.json
vendored
|
@ -1,5 +1,6 @@
|
||||||
{
|
{
|
||||||
"root": "../ca/testdata/secrets/root_ca.crt",
|
"root": "../ca/testdata/secrets/root_ca.crt",
|
||||||
|
"federatedRoots": ["../ca/testdata/secrets/federated_ca.crt"],
|
||||||
"crt": "../ca/testdata/secrets/intermediate_ca.crt",
|
"crt": "../ca/testdata/secrets/intermediate_ca.crt",
|
||||||
"key": "../ca/testdata/secrets/intermediate_ca_key",
|
"key": "../ca/testdata/secrets/intermediate_ca_key",
|
||||||
"password": "password",
|
"password": "password",
|
||||||
|
|
11
ca/testdata/secrets/federated_ca.crt
vendored
Normal file
11
ca/testdata/secrets/federated_ca.crt
vendored
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa
|
||||||
|
MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw
|
||||||
|
MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG
|
||||||
|
SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL
|
||||||
|
jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB
|
||||||
|
/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR
|
||||||
|
qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z
|
||||||
|
YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc
|
||||||
|
Sw==
|
||||||
|
-----END CERTIFICATE-----
|
|
@ -41,7 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options if given
|
||||||
if err := setTLSOptions(c, tlsConfig, options); err != nil {
|
if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options if given
|
||||||
if err := setTLSOptions(c, tlsConfig, options); err != nil {
|
if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,28 +1,57 @@
|
||||||
package ca
|
package ca
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||||
type TLSOption func(c *Client, config *tls.Config) error
|
type TLSOption func(c *Client, tr http.RoundTripper, config *tls.Config) error
|
||||||
|
|
||||||
// setTLSOptions takes one or more option function and applies them in order to
|
// setTLSOptions takes one or more option function and applies them in order to
|
||||||
// a tls.Config.
|
// a tls.Config.
|
||||||
func setTLSOptions(c *Client, config *tls.Config, options []TLSOption) error {
|
func setTLSOptions(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config, options []TLSOption) error {
|
||||||
|
tr, err := getTLSOptionsTransport(sign, pk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, opt := range options {
|
for _, opt := range options {
|
||||||
if err := opt(c, config); err != nil {
|
if err := opt(c, tr, config); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getTLSOptionsTransport is the transport used by TLSOptions. It is used to get
|
||||||
|
// root certificates using a mTLS connection with the CA.
|
||||||
|
func getTLSOptionsTransport(sign *api.SignResponse, pk crypto.PrivateKey) (http.RoundTripper, error) {
|
||||||
|
cert, err := TLSCertificate(sign, pk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build default transport with fixed certificate
|
||||||
|
tlsConfig := getDefaultTLSConfig(sign)
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{*cert}
|
||||||
|
tlsConfig.PreferServerCipherSuites = true
|
||||||
|
// Build RootCAs with given root certificate
|
||||||
|
if pool := getCertPool(sign); pool != nil {
|
||||||
|
tlsConfig.RootCAs = pool
|
||||||
|
}
|
||||||
|
|
||||||
|
return getDefaultTransport(tlsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
||||||
// a valid TLS client certificate. This is the default option for mTLS servers.
|
// a valid TLS client certificate. This is the default option for mTLS servers.
|
||||||
func RequireAndVerifyClientCert() TLSOption {
|
func RequireAndVerifyClientCert() TLSOption {
|
||||||
return func(_ *Client, config *tls.Config) error {
|
return func(_ *Client, _ http.RoundTripper, config *tls.Config) error {
|
||||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -31,7 +60,7 @@ func RequireAndVerifyClientCert() TLSOption {
|
||||||
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
||||||
// a TLS client certificate if it is provided. It does not requires a certificate.
|
// a TLS client certificate if it is provided. It does not requires a certificate.
|
||||||
func VerifyClientCertIfGiven() TLSOption {
|
func VerifyClientCertIfGiven() TLSOption {
|
||||||
return func(_ *Client, config *tls.Config) error {
|
return func(_ *Client, _ http.RoundTripper, config *tls.Config) error {
|
||||||
config.ClientAuth = tls.VerifyClientCertIfGiven
|
config.ClientAuth = tls.VerifyClientCertIfGiven
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -41,7 +70,7 @@ func VerifyClientCertIfGiven() TLSOption {
|
||||||
// defines the set of root certificate authorities that clients use when
|
// defines the set of root certificate authorities that clients use when
|
||||||
// verifying server certificates.
|
// verifying server certificates.
|
||||||
func AddRootCA(cert *x509.Certificate) TLSOption {
|
func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||||
return func(_ *Client, config *tls.Config) error {
|
return func(_ *Client, _ http.RoundTripper, config *tls.Config) error {
|
||||||
if config.RootCAs == nil {
|
if config.RootCAs == nil {
|
||||||
config.RootCAs = x509.NewCertPool()
|
config.RootCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
|
@ -54,7 +83,7 @@ func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||||
// defines the set of root certificate authorities that servers use if required
|
// defines the set of root certificate authorities that servers use if required
|
||||||
// to verify a client certificate by the policy in ClientAuth.
|
// to verify a client certificate by the policy in ClientAuth.
|
||||||
func AddClientCA(cert *x509.Certificate) TLSOption {
|
func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||||
return func(_ *Client, config *tls.Config) error {
|
return func(_ *Client, _ http.RoundTripper, config *tls.Config) error {
|
||||||
if config.ClientCAs == nil {
|
if config.ClientCAs == nil {
|
||||||
config.ClientCAs = x509.NewCertPool()
|
config.ClientCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
|
@ -63,19 +92,18 @@ func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRootFederation does a federation request and adds to the tls.Config
|
// AddRootsToRootCAs does a roots request and adds to the tls.Config RootCAs all
|
||||||
// RootCAs all the certificates in the response. RootCAs
|
// the certificates in the response. RootCAs defines the set of root certificate
|
||||||
// defines the set of root certificate authorities that clients use when
|
// authorities that clients use when verifying server certificates.
|
||||||
// verifying server certificates.
|
func AddRootsToRootCAs() TLSOption {
|
||||||
func AddRootFederation() TLSOption {
|
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
||||||
return func(c *Client, config *tls.Config) error {
|
certs, err := c.Roots(tr)
|
||||||
if config.RootCAs == nil {
|
|
||||||
config.RootCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
certs, err := c.Federation(nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if config.RootCAs == nil {
|
||||||
|
config.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
for _, cert := range certs.Certificates {
|
for _, cert := range certs.Certificates {
|
||||||
config.RootCAs.AddCert(cert.Certificate)
|
config.RootCAs.AddCert(cert.Certificate)
|
||||||
}
|
}
|
||||||
|
@ -83,19 +111,58 @@ func AddRootFederation() TLSOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddClientFederation does a federation request and adds to the tls.Config
|
// AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs
|
||||||
// ClientCAs all the certificates in the response. ClientCAs defines the set of
|
// all the certificates in the response. ClientCAs defines the set of root
|
||||||
// root certificate authorities that servers use if required to verify a client
|
// certificate authorities that servers use if required to verify a client
|
||||||
// certificate by the policy in ClientAuth.
|
// certificate by the policy in ClientAuth.
|
||||||
func AddClientFederation() TLSOption {
|
func AddRootsToClientCAs() TLSOption {
|
||||||
return func(c *Client, config *tls.Config) error {
|
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
||||||
if config.ClientCAs == nil {
|
certs, err := c.Roots(tr)
|
||||||
config.ClientCAs = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
certs, err := c.Federation(nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if config.ClientCAs == nil {
|
||||||
|
config.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
for _, cert := range certs.Certificates {
|
||||||
|
config.ClientCAs.AddCert(cert.Certificate)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFederationToRootCAs does a federation request and adds to the tls.Config
|
||||||
|
// RootCAs all the certificates in the response. RootCAs defines the set of root
|
||||||
|
// certificate authorities that clients use when verifying server certificates.
|
||||||
|
func AddFederationToRootCAs() TLSOption {
|
||||||
|
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
||||||
|
certs, err := c.Federation(tr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if config.RootCAs == nil {
|
||||||
|
config.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
for _, cert := range certs.Certificates {
|
||||||
|
config.RootCAs.AddCert(cert.Certificate)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFederationToClientCAs does a federation request and adds to the tls.Config
|
||||||
|
// ClientCAs all the certificates in the response. ClientCAs defines the set of
|
||||||
|
// root certificate authorities that servers use if required to verify a client
|
||||||
|
// certificate by the policy in ClientAuth.
|
||||||
|
func AddFederationToClientCAs() TLSOption {
|
||||||
|
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
||||||
|
certs, err := c.Federation(tr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if config.ClientCAs == nil {
|
||||||
|
config.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
for _, cert := range certs.Certificates {
|
for _, cert := range certs.Certificates {
|
||||||
config.ClientCAs.AddCert(cert.Certificate)
|
config.ClientCAs.AddCert(cert.Certificate)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,13 +4,15 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_setTLSOptions(t *testing.T) {
|
func Test_setTLSOptions(t *testing.T) {
|
||||||
fail := func() TLSOption {
|
fail := func() TLSOption {
|
||||||
return func(c *Client, config *tls.Config) error {
|
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
||||||
return fmt.Errorf("an error")
|
return fmt.Errorf("an error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,9 +29,13 @@ func Test_setTLSOptions(t *testing.T) {
|
||||||
{"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false},
|
{"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false},
|
||||||
{"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
|
{"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ca := startCATestServer()
|
||||||
|
defer ca.Close()
|
||||||
|
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := setTLSOptions(nil, tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
if err := setTLSOptions(client, sr, pk, tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -46,7 +52,7 @@ func TestRequireAndVerifyClientCert(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
got := &tls.Config{}
|
||||||
if err := RequireAndVerifyClientCert()(nil, got); err != nil {
|
if err := RequireAndVerifyClientCert()(nil, nil, got); err != nil {
|
||||||
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -67,7 +73,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
got := &tls.Config{}
|
||||||
if err := VerifyClientCertIfGiven()(nil, got); err != nil {
|
if err := VerifyClientCertIfGiven()(nil, nil, got); err != nil {
|
||||||
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -96,7 +102,7 @@ func TestAddRootCA(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
got := &tls.Config{}
|
||||||
if err := AddRootCA(tt.args.cert)(nil, got); err != nil {
|
if err := AddRootCA(tt.args.cert)(nil, nil, got); err != nil {
|
||||||
t.Errorf("AddRootCA() error = %v", err)
|
t.Errorf("AddRootCA() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -125,7 +131,7 @@ func TestAddClientCA(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
got := &tls.Config{}
|
||||||
if err := AddClientCA(tt.args.cert)(nil, got); err != nil {
|
if err := AddClientCA(tt.args.cert)(nil, nil, got); err != nil {
|
||||||
t.Errorf("AddClientCA() error = %v", err)
|
t.Errorf("AddClientCA() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -135,3 +141,185 @@ func TestAddClientCA(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddRootsToRootCAs(t *testing.T) {
|
||||||
|
ca := startCATestServer()
|
||||||
|
defer ca.Close()
|
||||||
|
|
||||||
|
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
||||||
|
tr, err := getTLSOptionsTransport(sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := parseCertificate(string(root))
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(cert)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tr http.RoundTripper
|
||||||
|
want *tls.Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", tr, &tls.Config{RootCAs: pool}, false},
|
||||||
|
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := &tls.Config{}
|
||||||
|
if err := AddRootsToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("AddRootsToRootCAs() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddRootsToClientCAs(t *testing.T) {
|
||||||
|
ca := startCATestServer()
|
||||||
|
defer ca.Close()
|
||||||
|
|
||||||
|
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
||||||
|
tr, err := getTLSOptionsTransport(sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := parseCertificate(string(root))
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(cert)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tr http.RoundTripper
|
||||||
|
want *tls.Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", tr, &tls.Config{ClientCAs: pool}, false},
|
||||||
|
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := &tls.Config{}
|
||||||
|
if err := AddRootsToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("AddRootsToClientCAs() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddFederationToRootCAs(t *testing.T) {
|
||||||
|
ca := startCATestServer()
|
||||||
|
defer ca.Close()
|
||||||
|
|
||||||
|
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
||||||
|
tr, err := getTLSOptionsTransport(sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
crt1 := parseCertificate(string(root))
|
||||||
|
crt2 := parseCertificate(string(federated))
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(crt1)
|
||||||
|
pool.AddCert(crt2)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tr http.RoundTripper
|
||||||
|
want *tls.Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", tr, &tls.Config{RootCAs: pool}, false},
|
||||||
|
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := &tls.Config{}
|
||||||
|
if err := AddFederationToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("AddFederationToRootCAs() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddFederationToClientCAs(t *testing.T) {
|
||||||
|
ca := startCATestServer()
|
||||||
|
defer ca.Close()
|
||||||
|
|
||||||
|
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
||||||
|
tr, err := getTLSOptionsTransport(sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
crt1 := parseCertificate(string(root))
|
||||||
|
crt2 := parseCertificate(string(federated))
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(crt1)
|
||||||
|
pool.AddCert(crt2)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tr http.RoundTripper
|
||||||
|
want *tls.Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", tr, &tls.Config{ClientCAs: pool}, false},
|
||||||
|
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := &tls.Config{}
|
||||||
|
if err := AddFederationToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("AddFederationToClientCAs() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue