From d296cf95a9b495bf2abc6dc0e72e054148167c15 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 7 Jan 2019 17:48:56 -0800 Subject: [PATCH] Add mTLS request to get all the root CAs, not the federated ones. --- api/api.go | 39 +++++- api/api_test.go | 61 +++++++- authority/root.go | 9 ++ ca/client.go | 26 +++- ca/client_test.go | 61 ++++++++ ca/testdata/ca.json | 1 + ca/testdata/secrets/federated_ca.crt | 11 ++ ca/tls.go | 4 +- ca/tls_options.go | 119 ++++++++++++---- ca/tls_options_test.go | 200 ++++++++++++++++++++++++++- 10 files changed, 488 insertions(+), 43 deletions(-) create mode 100644 ca/testdata/secrets/federated_ca.crt diff --git a/api/api.go b/api/api.go index e427bee3..563e65ea 100644 --- a/api/api.go +++ b/api/api.go @@ -22,10 +22,11 @@ type Authority interface { GetTLSOptions() *tlsutil.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) - Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) + Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error) GetEncryptedKey(kid string) (string, error) - GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) + GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error) + GetFederation(peer *x509.Certificate) ([]*x509.Certificate, error) } // Certificate wraps a *x509.Certificate and adds the json.Marshaler interface. @@ -187,6 +188,11 @@ type SignResponse struct { TLS *tls.ConnectionState `json:"-"` } +// RootsResponse is the response object of the roots request. +type RootsResponse struct { + Certificates []Certificate `json:"crts"` +} + // FederationResponse is the response object of the federation request. type FederationResponse struct { Certificates []Certificate `json:"crts"` @@ -211,6 +217,7 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("POST", "/renew", h.Renew) r.MethodFunc("GET", "/provisioners", h.Provisioners) r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) + r.MethodFunc("GET", "/roots", h.Roots) r.MethodFunc("GET", "/federation", h.Federation) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) @@ -327,7 +334,33 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { JSON(w, &ProvisionerKeyResponse{key}) } -// Federation returns all the public certificates in the federation. +// Roots returns all the root certificates for the CA. It requires a valid TLS +// client. +func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + WriteError(w, BadRequest(errors.New("missing peer certificate"))) + return + } + + roots, err := h.Authority.GetRoots(r.TLS.PeerCertificates[0]) + if err != nil { + WriteError(w, Forbidden(err)) + return + } + + certs := make([]Certificate, len(roots)) + for i := range roots { + certs[i] = Certificate{roots[i]} + } + + w.WriteHeader(http.StatusCreated) + JSON(w, &RootsResponse{ + Certificates: certs, + }) +} + +// Federation returns all the public certificates in the federation. It requires +// a valid TLS client. func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { WriteError(w, BadRequest(errors.New("missing peer certificate"))) diff --git a/api/api_test.go b/api/api_test.go index 82e12c8c..0a9eb120 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -392,6 +392,7 @@ type mockAuthority struct { renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error) getEncryptedKey func(kid string) (string, error) + getRoots func(cert *x509.Certificate) ([]*x509.Certificate, error) getFederation func(cert *x509.Certificate) ([]*x509.Certificate, error) } @@ -444,6 +445,13 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { return m.ret1.(string), m.err } +func (m *mockAuthority) GetRoots(cert *x509.Certificate) ([]*x509.Certificate, error) { + if m.getFederation != nil { + return m.getRoots(cert) + } + return m.ret1.([]*x509.Certificate), m.err +} + func (m *mockAuthority) GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) { if m.getFederation != nil { return m.getFederation(cert) @@ -821,6 +829,53 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { } } +func Test_caHandler_Roots(t *testing.T) { + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, + } + tests := []struct { + name string + tls *tls.ConnectionState + cert *x509.Certificate + root *x509.Certificate + err error + statusCode int + }{ + {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, + {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, + {"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + } + + expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + req := httptest.NewRequest("GET", "http://example.com/roots", nil) + req.TLS = tt.tls + w := httptest.NewRecorder() + h.Roots(w, req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Roots unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), expected) { + t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected) + } + } + }) + } +} + func Test_caHandler_Federation(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, @@ -851,17 +906,17 @@ func Test_caHandler_Federation(t *testing.T) { res := w.Result() if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { - t.Errorf("caHandler.Root unexpected error = %v", err) + t.Errorf("caHandler.Federation unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { - t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) + t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected) } } }) diff --git a/authority/root.go b/authority/root.go index d041ae8f..98974904 100644 --- a/authority/root.go +++ b/authority/root.go @@ -33,6 +33,15 @@ func (a *Authority) GetRootCertificates() []*x509.Certificate { return a.rootX509Certs } +// GetRoots returns all the root certificates for this CA. +func (a *Authority) GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error) { + // Check step provisioner extensions + if err := a.authorizeRenewal(peer); err != nil { + return nil, err + } + return a.rootX509Certs, nil +} + // GetFederation returns all the root certificates in the federation. func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) { // Check step provisioner extensions diff --git a/ca/client.go b/ca/client.go index 374a68ff..b8aab67f 100644 --- a/ca/client.go +++ b/ca/client.go @@ -237,9 +237,10 @@ func WithProvisionerLimit(limit int) ProvisionerOption { // Client implements an HTTP client for the CA server. type Client struct { - client *http.Client - endpoint *url.URL - certPool *x509.CertPool + client *http.Client + endpoint *url.URL + certPool *x509.CertPool + cachedSign *api.SignResponse } // NewClient creates a new Client with the given endpoint and options. @@ -413,6 +414,25 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) return &key, nil } +// Roots performs the get roots request to the CA and returns the +// api.RootsResponse struct. +func (c *Client) Roots(tr http.RoundTripper) (*api.RootsResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"}) + client := &http.Client{Transport: tr} + resp, err := client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var federation api.RootsResponse + if err := readJSON(resp.Body, &federation); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &federation, nil +} + // Federation performs the get federation request to the CA and returns the // api.FederationResponse struct. func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) { diff --git a/ca/client_test.go b/ca/client_test.go index 138b0d7d..0ec6324b 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -512,6 +512,67 @@ func TestClient_ProvisionerKey(t *testing.T) { } } +func TestClient_Roots(t *testing.T) { + ok := &api.RootsResponse{ + Certificates: []api.Certificate{ + {Certificate: parseCertificate(rootPEM)}, + }, + } + unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized")) + badRequest := api.BadRequest(fmt.Errorf("Bad Request")) + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + }{ + {"ok", ok, 200, false}, + {"unauthorized", unauthorized, 401, true}, + {"empty request", badRequest, 403, true}, + {"nil request", badRequest, 403, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(tt.responseCode) + api.JSON(w, tt.response) + }) + + got, err := c.Roots(nil) + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Roots() = %v, want nil", got) + } + if !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.Roots() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Roots() = %v, want %v", got, tt.response) + } + } + }) + } +} + func TestClient_Federation(t *testing.T) { ok := &api.FederationResponse{ Certificates: []api.Certificate{ diff --git a/ca/testdata/ca.json b/ca/testdata/ca.json index d61a8c49..f5484a7c 100644 --- a/ca/testdata/ca.json +++ b/ca/testdata/ca.json @@ -1,5 +1,6 @@ { "root": "../ca/testdata/secrets/root_ca.crt", + "federatedRoots": ["../ca/testdata/secrets/federated_ca.crt"], "crt": "../ca/testdata/secrets/intermediate_ca.crt", "key": "../ca/testdata/secrets/intermediate_ca_key", "password": "password", diff --git a/ca/testdata/secrets/federated_ca.crt b/ca/testdata/secrets/federated_ca.crt new file mode 100644 index 00000000..87cd5650 --- /dev/null +++ b/ca/testdata/secrets/federated_ca.crt @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa +MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw +MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG +SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL +jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB +/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR +qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z +YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc +Sw== +-----END CERTIFICATE----- diff --git a/ca/tls.go b/ca/tls.go index bef8e553..22eb667b 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -41,7 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(c, tlsConfig, options); err != nil { + if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil { return nil, err } @@ -87,7 +87,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(c, tlsConfig, options); err != nil { + if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil { return nil, err } diff --git a/ca/tls_options.go b/ca/tls_options.go index b1cd4696..26eae156 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -1,28 +1,57 @@ package ca import ( + "crypto" "crypto/tls" "crypto/x509" + "net/http" + + "github.com/smallstep/certificates/api" ) // TLSOption defines the type of a function that modifies a tls.Config. -type TLSOption func(c *Client, config *tls.Config) error +type TLSOption func(c *Client, tr http.RoundTripper, config *tls.Config) error // setTLSOptions takes one or more option function and applies them in order to // a tls.Config. -func setTLSOptions(c *Client, config *tls.Config, options []TLSOption) error { +func setTLSOptions(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config, options []TLSOption) error { + tr, err := getTLSOptionsTransport(sign, pk) + if err != nil { + return err + } + for _, opt := range options { - if err := opt(c, config); err != nil { + if err := opt(c, tr, config); err != nil { return err } } return nil } +// getTLSOptionsTransport is the transport used by TLSOptions. It is used to get +// root certificates using a mTLS connection with the CA. +func getTLSOptionsTransport(sign *api.SignResponse, pk crypto.PrivateKey) (http.RoundTripper, error) { + cert, err := TLSCertificate(sign, pk) + if err != nil { + return nil, err + } + + // Build default transport with fixed certificate + tlsConfig := getDefaultTLSConfig(sign) + tlsConfig.Certificates = []tls.Certificate{*cert} + tlsConfig.PreferServerCipherSuites = true + // Build RootCAs with given root certificate + if pool := getCertPool(sign); pool != nil { + tlsConfig.RootCAs = pool + } + + return getDefaultTransport(tlsConfig) +} + // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // a valid TLS client certificate. This is the default option for mTLS servers. func RequireAndVerifyClientCert() TLSOption { - return func(_ *Client, config *tls.Config) error { + return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { config.ClientAuth = tls.RequireAndVerifyClientCert return nil } @@ -31,7 +60,7 @@ func RequireAndVerifyClientCert() TLSOption { // VerifyClientCertIfGiven is a tls.Config option used on on servers to validate // a TLS client certificate if it is provided. It does not requires a certificate. func VerifyClientCertIfGiven() TLSOption { - return func(_ *Client, config *tls.Config) error { + return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { config.ClientAuth = tls.VerifyClientCertIfGiven return nil } @@ -41,7 +70,7 @@ func VerifyClientCertIfGiven() TLSOption { // defines the set of root certificate authorities that clients use when // verifying server certificates. func AddRootCA(cert *x509.Certificate) TLSOption { - return func(_ *Client, config *tls.Config) error { + return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } @@ -54,7 +83,7 @@ func AddRootCA(cert *x509.Certificate) TLSOption { // defines the set of root certificate authorities that servers use if required // to verify a client certificate by the policy in ClientAuth. func AddClientCA(cert *x509.Certificate) TLSOption { - return func(_ *Client, config *tls.Config) error { + return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { if config.ClientCAs == nil { config.ClientCAs = x509.NewCertPool() } @@ -63,19 +92,18 @@ func AddClientCA(cert *x509.Certificate) TLSOption { } } -// AddRootFederation does a federation request and adds to the tls.Config -// RootCAs all the certificates in the response. RootCAs -// defines the set of root certificate authorities that clients use when -// verifying server certificates. -func AddRootFederation() TLSOption { - return func(c *Client, config *tls.Config) error { - if config.RootCAs == nil { - config.RootCAs = x509.NewCertPool() - } - certs, err := c.Federation(nil) +// AddRootsToRootCAs does a roots request and adds to the tls.Config RootCAs all +// the certificates in the response. RootCAs defines the set of root certificate +// authorities that clients use when verifying server certificates. +func AddRootsToRootCAs() TLSOption { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { + certs, err := c.Roots(tr) if err != nil { return err } + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } for _, cert := range certs.Certificates { config.RootCAs.AddCert(cert.Certificate) } @@ -83,19 +111,58 @@ func AddRootFederation() TLSOption { } } -// AddClientFederation does a federation request and adds to the tls.Config -// ClientCAs all the certificates in the response. ClientCAs defines the set of -// root certificate authorities that servers use if required to verify a client +// AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs +// all the certificates in the response. ClientCAs defines the set of root +// certificate authorities that servers use if required to verify a client // certificate by the policy in ClientAuth. -func AddClientFederation() TLSOption { - return func(c *Client, config *tls.Config) error { - if config.ClientCAs == nil { - config.ClientCAs = x509.NewCertPool() - } - certs, err := c.Federation(nil) +func AddRootsToClientCAs() TLSOption { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { + certs, err := c.Roots(tr) if err != nil { return err } + if config.ClientCAs == nil { + config.ClientCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + config.ClientCAs.AddCert(cert.Certificate) + } + return nil + } +} + +// AddFederationToRootCAs does a federation request and adds to the tls.Config +// RootCAs all the certificates in the response. RootCAs defines the set of root +// certificate authorities that clients use when verifying server certificates. +func AddFederationToRootCAs() TLSOption { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { + certs, err := c.Federation(tr) + if err != nil { + return err + } + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + config.RootCAs.AddCert(cert.Certificate) + } + return nil + } +} + +// AddFederationToClientCAs does a federation request and adds to the tls.Config +// ClientCAs all the certificates in the response. ClientCAs defines the set of +// root certificate authorities that servers use if required to verify a client +// certificate by the policy in ClientAuth. +func AddFederationToClientCAs() TLSOption { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { + certs, err := c.Federation(tr) + if err != nil { + return err + } + if config.ClientCAs == nil { + config.ClientCAs = x509.NewCertPool() + } for _, cert := range certs.Certificates { config.ClientCAs.AddCert(cert.Certificate) } diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 9886e487..07068ca4 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -4,13 +4,15 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io/ioutil" + "net/http" "reflect" "testing" ) func Test_setTLSOptions(t *testing.T) { fail := func() TLSOption { - return func(c *Client, config *tls.Config) error { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { return fmt.Errorf("an error") } } @@ -27,9 +29,13 @@ func Test_setTLSOptions(t *testing.T) { {"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false}, {"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true}, } + + ca := startCATestServer() + defer ca.Close() + client, sr, pk := signDuration(ca, "127.0.0.1", 0) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := setTLSOptions(nil, tt.args.c, tt.args.options); (err != nil) != tt.wantErr { + if err := setTLSOptions(client, sr, pk, tt.args.c, tt.args.options); (err != nil) != tt.wantErr { t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -46,7 +52,7 @@ func TestRequireAndVerifyClientCert(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := RequireAndVerifyClientCert()(nil, got); err != nil { + if err := RequireAndVerifyClientCert()(nil, nil, got); err != nil { t.Errorf("RequireAndVerifyClientCert() error = %v", err) return } @@ -67,7 +73,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := VerifyClientCertIfGiven()(nil, got); err != nil { + if err := VerifyClientCertIfGiven()(nil, nil, got); err != nil { t.Errorf("VerifyClientCertIfGiven() error = %v", err) return } @@ -96,7 +102,7 @@ func TestAddRootCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := AddRootCA(tt.args.cert)(nil, got); err != nil { + if err := AddRootCA(tt.args.cert)(nil, nil, got); err != nil { t.Errorf("AddRootCA() error = %v", err) return } @@ -125,7 +131,7 @@ func TestAddClientCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := AddClientCA(tt.args.cert)(nil, got); err != nil { + if err := AddClientCA(tt.args.cert)(nil, nil, got); err != nil { t.Errorf("AddClientCA() error = %v", err) return } @@ -135,3 +141,185 @@ func TestAddClientCA(t *testing.T) { }) } } + +func TestAddRootsToRootCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + cert := parseCertificate(string(root)) + pool := x509.NewCertPool() + pool.AddCert(cert) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{RootCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddRootsToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddRootsToRootCAs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddRootsToClientCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + cert := parseCertificate(string(root)) + pool := x509.NewCertPool() + pool.AddCert(cert) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{ClientCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddRootsToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddRootsToClientCAs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddFederationToRootCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") + if err != nil { + t.Fatal(err) + } + + crt1 := parseCertificate(string(root)) + crt2 := parseCertificate(string(federated)) + pool := x509.NewCertPool() + pool.AddCert(crt1) + pool.AddCert(crt2) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{RootCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddFederationToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddFederationToRootCAs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddFederationToClientCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") + if err != nil { + t.Fatal(err) + } + + crt1 := parseCertificate(string(root)) + crt2 := parseCertificate(string(federated)) + pool := x509.NewCertPool() + pool.AddCert(crt1) + pool.AddCert(crt2) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{ClientCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddFederationToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddFederationToClientCAs() = %v, want %v", got, tt.want) + } + }) + } +}