diff --git a/api/api.go b/api/api.go index e4f34d51..4bdf1b09 100644 --- a/api/api.go +++ b/api/api.go @@ -22,9 +22,11 @@ type Authority interface { GetTLSOptions() *tlsutil.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, signOpts authority.SignOptions, extraOpts ...interface{}) (*x509.Certificate, *x509.Certificate, error) - Renew(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) + Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error) GetEncryptedKey(kid string) (string, error) + GetRoots() (federation []*x509.Certificate, err error) + GetFederation() ([]*x509.Certificate, error) } // Certificate wraps a *x509.Certificate and adds the json.Marshaler interface. @@ -186,6 +188,16 @@ type SignResponse struct { TLS *tls.ConnectionState `json:"-"` } +// RootsResponse is the response object of the roots request. +type RootsResponse struct { + Certificates []Certificate `json:"crts"` +} + +// FederationResponse is the response object of the federation request. +type FederationResponse struct { + Certificates []Certificate `json:"crts"` +} + // caHandler is the type used to implement the different CA HTTP endpoints. type caHandler struct { Authority Authority @@ -205,6 +217,8 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("POST", "/renew", h.Renew) r.MethodFunc("GET", "/provisioners", h.Provisioners) r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) + r.MethodFunc("GET", "/roots", h.Roots) + r.MethodFunc("GET", "/federation", h.Federation) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) } @@ -320,6 +334,44 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { JSON(w, &ProvisionerKeyResponse{key}) } +// Roots returns all the root certificates for the CA. +func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { + roots, err := h.Authority.GetRoots() + if err != nil { + WriteError(w, Forbidden(err)) + return + } + + certs := make([]Certificate, len(roots)) + for i := range roots { + certs[i] = Certificate{roots[i]} + } + + w.WriteHeader(http.StatusCreated) + JSON(w, &RootsResponse{ + Certificates: certs, + }) +} + +// Federation returns all the public certificates in the federation. +func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { + federated, err := h.Authority.GetFederation() + if err != nil { + WriteError(w, Forbidden(err)) + return + } + + certs := make([]Certificate, len(federated)) + for i := range federated { + certs[i] = Certificate{federated[i]} + } + + w.WriteHeader(http.StatusCreated) + JSON(w, &FederationResponse{ + Certificates: certs, + }) +} + func parseCursor(r *http.Request) (cursor string, limit int, err error) { q := r.URL.Query() cursor = q.Get("cursor") diff --git a/api/api_test.go b/api/api_test.go index e6f123b8..988b46dd 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -392,6 +392,8 @@ type mockAuthority struct { renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error) getEncryptedKey func(kid string) (string, error) + getRoots func() ([]*x509.Certificate, error) + getFederation func() ([]*x509.Certificate, error) } func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) { @@ -443,6 +445,20 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { return m.ret1.(string), m.err } +func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { + if m.getRoots != nil { + return m.getRoots() + } + return m.ret1.([]*x509.Certificate), m.err +} + +func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { + if m.getFederation != nil { + return m.getFederation() + } + return m.ret1.([]*x509.Certificate), m.err +} + func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority @@ -812,3 +828,95 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { }) } } + +func Test_caHandler_Roots(t *testing.T) { + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, + } + tests := []struct { + name string + tls *tls.ConnectionState + cert *x509.Certificate + root *x509.Certificate + err error + statusCode int + }{ + {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + } + + expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + req := httptest.NewRequest("GET", "http://example.com/roots", nil) + req.TLS = tt.tls + w := httptest.NewRecorder() + h.Roots(w, req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Roots unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), expected) { + t.Errorf("caHandler.Roots Body = %s, wants %s", body, expected) + } + } + }) + } +} + +func Test_caHandler_Federation(t *testing.T) { + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, + } + tests := []struct { + name string + tls *tls.ConnectionState + cert *x509.Certificate + root *x509.Certificate + err error + statusCode int + }{ + {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + } + + expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + req := httptest.NewRequest("GET", "http://example.com/federation", nil) + req.TLS = tt.tls + w := httptest.NewRecorder() + h.Federation(w, req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Federation unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), expected) { + t.Errorf("caHandler.Federation Body = %s, wants %s", body, expected) + } + } + }) + } +} diff --git a/authority/authority.go b/authority/authority.go index 4753f83b..5a0cf1ab 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -2,7 +2,7 @@ package authority import ( "crypto/sha256" - realx509 "crypto/x509" + "crypto/x509" "encoding/hex" "fmt" "sync" @@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority" // Authority implements the Certificate Authority internal interface. type Authority struct { config *Config - rootX509Crt *realx509.Certificate + rootX509Certs []*x509.Certificate intermediateIdentity *x509util.Identity validateOnce bool certificates *sync.Map @@ -79,15 +79,29 @@ func (a *Authority) init() error { } var err error - // First load the root using our modified pem/x509 package. - a.rootX509Crt, err = pemutil.ReadCertificate(a.config.Root) - if err != nil { - return err + + // Load the root certificates and add them to the certificate store + a.rootX509Certs = make([]*x509.Certificate, len(a.config.Root)) + for i, path := range a.config.Root { + crt, err := pemutil.ReadCertificate(path) + if err != nil { + return err + } + // Add root certificate to the certificate map + sum := sha256.Sum256(crt.Raw) + a.certificates.Store(hex.EncodeToString(sum[:]), crt) + a.rootX509Certs[i] = crt } - // Add root certificate to the certificate map - sum := sha256.Sum256(a.rootX509Crt.Raw) - a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt) + // Add federated roots + for _, path := range a.config.FederatedRoots { + crt, err := pemutil.ReadCertificate(path) + if err != nil { + return err + } + sum := sha256.Sum256(crt.Raw) + a.certificates.Store(hex.EncodeToString(sum[:]), crt) + } // Decrypt and load intermediate public / private key pair. if len(a.config.Password) > 0 { diff --git a/authority/authority_test.go b/authority/authority_test.go index ad2f4980..1020f808 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -38,7 +38,7 @@ func testAuthority(t *testing.T) *Authority { } c := &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.ca.smallstep.com"}, @@ -68,7 +68,7 @@ func TestAuthorityNew(t *testing.T) { "fail bad root": func(t *testing.T) *newTest { c, err := LoadConfiguration("../ca/testdata/ca.json") assert.FatalError(t, err) - c.Root = "foo" + c.Root = []string{"foo"} return &newTest{ config: c, err: errors.New("open foo failed: no such file or directory"), @@ -105,10 +105,10 @@ func TestAuthorityNew(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - sum := sha256.Sum256(auth.rootX509Crt.Raw) + sum := sha256.Sum256(auth.rootX509Certs[0].Raw) root, ok := auth.certificates.Load(hex.EncodeToString(sum[:])) assert.Fatal(t, ok) - assert.Equals(t, auth.rootX509Crt, root) + assert.Equals(t, auth.rootX509Certs[0], root) assert.True(t, auth.initOnce) assert.NotNil(t, auth.intermediateIdentity) diff --git a/authority/config.go b/authority/config.go index 5f9910af..a6a78523 100644 --- a/authority/config.go +++ b/authority/config.go @@ -33,38 +33,10 @@ var ( } ) -type duration struct { - time.Duration -} - -// MarshalJSON parses a duration string and sets it to the duration. -// -// A duration string is a possibly signed sequence of decimal numbers, each with -// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". -// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) MarshalJSON() ([]byte, error) { - return json.Marshal(d.String()) -} - -// UnmarshalJSON parses a duration string and sets it to the duration. -// -// A duration string is a possibly signed sequence of decimal numbers, each with -// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". -// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) UnmarshalJSON(data []byte) (err error) { - var s string - if err = json.Unmarshal(data, &s); err != nil { - return errors.Wrapf(err, "error unmarshalling %s", data) - } - if d.Duration, err = time.ParseDuration(s); err != nil { - return errors.Wrapf(err, "error parsing %s as duration", s) - } - return -} - // Config represents the CA configuration and it's mapped to a JSON object. type Config struct { - Root string `json:"root"` + Root multiString `json:"root"` + FederatedRoots []string `json:"federatedRoots"` IntermediateCert string `json:"crt"` IntermediateKey string `json:"key"` Address string `json:"address"` @@ -145,7 +117,7 @@ func (c *Config) Validate() error { case c.Address == "": return errors.New("address cannot be empty") - case c.Root == "": + case c.Root.HasEmpties(): return errors.New("root cannot be empty") case c.IntermediateCert == "": diff --git a/authority/config_test.go b/authority/config_test.go index c16b4780..01cea2a1 100644 --- a/authority/config_test.go +++ b/authority/config_test.go @@ -40,7 +40,7 @@ func TestConfigValidate(t *testing.T) { "empty-address": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -54,7 +54,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -81,7 +81,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", @@ -94,7 +94,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", DNSNames: []string{"test.smallstep.com"}, Password: "pass", @@ -107,7 +107,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", Password: "pass", @@ -120,7 +120,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -134,7 +134,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -149,7 +149,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -178,7 +178,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, diff --git a/authority/root.go b/authority/root.go index 5c918ee1..51ed6ac5 100644 --- a/authority/root.go +++ b/authority/root.go @@ -17,7 +17,7 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) { crt, ok := val.(*x509.Certificate) if !ok { - return nil, &apiError{errors.Errorf("stored value is not a *cryto/x509.Certificate"), + return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}} } return crt, nil @@ -25,5 +25,39 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) { // GetRootCertificate returns the server root certificate. func (a *Authority) GetRootCertificate() *x509.Certificate { - return a.rootX509Crt + return a.rootX509Certs[0] +} + +// GetRootCertificates returns the server root certificates. +// +// In the Authority interface we also have a similar method, GetRoots, at the +// moment the functionality of these two methods are almost identical, but this +// method is intended to be used internally by CA HTTP server to load the roots +// that will be set in the tls.Config while GetRoots will be used by the +// Authority interface and might have extra checks in the future. +func (a *Authority) GetRootCertificates() []*x509.Certificate { + return a.rootX509Certs +} + +// GetRoots returns all the root certificates for this CA. +// This method implements the Authority interface. +func (a *Authority) GetRoots() ([]*x509.Certificate, error) { + return a.rootX509Certs, nil +} + +// GetFederation returns all the root certificates in the federation. +// This method implements the Authority interface. +func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) { + a.certificates.Range(func(k, v interface{}) bool { + crt, ok := v.(*x509.Certificate) + if !ok { + federation = nil + err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"), + http.StatusInternalServerError, context{}} + return false + } + federation = append(federation, crt) + return true + }) + return } diff --git a/authority/root_test.go b/authority/root_test.go index fd4c31db..17f25755 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -1,11 +1,14 @@ package authority import ( + "crypto/x509" "net/http" + "reflect" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/cli/crypto/pemutil" ) func TestRoot(t *testing.T) { @@ -17,7 +20,7 @@ func TestRoot(t *testing.T) { err *apiError }{ "not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}}, - "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *cryto/x509.Certificate"), http.StatusInternalServerError, context{}}}, + "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}}, "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil}, } @@ -37,9 +40,116 @@ func TestRoot(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, crt, a.rootX509Crt) + assert.Equals(t, crt, a.rootX509Certs[0]) } } }) } } + +func TestAuthority_GetRootCertificate(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + want *x509.Certificate + }{ + {"ok", cert}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + if got := a.GetRootCertificate(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetRootCertificate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetRootCertificates(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + want []*x509.Certificate + }{ + {"ok", []*x509.Certificate{cert}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + if got := a.GetRootCertificates(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetRootCertificates() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetRoots(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + want []*x509.Certificate + wantErr bool + }{ + {"ok", []*x509.Certificate{cert}, false}, + } + for _, tt := range tests { + a := testAuthority(t) + t.Run(tt.name, func(t *testing.T) { + got, err := a.GetRoots() + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetRoots() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_GetFederation(t *testing.T) { + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + wantFederation []*x509.Certificate + wantErr bool + fn func(a *Authority) + }{ + {"ok", []*x509.Certificate{cert}, false, nil}, + {"fail", nil, true, func(a *Authority) { + a.certificates.Store("foo", "bar") + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) + if tt.fn != nil { + tt.fn(a) + } + gotFederation, err := a.GetFederation() + if (err != nil) != tt.wantErr { + t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotFederation, tt.wantFederation) { + t.Errorf("Authority.GetFederation() = %v, want %v", gotFederation, tt.wantFederation) + } + }) + } +} diff --git a/authority/types.go b/authority/types.go new file mode 100644 index 00000000..d9120f59 --- /dev/null +++ b/authority/types.go @@ -0,0 +1,104 @@ +package authority + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +type duration struct { + time.Duration +} + +// MarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +// UnmarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *duration) UnmarshalJSON(data []byte) (err error) { + var s string + if err = json.Unmarshal(data, &s); err != nil { + return errors.Wrapf(err, "error unmarshalling %s", data) + } + if d.Duration, err = time.ParseDuration(s); err != nil { + return errors.Wrapf(err, "error parsing %s as duration", s) + } + return +} + +// multiString represents a type that can be encoded/decoded in JSON as a single +// string or an array of strings. +type multiString []string + +// First returns the first element of a multiString. It will return an empty +// string if the multistring is empty. +func (s multiString) First() string { + if len(s) > 0 { + return s[0] + } + return "" +} + +// HasEmpties returns `true` if any string in the array is empty. +func (s multiString) HasEmpties() bool { + if len(s) == 0 { + return true + } + for _, ss := range s { + if len(ss) == 0 { + return true + } + } + return false +} + +// MarshalJSON marshals the multistring as a string or a slice of strings . With +// 0 elements it will return the empty string, with 1 element a regular string, +// otherwise a slice of strings. +func (s multiString) MarshalJSON() ([]byte, error) { + switch len(s) { + case 0: + return []byte(`""`), nil + case 1: + return json.Marshal(s[0]) + default: + return json.Marshal([]string(s)) + } +} + +// UnmarshalJSON parses a string or a slice and sets it to the multiString. +func (s *multiString) UnmarshalJSON(data []byte) error { + if s == nil { + return errors.New("multiString cannot be nil") + } + if len(data) == 0 { + *s = nil + return nil + } + // Parse string + if data[0] == '"' { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return errors.Wrapf(err, "error unmarshalling %s", data) + } + *s = []string{str} + return nil + } + // Parse array + var ss []string + if err := json.Unmarshal(data, &ss); err != nil { + return errors.Wrapf(err, "error unmarshalling %s", data) + } + *s = ss + return nil +} diff --git a/authority/types_test.go b/authority/types_test.go new file mode 100644 index 00000000..36877dcc --- /dev/null +++ b/authority/types_test.go @@ -0,0 +1,103 @@ +package authority + +import ( + "reflect" + "testing" +) + +func Test_multiString_First(t *testing.T) { + tests := []struct { + name string + s multiString + want string + }{ + {"empty", multiString{}, ""}, + {"string", multiString{"one"}, "one"}, + {"slice", multiString{"one", "two"}, "one"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.s.First(); got != tt.want { + t.Errorf("multiString.First() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiString_Empties(t *testing.T) { + tests := []struct { + name string + s multiString + want bool + }{ + {"empty", multiString{}, true}, + {"string", multiString{"one"}, false}, + {"empty string", multiString{""}, true}, + {"slice", multiString{"one", "two"}, false}, + {"empty slice", multiString{"one", ""}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.s.HasEmpties(); got != tt.want { + t.Errorf("multiString.Empties() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiString_MarshalJSON(t *testing.T) { + tests := []struct { + name string + s multiString + want []byte + wantErr bool + }{ + {"empty", []string{}, []byte(`""`), false}, + {"string", []string{"a string"}, []byte(`"a string"`), false}, + {"slice", []string{"string one", "string two"}, []byte(`["string one","string two"]`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.s.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("multiString.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("multiString.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiString_UnmarshalJSON(t *testing.T) { + + type args struct { + data []byte + } + tests := []struct { + name string + s *multiString + args args + want *multiString + wantErr bool + }{ + {"empty", new(multiString), args{[]byte{}}, new(multiString), false}, + {"empty string", new(multiString), args{[]byte(`""`)}, &multiString{""}, false}, + {"string", new(multiString), args{[]byte(`"a string"`)}, &multiString{"a string"}, false}, + {"slice", new(multiString), args{[]byte(`["string one","string two"]`)}, &multiString{"string one", "string two"}, false}, + {"error", new(multiString), args{[]byte(`["123",123]`)}, new(multiString), true}, + {"nil", nil, args{nil}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.s.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("multiString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(tt.s, tt.want) { + t.Errorf("multiString.UnmarshalJSON() = %v, want %v", tt.s, tt.want) + } + }) + } +} diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 577e4aaa..8989b3c0 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -87,6 +87,9 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio return nil, err } + // Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs + options = append(options, AddRootsToCAs()) + tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...) if err != nil { return nil, err @@ -130,6 +133,9 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (* return nil, err } + // Make sure the tlsConfig have all supported roots on RootCAs + options = append(options, AddRootsToRootCAs()) + transport, err := client.Transport(ctx, sign, pk, options...) if err != nil { return nil, err diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 241827c6..62e1493b 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -3,12 +3,15 @@ package ca import ( "context" "crypto/tls" + "io/ioutil" + "net" "net/http" "net/http/httptest" "reflect" "testing" "time" + "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" @@ -18,6 +21,24 @@ import ( "gopkg.in/square/go-jose.v2/jwt" ) +func newLocalListener() net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(errors.Wrap(err, "failed to listen on a port")) + } + } + return l +} + +func setMinCertDuration(d time.Duration) func() { + tmp := minCertDuration + minCertDuration = 1 * time.Second + return func() { + minCertDuration = tmp + } +} + func startCABootstrapServer() *httptest.Server { config, err := authority.LoadConfiguration("testdata/ca.json") if err != nil { @@ -115,8 +136,10 @@ func TestBootstrap(t *testing.T) { if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) { t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint) } - if !reflect.DeepEqual(got.certPool, tt.want.certPool) { - t.Errorf("Bootstrap() certPool = %v, want %v", got.certPool, tt.want.certPool) + gotTR := got.client.Transport.(*http.Transport) + wantTR := tt.want.client.Transport.(*http.Transport) + if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) { + t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) } } } @@ -267,3 +290,147 @@ func TestBootstrapClient(t *testing.T) { }) } } + +func TestBootstrapClientServerRotation(t *testing.T) { + reset := setMinCertDuration(1 * time.Second) + defer reset() + + // Configuration with current root + config, err := authority.LoadConfiguration("testdata/rotate-ca-0.json") + if err != nil { + t.Fatal(err) + } + + // Get local address + listener := newLocalListener() + config.Address = listener.Addr().String() + caURL := "https://" + listener.Addr().String() + + // Start CA server + ca, err := New(config) + if err != nil { + t.Fatal(err) + } + go func() { + ca.srv.Serve(listener) + }() + defer ca.Stop() + time.Sleep(1 * time.Second) + + // Create bootstrap server + token := generateBootstrapToken(caURL, "127.0.0.1", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + server, err := BootstrapServer(context.Background(), token, &http.Server{ + Addr: ":0", + Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("ok")) + }), + }, RequireAndVerifyClientCert()) + if err != nil { + t.Fatal(err) + } + listener = newLocalListener() + srvURL := "https://" + listener.Addr().String() + go func() { + server.ServeTLS(listener, "", "") + }() + defer server.Close() + + // Create bootstrap client + token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") + client, err := BootstrapClient(context.Background(), token) + if err != nil { + t.Errorf("BootstrapClient() error = %v", err) + return + } + + // doTest does a request that requires mTLS + doTest := func(client *http.Client) error { + // test with ca + resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) + if err != nil { + return errors.Wrap(err, "client.Post() failed") + } + var renew api.SignResponse + if err := readJSON(resp.Body, &renew); err != nil { + return errors.Wrap(err, "client.Post() error reading response") + } + if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil { + return errors.New("client.Post() unexpected response found") + } + // test with bootstrap server + resp, err = client.Get(srvURL) + if err != nil { + return errors.Wrapf(err, "client.Get(%s) failed", srvURL) + } + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrap(err, "client.Get() error reading response") + } + if string(b) != "ok" { + return errors.New("client.Get() unexpected response found") + } + return nil + } + + // Test with default root + if err := doTest(client); err != nil { + t.Errorf("Test with rotate-ca-0.json failed: %v", err) + } + + // wait for renew + time.Sleep(5 * time.Second) + + // Reload with configuration with current and future root + ca.opts.configFile = "testdata/rotate-ca-1.json" + if err := doReload(ca); err != nil { + t.Errorf("ca.Reload() error = %v", err) + return + } + if err := doTest(client); err != nil { + t.Errorf("Test with rotate-ca-1.json failed: %v", err) + } + + // wait for renew + time.Sleep(5 * time.Second) + + // Reload with new and old root + ca.opts.configFile = "testdata/rotate-ca-2.json" + if err := doReload(ca); err != nil { + t.Errorf("ca.Reload() error = %v", err) + return + } + if err := doTest(client); err != nil { + t.Errorf("Test with rotate-ca-2.json failed: %v", err) + } + + // wait for renew + time.Sleep(5 * time.Second) + + // Reload with pnly the new root + ca.opts.configFile = "testdata/rotate-ca-3.json" + if err := doReload(ca); err != nil { + t.Errorf("ca.Reload() error = %v", err) + return + } + if err := doTest(client); err != nil { + t.Errorf("Test with rotate-ca-3.json failed: %v", err) + } +} + +// doReload uses the reload implementation but overwrites the new address with +// the one being used. +func doReload(ca *CA) error { + config, err := authority.LoadConfiguration(ca.opts.configFile) + if err != nil { + return errors.Wrap(err, "error reloading ca") + } + + newCA, err := New(config, WithPassword(ca.opts.password), WithConfigFile(ca.opts.configFile)) + if err != nil { + return errors.Wrap(err, "error reloading ca") + } + // Use same address in new server + newCA.srv.Addr = ca.srv.Addr + return ca.srv.Reload(newCA.srv) +} diff --git a/ca/ca.go b/ca/ca.go index 8f72984f..07ee3311 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -176,7 +176,9 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) { } certPool := x509.NewCertPool() - certPool.AddCert(auth.GetRootCertificate()) + for _, crt := range auth.GetRootCertificates() { + certPool.AddCert(crt) + } // GetCertificate will only be called if the client supplies SNI // information or if tlsConfig.Certificates is empty. diff --git a/ca/client.go b/ca/client.go index 47116245..e138698f 100644 --- a/ca/client.go +++ b/ca/client.go @@ -23,7 +23,6 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api" - "golang.org/x/net/http2" "gopkg.in/square/go-jose.v2/jwt" ) @@ -239,7 +238,6 @@ func WithProvisionerLimit(limit int) ProvisionerOption { type Client struct { client *http.Client endpoint *url.URL - certPool *x509.CertPool } // NewClient creates a new Client with the given endpoint and options. @@ -258,23 +256,11 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) { return nil, err } - var cp *x509.CertPool - switch tr := tr.(type) { - case *http.Transport: - if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil { - cp = tr.TLSClientConfig.RootCAs - } - case *http2.Transport: - if tr.TLSClientConfig != nil && tr.TLSClientConfig.RootCAs != nil { - cp = tr.TLSClientConfig.RootCAs - } - } return &Client{ client: &http.Client{ Transport: tr, }, endpoint: u, - certPool: cp, }, nil } @@ -413,6 +399,42 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) return &key, nil } +// Roots performs the get roots request to the CA and returns the +// api.RootsResponse struct. +func (c *Client) Roots() (*api.RootsResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"}) + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var roots api.RootsResponse + if err := readJSON(resp.Body, &roots); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &roots, nil +} + +// Federation performs the get federation request to the CA and returns the +// api.FederationResponse struct. +func (c *Client) Federation() (*api.FederationResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"}) + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var federation api.FederationResponse + if err := readJSON(resp.Body, &federation); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &federation, nil +} + // CreateSignRequest is a helper function that given an x509 OTT returns a // simple but secure sign request as well as the private key used. func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) { diff --git a/ca/client_test.go b/ca/client_test.go index 6d5cd22a..d82afa31 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -512,6 +512,128 @@ func TestClient_ProvisionerKey(t *testing.T) { } } +func TestClient_Roots(t *testing.T) { + ok := &api.RootsResponse{ + Certificates: []api.Certificate{ + {Certificate: parseCertificate(rootPEM)}, + }, + } + unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized")) + badRequest := api.BadRequest(fmt.Errorf("Bad Request")) + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + }{ + {"ok", ok, 200, false}, + {"unauthorized", unauthorized, 401, true}, + {"empty request", badRequest, 403, true}, + {"nil request", badRequest, 403, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(tt.responseCode) + api.JSON(w, tt.response) + }) + + got, err := c.Roots() + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Roots() = %v, want nil", got) + } + if !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.Roots() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Roots() = %v, want %v", got, tt.response) + } + } + }) + } +} + +func TestClient_Federation(t *testing.T) { + ok := &api.FederationResponse{ + Certificates: []api.Certificate{ + {Certificate: parseCertificate(rootPEM)}, + }, + } + unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized")) + badRequest := api.BadRequest(fmt.Errorf("Bad Request")) + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + }{ + {"ok", ok, 200, false}, + {"unauthorized", unauthorized, 401, true}, + {"empty request", badRequest, 403, true}, + {"nil request", badRequest, 403, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(tt.responseCode) + api.JSON(w, tt.response) + }) + + got, err := c.Federation() + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Federation() = %v, want nil", got) + } + if !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.Federation() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Federation() = %v, want %v", got, tt.response) + } + } + }) + } +} + func Test_parseEndpoint(t *testing.T) { expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"} expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"} diff --git a/ca/renew.go b/ca/renew.go index aee9dd13..44234781 100644 --- a/ca/renew.go +++ b/ca/renew.go @@ -14,6 +14,8 @@ import ( // certificate. type RenewFunc func() (*tls.Certificate, error) +var minCertDuration = time.Minute + // TLSRenewer automatically renews a tls certificate using a RenewFunc. type TLSRenewer struct { sync.RWMutex @@ -58,8 +60,8 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption } period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore) - if period < time.Minute { - return nil, errors.Errorf("period must be greater than or equal to 1 Minute, but got %v.", period) + if period < minCertDuration { + return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period) } // By default we will try to renew the cert before 2/3 of the validity // period have expired. diff --git a/ca/testdata/ca.json b/ca/testdata/ca.json index d61a8c49..f29f24c6 100644 --- a/ca/testdata/ca.json +++ b/ca/testdata/ca.json @@ -1,5 +1,6 @@ { "root": "../ca/testdata/secrets/root_ca.crt", + "federatedRoots": ["../ca/testdata/secrets/federated_ca.crt"], "crt": "../ca/testdata/secrets/intermediate_ca.crt", "key": "../ca/testdata/secrets/intermediate_ca_key", "password": "password", @@ -17,7 +18,6 @@ ] }, "authority": { - "minCertDuration": "1m", "provisioners": [ { "name": "max", @@ -72,7 +72,7 @@ "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" }, "claims": { - "minTLSCertDuration": "30s" + "minTLSCertDuration": "1s" } }, { "name": "mariano", diff --git a/ca/testdata/rotate-ca-0.json b/ca/testdata/rotate-ca-0.json new file mode 100644 index 00000000..20dd603a --- /dev/null +++ b/ca/testdata/rotate-ca-0.json @@ -0,0 +1,46 @@ +{ + "root": "testdata/secrets/root_ca.crt", + "crt": "testdata/secrets/intermediate_ca.crt", + "key": "testdata/secrets/intermediate_ca_key", + "password": "password", + "address": "127.0.0.1:0", + "dnsNames": ["127.0.0.1"], + "logger": {"format": "text"}, + "tls": { + "minVersion": 1.2, + "maxVersion": 1.2, + "renegotiation": false, + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + ] + }, + "authority": { + "provisioners": [ + { + "name": "mariano", + "type": "jwk", + "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", + "key": { + "use": "sig", + "kty": "EC", + "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", + "crv": "P-256", + "alg": "ES256", + "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", + "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" + }, + "claims": { + "minTLSCertDuration": "1s", + "defaultTLSCertDuration": "5s" + } + } + ], + "template": { + "country": "US", + "locality": "San Francisco", + "organization": "Smallstep" + } + } +} diff --git a/ca/testdata/rotate-ca-1.json b/ca/testdata/rotate-ca-1.json new file mode 100644 index 00000000..b038f694 --- /dev/null +++ b/ca/testdata/rotate-ca-1.json @@ -0,0 +1,46 @@ +{ + "root": ["testdata/secrets/root_ca.crt", "testdata/rotated/root_ca.crt"], + "crt": "testdata/secrets/intermediate_ca.crt", + "key": "testdata/secrets/intermediate_ca_key", + "password": "password", + "address": "127.0.0.1:0", + "dnsNames": ["127.0.0.1"], + "logger": {"format": "text"}, + "tls": { + "minVersion": 1.2, + "maxVersion": 1.2, + "renegotiation": false, + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + ] + }, + "authority": { + "provisioners": [ + { + "name": "mariano", + "type": "jwk", + "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", + "key": { + "use": "sig", + "kty": "EC", + "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", + "crv": "P-256", + "alg": "ES256", + "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", + "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" + }, + "claims": { + "minTLSCertDuration": "1s", + "defaultTLSCertDuration": "5s" + } + } + ], + "template": { + "country": "US", + "locality": "San Francisco", + "organization": "Smallstep" + } + } +} diff --git a/ca/testdata/rotate-ca-2.json b/ca/testdata/rotate-ca-2.json new file mode 100644 index 00000000..7ec965d0 --- /dev/null +++ b/ca/testdata/rotate-ca-2.json @@ -0,0 +1,46 @@ +{ + "root": ["testdata/rotated/root_ca.crt", "testdata/secrets/root_ca.crt"], + "crt": "testdata/rotated/intermediate_ca.crt", + "key": "testdata/rotated/intermediate_ca_key", + "password": "asdf", + "address": "127.0.0.1:0", + "dnsNames": ["127.0.0.1"], + "logger": {"format": "text"}, + "tls": { + "minVersion": 1.2, + "maxVersion": 1.2, + "renegotiation": false, + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + ] + }, + "authority": { + "provisioners": [ + { + "name": "mariano", + "type": "jwk", + "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", + "key": { + "use": "sig", + "kty": "EC", + "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", + "crv": "P-256", + "alg": "ES256", + "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", + "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" + }, + "claims": { + "minTLSCertDuration": "1s", + "defaultTLSCertDuration": "5s" + } + } + ], + "template": { + "country": "US", + "locality": "San Francisco", + "organization": "Smallstep" + } + } +} diff --git a/ca/testdata/rotate-ca-3.json b/ca/testdata/rotate-ca-3.json new file mode 100644 index 00000000..968da6ba --- /dev/null +++ b/ca/testdata/rotate-ca-3.json @@ -0,0 +1,46 @@ +{ + "root": "testdata/rotated/root_ca.crt", + "crt": "testdata/rotated/intermediate_ca.crt", + "key": "testdata/rotated/intermediate_ca_key", + "password": "asdf", + "address": "127.0.0.1:0", + "dnsNames": ["127.0.0.1"], + "logger": {"format": "text"}, + "tls": { + "minVersion": 1.2, + "maxVersion": 1.2, + "renegotiation": false, + "cipherSuites": [ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" + ] + }, + "authority": { + "provisioners": [ + { + "name": "mariano", + "type": "jwk", + "encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlB1UnJVQ1RZZkR1T2F5MEh2cGl6bncifQ.7a-OP5xWGbFra8m2MN9YuLGt6v4y0wmB.u-54daK2y-0UO9na.3GQy6E52-fOSUu5NJ_sEbxj_T3CTyWb7wOPFv2oI2PBWXp5CLpiWJbCFpF4v2oD9fN5XbxMP14ootbrFjATnoMWfWgyLwG-KOj9BqMGNxhG2v37yC7Wrris6s30nrPa3uyNEYZ12AOQW1K04cU2X0u_qJM3vzMCle548ZFTWs6_d6L8lp3o0F9MEbCmJ4p6CLqQxjxYtn1aD79lM91NbIXpRP3iUFQRly-y_iC2mSkXCdd_cQ6-dqLUchXwWRyVO5nBHb4J87aZ91VApw7ldTLtwRZ2ZGJpqGQGgjTwi4sgjEcMuGg0_83XGk2ubdlKDpmGFedOHS5rYCbxotts.vSYfxsi2UU9LQeySDjAnnQ", + "key": { + "use": "sig", + "kty": "EC", + "kid": "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", + "crv": "P-256", + "alg": "ES256", + "x": "tTKthEHN7RuybhkaC43J2oLfBG995FNSWbtahLAiK7Y", + "y": "e3wycXwVB366F0wLE5J9gIpq8EIQ4900nHBNpIGebEA" + }, + "claims": { + "minTLSCertDuration": "1s", + "defaultTLSCertDuration": "5s" + } + } + ], + "template": { + "country": "US", + "locality": "San Francisco", + "organization": "Smallstep" + } + } +} diff --git a/ca/testdata/rotated/intermediate_ca.crt b/ca/testdata/rotated/intermediate_ca.crt new file mode 100644 index 00000000..338ebb22 --- /dev/null +++ b/ca/testdata/rotated/intermediate_ca.crt @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBxTCCAWugAwIBAgIQLIY6MR/1fBRQY4ZTTsPAJjAKBggqhkjOPQQDAjAcMRow +GAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTAeFw0xOTAxMDcyMDExMzBaFw0yOTAx +MDQyMDExMzBaMCQxIjAgBgNVBAMTGVNtYWxsc3RlcCBJbnRlcm1lZGlhdGUgQ0Ew +WTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARgtjL/KLNpdq81YYWaek1lrkPM/QF1 +m+ujwv5jya21fAXljdBLh6m2xco1GPfwPBbwUGlNOdEqE9Nq3Qx3ngPKo4GGMIGD +MA4GA1UdDwEB/wQEAwIBpjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIw +EgYDVR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUqixeZ/K1HW9N6SVw7ONya98S +u8UwHwYDVR0jBBgwFoAUgIzlCLxh/RlwEany4JQHOorLAIEwCgYIKoZIzj0EAwID +SAAwRQIgdGX6lxThrKlt3v+3HJZlaWdmoeQ3vYwpJb9uHExZdVYCIQDCxsdI8EnB +bxjnJscbT4zvqVsq6AmycdbFwgy8RIeVzg== +-----END CERTIFICATE----- diff --git a/ca/testdata/rotated/intermediate_ca_key b/ca/testdata/rotated/intermediate_ca_key new file mode 100644 index 00000000..6c3b1622 --- /dev/null +++ b/ca/testdata/rotated/intermediate_ca_key @@ -0,0 +1,8 @@ +-----BEGIN EC PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,7dcc0a8c1d73c8d438184e0928875329 + +r6yrQrHg6zBZRSjQpe8RzyQALEfiT3/8lMvvPu3BX6yign5skMfCVMXZhzbmAwmR +BJBIX+5hkudR2VN+hrsOyuU7FvIk4gx2c8buIlFObfYXIml0mpuThfm52ciAtOTE +S0hkfYvPcOAjzaDZ+8Po/mYhkODgyvijogn4ioTF/Ss= +-----END EC PRIVATE KEY----- diff --git a/ca/testdata/rotated/root_ca.crt b/ca/testdata/rotated/root_ca.crt new file mode 100644 index 00000000..87cd5650 --- /dev/null +++ b/ca/testdata/rotated/root_ca.crt @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa +MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw +MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG +SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL +jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB +/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR +qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z +YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc +Sw== +-----END CERTIFICATE----- diff --git a/ca/testdata/rotated/root_ca_key b/ca/testdata/rotated/root_ca_key new file mode 100644 index 00000000..c92f587e --- /dev/null +++ b/ca/testdata/rotated/root_ca_key @@ -0,0 +1,8 @@ +-----BEGIN EC PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,8ce79d28601b9809905ef7c362a20749 + +H+pTTL3B5fLYycgHLxFOW0fZsayr7Y+BW8THKf12h8dk0/eOE1wNoX2TuMtpbZgO +lMJdFPL+SAPCCmuZOZIcQDejRHVcYBq1wvrrnw/yfVawXC4xze+J4Y+q0J2WY+rM +xcLGlEOIRZkvdDVGmSitEZBl0Ibk0p9tG++7QGqAvnk= +-----END EC PRIVATE KEY----- diff --git a/ca/testdata/secrets/federated_ca.crt b/ca/testdata/secrets/federated_ca.crt new file mode 100644 index 00000000..87cd5650 --- /dev/null +++ b/ca/testdata/secrets/federated_ca.crt @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBfTCCASKgAwIBAgIRAJPUE0MTA+fMz6f6i/XYmTwwCgYIKoZIzj0EAwIwHDEa +MBgGA1UEAxMRU21hbGxzdGVwIFJvb3QgQ0EwHhcNMTkwMTA3MjAxMTMwWhcNMjkw +MTA0MjAxMTMwWjAcMRowGAYDVQQDExFTbWFsbHN0ZXAgUm9vdCBDQTBZMBMGByqG +SM49AgEGCCqGSM49AwEHA0IABCOH/PGThn0cMOGDeqDxb22olsdCm8hVdyW9cHQL +jfIYAqpWNh9f7E5umlnxkOy6OEROTtpq7etzfBbzb52loVWjRTBDMA4GA1UdDwEB +/wQEAwIBpjASBgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBSAjOUIvGH9GXAR +qfLglAc6issAgTAKBggqhkjOPQQDAgNJADBGAiEAjs0yjbQ/9dmGoUn7JS3lE83z +YlnXZ0fHdeNakkIKhQICIQCUENhGZp63pMtm3ipgwp91EM0T7YtKgrFNvDekqufc +Sw== +-----END CERTIFICATE----- diff --git a/ca/tls.go b/ca/tls.go index 5e8c4118..31d1632b 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -41,7 +41,8 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(tlsConfig, options); err != nil { + tlsCtx := newTLSOptionCtx(c, tlsConfig) + if err := tlsCtx.apply(options); err != nil { return nil, err } @@ -50,7 +51,10 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, if err != nil { return nil, err } - renewer.RenewCertificate = getRenewFunc(c, tr, pk) + renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) + + // Update client transport + c.client.Transport = tr // Start renewer renewer.RunContext(ctx) @@ -87,7 +91,8 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(tlsConfig, options); err != nil { + tlsCtx := newTLSOptionCtx(c, tlsConfig) + if err := tlsCtx.apply(options); err != nil { return nil, err } @@ -96,7 +101,10 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, if err != nil { return nil, err } - renewer.RenewCertificate = getRenewFunc(c, tr, pk) + renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) + + // Update client transport + c.client.Transport = tr // Start renewer renewer.RunContext(ctx) @@ -238,8 +246,13 @@ func getPEM(i interface{}) ([]byte, error) { return pem.EncodeToMemory(block), nil } -func getRenewFunc(client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc { +func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc { return func() (*tls.Certificate, error) { + // Get updated list of roots + if err := ctx.applyRenew(); err != nil { + return nil, err + } + // Get new certificate sign, err := client.Renew(tr) if err != nil { return nil, err diff --git a/ca/tls_options.go b/ca/tls_options.go index fb0bb20b..47e2c627 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -6,13 +6,35 @@ import ( ) // TLSOption defines the type of a function that modifies a tls.Config. -type TLSOption func(c *tls.Config) error +type TLSOption func(ctx *TLSOptionCtx) error -// setTLSOptions takes one or more option function and applies them in order to -// a tls.Config. -func setTLSOptions(c *tls.Config, options []TLSOption) error { - for _, opt := range options { - if err := opt(c); err != nil { +// TLSOptionCtx is the context modified on TLSOption methods. +type TLSOptionCtx struct { + Client *Client + Config *tls.Config + OnRenewFunc []TLSOption +} + +// newTLSOptionCtx creates the TLSOption context. +func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx { + return &TLSOptionCtx{ + Client: c, + Config: config, + } +} + +func (ctx *TLSOptionCtx) apply(options []TLSOption) error { + for _, fn := range options { + if err := fn(ctx); err != nil { + return err + } + } + return nil +} + +func (ctx *TLSOptionCtx) applyRenew() error { + for _, fn := range ctx.OnRenewFunc { + if err := fn(ctx); err != nil { return err } } @@ -22,8 +44,8 @@ func setTLSOptions(c *tls.Config, options []TLSOption) error { // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // a valid TLS client certificate. This is the default option for mTLS servers. func RequireAndVerifyClientCert() TLSOption { - return func(c *tls.Config) error { - c.ClientAuth = tls.RequireAndVerifyClientCert + return func(ctx *TLSOptionCtx) error { + ctx.Config.ClientAuth = tls.RequireAndVerifyClientCert return nil } } @@ -31,8 +53,8 @@ func RequireAndVerifyClientCert() TLSOption { // VerifyClientCertIfGiven is a tls.Config option used on on servers to validate // a TLS client certificate if it is provided. It does not requires a certificate. func VerifyClientCertIfGiven() TLSOption { - return func(c *tls.Config) error { - c.ClientAuth = tls.VerifyClientCertIfGiven + return func(ctx *TLSOptionCtx) error { + ctx.Config.ClientAuth = tls.VerifyClientCertIfGiven return nil } } @@ -41,11 +63,11 @@ func VerifyClientCertIfGiven() TLSOption { // defines the set of root certificate authorities that clients use when // verifying server certificates. func AddRootCA(cert *x509.Certificate) TLSOption { - return func(c *tls.Config) error { - if c.RootCAs == nil { - c.RootCAs = x509.NewCertPool() + return func(ctx *TLSOptionCtx) error { + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() } - c.RootCAs.AddCert(cert) + ctx.Config.RootCAs.AddCert(cert) return nil } } @@ -54,11 +76,163 @@ func AddRootCA(cert *x509.Certificate) TLSOption { // defines the set of root certificate authorities that servers use if required // to verify a client certificate by the policy in ClientAuth. func AddClientCA(cert *x509.Certificate) TLSOption { - return func(c *tls.Config) error { - if c.ClientCAs == nil { - c.ClientCAs = x509.NewCertPool() + return func(ctx *TLSOptionCtx) error { + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() } - c.ClientCAs.AddCert(cert) + ctx.Config.ClientCAs.AddCert(cert) return nil } } + +// AddRootsToRootCAs does a roots request and adds to the tls.Config RootCAs all +// the certificates in the response. RootCAs defines the set of root certificate +// authorities that clients use when verifying server certificates. +// +// BootstrapServer and BootstrapClient methods include this option by default. +func AddRootsToRootCAs() TLSOption { + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Roots() + if err != nil { + return err + } + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.RootCAs.AddCert(cert.Certificate) + } + return nil + } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } +} + +// AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs +// all the certificates in the response. ClientCAs defines the set of root +// certificate authorities that servers use if required to verify a client +// certificate by the policy in ClientAuth. +// +// BootstrapServer method includes this option by default. +func AddRootsToClientCAs() TLSOption { + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Roots() + if err != nil { + return err + } + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.ClientCAs.AddCert(cert.Certificate) + } + return nil + } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } +} + +// AddFederationToRootCAs does a federation request and adds to the tls.Config +// RootCAs all the certificates in the response. RootCAs defines the set of root +// certificate authorities that clients use when verifying server certificates. +func AddFederationToRootCAs() TLSOption { + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Federation() + if err != nil { + return err + } + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.RootCAs.AddCert(cert.Certificate) + } + return nil + } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } +} + +// AddFederationToClientCAs does a federation request and adds to the tls.Config +// ClientCAs all the certificates in the response. ClientCAs defines the set of +// root certificate authorities that servers use if required to verify a client +// certificate by the policy in ClientAuth. +func AddFederationToClientCAs() TLSOption { + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Federation() + if err != nil { + return err + } + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.ClientCAs.AddCert(cert.Certificate) + } + return nil + } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } +} + +// AddRootsToCAs does a roots request and adds the resulting certs to the +// tls.Config RootCAs and ClientCAs. Combines the functionality of +// AddRootsToRootCAs and AddRootsToClientCAs. +func AddRootsToCAs() TLSOption { + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Roots() + if err != nil { + return err + } + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.ClientCAs.AddCert(cert.Certificate) + ctx.Config.RootCAs.AddCert(cert.Certificate) + } + return nil + } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } +} + +// AddFederationToCAs does a federation request and adds the resulting certs to the +// tls.Config RootCAs and ClientCAs. Combines the functionality of +// AddFederationToRootCAs and AddFederationToClientCAs. +func AddFederationToCAs() TLSOption { + fn := func(ctx *TLSOptionCtx) error { + certs, err := ctx.Client.Federation() + if err != nil { + return err + } + if ctx.Config.ClientCAs == nil { + ctx.Config.ClientCAs = x509.NewCertPool() + } + if ctx.Config.RootCAs == nil { + ctx.Config.RootCAs = x509.NewCertPool() + } + for _, cert := range certs.Certificates { + ctx.Config.ClientCAs.AddCert(cert.Certificate) + ctx.Config.RootCAs.AddCert(cert.Certificate) + } + return nil + } + return func(ctx *TLSOptionCtx) error { + ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn) + return fn(ctx) + } +} diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 896ff72b..ceeea7dc 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -4,33 +4,69 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io/ioutil" + "net/http" "reflect" + "sort" "testing" ) -func Test_setTLSOptions(t *testing.T) { +func Test_newTLSOptionCtx(t *testing.T) { + client, err := NewClient("https://ca.smallstep.com", WithTransport(http.DefaultTransport)) + if err != nil { + t.Fatalf("NewClient() error = %v", err) + } + + type args struct { + c *Client + config *tls.Config + } + tests := []struct { + name string + args args + want *TLSOptionCtx + }{ + {"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTLSOptionCtx_apply(t *testing.T) { fail := func() TLSOption { - return func(c *tls.Config) error { + return func(ctx *TLSOptionCtx) error { return fmt.Errorf("an error") } } + + type fields struct { + Config *tls.Config + } type args struct { - c *tls.Config options []TLSOption } tests := []struct { name string + fields fields args args wantErr bool }{ - {"ok", args{&tls.Config{}, []TLSOption{RequireAndVerifyClientCert()}}, false}, - {"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false}, - {"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true}, + {"ok", fields{&tls.Config{}}, args{[]TLSOption{RequireAndVerifyClientCert()}}, false}, + {"ok", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven()}}, false}, + {"fail", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven(), fail()}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr { - t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr) + ctx := &TLSOptionCtx{ + Config: tt.fields.Config, + } + if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr { + t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr) } }) } @@ -45,13 +81,15 @@ func TestRequireAndVerifyClientCert(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := RequireAndVerifyClientCert()(got); err != nil { + ctx := &TLSOptionCtx{ + Config: &tls.Config{}, + } + if err := RequireAndVerifyClientCert()(ctx); err != nil { t.Errorf("RequireAndVerifyClientCert() error = %v", err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("RequireAndVerifyClientCert() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("RequireAndVerifyClientCert() = %v, want %v", ctx.Config, tt.want) } }) } @@ -66,13 +104,15 @@ func TestVerifyClientCertIfGiven(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := VerifyClientCertIfGiven()(got); err != nil { + ctx := &TLSOptionCtx{ + Config: &tls.Config{}, + } + if err := VerifyClientCertIfGiven()(ctx); err != nil { t.Errorf("VerifyClientCertIfGiven() error = %v", err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("VerifyClientCertIfGiven() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("VerifyClientCertIfGiven() = %v, want %v", ctx.Config, tt.want) } }) } @@ -95,13 +135,15 @@ func TestAddRootCA(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := AddRootCA(tt.args.cert)(got); err != nil { + ctx := &TLSOptionCtx{ + Config: &tls.Config{}, + } + if err := AddRootCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddRootCA() error = %v", err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("AddRootCA() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddRootCA() = %v, want %v", ctx.Config, tt.want) } }) } @@ -124,14 +166,380 @@ func TestAddClientCA(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := &tls.Config{} - if err := AddClientCA(tt.args.cert)(got); err != nil { + ctx := &TLSOptionCtx{ + Config: &tls.Config{}, + } + if err := AddClientCA(tt.args.cert)(ctx); err != nil { t.Errorf("AddClientCA() error = %v", err) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("AddClientCA() = %v, want %v", got, tt.want) + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddClientCA() = %v, want %v", ctx.Config, tt.want) } }) } } + +func TestAddRootsToRootCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + cert := parseCertificate(string(root)) + pool := x509.NewCertPool() + pool.AddCert(cert) + + type args struct { + client *Client + config *tls.Config + } + tests := []struct { + name string + args args + want *tls.Config + wantErr bool + }{ + {"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &TLSOptionCtx{ + Client: tt.args.client, + Config: tt.args.config, + } + if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr { + t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want) + } + }) + } +} + +func TestAddRootsToClientCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + cert := parseCertificate(string(root)) + pool := x509.NewCertPool() + pool.AddCert(cert) + + type args struct { + client *Client + config *tls.Config + } + tests := []struct { + name string + args args + want *tls.Config + wantErr bool + }{ + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &TLSOptionCtx{ + Client: tt.args.client, + Config: tt.args.config, + } + if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr { + t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want) + } + }) + } +} + +func TestAddFederationToRootCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") + if err != nil { + t.Fatal(err) + } + + crt1 := parseCertificate(string(root)) + crt2 := parseCertificate(string(federated)) + pool := x509.NewCertPool() + pool.AddCert(crt1) + pool.AddCert(crt2) + + type args struct { + client *Client + config *tls.Config + } + tests := []struct { + name string + args args + want *tls.Config + wantErr bool + }{ + {"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &TLSOptionCtx{ + Client: tt.args.client, + Config: tt.args.config, + } + if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr { + t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(ctx.Config, tt.want) { + // Federated roots are randomly sorted + if !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) || ctx.Config.ClientCAs != nil { + t.Errorf("AddFederationToRootCAs() = %v, want %v", ctx.Config, tt.want) + } + } + }) + } +} + +func TestAddFederationToClientCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") + if err != nil { + t.Fatal(err) + } + + crt1 := parseCertificate(string(root)) + crt2 := parseCertificate(string(federated)) + pool := x509.NewCertPool() + pool.AddCert(crt1) + pool.AddCert(crt2) + + type args struct { + client *Client + config *tls.Config + } + tests := []struct { + name string + args args + want *tls.Config + wantErr bool + }{ + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &TLSOptionCtx{ + Client: tt.args.client, + Config: tt.args.config, + } + if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr { + t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(ctx.Config, tt.want) { + // Federated roots are randomly sorted + if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || ctx.Config.RootCAs != nil { + t.Errorf("AddFederationToClientCAs() = %v, want %v", ctx.Config, tt.want) + } + } + }) + } +} + +func TestAddRootsToCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + cert := parseCertificate(string(root)) + pool := x509.NewCertPool() + pool.AddCert(cert) + + type args struct { + client *Client + config *tls.Config + } + tests := []struct { + name string + args args + want *tls.Config + wantErr bool + }{ + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &TLSOptionCtx{ + Client: tt.args.client, + Config: tt.args.config, + } + if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr { + t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(ctx.Config, tt.want) { + t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want) + } + }) + } +} + +func TestAddFederationToCAs(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Fatal(err) + } + + root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") + if err != nil { + t.Fatal(err) + } + + crt1 := parseCertificate(string(root)) + crt2 := parseCertificate(string(federated)) + pool := x509.NewCertPool() + pool.AddCert(crt1) + pool.AddCert(crt2) + + type args struct { + client *Client + config *tls.Config + } + tests := []struct { + name string + args args + want *tls.Config + wantErr bool + }{ + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &TLSOptionCtx{ + Client: tt.args.client, + Config: tt.args.config, + } + if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr { + t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(ctx.Config, tt.want) { + // Federated roots are randomly sorted + if !equalPools(ctx.Config.ClientCAs, tt.want.ClientCAs) || !equalPools(ctx.Config.RootCAs, tt.want.RootCAs) { + t.Errorf("AddFederationToCAs() = %v, want %v", ctx.Config, tt.want) + } + } + }) + } +} + +func equalPools(a, b *x509.CertPool) bool { + subjects := a.Subjects() + sA := make([]string, len(subjects)) + for i := range subjects { + sA[i] = string(subjects[i]) + } + subjects = b.Subjects() + sB := make([]string, len(subjects)) + for i := range subjects { + sB[i] = string(subjects[i]) + } + sort.Sort(sort.StringSlice(sA)) + sort.Sort(sort.StringSlice(sB)) + return reflect.DeepEqual(sA, sB) +} diff --git a/ca/tls_test.go b/ca/tls_test.go index 799b496a..6c6a5291 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -20,7 +20,7 @@ import ( "github.com/smallstep/certificates/authority" "github.com/smallstep/cli/crypto/randutil" stepJOSE "github.com/smallstep/cli/jose" - "gopkg.in/square/go-jose.v2" + jose "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" ) @@ -242,16 +242,15 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { } func TestClient_GetServerTLSConfig_renew(t *testing.T) { - if testing.Short() { - t.Skip("skipping test in short mode.") - } + reset := setMinCertDuration(1 * time.Second) + defer reset() // Start CA ca := startCATestServer() defer ca.Close() clientDomain := "test.domain" - client, sr, pk := signDuration(ca, "127.0.0.1", 1*time.Minute) + client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second) // Start mTLS server ctx, cancel := context.WithCancel(context.Background()) @@ -274,13 +273,13 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { defer srvTLS.Close() // Transport - client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute) + client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) tr1, err := client.Transport(context.Background(), sr, pk) if err != nil { t.Fatalf("Client.Transport() error = %v", err) } // Transport with tlsConfig - client, sr, pk = signDuration(ca, clientDomain, 1*time.Minute) + client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) if err != nil { t.Fatalf("Client.GetClientTLSConfig() error = %v", err) @@ -367,9 +366,9 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { t.Errorf("number of fingerprints unexpected, got %d, want 2", l) } - // Wait for renewal 40s == 1m-1m/3 - log.Printf("Sleeping for %s ...\n", 40*time.Second) - time.Sleep(40 * time.Second) + // Wait for renewal + log.Printf("Sleeping for %s ...\n", 5*time.Second) + time.Sleep(5 * time.Second) for _, tt := range tests { t.Run("renewed "+tt.name, func(t *testing.T) {