forked from TrueCloudLab/certificates
Add initial support for federated root certificates.
This commit is contained in:
parent
37149ed3ea
commit
722bcb7e7a
10 changed files with 277 additions and 56 deletions
|
@ -2,7 +2,7 @@ package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
realx509 "crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority"
|
||||||
// Authority implements the Certificate Authority internal interface.
|
// Authority implements the Certificate Authority internal interface.
|
||||||
type Authority struct {
|
type Authority struct {
|
||||||
config *Config
|
config *Config
|
||||||
rootX509Crt *realx509.Certificate
|
rootX509Crt *x509.Certificate
|
||||||
intermediateIdentity *x509util.Identity
|
intermediateIdentity *x509util.Identity
|
||||||
validateOnce bool
|
validateOnce bool
|
||||||
certificates *sync.Map
|
certificates *sync.Map
|
||||||
|
@ -89,6 +89,16 @@ func (a *Authority) init() error {
|
||||||
sum := sha256.Sum256(a.rootX509Crt.Raw)
|
sum := sha256.Sum256(a.rootX509Crt.Raw)
|
||||||
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
|
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
|
||||||
|
|
||||||
|
// Add federated roots
|
||||||
|
for _, path := range a.config.FederatedRoots {
|
||||||
|
crt, err := pemutil.ReadCertificate(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(crt.Raw)
|
||||||
|
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
|
||||||
|
}
|
||||||
|
|
||||||
// Decrypt and load intermediate public / private key pair.
|
// Decrypt and load intermediate public / private key pair.
|
||||||
if len(a.config.Password) > 0 {
|
if len(a.config.Password) > 0 {
|
||||||
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(
|
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(
|
||||||
|
|
|
@ -33,38 +33,10 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type duration struct {
|
|
||||||
time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalJSON parses a duration string and sets it to the duration.
|
|
||||||
//
|
|
||||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
|
||||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
|
||||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
|
||||||
func (d *duration) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(d.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalJSON parses a duration string and sets it to the duration.
|
|
||||||
//
|
|
||||||
// A duration string is a possibly signed sequence of decimal numbers, each with
|
|
||||||
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
|
||||||
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
|
||||||
func (d *duration) UnmarshalJSON(data []byte) (err error) {
|
|
||||||
var s string
|
|
||||||
if err = json.Unmarshal(data, &s); err != nil {
|
|
||||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
|
||||||
}
|
|
||||||
if d.Duration, err = time.ParseDuration(s); err != nil {
|
|
||||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config represents the CA configuration and it's mapped to a JSON object.
|
// Config represents the CA configuration and it's mapped to a JSON object.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Root string `json:"root"`
|
Root string `json:"root"`
|
||||||
|
FederatedRoots []string `json:"federatedRoots"`
|
||||||
IntermediateCert string `json:"crt"`
|
IntermediateCert string `json:"crt"`
|
||||||
IntermediateKey string `json:"key"`
|
IntermediateKey string `json:"key"`
|
||||||
Address string `json:"address"`
|
Address string `json:"address"`
|
||||||
|
|
|
@ -17,7 +17,7 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
||||||
|
|
||||||
crt, ok := val.(*x509.Certificate)
|
crt, ok := val.(*x509.Certificate)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &apiError{errors.Errorf("stored value is not a *cryto/x509.Certificate"),
|
return nil, &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||||
http.StatusInternalServerError, context{}}
|
http.StatusInternalServerError, context{}}
|
||||||
}
|
}
|
||||||
return crt, nil
|
return crt, nil
|
||||||
|
@ -27,3 +27,24 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) {
|
||||||
func (a *Authority) GetRootCertificate() *x509.Certificate {
|
func (a *Authority) GetRootCertificate() *x509.Certificate {
|
||||||
return a.rootX509Crt
|
return a.rootX509Crt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetFederation returns all the root certificates in the federation.
|
||||||
|
func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) {
|
||||||
|
// Check step provisioner extensions
|
||||||
|
if err := a.authorizeRenewal(peer); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.certificates.Range(func(k, v interface{}) bool {
|
||||||
|
crt, ok := v.(*x509.Certificate)
|
||||||
|
if !ok {
|
||||||
|
federation = nil
|
||||||
|
err = &apiError{errors.Errorf("stored value is not a *x509.Certificate"),
|
||||||
|
http.StatusInternalServerError, context{}}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
federation = append(federation, crt)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ func TestRoot(t *testing.T) {
|
||||||
err *apiError
|
err *apiError
|
||||||
}{
|
}{
|
||||||
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
|
"not-found": {"foo", &apiError{errors.New("certificate with fingerprint foo was not found"), http.StatusNotFound, context{}}},
|
||||||
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *cryto/x509.Certificate"), http.StatusInternalServerError, context{}}},
|
"invalid-stored-certificate": {"invaliddata", &apiError{errors.New("stored value is not a *x509.Certificate"), http.StatusInternalServerError, context{}}},
|
||||||
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
|
"success": {"189f573cfa159251e445530847ef80b1b62a3a380ee670dcb49e33ed34da0616", nil},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
98
authority/types.go
Normal file
98
authority/types.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package authority
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type duration struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON parses a duration string and sets it to the duration.
|
||||||
|
//
|
||||||
|
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||||
|
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||||
|
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||||
|
func (d *duration) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(d.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON parses a duration string and sets it to the duration.
|
||||||
|
//
|
||||||
|
// A duration string is a possibly signed sequence of decimal numbers, each with
|
||||||
|
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
|
||||||
|
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||||
|
func (d *duration) UnmarshalJSON(data []byte) (err error) {
|
||||||
|
var s string
|
||||||
|
if err = json.Unmarshal(data, &s); err != nil {
|
||||||
|
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||||
|
}
|
||||||
|
if d.Duration, err = time.ParseDuration(s); err != nil {
|
||||||
|
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type multiString []string
|
||||||
|
|
||||||
|
// FIXME: remove me, avoids deadcode warning
|
||||||
|
var _ = multiString{}
|
||||||
|
|
||||||
|
// First returns the first element of a multiString. It will return an empty
|
||||||
|
// string if the multistring is empty.
|
||||||
|
func (s multiString) First() string {
|
||||||
|
if len(s) > 0 {
|
||||||
|
return s[0]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empties checks that none of the string is empty.
|
||||||
|
func (s multiString) Empties() bool {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, ss := range s {
|
||||||
|
if len(ss) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON marshals the multistring as a string or a slice of strings . With
|
||||||
|
// 0 elements it will return the empty string, with 1 element a regular string,
|
||||||
|
// otherwise a slice of strings.
|
||||||
|
func (s multiString) MarshalJSON() ([]byte, error) {
|
||||||
|
switch len(s) {
|
||||||
|
case 0:
|
||||||
|
return []byte(""), nil
|
||||||
|
case 1:
|
||||||
|
return json.Marshal(s[0])
|
||||||
|
default:
|
||||||
|
return json.Marshal(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON parses a string or a slice and sets it to the multiString.
|
||||||
|
func (s *multiString) UnmarshalJSON(data []byte) error {
|
||||||
|
if len(data) == 0 {
|
||||||
|
*s = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data[0] == '"' {
|
||||||
|
var str string
|
||||||
|
if err := json.Unmarshal(data, &str); err != nil {
|
||||||
|
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||||
|
}
|
||||||
|
*s = []string{str}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, s); err != nil {
|
||||||
|
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
19
ca/client.go
19
ca/client.go
|
@ -413,6 +413,25 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error)
|
||||||
return &key, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Federation performs the get federation request to the CA and returns the
|
||||||
|
// api.FederationResponse struct.
|
||||||
|
func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) {
|
||||||
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
resp, err := client.Get(u.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, readError(resp.Body)
|
||||||
|
}
|
||||||
|
var federation api.FederationResponse
|
||||||
|
if err := readJSON(resp.Body, &federation); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||||
|
}
|
||||||
|
return &federation, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CreateSignRequest is a helper function that given an x509 OTT returns a
|
// CreateSignRequest is a helper function that given an x509 OTT returns a
|
||||||
// simple but secure sign request as well as the private key used.
|
// simple but secure sign request as well as the private key used.
|
||||||
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {
|
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {
|
||||||
|
|
|
@ -512,6 +512,67 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_Federation(t *testing.T) {
|
||||||
|
ok := &api.FederationResponse{
|
||||||
|
Certificates: []api.Certificate{
|
||||||
|
{Certificate: parseCertificate(rootPEM)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized"))
|
||||||
|
badRequest := api.BadRequest(fmt.Errorf("Bad Request"))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
response interface{}
|
||||||
|
responseCode int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", ok, 200, false},
|
||||||
|
{"unauthorized", unauthorized, 401, true},
|
||||||
|
{"empty request", badRequest, 403, true},
|
||||||
|
{"nil request", badRequest, 403, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(nil)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewClient() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
w.WriteHeader(tt.responseCode)
|
||||||
|
api.JSON(w, tt.response)
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := c.Federation(nil)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
fmt.Printf("%+v", err)
|
||||||
|
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
if got != nil {
|
||||||
|
t.Errorf("Client.Federation() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(err, tt.response) {
|
||||||
|
t.Errorf("Client.Federation() error = %v, want %v", err, tt.response)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if !reflect.DeepEqual(got, tt.response) {
|
||||||
|
t.Errorf("Client.Federation() = %v, want %v", got, tt.response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_parseEndpoint(t *testing.T) {
|
func Test_parseEndpoint(t *testing.T) {
|
||||||
expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"}
|
expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"}
|
||||||
expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"}
|
expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"}
|
||||||
|
|
|
@ -41,7 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options if given
|
||||||
if err := setTLSOptions(tlsConfig, options); err != nil {
|
if err := setTLSOptions(c, tlsConfig, options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options if given
|
||||||
if err := setTLSOptions(tlsConfig, options); err != nil {
|
if err := setTLSOptions(c, tlsConfig, options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,13 +6,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||||
type TLSOption func(c *tls.Config) error
|
type TLSOption func(c *Client, config *tls.Config) error
|
||||||
|
|
||||||
// setTLSOptions takes one or more option function and applies them in order to
|
// setTLSOptions takes one or more option function and applies them in order to
|
||||||
// a tls.Config.
|
// a tls.Config.
|
||||||
func setTLSOptions(c *tls.Config, options []TLSOption) error {
|
func setTLSOptions(c *Client, config *tls.Config, options []TLSOption) error {
|
||||||
for _, opt := range options {
|
for _, opt := range options {
|
||||||
if err := opt(c); err != nil {
|
if err := opt(c, config); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,8 +22,8 @@ func setTLSOptions(c *tls.Config, options []TLSOption) error {
|
||||||
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
||||||
// a valid TLS client certificate. This is the default option for mTLS servers.
|
// a valid TLS client certificate. This is the default option for mTLS servers.
|
||||||
func RequireAndVerifyClientCert() TLSOption {
|
func RequireAndVerifyClientCert() TLSOption {
|
||||||
return func(c *tls.Config) error {
|
return func(_ *Client, config *tls.Config) error {
|
||||||
c.ClientAuth = tls.RequireAndVerifyClientCert
|
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,8 +31,8 @@ func RequireAndVerifyClientCert() TLSOption {
|
||||||
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
||||||
// a TLS client certificate if it is provided. It does not requires a certificate.
|
// a TLS client certificate if it is provided. It does not requires a certificate.
|
||||||
func VerifyClientCertIfGiven() TLSOption {
|
func VerifyClientCertIfGiven() TLSOption {
|
||||||
return func(c *tls.Config) error {
|
return func(_ *Client, config *tls.Config) error {
|
||||||
c.ClientAuth = tls.VerifyClientCertIfGiven
|
config.ClientAuth = tls.VerifyClientCertIfGiven
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -41,11 +41,11 @@ func VerifyClientCertIfGiven() TLSOption {
|
||||||
// defines the set of root certificate authorities that clients use when
|
// defines the set of root certificate authorities that clients use when
|
||||||
// verifying server certificates.
|
// verifying server certificates.
|
||||||
func AddRootCA(cert *x509.Certificate) TLSOption {
|
func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||||
return func(c *tls.Config) error {
|
return func(_ *Client, config *tls.Config) error {
|
||||||
if c.RootCAs == nil {
|
if config.RootCAs == nil {
|
||||||
c.RootCAs = x509.NewCertPool()
|
config.RootCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
c.RootCAs.AddCert(cert)
|
config.RootCAs.AddCert(cert)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -54,11 +54,51 @@ func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||||
// defines the set of root certificate authorities that servers use if required
|
// defines the set of root certificate authorities that servers use if required
|
||||||
// to verify a client certificate by the policy in ClientAuth.
|
// to verify a client certificate by the policy in ClientAuth.
|
||||||
func AddClientCA(cert *x509.Certificate) TLSOption {
|
func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||||
return func(c *tls.Config) error {
|
return func(_ *Client, config *tls.Config) error {
|
||||||
if c.ClientCAs == nil {
|
if config.ClientCAs == nil {
|
||||||
c.ClientCAs = x509.NewCertPool()
|
config.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
config.ClientCAs.AddCert(cert)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRootFederation does a federation request and adds to the tls.Config
|
||||||
|
// RootCAs all the certificates in the response. RootCAs
|
||||||
|
// defines the set of root certificate authorities that clients use when
|
||||||
|
// verifying server certificates.
|
||||||
|
func AddRootFederation() TLSOption {
|
||||||
|
return func(c *Client, config *tls.Config) error {
|
||||||
|
if config.RootCAs == nil {
|
||||||
|
config.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
certs, err := c.Federation(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, cert := range certs.Certificates {
|
||||||
|
config.RootCAs.AddCert(cert.Certificate)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddClientFederation does a federation request and adds to the tls.Config
|
||||||
|
// ClientCAs all the certificates in the response. ClientCAs defines the set of
|
||||||
|
// root certificate authorities that servers use if required to verify a client
|
||||||
|
// certificate by the policy in ClientAuth.
|
||||||
|
func AddClientFederation() TLSOption {
|
||||||
|
return func(c *Client, config *tls.Config) error {
|
||||||
|
if config.ClientCAs == nil {
|
||||||
|
config.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
certs, err := c.Federation(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, cert := range certs.Certificates {
|
||||||
|
config.ClientCAs.AddCert(cert.Certificate)
|
||||||
}
|
}
|
||||||
c.ClientCAs.AddCert(cert)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
|
|
||||||
func Test_setTLSOptions(t *testing.T) {
|
func Test_setTLSOptions(t *testing.T) {
|
||||||
fail := func() TLSOption {
|
fail := func() TLSOption {
|
||||||
return func(c *tls.Config) error {
|
return func(c *Client, config *tls.Config) error {
|
||||||
return fmt.Errorf("an error")
|
return fmt.Errorf("an error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ func Test_setTLSOptions(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := setTLSOptions(tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
if err := setTLSOptions(nil, tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -46,7 +46,7 @@ func TestRequireAndVerifyClientCert(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
got := &tls.Config{}
|
||||||
if err := RequireAndVerifyClientCert()(got); err != nil {
|
if err := RequireAndVerifyClientCert()(nil, got); err != nil {
|
||||||
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
got := &tls.Config{}
|
||||||
if err := VerifyClientCertIfGiven()(got); err != nil {
|
if err := VerifyClientCertIfGiven()(nil, got); err != nil {
|
||||||
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -96,7 +96,7 @@ func TestAddRootCA(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
got := &tls.Config{}
|
||||||
if err := AddRootCA(tt.args.cert)(got); err != nil {
|
if err := AddRootCA(tt.args.cert)(nil, got); err != nil {
|
||||||
t.Errorf("AddRootCA() error = %v", err)
|
t.Errorf("AddRootCA() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -125,7 +125,7 @@ func TestAddClientCA(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
got := &tls.Config{}
|
||||||
if err := AddClientCA(tt.args.cert)(got); err != nil {
|
if err := AddClientCA(tt.args.cert)(nil, got); err != nil {
|
||||||
t.Errorf("AddClientCA() error = %v", err)
|
t.Errorf("AddClientCA() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue