forked from TrueCloudLab/certificates
Update root certificates on renew.
This commit is contained in:
parent
6d3e8ed93c
commit
10aaece1b0
4 changed files with 324 additions and 91 deletions
|
@ -87,8 +87,8 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the tlsConfig have all supported roots
|
// Make sure the tlsConfig have all supported roots on ClientCAs and RootCAs
|
||||||
options = append(options, AddRootsToClientCAs(), AddRootsToRootCAs())
|
options = append(options, AddRootsToCAs())
|
||||||
|
|
||||||
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
|
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -133,7 +133,7 @@ func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the tlsConfig have all supported roots
|
// Make sure the tlsConfig have all supported roots on RootCAs
|
||||||
options = append(options, AddRootsToRootCAs())
|
options = append(options, AddRootsToRootCAs())
|
||||||
|
|
||||||
transport, err := client.Transport(ctx, sign, pk, options...)
|
transport, err := client.Transport(ctx, sign, pk, options...)
|
||||||
|
|
23
ca/tls.go
23
ca/tls.go
|
@ -41,7 +41,11 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options if given
|
||||||
if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil {
|
tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := tlsCtx.apply(options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +54,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
// Start renewer
|
// Start renewer
|
||||||
renewer.RunContext(ctx)
|
renewer.RunContext(ctx)
|
||||||
|
@ -87,7 +91,11 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options if given
|
||||||
if err := setTLSOptions(c, sign, pk, tlsConfig, options); err != nil {
|
tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := tlsCtx.apply(options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,7 +104,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
renewer.RenewCertificate = getRenewFunc(c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
// Start renewer
|
// Start renewer
|
||||||
renewer.RunContext(ctx)
|
renewer.RunContext(ctx)
|
||||||
|
@ -238,8 +246,13 @@ func getPEM(i interface{}) ([]byte, error) {
|
||||||
return pem.EncodeToMemory(block), nil
|
return pem.EncodeToMemory(block), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRenewFunc(client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
|
func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
|
||||||
return func() (*tls.Certificate, error) {
|
return func() (*tls.Certificate, error) {
|
||||||
|
// Get updated list of roots
|
||||||
|
if err := ctx.applyRenew(tr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Get new certificate
|
||||||
sign, err := client.Renew(tr)
|
sign, err := client.Renew(tr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -10,18 +10,42 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||||
type TLSOption func(c *Client, tr http.RoundTripper, config *tls.Config) error
|
type TLSOption func(ctx *TLSOptionCtx) error
|
||||||
|
|
||||||
// setTLSOptions takes one or more option function and applies them in order to
|
// TLSOptionCtx is the context modified on TLSOption methods.
|
||||||
// a tls.Config.
|
type TLSOptionCtx struct {
|
||||||
func setTLSOptions(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config, options []TLSOption) error {
|
Client *Client
|
||||||
|
Transport http.RoundTripper
|
||||||
|
Config *tls.Config
|
||||||
|
OnRenewFunc []TLSOption
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTLSOptionCtx creates the TLSOption context.
|
||||||
|
func newTLSOptionCtx(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config) (*TLSOptionCtx, error) {
|
||||||
tr, err := getTLSOptionsTransport(sign, pk)
|
tr, err := getTLSOptionsTransport(sign, pk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return &TLSOptionCtx{
|
||||||
|
Client: c,
|
||||||
|
Transport: tr,
|
||||||
|
Config: config,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, opt := range options {
|
func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
||||||
if err := opt(c, tr, config); err != nil {
|
for _, fn := range options {
|
||||||
|
if err := fn(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *TLSOptionCtx) applyRenew(tr http.RoundTripper) error {
|
||||||
|
ctx.Transport = tr
|
||||||
|
for _, fn := range ctx.OnRenewFunc {
|
||||||
|
if err := fn(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -51,8 +75,8 @@ func getTLSOptionsTransport(sign *api.SignResponse, pk crypto.PrivateKey) (http.
|
||||||
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
||||||
// a valid TLS client certificate. This is the default option for mTLS servers.
|
// a valid TLS client certificate. This is the default option for mTLS servers.
|
||||||
func RequireAndVerifyClientCert() TLSOption {
|
func RequireAndVerifyClientCert() TLSOption {
|
||||||
return func(_ *Client, _ http.RoundTripper, config *tls.Config) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
ctx.Config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -60,8 +84,8 @@ func RequireAndVerifyClientCert() TLSOption {
|
||||||
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
|
||||||
// a TLS client certificate if it is provided. It does not requires a certificate.
|
// a TLS client certificate if it is provided. It does not requires a certificate.
|
||||||
func VerifyClientCertIfGiven() TLSOption {
|
func VerifyClientCertIfGiven() TLSOption {
|
||||||
return func(_ *Client, _ http.RoundTripper, config *tls.Config) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
config.ClientAuth = tls.VerifyClientCertIfGiven
|
ctx.Config.ClientAuth = tls.VerifyClientCertIfGiven
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -70,11 +94,11 @@ func VerifyClientCertIfGiven() TLSOption {
|
||||||
// defines the set of root certificate authorities that clients use when
|
// defines the set of root certificate authorities that clients use when
|
||||||
// verifying server certificates.
|
// verifying server certificates.
|
||||||
func AddRootCA(cert *x509.Certificate) TLSOption {
|
func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||||
return func(_ *Client, _ http.RoundTripper, config *tls.Config) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
if config.RootCAs == nil {
|
if ctx.Config.RootCAs == nil {
|
||||||
config.RootCAs = x509.NewCertPool()
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
config.RootCAs.AddCert(cert)
|
ctx.Config.RootCAs.AddCert(cert)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,11 +107,11 @@ func AddRootCA(cert *x509.Certificate) TLSOption {
|
||||||
// defines the set of root certificate authorities that servers use if required
|
// defines the set of root certificate authorities that servers use if required
|
||||||
// to verify a client certificate by the policy in ClientAuth.
|
// to verify a client certificate by the policy in ClientAuth.
|
||||||
func AddClientCA(cert *x509.Certificate) TLSOption {
|
func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||||
return func(_ *Client, _ http.RoundTripper, config *tls.Config) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
if config.ClientCAs == nil {
|
if ctx.Config.ClientCAs == nil {
|
||||||
config.ClientCAs = x509.NewCertPool()
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
config.ClientCAs.AddCert(cert)
|
ctx.Config.ClientCAs.AddCert(cert)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -98,19 +122,23 @@ func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||||
//
|
//
|
||||||
// BootstrapServer and BootstrapClient methods include this option by default.
|
// BootstrapServer and BootstrapClient methods include this option by default.
|
||||||
func AddRootsToRootCAs() TLSOption {
|
func AddRootsToRootCAs() TLSOption {
|
||||||
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := c.Roots(tr)
|
certs, err := ctx.Client.Roots(ctx.Transport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if config.RootCAs == nil {
|
if ctx.Config.RootCAs == nil {
|
||||||
config.RootCAs = x509.NewCertPool()
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
for _, cert := range certs.Certificates {
|
for _, cert := range certs.Certificates {
|
||||||
config.RootCAs.AddCert(cert.Certificate)
|
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||||
|
return fn(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs
|
// AddRootsToClientCAs does a roots request and adds to the tls.Config ClientCAs
|
||||||
|
@ -120,38 +148,46 @@ func AddRootsToRootCAs() TLSOption {
|
||||||
//
|
//
|
||||||
// BootstrapServer method includes this option by default.
|
// BootstrapServer method includes this option by default.
|
||||||
func AddRootsToClientCAs() TLSOption {
|
func AddRootsToClientCAs() TLSOption {
|
||||||
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := c.Roots(tr)
|
certs, err := ctx.Client.Roots(ctx.Transport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if config.ClientCAs == nil {
|
if ctx.Config.ClientCAs == nil {
|
||||||
config.ClientCAs = x509.NewCertPool()
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
for _, cert := range certs.Certificates {
|
for _, cert := range certs.Certificates {
|
||||||
config.ClientCAs.AddCert(cert.Certificate)
|
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||||
|
return fn(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFederationToRootCAs does a federation request and adds to the tls.Config
|
// AddFederationToRootCAs does a federation request and adds to the tls.Config
|
||||||
// RootCAs all the certificates in the response. RootCAs defines the set of root
|
// RootCAs all the certificates in the response. RootCAs defines the set of root
|
||||||
// certificate authorities that clients use when verifying server certificates.
|
// certificate authorities that clients use when verifying server certificates.
|
||||||
func AddFederationToRootCAs() TLSOption {
|
func AddFederationToRootCAs() TLSOption {
|
||||||
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := c.Federation(tr)
|
certs, err := ctx.Client.Federation(ctx.Transport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if config.RootCAs == nil {
|
if ctx.Config.RootCAs == nil {
|
||||||
config.RootCAs = x509.NewCertPool()
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
for _, cert := range certs.Certificates {
|
for _, cert := range certs.Certificates {
|
||||||
config.RootCAs.AddCert(cert.Certificate)
|
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||||
|
return fn(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFederationToClientCAs does a federation request and adds to the tls.Config
|
// AddFederationToClientCAs does a federation request and adds to the tls.Config
|
||||||
|
@ -159,17 +195,75 @@ func AddFederationToRootCAs() TLSOption {
|
||||||
// root certificate authorities that servers use if required to verify a client
|
// root certificate authorities that servers use if required to verify a client
|
||||||
// certificate by the policy in ClientAuth.
|
// certificate by the policy in ClientAuth.
|
||||||
func AddFederationToClientCAs() TLSOption {
|
func AddFederationToClientCAs() TLSOption {
|
||||||
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := c.Federation(tr)
|
certs, err := ctx.Client.Federation(ctx.Transport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if config.ClientCAs == nil {
|
if ctx.Config.ClientCAs == nil {
|
||||||
config.ClientCAs = x509.NewCertPool()
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
for _, cert := range certs.Certificates {
|
for _, cert := range certs.Certificates {
|
||||||
config.ClientCAs.AddCert(cert.Certificate)
|
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||||
|
return fn(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRootsToCAs does a roots request and adds the resulting certs to the
|
||||||
|
// tls.Config RootCAs and ClientCAs. Combines the functionality of
|
||||||
|
// AddRootsToRootCAs and AddRootsToClientCAs.
|
||||||
|
func AddRootsToCAs() TLSOption {
|
||||||
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
|
certs, err := ctx.Client.Roots(ctx.Transport)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ctx.Config.ClientCAs == nil {
|
||||||
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
if ctx.Config.RootCAs == nil {
|
||||||
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
for _, cert := range certs.Certificates {
|
||||||
|
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||||
|
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||||
|
return fn(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFederationToCAs does a federation request and adds the resulting certs to the
|
||||||
|
// tls.Config RootCAs and ClientCAs. Combines the functionality of
|
||||||
|
// AddFederationToRootCAs and AddFederationToClientCAs.
|
||||||
|
func AddFederationToCAs() TLSOption {
|
||||||
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
|
certs, err := ctx.Client.Federation(ctx.Transport)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ctx.Config.ClientCAs == nil {
|
||||||
|
ctx.Config.ClientCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
if ctx.Config.RootCAs == nil {
|
||||||
|
ctx.Config.RootCAs = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
for _, cert := range certs.Certificates {
|
||||||
|
ctx.Config.ClientCAs.AddCert(cert.Certificate)
|
||||||
|
ctx.Config.RootCAs.AddCert(cert.Certificate)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return func(ctx *TLSOptionCtx) error {
|
||||||
|
ctx.OnRenewFunc = append(ctx.OnRenewFunc, fn)
|
||||||
|
return fn(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,33 +10,36 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_setTLSOptions(t *testing.T) {
|
func TestTLSOptionCtx_apply(t *testing.T) {
|
||||||
fail := func() TLSOption {
|
fail := func() TLSOption {
|
||||||
return func(c *Client, tr http.RoundTripper, config *tls.Config) error {
|
return func(ctx *TLSOptionCtx) error {
|
||||||
return fmt.Errorf("an error")
|
return fmt.Errorf("an error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
Config *tls.Config
|
||||||
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
c *tls.Config
|
|
||||||
options []TLSOption
|
options []TLSOption
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
fields fields
|
||||||
args args
|
args args
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", args{&tls.Config{}, []TLSOption{RequireAndVerifyClientCert()}}, false},
|
{"ok", fields{&tls.Config{}}, args{[]TLSOption{RequireAndVerifyClientCert()}}, false},
|
||||||
{"ok", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven()}}, false},
|
{"ok", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven()}}, false},
|
||||||
{"fail", args{&tls.Config{}, []TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
|
{"fail", fields{&tls.Config{}}, args{[]TLSOption{VerifyClientCertIfGiven(), fail()}}, true},
|
||||||
}
|
}
|
||||||
|
|
||||||
ca := startCATestServer()
|
|
||||||
defer ca.Close()
|
|
||||||
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := setTLSOptions(client, sr, pk, tt.args.c, tt.args.options); (err != nil) != tt.wantErr {
|
ctx := &TLSOptionCtx{
|
||||||
t.Errorf("setTLSOptions() error = %v, wantErr %v", err, tt.wantErr)
|
Config: tt.fields.Config,
|
||||||
|
}
|
||||||
|
if err := ctx.apply(tt.args.options); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("TLSOptionCtx.apply() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -51,13 +54,15 @@ func TestRequireAndVerifyClientCert(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
ctx := &TLSOptionCtx{
|
||||||
if err := RequireAndVerifyClientCert()(nil, nil, got); err != nil {
|
Config: &tls.Config{},
|
||||||
|
}
|
||||||
|
if err := RequireAndVerifyClientCert()(ctx); err != nil {
|
||||||
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
t.Errorf("RequireAndVerifyClientCert() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
t.Errorf("RequireAndVerifyClientCert() = %v, want %v", got, tt.want)
|
t.Errorf("RequireAndVerifyClientCert() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -72,13 +77,15 @@ func TestVerifyClientCertIfGiven(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
ctx := &TLSOptionCtx{
|
||||||
if err := VerifyClientCertIfGiven()(nil, nil, got); err != nil {
|
Config: &tls.Config{},
|
||||||
|
}
|
||||||
|
if err := VerifyClientCertIfGiven()(ctx); err != nil {
|
||||||
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
t.Errorf("VerifyClientCertIfGiven() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
t.Errorf("VerifyClientCertIfGiven() = %v, want %v", got, tt.want)
|
t.Errorf("VerifyClientCertIfGiven() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -101,13 +108,15 @@ func TestAddRootCA(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
ctx := &TLSOptionCtx{
|
||||||
if err := AddRootCA(tt.args.cert)(nil, nil, got); err != nil {
|
Config: &tls.Config{},
|
||||||
|
}
|
||||||
|
if err := AddRootCA(tt.args.cert)(ctx); err != nil {
|
||||||
t.Errorf("AddRootCA() error = %v", err)
|
t.Errorf("AddRootCA() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
t.Errorf("AddRootCA() = %v, want %v", got, tt.want)
|
t.Errorf("AddRootCA() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -130,13 +139,15 @@ func TestAddClientCA(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
ctx := &TLSOptionCtx{
|
||||||
if err := AddClientCA(tt.args.cert)(nil, nil, got); err != nil {
|
Config: &tls.Config{},
|
||||||
|
}
|
||||||
|
if err := AddClientCA(tt.args.cert)(ctx); err != nil {
|
||||||
t.Errorf("AddClientCA() error = %v", err)
|
t.Errorf("AddClientCA() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
t.Errorf("AddClientCA() = %v, want %v", got, tt.want)
|
t.Errorf("AddClientCA() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -172,13 +183,17 @@ func TestAddRootsToRootCAs(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
ctx := &TLSOptionCtx{
|
||||||
if err := AddRootsToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr {
|
Client: client,
|
||||||
|
Config: &tls.Config{},
|
||||||
|
Transport: tt.tr,
|
||||||
|
}
|
||||||
|
if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
t.Errorf("AddRootsToRootCAs() = %v, want %v", got, tt.want)
|
t.Errorf("AddRootsToRootCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -214,13 +229,17 @@ func TestAddRootsToClientCAs(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
ctx := &TLSOptionCtx{
|
||||||
if err := AddRootsToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr {
|
Client: client,
|
||||||
|
Config: &tls.Config{},
|
||||||
|
Transport: tt.tr,
|
||||||
|
}
|
||||||
|
if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
t.Errorf("AddRootsToClientCAs() = %v, want %v", got, tt.want)
|
t.Errorf("AddRootsToClientCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -263,13 +282,17 @@ func TestAddFederationToRootCAs(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
ctx := &TLSOptionCtx{
|
||||||
if err := AddFederationToRootCAs()(client, tt.tr, got); (err != nil) != tt.wantErr {
|
Client: client,
|
||||||
|
Config: &tls.Config{},
|
||||||
|
Transport: tt.tr,
|
||||||
|
}
|
||||||
|
if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
t.Errorf("AddFederationToRootCAs() = %v, want %v", got, tt.want)
|
t.Errorf("AddFederationToRootCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -312,13 +335,116 @@ func TestAddFederationToClientCAs(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := &tls.Config{}
|
ctx := &TLSOptionCtx{
|
||||||
if err := AddFederationToClientCAs()(client, tt.tr, got); (err != nil) != tt.wantErr {
|
Client: client,
|
||||||
|
Config: &tls.Config{},
|
||||||
|
Transport: tt.tr,
|
||||||
|
}
|
||||||
|
if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
t.Errorf("AddFederationToClientCAs() = %v, want %v", got, tt.want)
|
t.Errorf("AddFederationToClientCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddRootsToCAs(t *testing.T) {
|
||||||
|
ca := startCATestServer()
|
||||||
|
defer ca.Close()
|
||||||
|
|
||||||
|
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
||||||
|
tr, err := getTLSOptionsTransport(sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := parseCertificate(string(root))
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(cert)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tr http.RoundTripper
|
||||||
|
want *tls.Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
|
||||||
|
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := &TLSOptionCtx{
|
||||||
|
Client: client,
|
||||||
|
Config: &tls.Config{},
|
||||||
|
Transport: tt.tr,
|
||||||
|
}
|
||||||
|
if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
|
t.Errorf("AddRootsToCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddFederationToCAs(t *testing.T) {
|
||||||
|
ca := startCATestServer()
|
||||||
|
defer ca.Close()
|
||||||
|
|
||||||
|
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
||||||
|
tr, err := getTLSOptionsTransport(sr, pk)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
crt1 := parseCertificate(string(root))
|
||||||
|
crt2 := parseCertificate(string(federated))
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(crt1)
|
||||||
|
pool.AddCert(crt2)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tr http.RoundTripper
|
||||||
|
want *tls.Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
|
||||||
|
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := &TLSOptionCtx{
|
||||||
|
Client: client,
|
||||||
|
Config: &tls.Config{},
|
||||||
|
Transport: tt.tr,
|
||||||
|
}
|
||||||
|
if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(ctx.Config, tt.want) {
|
||||||
|
t.Errorf("AddFederationToCAs() = %v, want %v", ctx.Config, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue