Add initial support for federated root certificates.

This commit is contained in:
Mariano Cano 2019-01-04 17:51:32 -08:00
parent 37149ed3ea
commit 722bcb7e7a
10 changed files with 277 additions and 56 deletions

View file

@ -2,7 +2,7 @@ package authority
import ( import (
"crypto/sha256" "crypto/sha256"
realx509 "crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"sync" "sync"
@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority"
// Authority implements the Certificate Authority internal interface. // Authority implements the Certificate Authority internal interface.
type Authority struct { type Authority struct {
config *Config config *Config
rootX509Crt *realx509.Certificate rootX509Crt *x509.Certificate
intermediateIdentity *x509util.Identity intermediateIdentity *x509util.Identity
validateOnce bool validateOnce bool
certificates *sync.Map certificates *sync.Map
@ -89,6 +89,16 @@ func (a *Authority) init() error {
sum := sha256.Sum256(a.rootX509Crt.Raw) sum := sha256.Sum256(a.rootX509Crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt) a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
// Add federated roots
for _, path := range a.config.FederatedRoots {
crt, err := pemutil.ReadCertificate(path)
if err != nil {
return err
}
sum := sha256.Sum256(crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
}
// Decrypt and load intermediate public / private key pair. // Decrypt and load intermediate public / private key pair.
if len(a.config.Password) > 0 { if len(a.config.Password) > 0 {
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk( a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(

View file

@ -33,38 +33,10 @@ var (
} }
) )
type duration struct {
time.Duration
}
// MarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) UnmarshalJSON(data []byte) (err error) {
var s string
if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
if d.Duration, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s)
}
return
}
// Config represents the CA configuration and it's mapped to a JSON object. // Config represents the CA configuration and it's mapped to a JSON object.
type Config struct { type Config struct {
Root string `json:"root"` Root string `json:"root"`
FederatedRoots []string `json:"federatedRoots"`
IntermediateCert string `json:"crt"` IntermediateCert string `json:"crt"`
IntermediateKey string `json:"key"` IntermediateKey string `json:"key"`
Address string `json:"address"` Address string `json:"address"`

View file

@ -17,7 +17,7 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
crt, ok := val.(*x509.Certificate) crt, ok := val.(*x509.Certificate)
if !ok { if !ok {
return nil, &apiError{errors.Errorf("stored value is not a *cryto/x509.Certificate"), return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, context{}} http.StatusInternalServerError, context{}}
} }
return crt, nil return crt, nil
@ -27,3 +27,24 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
func (a *Authority) GetRootCertificate() *x509.Certificate { func (a *Authority) GetRootCertificate() *x509.Certificate {
return a.rootX509Crt return a.rootX509Crt
} }
// GetFederation returns all the root certificates in the federation.
func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) {
// Check step provisioner extensions
if err := a.authorizeRenewal(peer); err != nil {
return nil, err
}
a.certificates.Range(func(k, v interface{}) bool {
crt, ok := v.(*x509.Certificate)
if !ok {
federation = nil
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
http.StatusInternalServerError, context{}}
return false
}
federation = append(federation, crt)
return true
})
return
}

View file

@ -17,7 +17,7 @@ func TestRoot(t *testing.T) {
err *apiError err *apiError
}{ }{
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}}, "not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *cryto/x509.Certificate"), http.StatusInternalServerError, context{}}}, "invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}},
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil}, "success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
} }

98
authority/types.go Normal file
View file

@ -0,0 +1,98 @@
package authority
import (
"encoding/json"
"time"
"github.com/pkg/errors"
)
type duration struct {
time.Duration
}
// MarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON parses a duration string and sets it to the duration.
//
// A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) UnmarshalJSON(data []byte) (err error) {
var s string
if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
if d.Duration, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s)
}
return
}
type multiString []string
// FIXME: remove me, avoids deadcode warning
var _ = multiString{}
// First returns the first element of a multiString. It will return an empty
// string if the multistring is empty.
func (s multiString) First() string {
if len(s) > 0 {
return s[0]
}
return ""
}
// Empties checks that none of the string is empty.
func (s multiString) Empties() bool {
if len(s) == 0 {
return true
}
for _, ss := range s {
if len(ss) == 0 {
return true
}
}
return false
}
// MarshalJSON marshals the multistring as a string or a slice of strings . With
// 0 elements it will return the empty string, with 1 element a regular string,
// otherwise a slice of strings.
func (s multiString) MarshalJSON() ([]byte, error) {
switch len(s) {
case 0:
return []byte(""), nil
case 1:
return json.Marshal(s[0])
default:
return json.Marshal(s)
}
}
// UnmarshalJSON parses a string or a slice and sets it to the multiString.
func (s *multiString) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
*s = nil
return nil
}
if data[0] == '"' {
var str string
if err := json.Unmarshal(data, &str); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
*s = []string{str}
return nil
}
if err := json.Unmarshal(data, s); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data)
}
return nil
}

View file

@ -413,6 +413,25 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error)
return &key, nil return &key, nil
} }
// Federation performs the get federation request to the CA and returns the
// api.FederationResponse struct.
func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
client := &http.Client{Transport: tr}
resp, err := client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
return nil, readError(resp.Body)
}
var federation api.FederationResponse
if err := readJSON(resp.Body, &federation); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return &federation, nil
}
// CreateSignRequest is a helper function that given an x509 OTT returns a // CreateSignRequest is a helper function that given an x509 OTT returns a
// simple but secure sign request as well as the private key used. // simple but secure sign request as well as the private key used.
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) { func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {

View file

@ -512,6 +512,67 @@ func TestClient_ProvisionerKey(t *testing.T) {
} }
} }
func TestClient_Federation(t *testing.T) {
ok := &api.FederationResponse{
Certificates: []api.Certificate{
{Certificate: parseCertificate(rootPEM)},
},
}
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
}{
{"ok", ok, 200, false},
{"unauthorized", unauthorized, 401, true},
{"empty request", badRequest, 403, true},
{"nil request", badRequest, 403, true},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.Federation(nil)
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got)
}
if !reflect.DeepEqual(err, tt.response) {
t.Errorf("Client.Federation() error = %v, want %v", err, tt.response)
}
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
}
}
})
}
}
func Test_parseEndpoint(t *testing.T) { func Test_parseEndpoint(t *testing.T) {
expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"} expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"}
expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"} expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"}

View file

@ -41,7 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
} }
// Apply options if given // Apply options if given
if err := setTLSOptions(tlsConfig, options); err != nil { if err := setTLSOptions(c, tlsConfig, options); err != nil {
return nil, err return nil, err
} }
@ -87,7 +87,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
} }
// Apply options if given // Apply options if given
if err := setTLSOptions(tlsConfig, options); err != nil { if err := setTLSOptions(c, tlsConfig, options); err != nil {
return nil, err return nil, err
} }

View file

@ -6,13 +6,13 @@ import (
) )
// TLSOption defines the type of a function that modifies a tls.Config. // TLSOption defines the type of a function that modifies a tls.Config.
type TLSOption func(c *tls.Config) error type TLSOption func(c *Client, config *tls.Config) error
// setTLSOptions takes one or more option function and applies them in order to // setTLSOptions takes one or more option function and applies them in order to
// a tls.Config. // a tls.Config.
func setTLSOptions(c *tls.Config, options []TLSOption) error { func setTLSOptions(c *Client, config *tls.Config, options []TLSOption) error {
for _, opt := range options { for _, opt := range options {
if err := opt(c); err != nil { if err := opt(c, config); err != nil {
return err return err
} }
} }
@ -22,8 +22,8 @@ func setTLSOptions(c *tls.Config, options []TLSOption) error {
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
// a valid TLS client certificate. This is the default option for mTLS servers. // a valid TLS client certificate. This is the default option for mTLS servers.
func RequireAndVerifyClientCert() TLSOption { func RequireAndVerifyClientCert() TLSOption {
return func(c *tls.Config) error { return func(_ *Client, config *tls.Config) error {
c.ClientAuth = tls.RequireAndVerifyClientCert config.ClientAuth = tls.RequireAndVerifyClientCert
return nil return nil
} }
} }
@ -31,8 +31,8 @@ func RequireAndVerifyClientCert() TLSOption {
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate // VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
// a TLS client certificate if it is provided. It does not requires a certificate. // a TLS client certificate if it is provided. It does not requires a certificate.
func VerifyClientCertIfGiven() TLSOption { func VerifyClientCertIfGiven() TLSOption {
return func(c *tls.Config) error { return func(_ *Client, config *tls.Config) error {
c.ClientAuth = tls.VerifyClientCertIfGiven config.ClientAuth = tls.VerifyClientCertIfGiven
return nil return nil
} }
} }
@ -41,11 +41,11 @@ func VerifyClientCertIfGiven() TLSOption {
// defines the set of root certificate authorities that clients use when // defines the set of root certificate authorities that clients use when
// verifying server certificates. // verifying server certificates.
func AddRootCA(cert *x509.Certificate) TLSOption { func AddRootCA(cert *x509.Certificate) TLSOption {
return func(c *tls.Config) error { return func(_ *Client, config *tls.Config) error {
if c.RootCAs == nil { if config.RootCAs == nil {
c.RootCAs = x509.NewCertPool() config.RootCAs = x509.NewCertPool()
} }
c.RootCAs.AddCert(cert) config.RootCAs.AddCert(cert)
return nil return nil
} }
} }
@ -54,11 +54,51 @@ func AddRootCA(cert *x509.Certificate) TLSOption {
// defines the set of root certificate authorities that servers use if required // defines the set of root certificate authorities that servers use if required
// to verify a client certificate by the policy in ClientAuth. // to verify a client certificate by the policy in ClientAuth.
func AddClientCA(cert *x509.Certificate) TLSOption { func AddClientCA(cert *x509.Certificate) TLSOption {
return func(c *tls.Config) error { return func(_ *Client, config *tls.Config) error {
if c.ClientCAs == nil { if config.ClientCAs == nil {
c.ClientCAs = x509.NewCertPool() config.ClientCAs = x509.NewCertPool()
}
config.ClientCAs.AddCert(cert)
return nil
}
}
// AddRootFederation does a federation request and adds to the tls.Config
// RootCAs all the certificates in the response. RootCAs
// defines the set of root certificate authorities that clients use when
// verifying server certificates.
func AddRootFederation() TLSOption {
return func(c *Client, config *tls.Config) error {
if config.RootCAs == nil {
config.RootCAs = x509.NewCertPool()
}
certs, err := c.Federation(nil)
if err != nil {
return err
}
for _, cert := range certs.Certificates {
config.RootCAs.AddCert(cert.Certificate)
}
return nil
}
}
// AddClientFederation does a federation request and adds to the tls.Config
// ClientCAs all the certificates in the response. ClientCAs defines the set of
// root certificate authorities that servers use if required to verify a client
// certificate by the policy in ClientAuth.
func AddClientFederation() TLSOption {
return func(c *Client, config *tls.Config) error {
if config.ClientCAs == nil {
config.ClientCAs = x509.NewCertPool()
}
certs, err := c.Federation(nil)
if err != nil {
return err
}
for _, cert := range certs.Certificates {
config.ClientCAs.AddCert(cert.Certificate)
} }
c.ClientCAs.AddCert(cert)
return nil return nil
} }
} }

View file

@ -10,7 +10,7 @@ import (
func Test_setTLSOptions(t *testing.T) { func Test_setTLSOptions(t *testing.T) {
fail := func() TLSOption { fail := func() TLSOption {
return func(c *tls.Config) error { return func(c *Client, config *tls.Config) error {
return fmt.Errorf("an error") return fmt.Errorf("an error")
} }
} }
@ -29,7 +29,7 @@ func Test_setTLSOptions(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr { if err := setTLSOptions(nil, tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
@ -46,7 +46,7 @@ func TestRequireAndVerifyClientCert(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{} got := &tls.Config{}
if err := RequireAndVerifyClientCert()(got); err != nil { if err := RequireAndVerifyClientCert()(nil, got); err != nil {
t.Errorf("RequireAndVerifyClientCert() error = %v", err) t.Errorf("RequireAndVerifyClientCert() error = %v", err)
return return
} }
@ -67,7 +67,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{} got := &tls.Config{}
if err := VerifyClientCertIfGiven()(got); err != nil { if err := VerifyClientCertIfGiven()(nil, got); err != nil {
t.Errorf("VerifyClientCertIfGiven() error = %v", err) t.Errorf("VerifyClientCertIfGiven() error = %v", err)
return return
} }
@ -96,7 +96,7 @@ func TestAddRootCA(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{} got := &tls.Config{}
if err := AddRootCA(tt.args.cert)(got); err != nil { if err := AddRootCA(tt.args.cert)(nil, got); err != nil {
t.Errorf("AddRootCA() error = %v", err) t.Errorf("AddRootCA() error = %v", err)
return return
} }
@ -125,7 +125,7 @@ func TestAddClientCA(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := &tls.Config{} got := &tls.Config{}
if err := AddClientCA(tt.args.cert)(got); err != nil { if err := AddClientCA(tt.args.cert)(nil, got); err != nil {
t.Errorf("AddClientCA() error = %v", err) t.Errorf("AddClientCA() error = %v", err)
return return
} }