diff --git a/authority/authority.go b/authority/authority.go index 4753f83b..8c60fffc 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -2,7 +2,7 @@ package authority import ( "crypto/sha256" - realx509 "crypto/x509" + "crypto/x509" "encoding/hex" "fmt" "sync" @@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority" // Authority implements the Certificate Authority internal interface. type Authority struct { config *Config - rootX509Crt *realx509.Certificate + rootX509Crt *x509.Certificate intermediateIdentity *x509util.Identity validateOnce bool certificates *sync.Map @@ -89,6 +89,16 @@ func (a *Authority) init() error { sum := sha256.Sum256(a.rootX509Crt.Raw) a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt) + // Add federated roots + for _, path := range a.config.FederatedRoots { + crt, err := pemutil.ReadCertificate(path) + if err != nil { + return err + } + sum := sha256.Sum256(crt.Raw) + a.certificates.Store(hex.EncodeToString(sum[:]), crt) + } + // Decrypt and load intermediate public / private key pair. if len(a.config.Password) > 0 { a.intermediateIdentity, err = x509util.LoadIdentityFromDisk( diff --git a/authority/config.go b/authority/config.go index 5f9910af..0ffdbef0 100644 --- a/authority/config.go +++ b/authority/config.go @@ -33,38 +33,10 @@ var ( } ) -type duration struct { - time.Duration -} - -// MarshalJSON parses a duration string and sets it to the duration. -// -// A duration string is a possibly signed sequence of decimal numbers, each with -// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". -// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) MarshalJSON() ([]byte, error) { - return json.Marshal(d.String()) -} - -// UnmarshalJSON parses a duration string and sets it to the duration. -// -// A duration string is a possibly signed sequence of decimal numbers, each with -// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". -// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -func (d *duration) UnmarshalJSON(data []byte) (err error) { - var s string - if err = json.Unmarshal(data, &s); err != nil { - return errors.Wrapf(err, "error unmarshalling %s", data) - } - if d.Duration, err = time.ParseDuration(s); err != nil { - return errors.Wrapf(err, "error parsing %s as duration", s) - } - return -} - // Config represents the CA configuration and it's mapped to a JSON object. type Config struct { Root string `json:"root"` + FederatedRoots []string `json:"federatedRoots"` IntermediateCert string `json:"crt"` IntermediateKey string `json:"key"` Address string `json:"address"` diff --git a/authority/root.go b/authority/root.go index 5c918ee1..01710db8 100644 --- a/authority/root.go +++ b/authority/root.go @@ -17,7 +17,7 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) { crt, ok := val.(*x509.Certificate) if !ok { - return nil, &apiError{errors.Errorf("stored value is not a *cryto/x509.Certificate"), + return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}} } return crt, nil @@ -27,3 +27,24 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) { func (a *Authority) GetRootCertificate() *x509.Certificate { return a.rootX509Crt } + +// GetFederation returns all the root certificates in the federation. +func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) { + // Check step provisioner extensions + if err := a.authorizeRenewal(peer); err != nil { + return nil, err + } + + a.certificates.Range(func(k, v interface{}) bool { + crt, ok := v.(*x509.Certificate) + if !ok { + federation = nil + err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"), + http.StatusInternalServerError, context{}} + return false + } + federation = append(federation, crt) + return true + }) + return +} diff --git a/authority/root_test.go b/authority/root_test.go index fd4c31db..0db7e866 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -17,7 +17,7 @@ func TestRoot(t *testing.T) { err *apiError }{ "not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}}, - "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *cryto/x509.Certificate"), http.StatusInternalServerError, context{}}}, + "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}}, "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil}, } diff --git a/authority/types.go b/authority/types.go new file mode 100644 index 00000000..ec8f0d7b --- /dev/null +++ b/authority/types.go @@ -0,0 +1,98 @@ +package authority + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +type duration struct { + time.Duration +} + +// MarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +// UnmarshalJSON parses a duration string and sets it to the duration. +// +// A duration string is a possibly signed sequence of decimal numbers, each with +// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +func (d *duration) UnmarshalJSON(data []byte) (err error) { + var s string + if err = json.Unmarshal(data, &s); err != nil { + return errors.Wrapf(err, "error unmarshalling %s", data) + } + if d.Duration, err = time.ParseDuration(s); err != nil { + return errors.Wrapf(err, "error parsing %s as duration", s) + } + return +} + +type multiString []string + +// FIXME: remove me, avoids deadcode warning +var _ = multiString{} + +// First returns the first element of a multiString. It will return an empty +// string if the multistring is empty. +func (s multiString) First() string { + if len(s) > 0 { + return s[0] + } + return "" +} + +// Empties checks that none of the string is empty. +func (s multiString) Empties() bool { + if len(s) == 0 { + return true + } + for _, ss := range s { + if len(ss) == 0 { + return true + } + } + return false +} + +// MarshalJSON marshals the multistring as a string or a slice of strings . With +// 0 elements it will return the empty string, with 1 element a regular string, +// otherwise a slice of strings. +func (s multiString) MarshalJSON() ([]byte, error) { + switch len(s) { + case 0: + return []byte(""), nil + case 1: + return json.Marshal(s[0]) + default: + return json.Marshal(s) + } +} + +// UnmarshalJSON parses a string or a slice and sets it to the multiString. +func (s *multiString) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + *s = nil + return nil + } + if data[0] == '"' { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return errors.Wrapf(err, "error unmarshalling %s", data) + } + *s = []string{str} + return nil + } + if err := json.Unmarshal(data, s); err != nil { + return errors.Wrapf(err, "error unmarshalling %s", data) + } + return nil +} diff --git a/ca/client.go b/ca/client.go index 47116245..374a68ff 100644 --- a/ca/client.go +++ b/ca/client.go @@ -413,6 +413,25 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) return &key, nil } +// Federation performs the get federation request to the CA and returns the +// api.FederationResponse struct. +func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"}) + client := &http.Client{Transport: tr} + resp, err := client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var federation api.FederationResponse + if err := readJSON(resp.Body, &federation); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &federation, nil +} + // CreateSignRequest is a helper function that given an x509 OTT returns a // simple but secure sign request as well as the private key used. func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) { diff --git a/ca/client_test.go b/ca/client_test.go index 6d5cd22a..138b0d7d 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -512,6 +512,67 @@ func TestClient_ProvisionerKey(t *testing.T) { } } +func TestClient_Federation(t *testing.T) { + ok := &api.FederationResponse{ + Certificates: []api.Certificate{ + {Certificate: parseCertificate(rootPEM)}, + }, + } + unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized")) + badRequest := api.BadRequest(fmt.Errorf("Bad Request")) + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + }{ + {"ok", ok, 200, false}, + {"unauthorized", unauthorized, 401, true}, + {"empty request", badRequest, 403, true}, + {"nil request", badRequest, 403, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(tt.responseCode) + api.JSON(w, tt.response) + }) + + got, err := c.Federation(nil) + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Federation() = %v, want nil", got) + } + if !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.Federation() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Federation() = %v, want %v", got, tt.response) + } + } + }) + } +} + func Test_parseEndpoint(t *testing.T) { expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"} expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"} diff --git a/ca/tls.go b/ca/tls.go index 5e8c4118..bef8e553 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -41,7 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(tlsConfig, options); err != nil { + if err := setTLSOptions(c, tlsConfig, options); err != nil { return nil, err } @@ -87,7 +87,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - if err := setTLSOptions(tlsConfig, options); err != nil { + if err := setTLSOptions(c, tlsConfig, options); err != nil { return nil, err } diff --git a/ca/tls_options.go b/ca/tls_options.go index fb0bb20b..b1cd4696 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -6,13 +6,13 @@ import ( ) // TLSOption defines the type of a function that modifies a tls.Config. -type TLSOption func(c *tls.Config) error +type TLSOption func(c *Client, config *tls.Config) error // setTLSOptions takes one or more option function and applies them in order to // a tls.Config. -func setTLSOptions(c *tls.Config, options []TLSOption) error { +func setTLSOptions(c *Client, config *tls.Config, options []TLSOption) error { for _, opt := range options { - if err := opt(c); err != nil { + if err := opt(c, config); err != nil { return err } } @@ -22,8 +22,8 @@ func setTLSOptions(c *tls.Config, options []TLSOption) error { // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // a valid TLS client certificate. This is the default option for mTLS servers. func RequireAndVerifyClientCert() TLSOption { - return func(c *tls.Config) error { - c.ClientAuth = tls.RequireAndVerifyClientCert + return func(_ *Client, config *tls.Config) error { + config.ClientAuth = tls.RequireAndVerifyClientCert return nil } } @@ -31,8 +31,8 @@ func RequireAndVerifyClientCert() TLSOption { // VerifyClientCertIfGiven is a tls.Config option used on on servers to validate // a TLS client certificate if it is provided. It does not requires a certificate. func VerifyClientCertIfGiven() TLSOption { - return func(c *tls.Config) error { - c.ClientAuth = tls.VerifyClientCertIfGiven + return func(_ *Client, config *tls.Config) error { + config.ClientAuth = tls.VerifyClientCertIfGiven return nil } } @@ -41,11 +41,11 @@ func VerifyClientCertIfGiven() TLSOption { // defines the set of root certificate authorities that clients use when // verifying server certificates. func AddRootCA(cert *x509.Certificate) TLSOption { - return func(c *tls.Config) error { - if c.RootCAs == nil { - c.RootCAs = x509.NewCertPool() + return func(_ *Client, config *tls.Config) error { + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() } - c.RootCAs.AddCert(cert) + config.RootCAs.AddCert(cert) return nil } } @@ -54,11 +54,51 @@ func AddRootCA(cert *x509.Certificate) TLSOption { // defines the set of root certificate authorities that servers use if required // to verify a client certificate by the policy in ClientAuth. func AddClientCA(cert *x509.Certificate) TLSOption { - return func(c *tls.Config) error { - if c.ClientCAs == nil { - c.ClientCAs = x509.NewCertPool() + return func(_ *Client, config *tls.Config) error { + if config.ClientCAs == nil { + config.ClientCAs = x509.NewCertPool() + } + config.ClientCAs.AddCert(cert) + return nil + } +} + +// AddRootFederation does a federation request and adds to the tls.Config +// RootCAs all the certificates in the response. RootCAs +// defines the set of root certificate authorities that clients use when +// verifying server certificates. +func AddRootFederation() TLSOption { + return func(c *Client, config *tls.Config) error { + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + certs, err := c.Federation(nil) + if err != nil { + return err + } + for _, cert := range certs.Certificates { + config.RootCAs.AddCert(cert.Certificate) + } + return nil + } +} + +// AddClientFederation does a federation request and adds to the tls.Config +// ClientCAs all the certificates in the response. ClientCAs defines the set of +// root certificate authorities that servers use if required to verify a client +// certificate by the policy in ClientAuth. +func AddClientFederation() TLSOption { + return func(c *Client, config *tls.Config) error { + if config.ClientCAs == nil { + config.ClientCAs = x509.NewCertPool() + } + certs, err := c.Federation(nil) + if err != nil { + return err + } + for _, cert := range certs.Certificates { + config.ClientCAs.AddCert(cert.Certificate) } - c.ClientCAs.AddCert(cert) return nil } } diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 896ff72b..9886e487 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -10,7 +10,7 @@ import ( func Test_setTLSOptions(t *testing.T) { fail := func() TLSOption { - return func(c *tls.Config) error { + return func(c *Client, config *tls.Config) error { return fmt.Errorf("an error") } } @@ -29,7 +29,7 @@ func Test_setTLSOptions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr { + if err := setTLSOptions(nil, tt.args.c, tt.args.options); (err != nil) != tt.wantErr { t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -46,7 +46,7 @@ func TestRequireAndVerifyClientCert(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := RequireAndVerifyClientCert()(got); err != nil { + if err := RequireAndVerifyClientCert()(nil, got); err != nil { t.Errorf("RequireAndVerifyClientCert() error = %v", err) return } @@ -67,7 +67,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := VerifyClientCertIfGiven()(got); err != nil { + if err := VerifyClientCertIfGiven()(nil, got); err != nil { t.Errorf("VerifyClientCertIfGiven() error = %v", err) return } @@ -96,7 +96,7 @@ func TestAddRootCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := AddRootCA(tt.args.cert)(got); err != nil { + if err := AddRootCA(tt.args.cert)(nil, got); err != nil { t.Errorf("AddRootCA() error = %v", err) return } @@ -125,7 +125,7 @@ func TestAddClientCA(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := &tls.Config{} - if err := AddClientCA(tt.args.cert)(got); err != nil { + if err := AddClientCA(tt.args.cert)(nil, got); err != nil { t.Errorf("AddClientCA() error = %v", err) return }