Add initial support for federated root certificates.
This commit is contained in:
parent
37149ed3ea
commit
722bcb7e7a
10 changed files with 277 additions and 56 deletions
|
@ -2,7 +2,7 @@ package authority
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
realx509 "crypto/x509"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority"
|
|||
// Authority implements the Certificate Authority internal interface.
|
||||
type Authority struct {
|
||||
config *Config
|
||||
rootX509Crt *realx509.Certificate
|
||||
rootX509Crt *x509.Certificate
|
||||
intermediateIdentity *x509util.Identity
|
||||
validateOnce bool
|
||||
certificates *sync.Map
|
||||
|
@ -89,6 +89,16 @@ func (a *Authority) init() error {
|
|||
sum := sha256.Sum256(a.rootX509Crt.Raw)
|
||||
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
|
||||
|
||||
// Add federated roots
|
||||
for _, path := range a.config.FederatedRoots {
|
||||
crt, err := pemutil.ReadCertificate(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sum := sha256.Sum256(crt.Raw)
|
||||
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
|
||||
}
|
||||
|
||||
// Decrypt and load intermediate public / private key pair.
|
||||
if len(a.config.Password) > 0 {
|
||||
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(
|
||||
|
|
|
@ -33,38 +33,10 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
type duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
// MarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *duration) UnmarshalJSON(data []byte) (err error) {
|
||||
var s string
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
}
|
||||
if d.Duration, err = time.ParseDuration(s); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Config represents the CA configuration and it's mapped to a JSON object.
|
||||
type Config struct {
|
||||
Root string `json:"root"`
|
||||
FederatedRoots []string `json:"federatedRoots"`
|
||||
IntermediateCert string `json:"crt"`
|
||||
IntermediateKey string `json:"key"`
|
||||
Address string `json:"address"`
|
||||
|
|
|
@ -17,7 +17,7 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
|||
|
||||
crt, ok := val.(*x509.Certificate)
|
||||
if !ok {
|
||||
return nil, &apiError{errors.Errorf("stored value is not a *cryto/x509.Certificate"),
|
||||
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
}
|
||||
return crt, nil
|
||||
|
@ -27,3 +27,24 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
|||
func (a *Authority) GetRootCertificate() *x509.Certificate {
|
||||
return a.rootX509Crt
|
||||
}
|
||||
|
||||
// GetFederation returns all the root certificates in the federation.
|
||||
func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) {
|
||||
// Check step provisioner extensions
|
||||
if err := a.authorizeRenewal(peer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.certificates.Range(func(k, v interface{}) bool {
|
||||
crt, ok := v.(*x509.Certificate)
|
||||
if !ok {
|
||||
federation = nil
|
||||
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
return false
|
||||
}
|
||||
federation = append(federation, crt)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ func TestRoot(t *testing.T) {
|
|||
err *apiError
|
||||
}{
|
||||
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
|
||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *cryto/x509.Certificate"), http.StatusInternalServerError, context{}}},
|
||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}},
|
||||
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
|
||||
}
|
||||
|
||||
|
|
98
authority/types.go
Normal file
98
authority/types.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
// MarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a duration string and sets it to the duration.
|
||||
//
|
||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
func (d *duration) UnmarshalJSON(data []byte) (err error) {
|
||||
var s string
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
}
|
||||
if d.Duration, err = time.ParseDuration(s); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type multiString []string
|
||||
|
||||
// FIXME: remove me, avoids deadcode warning
|
||||
var _ = multiString{}
|
||||
|
||||
// First returns the first element of a multiString. It will return an empty
|
||||
// string if the multistring is empty.
|
||||
func (s multiString) First() string {
|
||||
if len(s) > 0 {
|
||||
return s[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Empties checks that none of the string is empty.
|
||||
func (s multiString) Empties() bool {
|
||||
if len(s) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, ss := range s {
|
||||
if len(ss) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalJSON marshals the multistring as a string or a slice of strings . With
|
||||
// 0 elements it will return the empty string, with 1 element a regular string,
|
||||
// otherwise a slice of strings.
|
||||
func (s multiString) MarshalJSON() ([]byte, error) {
|
||||
switch len(s) {
|
||||
case 0:
|
||||
return []byte(""), nil
|
||||
case 1:
|
||||
return json.Marshal(s[0])
|
||||
default:
|
||||
return json.Marshal(s)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a string or a slice and sets it to the multiString.
|
||||
func (s *multiString) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
*s = nil
|
||||
return nil
|
||||
}
|
||||
if data[0] == '"' {
|
||||
var str string
|
||||
if err := json.Unmarshal(data, &str); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
}
|
||||
*s = []string{str}
|
||||
return nil
|
||||
}
|
||||
if err := json.Unmarshal(data, s); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
}
|
||||
return nil
|
||||
}
|
19
ca/client.go
19
ca/client.go
|
@ -413,6 +413,25 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error)
|
|||
return &key, nil
|
||||
}
|
||||
|
||||
// Federation performs the get federation request to the CA and returns the
|
||||
// api.FederationResponse struct.
|
||||
func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) {
|
||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
|
||||
client := &http.Client{Transport: tr}
|
||||
resp, err := client.Get(u.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readError(resp.Body)
|
||||
}
|
||||
var federation api.FederationResponse
|
||||
if err := readJSON(resp.Body, &federation); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
}
|
||||
return &federation, nil
|
||||
}
|
||||
|
||||
// CreateSignRequest is a helper function that given an x509 OTT returns a
|
||||
// simple but secure sign request as well as the private key used.
|
||||
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {
|
||||
|
|
|
@ -512,6 +512,67 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClient_Federation(t *testing.T) {
|
||||
ok := &api.FederationResponse{
|
||||
Certificates: []api.Certificate{
|
||||
{Certificate: parseCertificate(rootPEM)},
|
||||
},
|
||||
}
|
||||
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", ok, 200, false},
|
||||
{"unauthorized", unauthorized, 401, true},
|
||||
{"empty request", badRequest, 403, true},
|
||||
{"nil request", badRequest, 403, true},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
defer srv.Close()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Errorf("NewClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(tt.responseCode)
|
||||
api.JSON(w, tt.response)
|
||||
})
|
||||
|
||||
got, err := c.Federation(nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Printf("%+v", err)
|
||||
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case err != nil:
|
||||
if got != nil {
|
||||
t.Errorf("Client.Federation() = %v, want nil", got)
|
||||
}
|
||||
if !reflect.DeepEqual(err, tt.response) {
|
||||
t.Errorf("Client.Federation() error = %v, want %v", err, tt.response)
|
||||
}
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseEndpoint(t *testing.T) {
|
||||
expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"}
|
||||
expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"}
|
||||
|
|
|
@ -41,7 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
}
|
||||
|
||||
// Apply options if given
|
||||
if err := setTLSOptions(tlsConfig, options); err != nil {
|
||||
if err := setTLSOptions(c, tlsConfig, options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -87,7 +87,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
|||
}
|
||||
|
||||
// Apply options if given
|
||||
if err := setTLSOptions(tlsConfig, options); err != nil {
|
||||
if err := setTLSOptions(c, tlsConfig, options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -6,13 +6,13 @@ import (
|
|||
)
|
||||
|
||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||
type TLSOption func(c *tls.Config) error
|
||||
type TLSOption func(c *Client, config *tls.Config) error
|
||||
|
||||
// setTLSOptions takes one or more option function and applies them in order to
|
||||
// a tls.Config.
|
||||
func setTLSOptions(c *tls.Config, options []TLSOption) error {
|
||||
func setTLSOptions(c *Client, config *tls.Config, options []TLSOption) error {
|
||||
for _, opt := range options {
|
||||
if err := opt(c); err != nil {
|
||||
if err := opt(c, config); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -22,8 +22,8 @@ func setTLSOptions(c *tls.Config, options []TLSOption) error {
|
|||
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
||||
// a valid TLS client certificate. This is the default option for mTLS servers.
|
||||
func RequireAndVerifyClientCert() TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
c.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
return func(_ *Client, config *tls.Config) error {
|
||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -31,8 +31,8 @@ func RequireAndVerifyClientCert() TLSOption {
|
|||
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
||||
// a TLS client certificate if it is provided. It does not requires a certificate.
|
||||
func VerifyClientCertIfGiven() TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
c.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
return func(_ *Client, config *tls.Config) error {
|
||||
config.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -41,11 +41,11 @@ func VerifyClientCertIfGiven() TLSOption {
|
|||
// defines the set of root certificate authorities that clients use when
|
||||
// verifying server certificates.
|
||||
func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
if c.RootCAs == nil {
|
||||
c.RootCAs = x509.NewCertPool()
|
||||
return func(_ *Client, config *tls.Config) error {
|
||||
if config.RootCAs == nil {
|
||||
config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
c.RootCAs.AddCert(cert)
|
||||
config.RootCAs.AddCert(cert)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -54,11 +54,51 @@ func AddRootCA(cert *x509.Certificate) TLSOption {
|
|||
// defines the set of root certificate authorities that servers use if required
|
||||
// to verify a client certificate by the policy in ClientAuth.
|
||||
func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
if c.ClientCAs == nil {
|
||||
c.ClientCAs = x509.NewCertPool()
|
||||
return func(_ *Client, config *tls.Config) error {
|
||||
if config.ClientCAs == nil {
|
||||
config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
config.ClientCAs.AddCert(cert)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddRootFederation does a federation request and adds to the tls.Config
|
||||
// RootCAs all the certificates in the response. RootCAs
|
||||
// defines the set of root certificate authorities that clients use when
|
||||
// verifying server certificates.
|
||||
func AddRootFederation() TLSOption {
|
||||
return func(c *Client, config *tls.Config) error {
|
||||
if config.RootCAs == nil {
|
||||
config.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
certs, err := c.Federation(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
config.RootCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddClientFederation does a federation request and adds to the tls.Config
|
||||
// ClientCAs all the certificates in the response. ClientCAs defines the set of
|
||||
// root certificate authorities that servers use if required to verify a client
|
||||
// certificate by the policy in ClientAuth.
|
||||
func AddClientFederation() TLSOption {
|
||||
return func(c *Client, config *tls.Config) error {
|
||||
if config.ClientCAs == nil {
|
||||
config.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
certs, err := c.Federation(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, cert := range certs.Certificates {
|
||||
config.ClientCAs.AddCert(cert.Certificate)
|
||||
}
|
||||
c.ClientCAs.AddCert(cert)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
func Test_setTLSOptions(t *testing.T) {
|
||||
fail := func() TLSOption {
|
||||
return func(c *tls.Config) error {
|
||||
return func(c *Client, config *tls.Config) error {
|
||||
return fmt.Errorf("an error")
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func Test_setTLSOptions(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
||||
if err := setTLSOptions(nil, tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
||||
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
|
@ -46,7 +46,7 @@ func TestRequireAndVerifyClientCert(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := RequireAndVerifyClientCert()(got); err != nil {
|
||||
if err := RequireAndVerifyClientCert()(nil, got); err != nil {
|
||||
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := VerifyClientCertIfGiven()(got); err != nil {
|
||||
if err := VerifyClientCertIfGiven()(nil, got); err != nil {
|
||||
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ func TestAddRootCA(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := AddRootCA(tt.args.cert)(got); err != nil {
|
||||
if err := AddRootCA(tt.args.cert)(nil, got); err != nil {
|
||||
t.Errorf("AddRootCA() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
@ -125,7 +125,7 @@ func TestAddClientCA(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &tls.Config{}
|
||||
if err := AddClientCA(tt.args.cert)(got); err != nil {
|
||||
if err := AddClientCA(tt.args.cert)(nil, got); err != nil {
|
||||
t.Errorf("AddClientCA() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue