From 37149ed3ea76777406d3cadc73d412533b0c2b4b Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 4 Jan 2019 16:51:37 -0800 Subject: [PATCH 01/20] Add method to get all the certs. --- api/api.go | 31 ++++++++++++++++++++++++++++ api/api_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/api/api.go b/api/api.go index e4f34d51..e427bee3 100644 --- a/api/api.go +++ b/api/api.go @@ -25,6 +25,7 @@ type Authority interface { Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error) GetEncryptedKey(kid string) (string, error) + GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) } // Certificate wraps a *x509.Certificate and adds the json.Marshaler interface. @@ -186,6 +187,11 @@ type SignResponse struct { TLS *tls.ConnectionState `json:"-"` } +// FederationResponse is the response object of the federation request. +type FederationResponse struct { + Certificates []Certificate `json:"crts"` +} + // caHandler is the type used to implement the different CA HTTP endpoints. type caHandler struct { Authority Authority @@ -205,6 +211,7 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("POST", "/renew", h.Renew) r.MethodFunc("GET", "/provisioners", h.Provisioners) r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) + r.MethodFunc("GET", "/federation", h.Federation) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) } @@ -320,6 +327,30 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { JSON(w, &ProvisionerKeyResponse{key}) } +// Federation returns all the public certificates in the federation. +func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + WriteError(w, BadRequest(errors.New("missing peer certificate"))) + return + } + + federated, err := h.Authority.GetFederation(r.TLS.PeerCertificates[0]) + if err != nil { + WriteError(w, Forbidden(err)) + return + } + + certs := make([]Certificate, len(federated)) + for i := range federated { + certs[i] = Certificate{federated[i]} + } + + w.WriteHeader(http.StatusCreated) + JSON(w, &FederationResponse{ + Certificates: certs, + }) +} + func parseCursor(r *http.Request) (cursor string, limit int, err error) { q := r.URL.Query() cursor = q.Get("cursor") diff --git a/api/api_test.go b/api/api_test.go index e6f123b8..82e12c8c 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -392,6 +392,7 @@ type mockAuthority struct { renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error) getEncryptedKey func(kid string) (string, error) + getFederation func(cert *x509.Certificate) ([]*x509.Certificate, error) } func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) { @@ -443,6 +444,13 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { return m.ret1.(string), m.err } +func (m *mockAuthority) GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) { + if m.getFederation != nil { + return m.getFederation(cert) + } + return m.ret1.([]*x509.Certificate), m.err +} + func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority @@ -812,3 +820,50 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { }) } } + +func Test_caHandler_Federation(t *testing.T) { + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, + } + tests := []struct { + name string + tls *tls.ConnectionState + cert *x509.Certificate + root *x509.Certificate + err error + statusCode int + }{ + {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, + {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, + {"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + } + + expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + req := httptest.NewRequest("GET", "http://example.com/federation", nil) + req.TLS = tt.tls + w := httptest.NewRecorder() + h.Federation(w, req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Root unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), expected) { + t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) + } + } + }) + } +} From 722bcb7e7ad86160fb9a321b8dd2ebfecb37fb13 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 4 Jan 2019 17:51:32 -0800 Subject: [PATCH 02/20] Add initial support for federated root certificates. --- authority/authority.go | 14 +++++- authority/config.go | 30 +------------ authority/root.go | 23 +++++++++- authority/root_test.go | 2 +- authority/types.go | 98 ++++++++++++++++++++++++++++++++++++++++++ ca/client.go | 19 ++++++++ ca/client_test.go | 61 ++++++++++++++++++++++++++ ca/tls.go | 4 +- ca/tls_options.go | 70 +++++++++++++++++++++++------- ca/tls_options_test.go | 12 +++--- 10 files changed, 277 insertions(+), 56 deletions(-) create mode 100644 authority/types.go diff --git a/authority/authority.go b/authority/authority.go index 4753f83b..8c60fffc 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -2,7 +2,7 @@ package authority import ( "crypto/sha256" - realx509 "crypto/x509" + "crypto/x509" "encoding/hex" "fmt" "sync" @@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority" // Authority implements the Certificate Authority internal interface. type Authority struct { config *Config - rootX509Crt *realx509.Certificate + rootX509Crt *x509.Certificate intermediateIdentity *x509util.Identity validateOnce bool certificates *sync.Map @@ -89,6 +89,16 @@ func (a *Authority) init() error { sum := sha256.Sum256(a.rootX509Crt.Raw) a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt) + // Add federated roots + for _, path := range a.config.FederatedRoots { + crt, err := pemutil.ReadCertificate(path) + if err != nil { + return err + } + sum := sha256.Sum256(crt.Raw) + a.certificates.Store(hex.EncodeToString(sum[:]), crt) + } + // Decrypt and load intermediate public / private key pair. if len(a.config.Password) > 0 { a.intermediateIdentity, err = x509util.LoadIdentityFromDisk( diff --git a/authority/config.go b/authority/config.go index 5f9910af..0ffdbef0 100644 --- a/authority/config.go +++ b/authority/config.go @@ -33,38 +33,10 @@ var ( } ) -type duration struct { - time.Duration -} - -// MarshalJSON parses a duration string and sets it to the duration. -// -// A duration string is a possibly signed sequence of decimal numbers, each with -// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". -// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) MarshalJSON() ([]byte, error) { - return json.Marshal(d.String()) -} - -// UnmarshalJSON parses a duration string and sets it to the duration. -// -// A duration string is a possibly signed sequence of decimal numbers, each with -// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". -// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) UnmarshalJSON(data []byte) (err error) { - var s string - if err = json.Unmarshal(data, &s); err != nil { - return errors.Wrapf(err, "error unmarshalling %s", data) - } - if d.Duration, err = time.ParseDuration(s); err != nil { - return errors.Wrapf(err, "error parsing %s as duration", s) - } - return -} - // Config represents the CA configuration and it's mapped to a JSON object. type Config struct { Root string `json:"root"` + FederatedRoots []string `json:"federatedRoots"` IntermediateCert string `json:"crt"` IntermediateKey string `json:"key"` Address string `json:"address"` diff --git a/authority/root.go b/authority/root.go index 5c918ee1..01710db8 100644 --- a/authority/root.go +++ b/authority/root.go @@ -17,7 +17,7 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) { crt, ok := val.(*x509.Certificate) if !ok { - return nil, &apiError{errors.Errorf("stored value is not a *cryto/x509.Certificate"), + return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}} } return crt, nil @@ -27,3 +27,24 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) { func (a *Authority) GetRootCertificate() *x509.Certificate { return a.rootX509Crt } + +// GetFederation returns all the root certificates in the federation. +func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) { + // Check step provisioner extensions + if err := a.authorizeRenewal(peer); err != nil { + return nil, err + } + + a.certificates.Range(func(k, v interface{}) bool { + crt, ok := v.(*x509.Certificate) + if !ok { + federation = nil + err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"), + http.StatusInternalServerError, context{}} + return false + } + federation = append(federation, crt) + return true + }) + return +} diff --git a/authority/root_test.go b/authority/root_test.go index fd4c31db..0db7e866 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -17,7 +17,7 @@ func TestRoot(t *testing.T) { err *apiError }{ "not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}}, - "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *cryto/x509.Certificate"), http.StatusInternalServerError, context{}}}, + "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}}, "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil}, } diff --git a/authority/types.go b/authority/types.go new file mode 100644 index 00000000..ec8f0d7b --- /dev/null +++ b/authority/types.go @@ -0,0 +1,98 @@ +package authority + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +type duration struct { + time.Duration +} + +// MarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +// UnmarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *duration) UnmarshalJSON(data []byte) (err error) { + var s string + if err = json.Unmarshal(data, &s); err != nil { + return errors.Wrapf(err, "error unmarshalling %s", data) + } + if d.Duration, err = time.ParseDuration(s); err != nil { + return errors.Wrapf(err, "error parsing %s as duration", s) + } + return +} + +type multiString []string + +// FIXME: remove me, avoids deadcode warning +var _ = multiString{} + +// First returns the first element of a multiString. It will return an empty +// string if the multistring is empty. +func (s multiString) First() string { + if len(s) > 0 { + return s[0] + } + return "" +} + +// Empties checks that none of the string is empty. +func (s multiString) Empties() bool { + if len(s) == 0 { + return true + } + for _, ss := range s { + if len(ss) == 0 { + return true + } + } + return false +} + +// MarshalJSON marshals the multistring as a string or a slice of strings . With +// 0 elements it will return the empty string, with 1 element a regular string, +// otherwise a slice of strings. +func (s multiString) MarshalJSON() ([]byte, error) { + switch len(s) { + case 0: + return []byte(""), nil + case 1: + return json.Marshal(s[0]) + default: + return json.Marshal(s) + } +} + +// UnmarshalJSON parses a string or a slice and sets it to the multiString. +func (s *multiString) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + *s = nil + return nil + } + if data[0] == '"' { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return errors.Wrapf(err, "error unmarshalling %s", data) + } + *s = []string{str} + return nil + } + if err := json.Unmarshal(data, s); err != nil { + return errors.Wrapf(err, "error unmarshalling %s", data) + } + return nil +} diff --git a/ca/client.go b/ca/client.go index 47116245..374a68ff 100644 --- a/ca/client.go +++ b/ca/client.go @@ -413,6 +413,25 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) return &key, nil } +// Federation performs the get federation request to the CA and returns the +// api.FederationResponse struct. +func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"}) + client := &http.Client{Transport: tr} + resp, err := client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var federation api.FederationResponse + if err := readJSON(resp.Body, &federation); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &federation, nil +} + // CreateSignRequest is a helper function that given an x509 OTT returns a // simple but secure sign request as well as the private key used. func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) { diff --git a/ca/client_test.go b/ca/client_test.go index 6d5cd22a..138b0d7d 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -512,6 +512,67 @@ func TestClient_ProvisionerKey(t *testing.T) { } } +func TestClient_Federation(t *testing.T) { + ok := &api.FederationResponse{ + Certificates: []api.Certificate{ + {Certificate: parseCertificate(rootPEM)}, + }, + } + unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized")) + badRequest := api.BadRequest(fmt.Errorf("Bad Request")) + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + }{ + {"ok", ok, 200, false}, + {"unauthorized", unauthorized, 401, true}, + {"empty request", badRequest, 403, true}, + {"nil request", badRequest, 403, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(tt.responseCode) + api.JSON(w, tt.response) + }) + + got, err := c.Federation(nil) + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Federation() = %v, want nil", got) + } + if !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.Federation() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Federation() = %v, want %v", got, tt.response) + } + } + }) + } +} + func Test_parseEndpoint(t *testing.T) { expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"} expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"} diff --git a/ca/tls.go b/ca/tls.go index 5e8c4118..bef8e553 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -41,7 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(tlsConfig, options); err != nil { + if err := setTLSOptions(c, tlsConfig, options); err != nil { return nil, err } @@ -87,7 +87,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(tlsConfig, options); err != nil { + if err := setTLSOptions(c, tlsConfig, options); err != nil { return nil, err } diff --git a/ca/tls_options.go b/ca/tls_options.go index fb0bb20b..b1cd4696 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -6,13 +6,13 @@ import ( ) // TLSOption defines the type of a function that modifies a tls.Config. -type TLSOption func(c *tls.Config) error +type TLSOption func(c *Client, config *tls.Config) error // setTLSOptions takes one or more option function and applies them in order to // a tls.Config. -func setTLSOptions(c *tls.Config, options []TLSOption) error { +func setTLSOptions(c *Client, config *tls.Config, options []TLSOption) error { for _, opt := range options { - if err := opt(c); err != nil { + if err := opt(c, config); err != nil { return err } } @@ -22,8 +22,8 @@ func setTLSOptions(c *tls.Config, options []TLSOption) error { // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // a valid TLS client certificate. This is the default option for mTLS servers. func RequireAndVerifyClientCert() TLSOption { - return func(c *tls.Config) error { - c.ClientAuth = tls.RequireAndVerifyClientCert + return func(_ *Client, config *tls.Config) error { + config.ClientAuth = tls.RequireAndVerifyClientCert return nil } } @@ -31,8 +31,8 @@ func RequireAndVerifyClientCert() TLSOption { // VerifyClientCertIfGiven is a tls.Config option used on on servers to validate // a TLS client certificate if it is provided. It does not requires a certificate. func VerifyClientCertIfGiven() TLSOption { - return func(c *tls.Config) error { - c.ClientAuth = tls.VerifyClientCertIfGiven + return func(_ *Client, config *tls.Config) error { + config.ClientAuth = tls.VerifyClientCertIfGiven return nil } } @@ -41,11 +41,11 @@ func VerifyClientCertIfGiven() TLSOption { // defines the set of root certificate authorities that clients use when // verifying server certificates. func AddRootCA(cert *x509.Certificate) TLSOption { - return func(c *tls.Config) error { - if c.RootCAs == nil { - c.RootCAs = x509.NewCertPool() + return func(_ *Client, config *tls.Config) error { + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() } - c.RootCAs.AddCert(cert) + config.RootCAs.AddCert(cert) return nil } } @@ -54,11 +54,51 @@ func AddRootCA(cert *x509.Certificate) TLSOption { // defines the set of root certificate authorities that servers use if required // to verify a client certificate by the policy in ClientAuth. func AddClientCA(cert *x509.Certificate) TLSOption { - return func(c *tls.Config) error { - if c.ClientCAs == nil { - c.ClientCAs = x509.NewCertPool() + return func(_ *Client, config *tls.Config) error { + if config.ClientCAs == nil { + config.ClientCAs = x509.NewCertPool() + } + config.ClientCAs.AddCert(cert) + return nil + } +} + +// AddRootFederation does a federation request and adds to the tls.Config +// RootCAs all the certificates in the response. RootCAs +// defines the set of root certificate authorities that clients use when +// verifying server certificates. +func AddRootFederation() TLSOption { + return func(c *Client, config *tls.Config) error { + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + certs, err := c.Federation(nil) + if err != nil { + return err + } + for _, cert := range certs.Certificates { + config.RootCAs.AddCert(cert.Certificate) + } + return nil + } +} + +// AddClientFederation does a federation request and adds to the tls.Config +// ClientCAs all the certificates in the response. ClientCAs defines the set of +// root certificate authorities that servers use if required to verify a client +// certificate by the policy in ClientAuth. +func AddClientFederation() TLSOption { + return func(c *Client, config *tls.Config) error { + if config.ClientCAs == nil { + config.ClientCAs = x509.NewCertPool() + } + certs, err := c.Federation(nil) + if err != nil { + return err + } + for _, cert := range certs.Certificates { + config.ClientCAs.AddCert(cert.Certificate) } - c.ClientCAs.AddCert(cert) return nil } } diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 896ff72b..9886e487 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -10,7 +10,7 @@ import ( func Test_setTLSOptions(t *testing.T) { fail := func() TLSOption { - return func(c *tls.Config) error { + return func(c *Client, config *tls.Config) error { return fmt.Errorf("an error") } } @@ -29,7 +29,7 @@ func Test_setTLSOptions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr { + if err := setTLSOptions(nil, tt.args.c, tt.args.options); (err != nil) != tt.wantErr { t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -46,7 +46,7 @@ func TestRequireAndVerifyClientCert(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := RequireAndVerifyClientCert()(got); err != nil { + if err := RequireAndVerifyClientCert()(nil, got); err != nil { t.Errorf("RequireAndVerifyClientCert() error = %v", err) return } @@ -67,7 +67,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := VerifyClientCertIfGiven()(got); err != nil { + if err := VerifyClientCertIfGiven()(nil, got); err != nil { t.Errorf("VerifyClientCertIfGiven() error = %v", err) return } @@ -96,7 +96,7 @@ func TestAddRootCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := AddRootCA(tt.args.cert)(got); err != nil { + if err := AddRootCA(tt.args.cert)(nil, got); err != nil { t.Errorf("AddRootCA() error = %v", err) return } @@ -125,7 +125,7 @@ func TestAddClientCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := AddClientCA(tt.args.cert)(got); err != nil { + if err := AddClientCA(tt.args.cert)(nil, got); err != nil { t.Errorf("AddClientCA() error = %v", err) return } From 98cc243a37339af88f9b0afb5602fca330253a4a Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 7 Jan 2019 15:30:28 -0800 Subject: [PATCH 03/20] Add support for multiple roots. --- authority/authority.go | 22 ++++---- authority/authority_test.go | 8 +-- authority/config.go | 4 +- authority/config_test.go | 18 +++---- authority/root.go | 7 ++- authority/root_test.go | 2 +- authority/types.go | 18 ++++--- authority/types_test.go | 103 ++++++++++++++++++++++++++++++++++++ ca/ca.go | 4 +- 9 files changed, 153 insertions(+), 33 deletions(-) create mode 100644 authority/types_test.go diff --git a/authority/authority.go b/authority/authority.go index 8c60fffc..5a0cf1ab 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority" // Authority implements the Certificate Authority internal interface. type Authority struct { config *Config - rootX509Crt *x509.Certificate + rootX509Certs []*x509.Certificate intermediateIdentity *x509util.Identity validateOnce bool certificates *sync.Map @@ -79,15 +79,19 @@ func (a *Authority) init() error { } var err error - // First load the root using our modified pem/x509 package. - a.rootX509Crt, err = pemutil.ReadCertificate(a.config.Root) - if err != nil { - return err - } - // Add root certificate to the certificate map - sum := sha256.Sum256(a.rootX509Crt.Raw) - a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt) + // Load the root certificates and add them to the certificate store + a.rootX509Certs = make([]*x509.Certificate, len(a.config.Root)) + for i, path := range a.config.Root { + crt, err := pemutil.ReadCertificate(path) + if err != nil { + return err + } + // Add root certificate to the certificate map + sum := sha256.Sum256(crt.Raw) + a.certificates.Store(hex.EncodeToString(sum[:]), crt) + a.rootX509Certs[i] = crt + } // Add federated roots for _, path := range a.config.FederatedRoots { diff --git a/authority/authority_test.go b/authority/authority_test.go index ad2f4980..1020f808 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -38,7 +38,7 @@ func testAuthority(t *testing.T) *Authority { } c := &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.ca.smallstep.com"}, @@ -68,7 +68,7 @@ func TestAuthorityNew(t *testing.T) { "fail bad root": func(t *testing.T) *newTest { c, err := LoadConfiguration("../ca/testdata/ca.json") assert.FatalError(t, err) - c.Root = "foo" + c.Root = []string{"foo"} return &newTest{ config: c, err: errors.New("open foo failed: no such file or directory"), @@ -105,10 +105,10 @@ func TestAuthorityNew(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - sum := sha256.Sum256(auth.rootX509Crt.Raw) + sum := sha256.Sum256(auth.rootX509Certs[0].Raw) root, ok := auth.certificates.Load(hex.EncodeToString(sum[:])) assert.Fatal(t, ok) - assert.Equals(t, auth.rootX509Crt, root) + assert.Equals(t, auth.rootX509Certs[0], root) assert.True(t, auth.initOnce) assert.NotNil(t, auth.intermediateIdentity) diff --git a/authority/config.go b/authority/config.go index 0ffdbef0..f19fb202 100644 --- a/authority/config.go +++ b/authority/config.go @@ -35,7 +35,7 @@ var ( // Config represents the CA configuration and it's mapped to a JSON object. type Config struct { - Root string `json:"root"` + Root multiString `json:"root"` FederatedRoots []string `json:"federatedRoots"` IntermediateCert string `json:"crt"` IntermediateKey string `json:"key"` @@ -117,7 +117,7 @@ func (c *Config) Validate() error { case c.Address == "": return errors.New("address cannot be empty") - case c.Root == "": + case c.Root.Empties(): return errors.New("root cannot be empty") case c.IntermediateCert == "": diff --git a/authority/config_test.go b/authority/config_test.go index c16b4780..01cea2a1 100644 --- a/authority/config_test.go +++ b/authority/config_test.go @@ -40,7 +40,7 @@ func TestConfigValidate(t *testing.T) { "empty-address": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -54,7 +54,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -81,7 +81,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", @@ -94,7 +94,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", DNSNames: []string{"test.smallstep.com"}, Password: "pass", @@ -107,7 +107,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", Password: "pass", @@ -120,7 +120,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -134,7 +134,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -149,7 +149,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -178,7 +178,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, diff --git a/authority/root.go b/authority/root.go index 01710db8..d041ae8f 100644 --- a/authority/root.go +++ b/authority/root.go @@ -25,7 +25,12 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) { // GetRootCertificate returns the server root certificate. func (a *Authority) GetRootCertificate() *x509.Certificate { - return a.rootX509Crt + return a.rootX509Certs[0] +} + +// GetRootCertificates returns the server root certificates. +func (a *Authority) GetRootCertificates() []*x509.Certificate { + return a.rootX509Certs } // GetFederation returns all the root certificates in the federation. diff --git a/authority/root_test.go b/authority/root_test.go index 0db7e866..d9803d8e 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -37,7 +37,7 @@ func TestRoot(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, crt, a.rootX509Crt) + assert.Equals(t, crt, a.rootX509Certs[0]) } } }) diff --git a/authority/types.go b/authority/types.go index ec8f0d7b..50b632d6 100644 --- a/authority/types.go +++ b/authority/types.go @@ -36,11 +36,10 @@ func (d *duration) UnmarshalJSON(data []byte) (err error) { return } +// multiString represents a type that can be encoded/decoded in JSON as a single +// string or an array of strings. type multiString []string -// FIXME: remove me, avoids deadcode warning -var _ = multiString{} - // First returns the first element of a multiString. It will return an empty // string if the multistring is empty. func (s multiString) First() string { @@ -69,20 +68,24 @@ func (s multiString) Empties() bool { func (s multiString) MarshalJSON() ([]byte, error) { switch len(s) { case 0: - return []byte(""), nil + return []byte(`""`), nil case 1: return json.Marshal(s[0]) default: - return json.Marshal(s) + return json.Marshal([]string(s)) } } // UnmarshalJSON parses a string or a slice and sets it to the multiString. func (s *multiString) UnmarshalJSON(data []byte) error { + if s == nil { + return errors.New("multiString cannot be nil") + } if len(data) == 0 { *s = nil return nil } + // Parse string if data[0] == '"' { var str string if err := json.Unmarshal(data, &str); err != nil { @@ -91,8 +94,11 @@ func (s *multiString) UnmarshalJSON(data []byte) error { *s = []string{str} return nil } - if err := json.Unmarshal(data, s); err != nil { + // Parse array + var ss []string + if err := json.Unmarshal(data, &ss); err != nil { return errors.Wrapf(err, "error unmarshalling %s", data) } + *s = ss return nil } diff --git a/authority/types_test.go b/authority/types_test.go new file mode 100644 index 00000000..620751d3 --- /dev/null +++ b/authority/types_test.go @@ -0,0 +1,103 @@ +package authority + +import ( + "reflect" + "testing" +) + +func Test_multiString_First(t *testing.T) { + tests := []struct { + name string + s multiString + want string + }{ + {"empty", multiString{}, ""}, + {"string", multiString{"one"}, "one"}, + {"slice", multiString{"one", "two"}, "one"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.s.First(); got != tt.want { + t.Errorf("multiString.First() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiString_Empties(t *testing.T) { + tests := []struct { + name string + s multiString + want bool + }{ + {"empty", multiString{}, true}, + {"string", multiString{"one"}, false}, + {"empty string", multiString{""}, true}, + {"slice", multiString{"one", "two"}, false}, + {"empty slice", multiString{"one", ""}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.s.Empties(); got != tt.want { + t.Errorf("multiString.Empties() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiString_MarshalJSON(t *testing.T) { + tests := []struct { + name string + s multiString + want []byte + wantErr bool + }{ + {"empty", []string{}, []byte(`""`), false}, + {"string", []string{"a string"}, []byte(`"a string"`), false}, + {"slice", []string{"string one", "string two"}, []byte(`["string one","string two"]`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.s.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("multiString.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("multiString.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiString_UnmarshalJSON(t *testing.T) { + + type args struct { + data []byte + } + tests := []struct { + name string + s *multiString + args args + want *multiString + wantErr bool + }{ + {"empty", new(multiString), args{[]byte{}}, new(multiString), false}, + {"empty string", new(multiString), args{[]byte(`""`)}, &multiString{""}, false}, + {"string", new(multiString), args{[]byte(`"a string"`)}, &multiString{"a string"}, false}, + {"slice", new(multiString), args{[]byte(`["string one","string two"]`)}, &multiString{"string one", "string two"}, false}, + {"error", new(multiString), args{[]byte(`["123",123]`)}, new(multiString), true}, + {"nil", nil, args{nil}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.s.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("multiString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(tt.s, tt.want) { + t.Errorf("multiString.UnmarshalJSON() = %v, want %v", tt.s, tt.want) + } + }) + } +} diff --git a/ca/ca.go b/ca/ca.go index 8f72984f..07ee3311 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -176,7 +176,9 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) { } certPool := x509.NewCertPool() - certPool.AddCert(auth.GetRootCertificate()) + for _, crt := range auth.GetRootCertificates() { + certPool.AddCert(crt) + } // GetCertificate will only be called if the client supplies SNI // information or if tlsConfig.Certificates is empty. From d296cf95a9b495bf2abc6dc0e72e054148167c15 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 7 Jan 2019 17:48:56 -0800 Subject: [PATCH 04/20] Add mTLS request to get all the root CAs, not the federated ones. --- api/api.go | 39 +++++- api/api_test.go | 61 +++++++- authority/root.go | 9 ++ ca/client.go | 26 +++- ca/client_test.go | 61 ++++++++ ca/testdata/ca.json | 1 + ca/testdata/secrets/federated_ca.crt | 11 ++ ca/tls.go | 4 +- ca/tls_options.go | 119 ++++++++++++---- ca/tls_options_test.go | 200 ++++++++++++++++++++++++++- 10 files changed, 488 insertions(+), 43 deletions(-) create mode 100644 ca/testdata/secrets/federated_ca.crt diff --git a/api/api.go b/api/api.go index e427bee3..563e65ea 100644 --- a/api/api.go +++ b/api/api.go @@ -22,10 +22,11 @@ type Authority interface { GetTLSOptions() *tlsutil.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) - Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) + Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error) GetEncryptedKey(kid string) (string, error) - GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) + GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error) + GetFederation(peer *x509.Certificate) ([]*x509.Certificate, error) } // Certificate wraps a *x509.Certificate and adds the json.Marshaler interface. @@ -187,6 +188,11 @@ type SignResponse struct { TLS *tls.ConnectionState `json:"-"` } +// RootsResponse is the response object of the roots request. +type RootsResponse struct { + Certificates []Certificate `json:"crts"` +} + // FederationResponse is the response object of the federation request. type FederationResponse struct { Certificates []Certificate `json:"crts"` @@ -211,6 +217,7 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("POST", "/renew", h.Renew) r.MethodFunc("GET", "/provisioners", h.Provisioners) r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) + r.MethodFunc("GET", "/roots", h.Roots) r.MethodFunc("GET", "/federation", h.Federation) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) @@ -327,7 +334,33 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { JSON(w, &ProvisionerKeyResponse{key}) } -// Federation returns all the public certificates in the federation. +// Roots returns all the root certificates for the CA. It requires a valid TLS +// client. +func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + WriteError(w, BadRequest(errors.New("missing peer certificate"))) + return + } + + roots, err := h.Authority.GetRoots(r.TLS.PeerCertificates[0]) + if err != nil { + WriteError(w, Forbidden(err)) + return + } + + certs := make([]Certificate, len(roots)) + for i := range roots { + certs[i] = Certificate{roots[i]} + } + + w.WriteHeader(http.StatusCreated) + JSON(w, &RootsResponse{ + Certificates: certs, + }) +} + +// Federation returns all the public certificates in the federation. It requires +// a valid TLS client. func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { WriteError(w, BadRequest(errors.New("missing peer certificate"))) diff --git a/api/api_test.go b/api/api_test.go index 82e12c8c..0a9eb120 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -392,6 +392,7 @@ type mockAuthority struct { renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error) getEncryptedKey func(kid string) (string, error) + getRoots func(cert *x509.Certificate) ([]*x509.Certificate, error) getFederation func(cert *x509.Certificate) ([]*x509.Certificate, error) } @@ -444,6 +445,13 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { return m.ret1.(string), m.err } +func (m *mockAuthority) GetRoots(cert *x509.Certificate) ([]*x509.Certificate, error) { + if m.getFederation != nil { + return m.getRoots(cert) + } + return m.ret1.([]*x509.Certificate), m.err +} + func (m *mockAuthority) GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) { if m.getFederation != nil { return m.getFederation(cert) @@ -821,6 +829,53 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { } } +func Test_caHandler_Roots(t *testing.T) { + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, + } + tests := []struct { + name string + tls *tls.ConnectionState + cert *x509.Certificate + root *x509.Certificate + err error + statusCode int + }{ + {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, + {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, + {"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + } + + expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + req := httptest.NewRequest("GET", "http://example.com/roots", nil) + req.TLS = tt.tls + w := httptest.NewRecorder() + h.Roots(w, req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Roots unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), expected) { + t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected) + } + } + }) + } +} + func Test_caHandler_Federation(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, @@ -851,17 +906,17 @@ func Test_caHandler_Federation(t *testing.T) { res := w.Result() if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { - t.Errorf("caHandler.Root unexpected error = %v", err) + t.Errorf("caHandler.Federation unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), expected) { - t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) + t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected) } } }) diff --git a/authority/root.go b/authority/root.go index d041ae8f..98974904 100644 --- a/authority/root.go +++ b/authority/root.go @@ -33,6 +33,15 @@ func (a *Authority) GetRootCertificates() []*x509.Certificate { return a.rootX509Certs } +// GetRoots returns all the root certificates for this CA. +func (a *Authority) GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error) { + // Check step provisioner extensions + if err := a.authorizeRenewal(peer); err != nil { + return nil, err + } + return a.rootX509Certs, nil +} + // GetFederation returns all the root certificates in the federation. func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) { // Check step provisioner extensions diff --git a/ca/client.go b/ca/client.go index 374a68ff..b8aab67f 100644 --- a/ca/client.go +++ b/ca/client.go @@ -237,9 +237,10 @@ func WithProvisionerLimit(limit int) ProvisionerOption { // Client implements an HTTP client for the CA server. type Client struct { - client *http.Client - endpoint *url.URL - certPool *x509.CertPool + client *http.Client + endpoint *url.URL + certPool *x509.CertPool + cachedSign *api.SignResponse } // NewClient creates a new Client with the given endpoint and options. @@ -413,6 +414,25 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) return &key, nil } +// Roots performs the get roots request to the CA and returns the +// api.RootsResponse struct. +func (c *Client) Roots(tr http.RoundTripper) (*api.RootsResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"}) + client := &http.Client{Transport: tr} + resp, err := client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var federation api.RootsResponse + if err := readJSON(resp.Body, &federation); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &federation, nil +} + // Federation performs the get federation request to the CA and returns the // api.FederationResponse struct. func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) { diff --git a/ca/client_test.go b/ca/client_test.go index 138b0d7d..0ec6324b 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -512,6 +512,67 @@ func TestClient_ProvisionerKey(t *testing.T) { } } +func TestClient_Roots(t *testing.T) { + ok := &api.RootsResponse{ + Certificates: []api.Certificate{ + {Certificate: parseCertificate(rootPEM)}, + }, + } + unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized")) + badRequest := api.BadRequest(fmt.Errorf("Bad Request")) + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + }{ + {"ok", ok, 200, false}, + {"unauthorized", unauthorized, 401, true}, + {"empty request", badRequest, 403, true}, + {"nil request", badRequest, 403, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(tt.responseCode) + api.JSON(w, tt.response) + }) + + got, err := c.Roots(nil) + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Roots() = %v, want nil", got) + } + if !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.Roots() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Roots() = %v, want %v", got, tt.response) + } + } + }) + } +} + func TestClient_Federation(t *testing.T) { ok := &api.FederationResponse{ Certificates: []api.Certificate{ diff --git a/ca/testdata/ca.json b/ca/testdata/ca.json index d61a8c49..f5484a7c 100644 --- a/ca/testdata/ca.json +++ b/ca/testdata/ca.json @@ -1,5 +1,6 @@ { "root": "../ca/testdata/secrets/root_ca.crt", + "federatedRoots": ["../ca/testdata/secrets/federated_ca.crt"], "crt": "../ca/testdata/secrets/intermediate_ca.crt", "key": "../ca/testdata/secrets/intermediate_ca_key", "password": "password", diff --git a/ca/testdata/secrets/federated_ca.crt b/ca/testdata/secrets/federated_ca.crt new file mode 100644 index 00000000..87cd5650 --- /dev/null +++ b/ca/testdata/secrets/federated_ca.crt @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa +MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw +MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG +SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL +jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB +/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR +qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z +YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc +Sw== +-----END CERTIFICATE----- diff --git a/ca/tls.go b/ca/tls.go index bef8e553..22eb667b 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -41,7 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(c, tlsConfig, options); err != nil { + if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil { return nil, err } @@ -87,7 +87,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(c, tlsConfig, options); err != nil { + if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil { return nil, err } diff --git a/ca/tls_options.go b/ca/tls_options.go index b1cd4696..26eae156 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -1,28 +1,57 @@ package ca import ( + "crypto" "crypto/tls" "crypto/x509" + "net/http" + + "github.com/smallstep/certificates/api" ) // TLSOption defines the type of a function that modifies a tls.Config. -type TLSOption func(c *Client, config *tls.Config) error +type TLSOption func(c *Client, tr http.RoundTripper, config *tls.Config) error // setTLSOptions takes one or more option function and applies them in order to // a tls.Config. -func setTLSOptions(c *Client, config *tls.Config, options []TLSOption) error { +func setTLSOptions(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config, options []TLSOption) error { + tr, err := getTLSOptionsTransport(sign, pk) + if err != nil { + return err + } + for _, opt := range options { - if err := opt(c, config); err != nil { + if err := opt(c, tr, config); err != nil { return err } } return nil } +// getTLSOptionsTransport is the transport used by TLSOptions. It is used to get +// root certificates using a mTLS connection with the CA. +func getTLSOptionsTransport(sign *api.SignResponse, pk crypto.PrivateKey) (http.RoundTripper, error) { + cert, err := TLSCertificate(sign, pk) + if err != nil { + return nil, err + } + + // Build default transport with fixed certificate + tlsConfig := getDefaultTLSConfig(sign) + tlsConfig.Certificates = []tls.Certificate{*cert} + tlsConfig.PreferServerCipherSuites = true + // Build RootCAs with given root certificate + if pool := getCertPool(sign); pool != nil { + tlsConfig.RootCAs = pool + } + + return getDefaultTransport(tlsConfig) +} + // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // a valid TLS client certificate. This is the default option for mTLS servers. func RequireAndVerifyClientCert() TLSOption { - return func(_ *Client, config *tls.Config) error { + return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { config.ClientAuth = tls.RequireAndVerifyClientCert return nil } @@ -31,7 +60,7 @@ func RequireAndVerifyClientCert() TLSOption { // VerifyClientCertIfGiven is a tls.Config option used on on servers to validate // a TLS client certificate if it is provided. It does not requires a certificate. func VerifyClientCertIfGiven() TLSOption { - return func(_ *Client, config *tls.Config) error { + return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { config.ClientAuth = tls.VerifyClientCertIfGiven return nil } @@ -41,7 +70,7 @@ func VerifyClientCertIfGiven() TLSOption { // defines the set of root certificate authorities that clients use when // verifying server certificates. func AddRootCA(cert *x509.Certificate) TLSOption { - return func(_ *Client, config *tls.Config) error { + return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } @@ -54,7 +83,7 @@ func AddRootCA(cert *x509.Certificate) TLSOption { // defines the set of root certificate authorities that servers use if required // to verify a client certificate by the policy in ClientAuth. func AddClientCA(cert *x509.Certificate) TLSOption { - return func(_ *Client, config *tls.Config) error { + return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { if config.ClientCAs == nil { config.ClientCAs = x509.NewCertPool() } @@ -63,19 +92,18 @@ func AddClientCA(cert *x509.Certificate) TLSOption { } } -// AddRootFederation does a federation request and adds to the tls.Config -// RootCAs all the certificates in the response. RootCAs -// defines the set of root certificate authorities that clients use when -// verifying server certificates. -func AddRootFederation() TLSOption { - return func(c *Client, config *tls.Config) error { - if config.RootCAs == nil { - config.RootCAs = x509.NewCertPool() - } - certs, err := c.Federation(nil) +// AddRootsToRootCAs does a roots request and adds to the tls.Config RootCAs all +// the certificates in the response. RootCAs defines the set of root certificate +// authorities that clients use when verifying server certificates. +func AddRootsToRootCAs() TLSOption { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { + certs, err := c.Roots(tr) if err != nil { return err } + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } for _, cert := range certs.Certificates { config.RootCAs.AddCert(cert.Certificate) } @@ -83,19 +111,58 @@ func AddRootFederation() TLSOption { } } -// AddClientFederation does a federation request and adds to the tls.Config -// ClientCAs all the certificates in the response. ClientCAs defines the set of -// root certificate authorities that servers use if required to verify a client +// AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs +// all the certificates in the response. ClientCAs defines the set of root +// certificate authorities that servers use if required to verify a client // certificate by the policy in ClientAuth. -func AddClientFederation() TLSOption { - return func(c *Client, config *tls.Config) error { - if config.ClientCAs == nil { - config.ClientCAs = x509.NewCertPool() - } - certs, err := c.Federation(nil) +func AddRootsToClientCAs() TLSOption { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { + certs, err := c.Roots(tr) if err != nil { return err } + if config.ClientCAs == nil { + config.ClientCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + config.ClientCAs.AddCert(cert.Certificate) + } + return nil + } +} + +// AddFederationToRootCAs does a federation request and adds to the tls.Config +// RootCAs all the certificates in the response. RootCAs defines the set of root +// certificate authorities that clients use when verifying server certificates. +func AddFederationToRootCAs() TLSOption { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { + certs, err := c.Federation(tr) + if err != nil { + return err + } + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + config.RootCAs.AddCert(cert.Certificate) + } + return nil + } +} + +// AddFederationToClientCAs does a federation request and adds to the tls.Config +// ClientCAs all the certificates in the response. ClientCAs defines the set of +// root certificate authorities that servers use if required to verify a client +// certificate by the policy in ClientAuth. +func AddFederationToClientCAs() TLSOption { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { + certs, err := c.Federation(tr) + if err != nil { + return err + } + if config.ClientCAs == nil { + config.ClientCAs = x509.NewCertPool() + } for _, cert := range certs.Certificates { config.ClientCAs.AddCert(cert.Certificate) } diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 9886e487..07068ca4 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -4,13 +4,15 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io/ioutil" + "net/http" "reflect" "testing" ) func Test_setTLSOptions(t *testing.T) { fail := func() TLSOption { - return func(c *Client, config *tls.Config) error { + return func(c *Client, tr http.RoundTripper, config *tls.Config) error { return fmt.Errorf("an error") } } @@ -27,9 +29,13 @@ func Test_setTLSOptions(t *testing.T) { {"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false}, {"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true}, } + + ca := startCATestServer() + defer ca.Close() + client, sr, pk := signDuration(ca, "127.0.0.1", 0) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := setTLSOptions(nil, tt.args.c, tt.args.options); (err != nil) != tt.wantErr { + if err := setTLSOptions(client, sr, pk, tt.args.c, tt.args.options); (err != nil) != tt.wantErr { t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -46,7 +52,7 @@ func TestRequireAndVerifyClientCert(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := RequireAndVerifyClientCert()(nil, got); err != nil { + if err := RequireAndVerifyClientCert()(nil, nil, got); err != nil { t.Errorf("RequireAndVerifyClientCert() error = %v", err) return } @@ -67,7 +73,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := VerifyClientCertIfGiven()(nil, got); err != nil { + if err := VerifyClientCertIfGiven()(nil, nil, got); err != nil { t.Errorf("VerifyClientCertIfGiven() error = %v", err) return } @@ -96,7 +102,7 @@ func TestAddRootCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := AddRootCA(tt.args.cert)(nil, got); err != nil { + if err := AddRootCA(tt.args.cert)(nil, nil, got); err != nil { t.Errorf("AddRootCA() error = %v", err) return } @@ -125,7 +131,7 @@ func TestAddClientCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := AddClientCA(tt.args.cert)(nil, got); err != nil { + if err := AddClientCA(tt.args.cert)(nil, nil, got); err != nil { t.Errorf("AddClientCA() error = %v", err) return } @@ -135,3 +141,185 @@ func TestAddClientCA(t *testing.T) { }) } } + +func TestAddRootsToRootCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + cert := parseCertificate(string(root)) + pool := x509.NewCertPool() + pool.AddCert(cert) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{RootCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddRootsToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddRootsToRootCAs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddRootsToClientCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + cert := parseCertificate(string(root)) + pool := x509.NewCertPool() + pool.AddCert(cert) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{ClientCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddRootsToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddRootsToClientCAs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddFederationToRootCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") + if err != nil { + t.Fatal(err) + } + + crt1 := parseCertificate(string(root)) + crt2 := parseCertificate(string(federated)) + pool := x509.NewCertPool() + pool.AddCert(crt1) + pool.AddCert(crt2) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{RootCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddFederationToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddFederationToRootCAs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAddFederationToClientCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") + if err != nil { + t.Fatal(err) + } + + crt1 := parseCertificate(string(root)) + crt2 := parseCertificate(string(federated)) + pool := x509.NewCertPool() + pool.AddCert(crt1) + pool.AddCert(crt2) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{ClientCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &tls.Config{} + if err := AddFederationToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddFederationToClientCAs() = %v, want %v", got, tt.want) + } + }) + } +} From 6d3e8ed93c70e8263c9eda24493dbd2e0cd2ea58 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 7 Jan 2019 18:55:40 -0800 Subject: [PATCH 05/20] Add all root certificates by default on bootstrap methods. --- ca/bootstrap.go | 6 ++++++ ca/tls_options.go | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 577e4aaa..fd91a0fe 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -87,6 +87,9 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio return nil, err } + // Make sure the tlsConfig have all supported roots + options = append(options, AddRootsToClientCAs(), AddRootsToRootCAs()) + tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) if err != nil { return nil, err @@ -130,6 +133,9 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (* return nil, err } + // Make sure the tlsConfig have all supported roots + options = append(options, AddRootsToRootCAs()) + transport, err := client.Transport(ctx, sign, pk, options...) if err != nil { return nil, err diff --git a/ca/tls_options.go b/ca/tls_options.go index 26eae156..2414b313 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -95,6 +95,8 @@ func AddClientCA(cert *x509.Certificate) TLSOption { // AddRootsToRootCAs does a roots request and adds to the tls.Config RootCAs all // the certificates in the response. RootCAs defines the set of root certificate // authorities that clients use when verifying server certificates. +// +// BootstrapServer and BootstrapClient methods include this option by default. func AddRootsToRootCAs() TLSOption { return func(c *Client, tr http.RoundTripper, config *tls.Config) error { certs, err := c.Roots(tr) @@ -115,6 +117,8 @@ func AddRootsToRootCAs() TLSOption { // all the certificates in the response. ClientCAs defines the set of root // certificate authorities that servers use if required to verify a client // certificate by the policy in ClientAuth. +// +// BootstrapServer method includes this option by default. func AddRootsToClientCAs() TLSOption { return func(c *Client, tr http.RoundTripper, config *tls.Config) error { certs, err := c.Roots(tr) From 10aaece1b0bfe83acfe78d12ed744a3d8cdeec95 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Jan 2019 13:20:28 -0800 Subject: [PATCH 06/20] Update root certificates on renew. --- ca/bootstrap.go | 6 +- ca/tls.go | 23 ++++- ca/tls_options.go | 172 +++++++++++++++++++++++++-------- ca/tls_options_test.go | 214 ++++++++++++++++++++++++++++++++--------- 4 files changed, 324 insertions(+), 91 deletions(-) diff --git a/ca/bootstrap.go b/ca/bootstrap.go index fd91a0fe..8989b3c0 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -87,8 +87,8 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio return nil, err } - // Make sure the tlsConfig have all supported roots - options = append(options, AddRootsToClientCAs(), AddRootsToRootCAs()) + // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs + options = append(options, AddRootsToCAs()) tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) if err != nil { @@ -133,7 +133,7 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (* return nil, err } - // Make sure the tlsConfig have all supported roots + // Make sure the tlsConfig have all supported roots on RootCAs options = append(options, AddRootsToRootCAs()) transport, err := client.Transport(ctx, sign, pk, options...) diff --git a/ca/tls.go b/ca/tls.go index 22eb667b..e8ff0a9e 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -41,7 +41,11 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil { + tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig) + if err != nil { + return nil, err + } + if err := tlsCtx.apply(options); err != nil { return nil, err } @@ -50,7 +54,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, if err != nil { return nil, err } - renewer.RenewCertificate = getRenewFunc(c, tr, pk) + renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Start renewer renewer.RunContext(ctx) @@ -87,7 +91,11 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil { + tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig) + if err != nil { + return nil, err + } + if err := tlsCtx.apply(options); err != nil { return nil, err } @@ -96,7 +104,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, if err != nil { return nil, err } - renewer.RenewCertificate = getRenewFunc(c, tr, pk) + renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Start renewer renewer.RunContext(ctx) @@ -238,8 +246,13 @@ func getPEM(i interface{}) ([]byte, error) { return pem.EncodeToMemory(block), nil } -func getRenewFunc(client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc { +func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc { return func() (*tls.Certificate, error) { + // Get updated list of roots + if err := ctx.applyRenew(tr); err != nil { + return nil, err + } + // Get new certificate sign, err := client.Renew(tr) if err != nil { return nil, err diff --git a/ca/tls_options.go b/ca/tls_options.go index 2414b313..dc15ab18 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -10,18 +10,42 @@ import ( ) // TLSOption defines the type of a function that modifies a tls.Config. -type TLSOption func(c *Client, tr http.RoundTripper, config *tls.Config) error +type TLSOption func(ctx *TLSOptionCtx) error -// setTLSOptions takes one or more option function and applies them in order to -// a tls.Config. -func setTLSOptions(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config, options []TLSOption) error { +// TLSOptionCtx is the context modified on TLSOption methods. +type TLSOptionCtx struct { + Client *Client + Transport http.RoundTripper + Config *tls.Config + OnRenewFunc []TLSOption +} + +// newTLSOptionCtx creates the TLSOption context. +func newTLSOptionCtx(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config) (*TLSOptionCtx, error) { tr, err := getTLSOptionsTransport(sign, pk) if err != nil { - return err + return nil, err } + return &TLSOptionCtx{ + Client: c, + Transport: tr, + Config: config, + }, nil +} - for _, opt := range options { - if err := opt(c, tr, config); err != nil { +func (ctx *TLSOptionCtx) apply(options []TLSOption) error { + for _, fn := range options { + if err := fn(ctx); err != nil { + return err + } + } + return nil +} + +func (ctx *TLSOptionCtx) applyRenew(tr http.RoundTripper) error { + ctx.Transport = tr + for _, fn := range ctx.OnRenewFunc { + if err := fn(ctx); err != nil { return err } } @@ -51,8 +75,8 @@ func getTLSOptionsTransport(sign *api.SignResponse, pk crypto.PrivateKey) (http. // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // a valid TLS client certificate. This is the default option for mTLS servers. func RequireAndVerifyClientCert() TLSOption { - return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { - config.ClientAuth = tls.RequireAndVerifyClientCert + return func(ctx *TLSOptionCtx) error { + ctx.Config.ClientAuth = tls.RequireAndVerifyClientCert return nil } } @@ -60,8 +84,8 @@ func RequireAndVerifyClientCert() TLSOption { // VerifyClientCertIfGiven is a tls.Config option used on on servers to validate // a TLS client certificate if it is provided. It does not requires a certificate. func VerifyClientCertIfGiven() TLSOption { - return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { - config.ClientAuth = tls.VerifyClientCertIfGiven + return func(ctx *TLSOptionCtx) error { + ctx.Config.ClientAuth = tls.VerifyClientCertIfGiven return nil } } @@ -70,11 +94,11 @@ func VerifyClientCertIfGiven() TLSOption { // defines the set of root certificate authorities that clients use when // verifying server certificates. func AddRootCA(cert *x509.Certificate) TLSOption { - return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { - if config.RootCAs == nil { - config.RootCAs = x509.NewCertPool() + return func(ctx *TLSOptionCtx) error { + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() } - config.RootCAs.AddCert(cert) + ctx.Config.RootCAs.AddCert(cert) return nil } } @@ -83,11 +107,11 @@ func AddRootCA(cert *x509.Certificate) TLSOption { // defines the set of root certificate authorities that servers use if required // to verify a client certificate by the policy in ClientAuth. func AddClientCA(cert *x509.Certificate) TLSOption { - return func(_ *Client, _ http.RoundTripper, config *tls.Config) error { - if config.ClientCAs == nil { - config.ClientCAs = x509.NewCertPool() + return func(ctx *TLSOptionCtx) error { + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() } - config.ClientCAs.AddCert(cert) + ctx.Config.ClientCAs.AddCert(cert) return nil } } @@ -98,19 +122,23 @@ func AddClientCA(cert *x509.Certificate) TLSOption { // // BootstrapServer and BootstrapClient methods include this option by default. func AddRootsToRootCAs() TLSOption { - return func(c *Client, tr http.RoundTripper, config *tls.Config) error { - certs, err := c.Roots(tr) + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Roots(ctx.Transport) if err != nil { return err } - if config.RootCAs == nil { - config.RootCAs = x509.NewCertPool() + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() } for _, cert := range certs.Certificates { - config.RootCAs.AddCert(cert.Certificate) + ctx.Config.RootCAs.AddCert(cert.Certificate) } return nil } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } } // AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs @@ -120,38 +148,46 @@ func AddRootsToRootCAs() TLSOption { // // BootstrapServer method includes this option by default. func AddRootsToClientCAs() TLSOption { - return func(c *Client, tr http.RoundTripper, config *tls.Config) error { - certs, err := c.Roots(tr) + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Roots(ctx.Transport) if err != nil { return err } - if config.ClientCAs == nil { - config.ClientCAs = x509.NewCertPool() + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() } for _, cert := range certs.Certificates { - config.ClientCAs.AddCert(cert.Certificate) + ctx.Config.ClientCAs.AddCert(cert.Certificate) } return nil } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } } // AddFederationToRootCAs does a federation request and adds to the tls.Config // RootCAs all the certificates in the response. RootCAs defines the set of root // certificate authorities that clients use when verifying server certificates. func AddFederationToRootCAs() TLSOption { - return func(c *Client, tr http.RoundTripper, config *tls.Config) error { - certs, err := c.Federation(tr) + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Federation(ctx.Transport) if err != nil { return err } - if config.RootCAs == nil { - config.RootCAs = x509.NewCertPool() + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() } for _, cert := range certs.Certificates { - config.RootCAs.AddCert(cert.Certificate) + ctx.Config.RootCAs.AddCert(cert.Certificate) } return nil } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } } // AddFederationToClientCAs does a federation request and adds to the tls.Config @@ -159,17 +195,75 @@ func AddFederationToRootCAs() TLSOption { // root certificate authorities that servers use if required to verify a client // certificate by the policy in ClientAuth. func AddFederationToClientCAs() TLSOption { - return func(c *Client, tr http.RoundTripper, config *tls.Config) error { - certs, err := c.Federation(tr) + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Federation(ctx.Transport) if err != nil { return err } - if config.ClientCAs == nil { - config.ClientCAs = x509.NewCertPool() + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() } for _, cert := range certs.Certificates { - config.ClientCAs.AddCert(cert.Certificate) + ctx.Config.ClientCAs.AddCert(cert.Certificate) } return nil } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } +} + +// AddRootsToCAs does a roots request and adds the resulting certs to the +// tls.Config RootCAs and ClientCAs. Combines the functionality of +// AddRootsToRootCAs and AddRootsToClientCAs. +func AddRootsToCAs() TLSOption { + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Roots(ctx.Transport) + if err != nil { + return err + } + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.ClientCAs.AddCert(cert.Certificate) + ctx.Config.RootCAs.AddCert(cert.Certificate) + } + return nil + } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } +} + +// AddFederationToCAs does a federation request and adds the resulting certs to the +// tls.Config RootCAs and ClientCAs. Combines the functionality of +// AddFederationToRootCAs and AddFederationToClientCAs. +func AddFederationToCAs() TLSOption { + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Federation(ctx.Transport) + if err != nil { + return err + } + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.ClientCAs.AddCert(cert.Certificate) + ctx.Config.RootCAs.AddCert(cert.Certificate) + } + return nil + } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } } diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 07068ca4..6c0e2b3b 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -10,33 +10,36 @@ import ( "testing" ) -func Test_setTLSOptions(t *testing.T) { +func TestTLSOptionCtx_apply(t *testing.T) { fail := func() TLSOption { - return func(c *Client, tr http.RoundTripper, config *tls.Config) error { + return func(ctx *TLSOptionCtx) error { return fmt.Errorf("an error") } } + + type fields struct { + Config *tls.Config + } type args struct { - c *tls.Config options []TLSOption } tests := []struct { name string + fields fields args args wantErr bool }{ - {"ok", args{&tls.Config{}, []TLSOption{RequireAndVerifyClientCert()}}, false}, - {"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false}, - {"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true}, + {"ok", fields{&tls.Config{}}, args{[]TLSOption{RequireAndVerifyClientCert()}}, false}, + {"ok", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven()}}, false}, + {"fail", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven(), fail()}}, true}, } - - ca := startCATestServer() - defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := setTLSOptions(client, sr, pk, tt.args.c, tt.args.options); (err != nil) != tt.wantErr { - t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr) + ctx := &TLSOptionCtx{ + Config: tt.fields.Config, + } + if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr { + t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr) } }) } @@ -51,13 +54,15 @@ func TestRequireAndVerifyClientCert(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := RequireAndVerifyClientCert()(nil, nil, got); err != nil { + ctx := &TLSOptionCtx{ + Config: &tls.Config{}, + } + if err := RequireAndVerifyClientCert()(ctx); err != nil { t.Errorf("RequireAndVerifyClientCert() error = %v", err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("RequireAndVerifyClientCert() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("RequireAndVerifyClientCert() = %v, want %v", ctx.Config, tt.want) } }) } @@ -72,13 +77,15 @@ func TestVerifyClientCertIfGiven(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := VerifyClientCertIfGiven()(nil, nil, got); err != nil { + ctx := &TLSOptionCtx{ + Config: &tls.Config{}, + } + if err := VerifyClientCertIfGiven()(ctx); err != nil { t.Errorf("VerifyClientCertIfGiven() error = %v", err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("VerifyClientCertIfGiven() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("VerifyClientCertIfGiven() = %v, want %v", ctx.Config, tt.want) } }) } @@ -101,13 +108,15 @@ func TestAddRootCA(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := AddRootCA(tt.args.cert)(nil, nil, got); err != nil { + ctx := &TLSOptionCtx{ + Config: &tls.Config{}, + } + if err := AddRootCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddRootCA() error = %v", err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("AddRootCA() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddRootCA() = %v, want %v", ctx.Config, tt.want) } }) } @@ -130,13 +139,15 @@ func TestAddClientCA(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := AddClientCA(tt.args.cert)(nil, nil, got); err != nil { + ctx := &TLSOptionCtx{ + Config: &tls.Config{}, + } + if err := AddClientCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddClientCA() error = %v", err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("AddClientCA() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddClientCA() = %v, want %v", ctx.Config, tt.want) } }) } @@ -172,13 +183,17 @@ func TestAddRootsToRootCAs(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := AddRootsToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + ctx := &TLSOptionCtx{ + Client: client, + Config: &tls.Config{}, + Transport: tt.tr, + } + if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("AddRootsToRootCAs() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want) } }) } @@ -214,13 +229,17 @@ func TestAddRootsToClientCAs(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := AddRootsToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + ctx := &TLSOptionCtx{ + Client: client, + Config: &tls.Config{}, + Transport: tt.tr, + } + if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("AddRootsToClientCAs() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want) } }) } @@ -263,13 +282,17 @@ func TestAddFederationToRootCAs(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := AddFederationToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + ctx := &TLSOptionCtx{ + Client: client, + Config: &tls.Config{}, + Transport: tt.tr, + } + if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("AddFederationToRootCAs() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddFederationToRootCAs() = %v, want %v", ctx.Config, tt.want) } }) } @@ -312,13 +335,116 @@ func TestAddFederationToClientCAs(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := AddFederationToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr { + ctx := &TLSOptionCtx{ + Client: client, + Config: &tls.Config{}, + Transport: tt.tr, + } + if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("AddFederationToClientCAs() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddFederationToClientCAs() = %v, want %v", ctx.Config, tt.want) + } + }) + } +} + +func TestAddRootsToCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + cert := parseCertificate(string(root)) + pool := x509.NewCertPool() + pool.AddCert(cert) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &TLSOptionCtx{ + Client: client, + Config: &tls.Config{}, + Transport: tt.tr, + } + if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr { + t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want) + } + }) + } +} + +func TestAddFederationToCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, sr, pk := signDuration(ca, "127.0.0.1", 0) + tr, err := getTLSOptionsTransport(sr, pk) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") + if err != nil { + t.Fatal(err) + } + + crt1 := parseCertificate(string(root)) + crt2 := parseCertificate(string(federated)) + pool := x509.NewCertPool() + pool.AddCert(crt1) + pool.AddCert(crt2) + + tests := []struct { + name string + tr http.RoundTripper + want *tls.Config + wantErr bool + }{ + {"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, + {"fail", http.DefaultTransport, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &TLSOptionCtx{ + Client: client, + Config: &tls.Config{}, + Transport: tt.tr, + } + if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr { + t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddFederationToCAs() = %v, want %v", ctx.Config, tt.want) } }) } From 25ddbaedffb54985841009110a93cd5f5aefed60 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Jan 2019 17:24:11 -0800 Subject: [PATCH 07/20] Allow to customize the minimal cert duration for tests. --- ca/renew.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ca/renew.go b/ca/renew.go index aee9dd13..44234781 100644 --- a/ca/renew.go +++ b/ca/renew.go @@ -14,6 +14,8 @@ import ( // certificate. type RenewFunc func() (*tls.Certificate, error) +var minCertDuration = time.Minute + // TLSRenewer automatically renews a tls certificate using a RenewFunc. type TLSRenewer struct { sync.RWMutex @@ -58,8 +60,8 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption } period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore) - if period < time.Minute { - return nil, errors.Errorf("period must be greater than or equal to 1 Minute, but got %v.", period) + if period < minCertDuration { + return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period) } // By default we will try to renew the cert before 2/3 of the validity // period have expired. From af9e6488fcad1a5bc80b1423a4e014d67b8c65fc Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Jan 2019 17:35:00 -0800 Subject: [PATCH 08/20] Make the renew test shorter. --- ca/tls_test.go | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/ca/tls_test.go b/ca/tls_test.go index 799b496a..6c6a5291 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -20,7 +20,7 @@ import ( "github.com/smallstep/certificates/authority" "github.com/smallstep/cli/crypto/randutil" stepJOSE "github.com/smallstep/cli/jose" - "gopkg.in/square/go-jose.v2" + jose "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" ) @@ -242,16 +242,15 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { } func TestClient_GetServerTLSConfig_renew(t *testing.T) { - if testing.Short() { - t.Skip("skipping test in short mode.") - } + reset := setMinCertDuration(1 * time.Second) + defer reset() // Start CA ca := startCATestServer() defer ca.Close() clientDomain := "test.domain" - client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute) + client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second) // Start mTLS server ctx, cancel := context.WithCancel(context.Background()) @@ -274,13 +273,13 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { defer srvTLS.Close() // Transport - client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute) + client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) tr1, err := client.Transport(context.Background(), sr, pk) if err != nil { t.Fatalf("Client.Transport() error = %v", err) } // Transport with tlsConfig - client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute) + client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) if err != nil { t.Fatalf("Client.GetClientTLSConfig() error = %v", err) @@ -367,9 +366,9 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { t.Errorf("number of fingerprints unexpected, got %d, want 2", l) } - // Wait for renewal 40s == 1m-1m/3 - log.Printf("Sleeping for %s ...\n", 40*time.Second) - time.Sleep(40 * time.Second) + // Wait for renewal + log.Printf("Sleeping for %s ...\n", 5*time.Second) + time.Sleep(5 * time.Second) for _, tt := range tests { t.Run("renewed "+tt.name, func(t *testing.T) { From f99ae9da93bff56b286f0c2551d00c1aa46381b6 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Jan 2019 17:55:32 -0800 Subject: [PATCH 09/20] Add root rotation test. --- ca/bootstrap_test.go | 130 ++++++++++++++++++++++++ ca/testdata/ca.json | 3 +- ca/testdata/rotate-ca-0.json | 46 +++++++++ ca/testdata/rotate-ca-1.json | 46 +++++++++ ca/testdata/rotate-ca-2.json | 46 +++++++++ ca/testdata/rotate-ca-3.json | 46 +++++++++ ca/testdata/rotated/intermediate_ca.crt | 12 +++ ca/testdata/rotated/intermediate_ca_key | 8 ++ ca/testdata/rotated/root_ca.crt | 11 ++ ca/testdata/rotated/root_ca_key | 8 ++ 10 files changed, 354 insertions(+), 2 deletions(-) create mode 100644 ca/testdata/rotate-ca-0.json create mode 100644 ca/testdata/rotate-ca-1.json create mode 100644 ca/testdata/rotate-ca-2.json create mode 100644 ca/testdata/rotate-ca-3.json create mode 100644 ca/testdata/rotated/intermediate_ca.crt create mode 100644 ca/testdata/rotated/intermediate_ca_key create mode 100644 ca/testdata/rotated/root_ca.crt create mode 100644 ca/testdata/rotated/root_ca_key diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 241827c6..00bef552 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -3,12 +3,15 @@ package ca import ( "context" "crypto/tls" + "fmt" + "net" "net/http" "net/http/httptest" "reflect" "testing" "time" + "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" @@ -18,6 +21,24 @@ import ( "gopkg.in/square/go-jose.v2/jwt" ) +func newLocalListener() net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("failed to listen on a port: %v", err)) + } + } + return l +} + +func setMinCertDuration(d time.Duration) func() { + tmp := minCertDuration + minCertDuration = 1 * time.Second + return func() { + minCertDuration = tmp + } +} + func startCABootstrapServer() *httptest.Server { config, err := authority.LoadConfiguration("testdata/ca.json") if err != nil { @@ -267,3 +288,112 @@ func TestBootstrapClient(t *testing.T) { }) } } + +func TestBootstrapClientRotation(t *testing.T) { + reset := setMinCertDuration(1 * time.Second) + defer reset() + + // Configuration with current root + config, err := authority.LoadConfiguration("testdata/rotate-ca-0.json") + if err != nil { + panic(err) + } + + // Get local address + listener := newLocalListener() + config.Address = listener.Addr().String() + srvURL := "https://" + listener.Addr().String() + + // Start CA server + ca, err := New(config) + if err != nil { + panic(err) + } + go func() { + ca.srv.Serve(listener) + }() + defer ca.Stop() + time.Sleep(1 * time.Second) + + // doTest does a request that requires mTLS + doTest := func(client *http.Client) error { + resp, err := client.Get(srvURL + "/roots") + if err != nil { + return errors.New("client.Get() failed getting roots") + } + var roots api.RootsResponse + if err := readJSON(resp.Body, &roots); err != nil { + return errors.Errorf("client.Get() error reading response: %v", err) + } + return nil + } + + // Create bootstrap client + token := generateBootstrapToken(srvURL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + client, err := BootstrapClient(context.Background(), token) + if err != nil { + t.Errorf("BootstrapClient() error = %v", err) + return + } + + // Test with default root + if err := doTest(client); err != nil { + t.Errorf("Test with rotate-ca-0.json failed: %v", err) + } + + // wait for renew + time.Sleep(5 * time.Second) + + // Reload with configuration with current and future root + ca.opts.configFile = "testdata/rotate-ca-1.json" + if err := doReload(ca); err != nil { + t.Errorf("ca.Reload() error = %v", err) + return + } + if err := doTest(client); err != nil { + t.Errorf("Test with rotate-ca-1.json failed: %v", err) + } + + // wait for renew + time.Sleep(5 * time.Second) + + // Reload with new and old root + ca.opts.configFile = "testdata/rotate-ca-2.json" + if err := doReload(ca); err != nil { + t.Errorf("ca.Reload() error = %v", err) + return + } + if err := doTest(client); err != nil { + t.Errorf("Test with rotate-ca-2.json failed: %v", err) + } + + // wait for renew + time.Sleep(5 * time.Second) + + // Reload with pnly the new root + ca.opts.configFile = "testdata/rotate-ca-3.json" + if err := doReload(ca); err != nil { + t.Errorf("ca.Reload() error = %v", err) + return + } + if err := doTest(client); err != nil { + t.Errorf("Test with rotate-ca-3.json failed: %v", err) + } +} + +// doReload uses the reload implementation but overwrites the new address with +// the one being used. +func doReload(ca *CA) error { + config, err := authority.LoadConfiguration(ca.opts.configFile) + if err != nil { + return errors.Wrap(err, "error reloading ca") + } + + newCA, err := New(config, WithPassword(ca.opts.password), WithConfigFile(ca.opts.configFile)) + if err != nil { + return errors.Wrap(err, "error reloading ca") + } + // Use same address in new server + newCA.srv.Addr = ca.srv.Addr + return ca.srv.Reload(newCA.srv) +} diff --git a/ca/testdata/ca.json b/ca/testdata/ca.json index f5484a7c..f29f24c6 100644 --- a/ca/testdata/ca.json +++ b/ca/testdata/ca.json @@ -18,7 +18,6 @@ ] }, "authority": { - "minCertDuration": "1m", "provisioners": [ { "name": "max", @@ -73,7 +72,7 @@ "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" }, "claims": { - "minTLSCertDuration": "30s" + "minTLSCertDuration": "1s" } }, { "name": "mariano", diff --git a/ca/testdata/rotate-ca-0.json b/ca/testdata/rotate-ca-0.json new file mode 100644 index 00000000..20dd603a --- /dev/null +++ b/ca/testdata/rotate-ca-0.json @@ -0,0 +1,46 @@ +{ + "root": "testdata/secrets/root_ca.crt", + "crt": "testdata/secrets/intermediate_ca.crt", + "key": "testdata/secrets/intermediate_ca_key", + "password": "password", + "address": "127.0.0.1:0", + "dnsNames": ["127.0.0.1"], + "logger": {"format": "text"}, + "tls": { + "minVersion": 1.2, + "maxVersion": 1.2, + "renegotiation": false, + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + ] + }, + "authority": { + "provisioners": [ + { + "name": "mariano", + "type": "jwk", + "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", + "key": { + "use": "sig", + "kty": "EC", + "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", + "crv": "P-256", + "alg": "ES256", + "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", + "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" + }, + "claims": { + "minTLSCertDuration": "1s", + "defaultTLSCertDuration": "5s" + } + } + ], + "template": { + "country": "US", + "locality": "San Francisco", + "organization": "Smallstep" + } + } +} diff --git a/ca/testdata/rotate-ca-1.json b/ca/testdata/rotate-ca-1.json new file mode 100644 index 00000000..b038f694 --- /dev/null +++ b/ca/testdata/rotate-ca-1.json @@ -0,0 +1,46 @@ +{ + "root": ["testdata/secrets/root_ca.crt", "testdata/rotated/root_ca.crt"], + "crt": "testdata/secrets/intermediate_ca.crt", + "key": "testdata/secrets/intermediate_ca_key", + "password": "password", + "address": "127.0.0.1:0", + "dnsNames": ["127.0.0.1"], + "logger": {"format": "text"}, + "tls": { + "minVersion": 1.2, + "maxVersion": 1.2, + "renegotiation": false, + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + ] + }, + "authority": { + "provisioners": [ + { + "name": "mariano", + "type": "jwk", + "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", + "key": { + "use": "sig", + "kty": "EC", + "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", + "crv": "P-256", + "alg": "ES256", + "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", + "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" + }, + "claims": { + "minTLSCertDuration": "1s", + "defaultTLSCertDuration": "5s" + } + } + ], + "template": { + "country": "US", + "locality": "San Francisco", + "organization": "Smallstep" + } + } +} diff --git a/ca/testdata/rotate-ca-2.json b/ca/testdata/rotate-ca-2.json new file mode 100644 index 00000000..7ec965d0 --- /dev/null +++ b/ca/testdata/rotate-ca-2.json @@ -0,0 +1,46 @@ +{ + "root": ["testdata/rotated/root_ca.crt", "testdata/secrets/root_ca.crt"], + "crt": "testdata/rotated/intermediate_ca.crt", + "key": "testdata/rotated/intermediate_ca_key", + "password": "asdf", + "address": "127.0.0.1:0", + "dnsNames": ["127.0.0.1"], + "logger": {"format": "text"}, + "tls": { + "minVersion": 1.2, + "maxVersion": 1.2, + "renegotiation": false, + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + ] + }, + "authority": { + "provisioners": [ + { + "name": "mariano", + "type": "jwk", + "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", + "key": { + "use": "sig", + "kty": "EC", + "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", + "crv": "P-256", + "alg": "ES256", + "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", + "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" + }, + "claims": { + "minTLSCertDuration": "1s", + "defaultTLSCertDuration": "5s" + } + } + ], + "template": { + "country": "US", + "locality": "San Francisco", + "organization": "Smallstep" + } + } +} diff --git a/ca/testdata/rotate-ca-3.json b/ca/testdata/rotate-ca-3.json new file mode 100644 index 00000000..968da6ba --- /dev/null +++ b/ca/testdata/rotate-ca-3.json @@ -0,0 +1,46 @@ +{ + "root": "testdata/rotated/root_ca.crt", + "crt": "testdata/rotated/intermediate_ca.crt", + "key": "testdata/rotated/intermediate_ca_key", + "password": "asdf", + "address": "127.0.0.1:0", + "dnsNames": ["127.0.0.1"], + "logger": {"format": "text"}, + "tls": { + "minVersion": 1.2, + "maxVersion": 1.2, + "renegotiation": false, + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + ] + }, + "authority": { + "provisioners": [ + { + "name": "mariano", + "type": "jwk", + "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", + "key": { + "use": "sig", + "kty": "EC", + "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", + "crv": "P-256", + "alg": "ES256", + "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", + "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" + }, + "claims": { + "minTLSCertDuration": "1s", + "defaultTLSCertDuration": "5s" + } + } + ], + "template": { + "country": "US", + "locality": "San Francisco", + "organization": "Smallstep" + } + } +} diff --git a/ca/testdata/rotated/intermediate_ca.crt b/ca/testdata/rotated/intermediate_ca.crt new file mode 100644 index 00000000..338ebb22 --- /dev/null +++ b/ca/testdata/rotated/intermediate_ca.crt @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBxTCCAWugAwIBAgIQLIY6MR/1fBRQY4ZTTsPAJjAKBggqhkjOPQQDAjAcMRow +GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTAxMDcyMDExMzBaFw0yOTAx +MDQyMDExMzBaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew +WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARgtjL/KLNpdq81YYWaek1lrkPM/QF1 +m+ujwv5jya21fAXljdBLh6m2xco1GPfwPBbwUGlNOdEqE9Nq3Qx3ngPKo4GGMIGD +MA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw +EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUqixeZ/K1HW9N6SVw7ONya98S +u8UwHwYDVR0jBBgwFoAUgIzlCLxh/RlwEany4JQHOorLAIEwCgYIKoZIzj0EAwID +SAAwRQIgdGX6lxThrKlt3v+3HJZlaWdmoeQ3vYwpJb9uHExZdVYCIQDCxsdI8EnB +bxjnJscbT4zvqVsq6AmycdbFwgy8RIeVzg== +-----END CERTIFICATE----- diff --git a/ca/testdata/rotated/intermediate_ca_key b/ca/testdata/rotated/intermediate_ca_key new file mode 100644 index 00000000..6c3b1622 --- /dev/null +++ b/ca/testdata/rotated/intermediate_ca_key @@ -0,0 +1,8 @@ +-----BEGIN EC PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,7dcc0a8c1d73c8d438184e0928875329 + +r6yrQrHg6zBZRSjQpe8RzyQALEfiT3/8lMvvPu3BX6yign5skMfCVMXZhzbmAwmR +BJBIX+5hkudR2VN+hrsOyuU7FvIk4gx2c8buIlFObfYXIml0mpuThfm52ciAtOTE +S0hkfYvPcOAjzaDZ+8Po/mYhkODgyvijogn4ioTF/Ss= +-----END EC PRIVATE KEY----- diff --git a/ca/testdata/rotated/root_ca.crt b/ca/testdata/rotated/root_ca.crt new file mode 100644 index 00000000..87cd5650 --- /dev/null +++ b/ca/testdata/rotated/root_ca.crt @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa +MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw +MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG +SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL +jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB +/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR +qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z +YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc +Sw== +-----END CERTIFICATE----- diff --git a/ca/testdata/rotated/root_ca_key b/ca/testdata/rotated/root_ca_key new file mode 100644 index 00000000..c92f587e --- /dev/null +++ b/ca/testdata/rotated/root_ca_key @@ -0,0 +1,8 @@ +-----BEGIN EC PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,8ce79d28601b9809905ef7c362a20749 + +H+pTTL3B5fLYycgHLxFOW0fZsayr7Y+BW8THKf12h8dk0/eOE1wNoX2TuMtpbZgO +lMJdFPL+SAPCCmuZOZIcQDejRHVcYBq1wvrrnw/yfVawXC4xze+J4Y+q0J2WY+rM +xcLGlEOIRZkvdDVGmSitEZBl0Ibk0p9tG++7QGqAvnk= +-----END EC PRIVATE KEY----- From 8510e25b3bcdad3a38a858953be3faace6350633 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Jan 2019 18:48:15 -0800 Subject: [PATCH 10/20] Add test with bootstrap server. --- ca/bootstrap_test.go | 68 +++++++++++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 00bef552..452f5878 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "io/ioutil" "net" "net/http" "net/http/httptest" @@ -289,25 +290,25 @@ func TestBootstrapClient(t *testing.T) { } } -func TestBootstrapClientRotation(t *testing.T) { +func TestBootstrapClientServerRotation(t *testing.T) { reset := setMinCertDuration(1 * time.Second) defer reset() // Configuration with current root config, err := authority.LoadConfiguration("testdata/rotate-ca-0.json") if err != nil { - panic(err) + t.Fatal(err) } // Get local address listener := newLocalListener() config.Address = listener.Addr().String() - srvURL := "https://" + listener.Addr().String() + caURL := "https://" + listener.Addr().String() // Start CA server ca, err := New(config) if err != nil { - panic(err) + t.Fatal(err) } go func() { ca.srv.Serve(listener) @@ -315,27 +316,62 @@ func TestBootstrapClientRotation(t *testing.T) { defer ca.Stop() time.Sleep(1 * time.Second) - // doTest does a request that requires mTLS - doTest := func(client *http.Client) error { - resp, err := client.Get(srvURL + "/roots") - if err != nil { - return errors.New("client.Get() failed getting roots") - } - var roots api.RootsResponse - if err := readJSON(resp.Body, &roots); err != nil { - return errors.Errorf("client.Get() error reading response: %v", err) - } - return nil + // Create bootstrap server + token := generateBootstrapToken(caURL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + server, err := BootstrapServer(context.Background(), token, &http.Server{ + Addr: ":0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("ok")) + }), + }, RequireAndVerifyClientCert()) + if err != nil { + t.Fatal(err) } + listener = newLocalListener() + srvURL := "https://" + listener.Addr().String() + go func() { + server.ServeTLS(listener, "", "") + }() + defer server.Close() // Create bootstrap client - token := generateBootstrapToken(srvURL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") client, err := BootstrapClient(context.Background(), token) if err != nil { t.Errorf("BootstrapClient() error = %v", err) return } + // doTest does a request that requires mTLS + doTest := func(client *http.Client) error { + // test with ca + resp, err := client.Get(caURL + "/roots") + if err != nil { + return errors.Wrapf(err, "client.Get(%s) failed", caURL+"/roots") + } + var roots api.RootsResponse + if err := readJSON(resp.Body, &roots); err != nil { + return errors.Wrap(err, "client.Get() error reading response") + } + if len(roots.Certificates) == 0 { + return errors.New("client.Get() error not certificates found") + } + // test with bootstrap server + resp, err = client.Get(srvURL) + if err != nil { + return errors.Wrapf(err, "client.Get(%s) failed", srvURL) + } + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrap(err, "client.Get() error reading response") + } + if string(b) != "ok" { + return errors.New("client.Get() unexpected response found") + } + return nil + } + // Test with default root if err := doTest(client); err != nil { t.Errorf("Test with rotate-ca-0.json failed: %v", err) From 61165230556d12a6b681f36864ef19749a8db261 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 10 Jan 2019 10:57:06 -0800 Subject: [PATCH 11/20] Fix random order in tests. --- ca/tls_options_test.go | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 6c0e2b3b..df3ee62a 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net/http" "reflect" + "sort" "testing" ) @@ -292,7 +293,10 @@ func TestAddFederationToRootCAs(t *testing.T) { return } if !reflect.DeepEqual(ctx.Config, tt.want) { - t.Errorf("AddFederationToRootCAs() = %v, want %v", ctx.Config, tt.want) + // Federated roots are randomly sorted + if !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) || ctx.Config.ClientCAs != nil { + t.Errorf("AddFederationToRootCAs() = %v, want %v", ctx.Config, tt.want) + } } }) } @@ -345,7 +349,10 @@ func TestAddFederationToClientCAs(t *testing.T) { return } if !reflect.DeepEqual(ctx.Config, tt.want) { - t.Errorf("AddFederationToClientCAs() = %v, want %v", ctx.Config, tt.want) + // Federated roots are randomly sorted + if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || ctx.Config.RootCAs != nil { + t.Errorf("AddFederationToClientCAs() = %v, want %v", ctx.Config, tt.want) + } } }) } @@ -444,8 +451,27 @@ func TestAddFederationToCAs(t *testing.T) { return } if !reflect.DeepEqual(ctx.Config, tt.want) { - t.Errorf("AddFederationToCAs() = %v, want %v", ctx.Config, tt.want) + // Federated roots are randomly sorted + if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) { + t.Errorf("AddFederationToCAs() = %v, want %v", ctx.Config, tt.want) + } } }) } } + +func equalPools(a, b *x509.CertPool) bool { + subjects := a.Subjects() + sA := make([]string, len(subjects)) + for i := range subjects { + sA[i] = string(subjects[i]) + } + subjects = b.Subjects() + sB := make([]string, len(subjects)) + for i := range subjects { + sB[i] = string(subjects[i]) + } + sort.Sort(sort.StringSlice(sA)) + sort.Sort(sort.StringSlice(sB)) + return reflect.DeepEqual(sA, sB) +} From 1763ede99d2c9eb57e5693f6b63275069a3d40ac Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 10 Jan 2019 13:19:51 -0800 Subject: [PATCH 12/20] Add tests for new methods. --- authority/root.go | 2 +- authority/root_test.go | 162 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) diff --git a/authority/root.go b/authority/root.go index 98974904..cfd63595 100644 --- a/authority/root.go +++ b/authority/root.go @@ -34,7 +34,7 @@ func (a *Authority) GetRootCertificates() []*x509.Certificate { } // GetRoots returns all the root certificates for this CA. -func (a *Authority) GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error) { +func (a *Authority) GetRoots(peer *x509.Certificate) ([]*x509.Certificate, error) { // Check step provisioner extensions if err := a.authorizeRenewal(peer); err != nil { return nil, err diff --git a/authority/root_test.go b/authority/root_test.go index d9803d8e..9b80cad6 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -1,11 +1,16 @@ package authority import ( + "crypto/x509" "net/http" + "reflect" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/cli/crypto/keys" + "github.com/smallstep/cli/crypto/pemutil" + "github.com/smallstep/cli/crypto/x509util" ) func TestRoot(t *testing.T) { @@ -43,3 +48,160 @@ func TestRoot(t *testing.T) { }) } } + +func TestAuthority_GetRootCertificate(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + want *x509.Certificate + }{ + {"ok", cert}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + if got := a.GetRootCertificate(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetRootCertificate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetRootCertificates(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + want []*x509.Certificate + }{ + {"ok", []*x509.Certificate{cert}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + if got := a.GetRootCertificates(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetRootCertificates() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetRoots(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + a := testAuthority(t) + pub, _, err := keys.GenerateDefaultKeyPair() + assert.FatalError(t, err) + leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, + withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test")) + assert.FatalError(t, err) + crtBytes, err := leaf.CreateCertificate() + assert.FatalError(t, err) + crt, err := x509.ParseCertificate(crtBytes) + assert.FatalError(t, err) + + leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, + withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"), + withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), + ) + assert.FatalError(t, err) + crtFailBytes, err := leafFail.CreateCertificate() + assert.FatalError(t, err) + crtFail, err := x509.ParseCertificate(crtFailBytes) + assert.FatalError(t, err) + + type args struct { + peer *x509.Certificate + } + tests := []struct { + name string + args args + want []*x509.Certificate + wantErr bool + }{ + {"ok", args{crt}, []*x509.Certificate{cert}, false}, + {"fail", args{crtFail}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := a.GetRoots(tt.args.peer) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetRoots() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetFederation(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + a := testAuthority(t) + pub, _, err := keys.GenerateDefaultKeyPair() + assert.FatalError(t, err) + leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, + withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test")) + assert.FatalError(t, err) + crtBytes, err := leaf.CreateCertificate() + assert.FatalError(t, err) + crt, err := x509.ParseCertificate(crtBytes) + assert.FatalError(t, err) + + leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, + withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"), + withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), + ) + assert.FatalError(t, err) + crtFailBytes, err := leafFail.CreateCertificate() + assert.FatalError(t, err) + crtFail, err := x509.ParseCertificate(crtFailBytes) + assert.FatalError(t, err) + + type args struct { + peer *x509.Certificate + } + tests := []struct { + name string + args args + wantFederation []*x509.Certificate + wantErr bool + fn func() + }{ + {"ok", args{crt}, []*x509.Certificate{cert}, false, nil}, + {"fail", args{crtFail}, nil, true, nil}, + {"fail not a certificate", args{crt}, nil, true, func() { + a.certificates.Store("foo", "bar") + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.fn != nil { + tt.fn() + } + gotFederation, err := a.GetFederation(tt.args.peer) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotFederation, tt.wantFederation) { + t.Errorf("Authority.GetFederation() = %v, want %v", gotFederation, tt.wantFederation) + } + }) + } +} From 9adc65febf5b42dfe7f5c8ed5eb91014cbe431b8 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 10 Jan 2019 15:31:40 -0800 Subject: [PATCH 13/20] Add test for newTLSOptionCtx --- ca/tls_options_test.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index df3ee62a..b52d1c89 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -1,6 +1,7 @@ package ca import ( + "crypto" "crypto/tls" "crypto/x509" "fmt" @@ -9,8 +10,37 @@ import ( "reflect" "sort" "testing" + + "github.com/smallstep/certificates/api" ) +func Test_newTLSOptionCtx(t *testing.T) { + client, sign, pk := sign("test.smallstep.com") + type args struct { + c *Client + sign *api.SignResponse + pk crypto.PrivateKey + config *tls.Config + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{client, sign, pk, &tls.Config{}}, false}, + {"fail", args{client, sign, "foo", &tls.Config{}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := newTLSOptionCtx(tt.args.c, tt.args.sign, tt.args.pk, tt.args.config) + if (err != nil) != tt.wantErr { + t.Errorf("newTLSOptionCtx() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + func TestTLSOptionCtx_apply(t *testing.T) { fail := func() TLSOption { return func(ctx *TLSOptionCtx) error { From 518b597535070da9dc58880ba44fb8a3a9e86b6c Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 11 Jan 2019 19:08:08 -0800 Subject: [PATCH 14/20] Remove mTLS client requirement in /roots and /federation --- api/api.go | 24 ++---- api/api_test.go | 22 +++--- authority/root.go | 13 +--- authority/root_test.go | 70 +++--------------- ca/bootstrap_test.go | 17 ++--- ca/client.go | 10 +-- ca/client_test.go | 4 +- ca/tls.go | 18 ++--- ca/tls_options.go | 53 +++---------- ca/tls_options_test.go | 164 +++++++++++++++++++++++++---------------- 10 files changed, 162 insertions(+), 233 deletions(-) diff --git a/api/api.go b/api/api.go index 563e65ea..4bdf1b09 100644 --- a/api/api.go +++ b/api/api.go @@ -25,8 +25,8 @@ type Authority interface { Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error) GetEncryptedKey(kid string) (string, error) - GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error) - GetFederation(peer *x509.Certificate) ([]*x509.Certificate, error) + GetRoots() (federation []*x509.Certificate, err error) + GetFederation() ([]*x509.Certificate, error) } // Certificate wraps a *x509.Certificate and adds the json.Marshaler interface. @@ -334,15 +334,9 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { JSON(w, &ProvisionerKeyResponse{key}) } -// Roots returns all the root certificates for the CA. It requires a valid TLS -// client. +// Roots returns all the root certificates for the CA. func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, BadRequest(errors.New("missing peer certificate"))) - return - } - - roots, err := h.Authority.GetRoots(r.TLS.PeerCertificates[0]) + roots, err := h.Authority.GetRoots() if err != nil { WriteError(w, Forbidden(err)) return @@ -359,15 +353,9 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { }) } -// Federation returns all the public certificates in the federation. It requires -// a valid TLS client. +// Federation returns all the public certificates in the federation. func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, BadRequest(errors.New("missing peer certificate"))) - return - } - - federated, err := h.Authority.GetFederation(r.TLS.PeerCertificates[0]) + federated, err := h.Authority.GetFederation() if err != nil { WriteError(w, Forbidden(err)) return diff --git a/api/api_test.go b/api/api_test.go index 0a9eb120..e60e6ba2 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -392,8 +392,8 @@ type mockAuthority struct { renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error) getEncryptedKey func(kid string) (string, error) - getRoots func(cert *x509.Certificate) ([]*x509.Certificate, error) - getFederation func(cert *x509.Certificate) ([]*x509.Certificate, error) + getRoots func() ([]*x509.Certificate, error) + getFederation func() ([]*x509.Certificate, error) } func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) { @@ -445,16 +445,16 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { return m.ret1.(string), m.err } -func (m *mockAuthority) GetRoots(cert *x509.Certificate) ([]*x509.Certificate, error) { +func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { if m.getFederation != nil { - return m.getRoots(cert) + return m.getRoots() } return m.ret1.([]*x509.Certificate), m.err } -func (m *mockAuthority) GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) { +func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { if m.getFederation != nil { - return m.getFederation(cert) + return m.getFederation() } return m.ret1.([]*x509.Certificate), m.err } @@ -842,9 +842,8 @@ func Test_caHandler_Roots(t *testing.T) { statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, - {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, - {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, - {"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) @@ -889,9 +888,8 @@ func Test_caHandler_Federation(t *testing.T) { statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, - {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, - {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, - {"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) diff --git a/authority/root.go b/authority/root.go index cfd63595..01b3508d 100644 --- a/authority/root.go +++ b/authority/root.go @@ -34,21 +34,12 @@ func (a *Authority) GetRootCertificates() []*x509.Certificate { } // GetRoots returns all the root certificates for this CA. -func (a *Authority) GetRoots(peer *x509.Certificate) ([]*x509.Certificate, error) { - // Check step provisioner extensions - if err := a.authorizeRenewal(peer); err != nil { - return nil, err - } +func (a *Authority) GetRoots() ([]*x509.Certificate, error) { return a.rootX509Certs, nil } // GetFederation returns all the root certificates in the federation. -func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) { - // Check step provisioner extensions - if err := a.authorizeRenewal(peer); err != nil { - return nil, err - } - +func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) { a.certificates.Range(func(k, v interface{}) bool { crt, ok := v.(*x509.Certificate) if !ok { diff --git a/authority/root_test.go b/authority/root_test.go index 9b80cad6..17f25755 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -8,9 +8,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/pemutil" - "github.com/smallstep/cli/crypto/x509util" ) func TestRoot(t *testing.T) { @@ -99,42 +97,17 @@ func TestAuthority_GetRoots(t *testing.T) { t.Fatal(err) } - a := testAuthority(t) - pub, _, err := keys.GenerateDefaultKeyPair() - assert.FatalError(t, err) - leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test")) - assert.FatalError(t, err) - crtBytes, err := leaf.CreateCertificate() - assert.FatalError(t, err) - crt, err := x509.ParseCertificate(crtBytes) - assert.FatalError(t, err) - - leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"), - withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), - ) - assert.FatalError(t, err) - crtFailBytes, err := leafFail.CreateCertificate() - assert.FatalError(t, err) - crtFail, err := x509.ParseCertificate(crtFailBytes) - assert.FatalError(t, err) - - type args struct { - peer *x509.Certificate - } tests := []struct { name string - args args want []*x509.Certificate wantErr bool }{ - {"ok", args{crt}, []*x509.Certificate{cert}, false}, - {"fail", args{crtFail}, nil, true}, + {"ok", []*x509.Certificate{cert}, false}, } for _, tt := range tests { + a := testAuthority(t) t.Run(tt.name, func(t *testing.T) { - got, err := a.GetRoots(tt.args.peer) + got, err := a.GetRoots() if (err != nil) != tt.wantErr { t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr) return @@ -152,49 +125,24 @@ func TestAuthority_GetFederation(t *testing.T) { t.Fatal(err) } - a := testAuthority(t) - pub, _, err := keys.GenerateDefaultKeyPair() - assert.FatalError(t, err) - leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test")) - assert.FatalError(t, err) - crtBytes, err := leaf.CreateCertificate() - assert.FatalError(t, err) - crt, err := x509.ParseCertificate(crtBytes) - assert.FatalError(t, err) - - leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"), - withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), - ) - assert.FatalError(t, err) - crtFailBytes, err := leafFail.CreateCertificate() - assert.FatalError(t, err) - crtFail, err := x509.ParseCertificate(crtFailBytes) - assert.FatalError(t, err) - - type args struct { - peer *x509.Certificate - } tests := []struct { name string - args args wantFederation []*x509.Certificate wantErr bool - fn func() + fn func(a *Authority) }{ - {"ok", args{crt}, []*x509.Certificate{cert}, false, nil}, - {"fail", args{crtFail}, nil, true, nil}, - {"fail not a certificate", args{crt}, nil, true, func() { + {"ok", []*x509.Certificate{cert}, false, nil}, + {"fail", nil, true, func(a *Authority) { a.certificates.Store("foo", "bar") }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) if tt.fn != nil { - tt.fn() + tt.fn(a) } - gotFederation, err := a.GetFederation(tt.args.peer) + gotFederation, err := a.GetFederation() if (err != nil) != tt.wantErr { t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 452f5878..a046fde2 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -3,7 +3,6 @@ package ca import ( "context" "crypto/tls" - "fmt" "io/ioutil" "net" "net/http" @@ -26,7 +25,7 @@ func newLocalListener() net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { - panic(fmt.Sprintf("failed to listen on a port: %v", err)) + panic(errors.Wrap(err, "failed to listen on a port")) } } return l @@ -345,16 +344,16 @@ func TestBootstrapClientServerRotation(t *testing.T) { // doTest does a request that requires mTLS doTest := func(client *http.Client) error { // test with ca - resp, err := client.Get(caURL + "/roots") + resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) if err != nil { - return errors.Wrapf(err, "client.Get(%s) failed", caURL+"/roots") + return errors.Wrap(err, "client.Post() failed") } - var roots api.RootsResponse - if err := readJSON(resp.Body, &roots); err != nil { - return errors.Wrap(err, "client.Get() error reading response") + var renew api.SignResponse + if err := readJSON(resp.Body, &renew); err != nil { + return errors.Wrap(err, "client.Post() error reading response") } - if len(roots.Certificates) == 0 { - return errors.New("client.Get() error not certificates found") + if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil { + return errors.New("client.Post() unexpected response found") } // test with bootstrap server resp, err = client.Get(srvURL) diff --git a/ca/client.go b/ca/client.go index b8aab67f..627cd450 100644 --- a/ca/client.go +++ b/ca/client.go @@ -416,10 +416,9 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) // Roots performs the get roots request to the CA and returns the // api.RootsResponse struct. -func (c *Client) Roots(tr http.RoundTripper) (*api.RootsResponse, error) { +func (c *Client) Roots() (*api.RootsResponse, error) { u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"}) - client := &http.Client{Transport: tr} - resp, err := client.Get(u.String()) + resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } @@ -435,10 +434,9 @@ func (c *Client) Roots(tr http.RoundTripper) (*api.RootsResponse, error) { // Federation performs the get federation request to the CA and returns the // api.FederationResponse struct. -func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) { +func (c *Client) Federation() (*api.FederationResponse, error) { u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"}) - client := &http.Client{Transport: tr} - resp, err := client.Get(u.String()) + resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } diff --git a/ca/client_test.go b/ca/client_test.go index 0ec6324b..d82afa31 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -549,7 +549,7 @@ func TestClient_Roots(t *testing.T) { api.JSON(w, tt.response) }) - got, err := c.Roots(nil) + got, err := c.Roots() if (err != nil) != tt.wantErr { fmt.Printf("%+v", err) t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) @@ -610,7 +610,7 @@ func TestClient_Federation(t *testing.T) { api.JSON(w, tt.response) }) - got, err := c.Federation(nil) + got, err := c.Federation() if (err != nil) != tt.wantErr { fmt.Printf("%+v", err) t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) diff --git a/ca/tls.go b/ca/tls.go index e8ff0a9e..31d1632b 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -41,10 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig) - if err != nil { - return nil, err - } + tlsCtx := newTLSOptionCtx(c, tlsConfig) if err := tlsCtx.apply(options); err != nil { return nil, err } @@ -56,6 +53,9 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) + // Update client transport + c.client.Transport = tr + // Start renewer renewer.RunContext(ctx) return tlsConfig, nil @@ -91,10 +91,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig) - if err != nil { - return nil, err - } + tlsCtx := newTLSOptionCtx(c, tlsConfig) if err := tlsCtx.apply(options); err != nil { return nil, err } @@ -106,6 +103,9 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) + // Update client transport + c.client.Transport = tr + // Start renewer renewer.RunContext(ctx) return tlsConfig, nil @@ -249,7 +249,7 @@ func getPEM(i interface{}) ([]byte, error) { func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc { return func() (*tls.Certificate, error) { // Get updated list of roots - if err := ctx.applyRenew(tr); err != nil { + if err := ctx.applyRenew(); err != nil { return nil, err } // Get new certificate diff --git a/ca/tls_options.go b/ca/tls_options.go index dc15ab18..47e2c627 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -1,12 +1,8 @@ package ca import ( - "crypto" "crypto/tls" "crypto/x509" - "net/http" - - "github.com/smallstep/certificates/api" ) // TLSOption defines the type of a function that modifies a tls.Config. @@ -15,22 +11,16 @@ type TLSOption func(ctx *TLSOptionCtx) error // TLSOptionCtx is the context modified on TLSOption methods. type TLSOptionCtx struct { Client *Client - Transport http.RoundTripper Config *tls.Config OnRenewFunc []TLSOption } // newTLSOptionCtx creates the TLSOption context. -func newTLSOptionCtx(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config) (*TLSOptionCtx, error) { - tr, err := getTLSOptionsTransport(sign, pk) - if err != nil { - return nil, err - } +func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx { return &TLSOptionCtx{ - Client: c, - Transport: tr, - Config: config, - }, nil + Client: c, + Config: config, + } } func (ctx *TLSOptionCtx) apply(options []TLSOption) error { @@ -42,8 +32,7 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error { return nil } -func (ctx *TLSOptionCtx) applyRenew(tr http.RoundTripper) error { - ctx.Transport = tr +func (ctx *TLSOptionCtx) applyRenew() error { for _, fn := range ctx.OnRenewFunc { if err := fn(ctx); err != nil { return err @@ -52,26 +41,6 @@ func (ctx *TLSOptionCtx) applyRenew(tr http.RoundTripper) error { return nil } -// getTLSOptionsTransport is the transport used by TLSOptions. It is used to get -// root certificates using a mTLS connection with the CA. -func getTLSOptionsTransport(sign *api.SignResponse, pk crypto.PrivateKey) (http.RoundTripper, error) { - cert, err := TLSCertificate(sign, pk) - if err != nil { - return nil, err - } - - // Build default transport with fixed certificate - tlsConfig := getDefaultTLSConfig(sign) - tlsConfig.Certificates = []tls.Certificate{*cert} - tlsConfig.PreferServerCipherSuites = true - // Build RootCAs with given root certificate - if pool := getCertPool(sign); pool != nil { - tlsConfig.RootCAs = pool - } - - return getDefaultTransport(tlsConfig) -} - // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // a valid TLS client certificate. This is the default option for mTLS servers. func RequireAndVerifyClientCert() TLSOption { @@ -123,7 +92,7 @@ func AddClientCA(cert *x509.Certificate) TLSOption { // BootstrapServer and BootstrapClient methods include this option by default. func AddRootsToRootCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Roots(ctx.Transport) + certs, err := ctx.Client.Roots() if err != nil { return err } @@ -149,7 +118,7 @@ func AddRootsToRootCAs() TLSOption { // BootstrapServer method includes this option by default. func AddRootsToClientCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Roots(ctx.Transport) + certs, err := ctx.Client.Roots() if err != nil { return err } @@ -172,7 +141,7 @@ func AddRootsToClientCAs() TLSOption { // certificate authorities that clients use when verifying server certificates. func AddFederationToRootCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Federation(ctx.Transport) + certs, err := ctx.Client.Federation() if err != nil { return err } @@ -196,7 +165,7 @@ func AddFederationToRootCAs() TLSOption { // certificate by the policy in ClientAuth. func AddFederationToClientCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Federation(ctx.Transport) + certs, err := ctx.Client.Federation() if err != nil { return err } @@ -219,7 +188,7 @@ func AddFederationToClientCAs() TLSOption { // AddRootsToRootCAs and AddRootsToClientCAs. func AddRootsToCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Roots(ctx.Transport) + certs, err := ctx.Client.Roots() if err != nil { return err } @@ -246,7 +215,7 @@ func AddRootsToCAs() TLSOption { // AddFederationToRootCAs and AddFederationToClientCAs. func AddFederationToCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Federation(ctx.Transport) + certs, err := ctx.Client.Federation() if err != nil { return err } diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index b52d1c89..ceeea7dc 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -1,7 +1,6 @@ package ca import ( - "crypto" "crypto/tls" "crypto/x509" "fmt" @@ -10,32 +9,29 @@ import ( "reflect" "sort" "testing" - - "github.com/smallstep/certificates/api" ) func Test_newTLSOptionCtx(t *testing.T) { - client, sign, pk := sign("test.smallstep.com") + client, err := NewClient("https://ca.smallstep.com", WithTransport(http.DefaultTransport)) + if err != nil { + t.Fatalf("NewClient() error = %v", err) + } + type args struct { c *Client - sign *api.SignResponse - pk crypto.PrivateKey config *tls.Config } tests := []struct { - name string - args args - wantErr bool + name string + args args + want *TLSOptionCtx }{ - {"ok", args{client, sign, pk, &tls.Config{}}, false}, - {"fail", args{client, sign, "foo", &tls.Config{}}, true}, + {"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := newTLSOptionCtx(tt.args.c, tt.args.sign, tt.args.pk, tt.args.config) - if (err != nil) != tt.wantErr { - t.Errorf("newTLSOptionCtx() error = %v, wantErr %v", err, tt.wantErr) - return + if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want) } }) } @@ -188,8 +184,12 @@ func TestAddRootsToRootCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -203,21 +203,24 @@ func TestAddRootsToRootCAs(t *testing.T) { pool := x509.NewCertPool() pool.AddCert(cert) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{RootCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -234,8 +237,12 @@ func TestAddRootsToClientCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -249,21 +256,24 @@ func TestAddRootsToClientCAs(t *testing.T) { pool := x509.NewCertPool() pool.AddCert(cert) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{ClientCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -280,8 +290,12 @@ func TestAddFederationToRootCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -302,21 +316,24 @@ func TestAddFederationToRootCAs(t *testing.T) { pool.AddCert(crt1) pool.AddCert(crt2) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{RootCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -336,8 +353,12 @@ func TestAddFederationToClientCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -358,21 +379,24 @@ func TestAddFederationToClientCAs(t *testing.T) { pool.AddCert(crt1) pool.AddCert(crt2) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{ClientCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -392,8 +416,12 @@ func TestAddRootsToCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -407,21 +435,24 @@ func TestAddRootsToCAs(t *testing.T) { pool := x509.NewCertPool() pool.AddCert(cert) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -438,8 +469,12 @@ func TestAddFederationToCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -460,21 +495,24 @@ func TestAddFederationToCAs(t *testing.T) { pool.AddCert(crt1) pool.AddCert(crt2) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr) From 7dc61bf2338df98dd1a45b0767a35b0517976408 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 11 Jan 2019 19:13:06 -0800 Subject: [PATCH 15/20] Remove deprecated code --- ca/bootstrap_test.go | 6 ++++-- ca/client.go | 19 ++----------------- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index a046fde2..62e1493b 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -136,8 +136,10 @@ func TestBootstrap(t *testing.T) { if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) { t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint) } - if !reflect.DeepEqual(got.certPool, tt.want.certPool) { - t.Errorf("Bootstrap() certPool = %v, want %v", got.certPool, tt.want.certPool) + gotTR := got.client.Transport.(*http.Transport) + wantTR := tt.want.client.Transport.(*http.Transport) + if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) { + t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) } } } diff --git a/ca/client.go b/ca/client.go index 627cd450..83ee73db 100644 --- a/ca/client.go +++ b/ca/client.go @@ -23,7 +23,6 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api" - "golang.org/x/net/http2" "gopkg.in/square/go-jose.v2/jwt" ) @@ -237,10 +236,8 @@ func WithProvisionerLimit(limit int) ProvisionerOption { // Client implements an HTTP client for the CA server. type Client struct { - client *http.Client - endpoint *url.URL - certPool *x509.CertPool - cachedSign *api.SignResponse + client *http.Client + endpoint *url.URL } // NewClient creates a new Client with the given endpoint and options. @@ -259,23 +256,11 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) { return nil, err } - var cp *x509.CertPool - switch tr := tr.(type) { - case *http.Transport: - if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil { - cp = tr.TLSClientConfig.RootCAs - } - case *http2.Transport: - if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil { - cp = tr.TLSClientConfig.RootCAs - } - } return &Client{ client: &http.Client{ Transport: tr, }, endpoint: u, - certPool: cp, }, nil } From 8252608ca2542ca1cc574adc28673ffd0d68e1e7 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 14 Jan 2019 14:33:00 -0800 Subject: [PATCH 16/20] Fix mock --- api/api_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/api_test.go b/api/api_test.go index e60e6ba2..988b46dd 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -446,7 +446,7 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { } func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { - if m.getFederation != nil { + if m.getRoots != nil { return m.getRoots() } return m.ret1.([]*x509.Certificate), m.err From dbd1bf11f126739676cfb54704a3f07b67ff3f80 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 14 Jan 2019 17:35:38 -0800 Subject: [PATCH 17/20] Rename variable. --- ca/client.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ca/client.go b/ca/client.go index 83ee73db..e138698f 100644 --- a/ca/client.go +++ b/ca/client.go @@ -410,11 +410,11 @@ func (c *Client) Roots() (*api.RootsResponse, error) { if resp.StatusCode >= 400 { return nil, readError(resp.Body) } - var federation api.RootsResponse - if err := readJSON(resp.Body, &federation); err != nil { + var roots api.RootsResponse + if err := readJSON(resp.Body, &roots); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } - return &federation, nil + return &roots, nil } // Federation performs the get federation request to the CA and returns the From cfbb2a6f41478683683a504b5a92161c6f2abcd7 Mon Sep 17 00:00:00 2001 From: max furman Date: Mon, 14 Jan 2019 17:55:01 -0800 Subject: [PATCH 18/20] method documentation grammar fix --- authority/types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/authority/types.go b/authority/types.go index 50b632d6..29c12282 100644 --- a/authority/types.go +++ b/authority/types.go @@ -49,7 +49,7 @@ func (s multiString) First() string { return "" } -// Empties checks that none of the string is empty. +// Empties returns `true` if any string in the array is empty. func (s multiString) Empties() bool { if len(s) == 0 { return true From 6e620073f5bb05733af28fe7680991beca7de7bb Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 14 Jan 2019 17:59:31 -0800 Subject: [PATCH 19/20] Rename method Empties to HasEmpties --- authority/config.go | 2 +- authority/types.go | 4 ++-- authority/types_test.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/authority/config.go b/authority/config.go index f19fb202..a6a78523 100644 --- a/authority/config.go +++ b/authority/config.go @@ -117,7 +117,7 @@ func (c *Config) Validate() error { case c.Address == "": return errors.New("address cannot be empty") - case c.Root.Empties(): + case c.Root.HasEmpties(): return errors.New("root cannot be empty") case c.IntermediateCert == "": diff --git a/authority/types.go b/authority/types.go index 29c12282..d9120f59 100644 --- a/authority/types.go +++ b/authority/types.go @@ -49,8 +49,8 @@ func (s multiString) First() string { return "" } -// Empties returns `true` if any string in the array is empty. -func (s multiString) Empties() bool { +// HasEmpties returns `true` if any string in the array is empty. +func (s multiString) HasEmpties() bool { if len(s) == 0 { return true } diff --git a/authority/types_test.go b/authority/types_test.go index 620751d3..36877dcc 100644 --- a/authority/types_test.go +++ b/authority/types_test.go @@ -38,7 +38,7 @@ func Test_multiString_Empties(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := tt.s.Empties(); got != tt.want { + if got := tt.s.HasEmpties(); got != tt.want { t.Errorf("multiString.Empties() = %v, want %v", got, tt.want) } }) From e8ac3f488818976891e8bab6cf1f491e3bd8f5b4 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Mon, 14 Jan 2019 18:09:06 -0800 Subject: [PATCH 20/20] Add comment to differentiate GetRootCertificates and GetRoots. --- authority/root.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/authority/root.go b/authority/root.go index 01b3508d..51ed6ac5 100644 --- a/authority/root.go +++ b/authority/root.go @@ -29,16 +29,24 @@ func (a *Authority) GetRootCertificate() *x509.Certificate { } // GetRootCertificates returns the server root certificates. +// +// In the Authority interface we also have a similar method, GetRoots, at the +// moment the functionality of these two methods are almost identical, but this +// method is intended to be used internally by CA HTTP server to load the roots +// that will be set in the tls.Config while GetRoots will be used by the +// Authority interface and might have extra checks in the future. func (a *Authority) GetRootCertificates() []*x509.Certificate { return a.rootX509Certs } // GetRoots returns all the root certificates for this CA. +// This method implements the Authority interface. func (a *Authority) GetRoots() ([]*x509.Certificate, error) { return a.rootX509Certs, nil } // GetFederation returns all the root certificates in the federation. +// This method implements the Authority interface. func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) { a.certificates.Range(func(k, v interface{}) bool { crt, ok := v.(*x509.Certificate)