forked from TrueCloudLab/certificates
Remove mTLS client requirement in /roots and /federation
This commit is contained in:
parent
9adc65febf
commit
518b597535
10 changed files with 162 additions and 233 deletions
24
api/api.go
24
api/api.go
|
@ -25,8 +25,8 @@ type Authority interface {
|
||||||
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||||
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
|
GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error)
|
||||||
GetEncryptedKey(kid string) (string, error)
|
GetEncryptedKey(kid string) (string, error)
|
||||||
GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error)
|
GetRoots() (federation []*x509.Certificate, err error)
|
||||||
GetFederation(peer *x509.Certificate) ([]*x509.Certificate, error)
|
GetFederation() ([]*x509.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Certificate wraps a *x509.Certificate and adds the json.Marshaler interface.
|
// Certificate wraps a *x509.Certificate and adds the json.Marshaler interface.
|
||||||
|
@ -334,15 +334,9 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||||
JSON(w, &ProvisionerKeyResponse{key})
|
JSON(w, &ProvisionerKeyResponse{key})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Roots returns all the root certificates for the CA. It requires a valid TLS
|
// Roots returns all the root certificates for the CA.
|
||||||
// client.
|
|
||||||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
roots, err := h.Authority.GetRoots()
|
||||||
WriteError(w, BadRequest(errors.New("missing peer certificate")))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
roots, err := h.Authority.GetRoots(r.TLS.PeerCertificates[0])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, Forbidden(err))
|
WriteError(w, Forbidden(err))
|
||||||
return
|
return
|
||||||
|
@ -359,15 +353,9 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Federation returns all the public certificates in the federation. It requires
|
// Federation returns all the public certificates in the federation.
|
||||||
// a valid TLS client.
|
|
||||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
federated, err := h.Authority.GetFederation()
|
||||||
WriteError(w, BadRequest(errors.New("missing peer certificate")))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
federated, err := h.Authority.GetFederation(r.TLS.PeerCertificates[0])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, Forbidden(err))
|
WriteError(w, Forbidden(err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -392,8 +392,8 @@ type mockAuthority struct {
|
||||||
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
|
||||||
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
|
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
|
||||||
getEncryptedKey func(kid string) (string, error)
|
getEncryptedKey func(kid string) (string, error)
|
||||||
getRoots func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
getRoots func() ([]*x509.Certificate, error)
|
||||||
getFederation func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
getFederation func() ([]*x509.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) {
|
func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) {
|
||||||
|
@ -445,16 +445,16 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
|
||||||
return m.ret1.(string), m.err
|
return m.ret1.(string), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetRoots(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) {
|
||||||
if m.getFederation != nil {
|
if m.getFederation != nil {
|
||||||
return m.getRoots(cert)
|
return m.getRoots()
|
||||||
}
|
}
|
||||||
return m.ret1.([]*x509.Certificate), m.err
|
return m.ret1.([]*x509.Certificate), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) {
|
||||||
if m.getFederation != nil {
|
if m.getFederation != nil {
|
||||||
return m.getFederation(cert)
|
return m.getFederation()
|
||||||
}
|
}
|
||||||
return m.ret1.([]*x509.Certificate), m.err
|
return m.ret1.([]*x509.Certificate), m.err
|
||||||
}
|
}
|
||||||
|
@ -842,9 +842,8 @@ func Test_caHandler_Roots(t *testing.T) {
|
||||||
statusCode int
|
statusCode int
|
||||||
}{
|
}{
|
||||||
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||||
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
|
{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||||
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
|
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||||
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||||
|
@ -889,9 +888,8 @@ func Test_caHandler_Federation(t *testing.T) {
|
||||||
statusCode int
|
statusCode int
|
||||||
}{
|
}{
|
||||||
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||||
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
|
{"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||||
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
|
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||||
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
||||||
|
|
|
@ -34,21 +34,12 @@ func (a *Authority) GetRootCertificates() []*x509.Certificate {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoots returns all the root certificates for this CA.
|
// GetRoots returns all the root certificates for this CA.
|
||||||
func (a *Authority) GetRoots(peer *x509.Certificate) ([]*x509.Certificate, error) {
|
func (a *Authority) GetRoots() ([]*x509.Certificate, error) {
|
||||||
// Check step provisioner extensions
|
|
||||||
if err := a.authorizeRenewal(peer); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return a.rootX509Certs, nil
|
return a.rootX509Certs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFederation returns all the root certificates in the federation.
|
// GetFederation returns all the root certificates in the federation.
|
||||||
func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) {
|
func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) {
|
||||||
// Check step provisioner extensions
|
|
||||||
if err := a.authorizeRenewal(peer); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
a.certificates.Range(func(k, v interface{}) bool {
|
a.certificates.Range(func(k, v interface{}) bool {
|
||||||
crt, ok := v.(*x509.Certificate)
|
crt, ok := v.(*x509.Certificate)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|
|
@ -8,9 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/cli/crypto/keys"
|
|
||||||
"github.com/smallstep/cli/crypto/pemutil"
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
"github.com/smallstep/cli/crypto/x509util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRoot(t *testing.T) {
|
func TestRoot(t *testing.T) {
|
||||||
|
@ -99,42 +97,17 @@ func TestAuthority_GetRoots(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
a := testAuthority(t)
|
|
||||||
pub, _, err := keys.GenerateDefaultKeyPair()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key,
|
|
||||||
withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"))
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
crtBytes, err := leaf.CreateCertificate()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
crt, err := x509.ParseCertificate(crtBytes)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key,
|
|
||||||
withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"),
|
|
||||||
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID),
|
|
||||||
)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
crtFailBytes, err := leafFail.CreateCertificate()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
crtFail, err := x509.ParseCertificate(crtFailBytes)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
peer *x509.Certificate
|
|
||||||
}
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
|
||||||
want []*x509.Certificate
|
want []*x509.Certificate
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", args{crt}, []*x509.Certificate{cert}, false},
|
{"ok", []*x509.Certificate{cert}, false},
|
||||||
{"fail", args{crtFail}, nil, true},
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
a := testAuthority(t)
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := a.GetRoots(tt.args.peer)
|
got, err := a.GetRoots()
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -152,49 +125,24 @@ func TestAuthority_GetFederation(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
a := testAuthority(t)
|
|
||||||
pub, _, err := keys.GenerateDefaultKeyPair()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key,
|
|
||||||
withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"))
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
crtBytes, err := leaf.CreateCertificate()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
crt, err := x509.ParseCertificate(crtBytes)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key,
|
|
||||||
withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"),
|
|
||||||
withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID),
|
|
||||||
)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
crtFailBytes, err := leafFail.CreateCertificate()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
crtFail, err := x509.ParseCertificate(crtFailBytes)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
type args struct {
|
|
||||||
peer *x509.Certificate
|
|
||||||
}
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
|
||||||
wantFederation []*x509.Certificate
|
wantFederation []*x509.Certificate
|
||||||
wantErr bool
|
wantErr bool
|
||||||
fn func()
|
fn func(a *Authority)
|
||||||
}{
|
}{
|
||||||
{"ok", args{crt}, []*x509.Certificate{cert}, false, nil},
|
{"ok", []*x509.Certificate{cert}, false, nil},
|
||||||
{"fail", args{crtFail}, nil, true, nil},
|
{"fail", nil, true, func(a *Authority) {
|
||||||
{"fail not a certificate", args{crt}, nil, true, func() {
|
|
||||||
a.certificates.Store("foo", "bar")
|
a.certificates.Store("foo", "bar")
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := testAuthority(t)
|
||||||
if tt.fn != nil {
|
if tt.fn != nil {
|
||||||
tt.fn()
|
tt.fn(a)
|
||||||
}
|
}
|
||||||
gotFederation, err := a.GetFederation(tt.args.peer)
|
gotFederation, err := a.GetFederation()
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
|
|
@ -3,7 +3,6 @@ package ca
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -26,7 +25,7 @@ func newLocalListener() net.Listener {
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
|
||||||
panic(fmt.Sprintf("failed to listen on a port: %v", err))
|
panic(errors.Wrap(err, "failed to listen on a port"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return l
|
return l
|
||||||
|
@ -345,16 +344,16 @@ func TestBootstrapClientServerRotation(t *testing.T) {
|
||||||
// doTest does a request that requires mTLS
|
// doTest does a request that requires mTLS
|
||||||
doTest := func(client *http.Client) error {
|
doTest := func(client *http.Client) error {
|
||||||
// test with ca
|
// test with ca
|
||||||
resp, err := client.Get(caURL + "/roots")
|
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "client.Get(%s) failed", caURL+"/roots")
|
return errors.Wrap(err, "client.Post() failed")
|
||||||
}
|
}
|
||||||
var roots api.RootsResponse
|
var renew api.SignResponse
|
||||||
if err := readJSON(resp.Body, &roots); err != nil {
|
if err := readJSON(resp.Body, &renew); err != nil {
|
||||||
return errors.Wrap(err, "client.Get() error reading response")
|
return errors.Wrap(err, "client.Post() error reading response")
|
||||||
}
|
}
|
||||||
if len(roots.Certificates) == 0 {
|
if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil {
|
||||||
return errors.New("client.Get() error not certificates found")
|
return errors.New("client.Post() unexpected response found")
|
||||||
}
|
}
|
||||||
// test with bootstrap server
|
// test with bootstrap server
|
||||||
resp, err = client.Get(srvURL)
|
resp, err = client.Get(srvURL)
|
||||||
|
|
10
ca/client.go
10
ca/client.go
|
@ -416,10 +416,9 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error)
|
||||||
|
|
||||||
// Roots performs the get roots request to the CA and returns the
|
// Roots performs the get roots request to the CA and returns the
|
||||||
// api.RootsResponse struct.
|
// api.RootsResponse struct.
|
||||||
func (c *Client) Roots(tr http.RoundTripper) (*api.RootsResponse, error) {
|
func (c *Client) Roots() (*api.RootsResponse, error) {
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
|
||||||
client := &http.Client{Transport: tr}
|
resp, err := c.client.Get(u.String())
|
||||||
resp, err := client.Get(u.String())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
|
@ -435,10 +434,9 @@ func (c *Client) Roots(tr http.RoundTripper) (*api.RootsResponse, error) {
|
||||||
|
|
||||||
// Federation performs the get federation request to the CA and returns the
|
// Federation performs the get federation request to the CA and returns the
|
||||||
// api.FederationResponse struct.
|
// api.FederationResponse struct.
|
||||||
func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) {
|
func (c *Client) Federation() (*api.FederationResponse, error) {
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
|
||||||
client := &http.Client{Transport: tr}
|
resp, err := c.client.Get(u.String())
|
||||||
resp, err := client.Get(u.String())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
|
|
|
@ -549,7 +549,7 @@ func TestClient_Roots(t *testing.T) {
|
||||||
api.JSON(w, tt.response)
|
api.JSON(w, tt.response)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Roots(nil)
|
got, err := c.Roots()
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
fmt.Printf("%+v", err)
|
fmt.Printf("%+v", err)
|
||||||
t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -610,7 +610,7 @@ func TestClient_Federation(t *testing.T) {
|
||||||
api.JSON(w, tt.response)
|
api.JSON(w, tt.response)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Federation(nil)
|
got, err := c.Federation()
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
fmt.Printf("%+v", err)
|
fmt.Printf("%+v", err)
|
||||||
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
18
ca/tls.go
18
ca/tls.go
|
@ -41,10 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options if given
|
||||||
tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig)
|
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := tlsCtx.apply(options); err != nil {
|
if err := tlsCtx.apply(options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -56,6 +53,9 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
|
// Update client transport
|
||||||
|
c.client.Transport = tr
|
||||||
|
|
||||||
// Start renewer
|
// Start renewer
|
||||||
renewer.RunContext(ctx)
|
renewer.RunContext(ctx)
|
||||||
return tlsConfig, nil
|
return tlsConfig, nil
|
||||||
|
@ -91,10 +91,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply options if given
|
// Apply options if given
|
||||||
tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig)
|
tlsCtx := newTLSOptionCtx(c, tlsConfig)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := tlsCtx.apply(options); err != nil {
|
if err := tlsCtx.apply(options); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -106,6 +103,9 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
}
|
}
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
|
// Update client transport
|
||||||
|
c.client.Transport = tr
|
||||||
|
|
||||||
// Start renewer
|
// Start renewer
|
||||||
renewer.RunContext(ctx)
|
renewer.RunContext(ctx)
|
||||||
return tlsConfig, nil
|
return tlsConfig, nil
|
||||||
|
@ -249,7 +249,7 @@ func getPEM(i interface{}) ([]byte, error) {
|
||||||
func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
|
func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc {
|
||||||
return func() (*tls.Certificate, error) {
|
return func() (*tls.Certificate, error) {
|
||||||
// Get updated list of roots
|
// Get updated list of roots
|
||||||
if err := ctx.applyRenew(tr); err != nil {
|
if err := ctx.applyRenew(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// Get new certificate
|
// Get new certificate
|
||||||
|
|
|
@ -1,12 +1,8 @@
|
||||||
package ca
|
package ca
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/smallstep/certificates/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLSOption defines the type of a function that modifies a tls.Config.
|
// TLSOption defines the type of a function that modifies a tls.Config.
|
||||||
|
@ -15,22 +11,16 @@ type TLSOption func(ctx *TLSOptionCtx) error
|
||||||
// TLSOptionCtx is the context modified on TLSOption methods.
|
// TLSOptionCtx is the context modified on TLSOption methods.
|
||||||
type TLSOptionCtx struct {
|
type TLSOptionCtx struct {
|
||||||
Client *Client
|
Client *Client
|
||||||
Transport http.RoundTripper
|
|
||||||
Config *tls.Config
|
Config *tls.Config
|
||||||
OnRenewFunc []TLSOption
|
OnRenewFunc []TLSOption
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTLSOptionCtx creates the TLSOption context.
|
// newTLSOptionCtx creates the TLSOption context.
|
||||||
func newTLSOptionCtx(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config) (*TLSOptionCtx, error) {
|
func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx {
|
||||||
tr, err := getTLSOptionsTransport(sign, pk)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &TLSOptionCtx{
|
return &TLSOptionCtx{
|
||||||
Client: c,
|
Client: c,
|
||||||
Transport: tr,
|
|
||||||
Config: config,
|
Config: config,
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
||||||
|
@ -42,8 +32,7 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx *TLSOptionCtx) applyRenew(tr http.RoundTripper) error {
|
func (ctx *TLSOptionCtx) applyRenew() error {
|
||||||
ctx.Transport = tr
|
|
||||||
for _, fn := range ctx.OnRenewFunc {
|
for _, fn := range ctx.OnRenewFunc {
|
||||||
if err := fn(ctx); err != nil {
|
if err := fn(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -52,26 +41,6 @@ func (ctx *TLSOptionCtx) applyRenew(tr http.RoundTripper) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTLSOptionsTransport is the transport used by TLSOptions. It is used to get
|
|
||||||
// root certificates using a mTLS connection with the CA.
|
|
||||||
func getTLSOptionsTransport(sign *api.SignResponse, pk crypto.PrivateKey) (http.RoundTripper, error) {
|
|
||||||
cert, err := TLSCertificate(sign, pk)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build default transport with fixed certificate
|
|
||||||
tlsConfig := getDefaultTLSConfig(sign)
|
|
||||||
tlsConfig.Certificates = []tls.Certificate{*cert}
|
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
|
||||||
// Build RootCAs with given root certificate
|
|
||||||
if pool := getCertPool(sign); pool != nil {
|
|
||||||
tlsConfig.RootCAs = pool
|
|
||||||
}
|
|
||||||
|
|
||||||
return getDefaultTransport(tlsConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
|
||||||
// a valid TLS client certificate. This is the default option for mTLS servers.
|
// a valid TLS client certificate. This is the default option for mTLS servers.
|
||||||
func RequireAndVerifyClientCert() TLSOption {
|
func RequireAndVerifyClientCert() TLSOption {
|
||||||
|
@ -123,7 +92,7 @@ func AddClientCA(cert *x509.Certificate) TLSOption {
|
||||||
// BootstrapServer and BootstrapClient methods include this option by default.
|
// BootstrapServer and BootstrapClient methods include this option by default.
|
||||||
func AddRootsToRootCAs() TLSOption {
|
func AddRootsToRootCAs() TLSOption {
|
||||||
fn := func(ctx *TLSOptionCtx) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := ctx.Client.Roots(ctx.Transport)
|
certs, err := ctx.Client.Roots()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -149,7 +118,7 @@ func AddRootsToRootCAs() TLSOption {
|
||||||
// BootstrapServer method includes this option by default.
|
// BootstrapServer method includes this option by default.
|
||||||
func AddRootsToClientCAs() TLSOption {
|
func AddRootsToClientCAs() TLSOption {
|
||||||
fn := func(ctx *TLSOptionCtx) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := ctx.Client.Roots(ctx.Transport)
|
certs, err := ctx.Client.Roots()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -172,7 +141,7 @@ func AddRootsToClientCAs() TLSOption {
|
||||||
// certificate authorities that clients use when verifying server certificates.
|
// certificate authorities that clients use when verifying server certificates.
|
||||||
func AddFederationToRootCAs() TLSOption {
|
func AddFederationToRootCAs() TLSOption {
|
||||||
fn := func(ctx *TLSOptionCtx) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := ctx.Client.Federation(ctx.Transport)
|
certs, err := ctx.Client.Federation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -196,7 +165,7 @@ func AddFederationToRootCAs() TLSOption {
|
||||||
// certificate by the policy in ClientAuth.
|
// certificate by the policy in ClientAuth.
|
||||||
func AddFederationToClientCAs() TLSOption {
|
func AddFederationToClientCAs() TLSOption {
|
||||||
fn := func(ctx *TLSOptionCtx) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := ctx.Client.Federation(ctx.Transport)
|
certs, err := ctx.Client.Federation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -219,7 +188,7 @@ func AddFederationToClientCAs() TLSOption {
|
||||||
// AddRootsToRootCAs and AddRootsToClientCAs.
|
// AddRootsToRootCAs and AddRootsToClientCAs.
|
||||||
func AddRootsToCAs() TLSOption {
|
func AddRootsToCAs() TLSOption {
|
||||||
fn := func(ctx *TLSOptionCtx) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := ctx.Client.Roots(ctx.Transport)
|
certs, err := ctx.Client.Roots()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -246,7 +215,7 @@ func AddRootsToCAs() TLSOption {
|
||||||
// AddFederationToRootCAs and AddFederationToClientCAs.
|
// AddFederationToRootCAs and AddFederationToClientCAs.
|
||||||
func AddFederationToCAs() TLSOption {
|
func AddFederationToCAs() TLSOption {
|
||||||
fn := func(ctx *TLSOptionCtx) error {
|
fn := func(ctx *TLSOptionCtx) error {
|
||||||
certs, err := ctx.Client.Federation(ctx.Transport)
|
certs, err := ctx.Client.Federation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package ca
|
package ca
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -10,32 +9,29 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_newTLSOptionCtx(t *testing.T) {
|
func Test_newTLSOptionCtx(t *testing.T) {
|
||||||
client, sign, pk := sign("test.smallstep.com")
|
client, err := NewClient("https://ca.smallstep.com", WithTransport(http.DefaultTransport))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewClient() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
c *Client
|
c *Client
|
||||||
sign *api.SignResponse
|
|
||||||
pk crypto.PrivateKey
|
|
||||||
config *tls.Config
|
config *tls.Config
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
wantErr bool
|
want *TLSOptionCtx
|
||||||
}{
|
}{
|
||||||
{"ok", args{client, sign, pk, &tls.Config{}}, false},
|
{"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}}},
|
||||||
{"fail", args{client, sign, "foo", &tls.Config{}}, true},
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := newTLSOptionCtx(tt.args.c, tt.args.sign, tt.args.pk, tt.args.config)
|
if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) {
|
||||||
if (err != nil) != tt.wantErr {
|
t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want)
|
||||||
t.Errorf("newTLSOptionCtx() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -188,8 +184,12 @@ func TestAddRootsToRootCAs(t *testing.T) {
|
||||||
ca := startCATestServer()
|
ca := startCATestServer()
|
||||||
defer ca.Close()
|
defer ca.Close()
|
||||||
|
|
||||||
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||||
tr, err := getTLSOptionsTransport(sr, pk)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -203,21 +203,24 @@ func TestAddRootsToRootCAs(t *testing.T) {
|
||||||
pool := x509.NewCertPool()
|
pool := x509.NewCertPool()
|
||||||
pool.AddCert(cert)
|
pool.AddCert(cert)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
client *Client
|
||||||
|
config *tls.Config
|
||||||
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tr http.RoundTripper
|
args args
|
||||||
want *tls.Config
|
want *tls.Config
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", tr, &tls.Config{RootCAs: pool}, false},
|
{"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false},
|
||||||
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: client,
|
Client: tt.args.client,
|
||||||
Config: &tls.Config{},
|
Config: tt.args.config,
|
||||||
Transport: tt.tr,
|
|
||||||
}
|
}
|
||||||
if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -234,8 +237,12 @@ func TestAddRootsToClientCAs(t *testing.T) {
|
||||||
ca := startCATestServer()
|
ca := startCATestServer()
|
||||||
defer ca.Close()
|
defer ca.Close()
|
||||||
|
|
||||||
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||||
tr, err := getTLSOptionsTransport(sr, pk)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -249,21 +256,24 @@ func TestAddRootsToClientCAs(t *testing.T) {
|
||||||
pool := x509.NewCertPool()
|
pool := x509.NewCertPool()
|
||||||
pool.AddCert(cert)
|
pool.AddCert(cert)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
client *Client
|
||||||
|
config *tls.Config
|
||||||
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tr http.RoundTripper
|
args args
|
||||||
want *tls.Config
|
want *tls.Config
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", tr, &tls.Config{ClientCAs: pool}, false},
|
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false},
|
||||||
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: client,
|
Client: tt.args.client,
|
||||||
Config: &tls.Config{},
|
Config: tt.args.config,
|
||||||
Transport: tt.tr,
|
|
||||||
}
|
}
|
||||||
if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -280,8 +290,12 @@ func TestAddFederationToRootCAs(t *testing.T) {
|
||||||
ca := startCATestServer()
|
ca := startCATestServer()
|
||||||
defer ca.Close()
|
defer ca.Close()
|
||||||
|
|
||||||
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||||
tr, err := getTLSOptionsTransport(sr, pk)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -302,21 +316,24 @@ func TestAddFederationToRootCAs(t *testing.T) {
|
||||||
pool.AddCert(crt1)
|
pool.AddCert(crt1)
|
||||||
pool.AddCert(crt2)
|
pool.AddCert(crt2)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
client *Client
|
||||||
|
config *tls.Config
|
||||||
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tr http.RoundTripper
|
args args
|
||||||
want *tls.Config
|
want *tls.Config
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", tr, &tls.Config{RootCAs: pool}, false},
|
{"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false},
|
||||||
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: client,
|
Client: tt.args.client,
|
||||||
Config: &tls.Config{},
|
Config: tt.args.config,
|
||||||
Transport: tt.tr,
|
|
||||||
}
|
}
|
||||||
if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -336,8 +353,12 @@ func TestAddFederationToClientCAs(t *testing.T) {
|
||||||
ca := startCATestServer()
|
ca := startCATestServer()
|
||||||
defer ca.Close()
|
defer ca.Close()
|
||||||
|
|
||||||
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||||
tr, err := getTLSOptionsTransport(sr, pk)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -358,21 +379,24 @@ func TestAddFederationToClientCAs(t *testing.T) {
|
||||||
pool.AddCert(crt1)
|
pool.AddCert(crt1)
|
||||||
pool.AddCert(crt2)
|
pool.AddCert(crt2)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
client *Client
|
||||||
|
config *tls.Config
|
||||||
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tr http.RoundTripper
|
args args
|
||||||
want *tls.Config
|
want *tls.Config
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", tr, &tls.Config{ClientCAs: pool}, false},
|
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false},
|
||||||
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: client,
|
Client: tt.args.client,
|
||||||
Config: &tls.Config{},
|
Config: tt.args.config,
|
||||||
Transport: tt.tr,
|
|
||||||
}
|
}
|
||||||
if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -392,8 +416,12 @@ func TestAddRootsToCAs(t *testing.T) {
|
||||||
ca := startCATestServer()
|
ca := startCATestServer()
|
||||||
defer ca.Close()
|
defer ca.Close()
|
||||||
|
|
||||||
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||||
tr, err := getTLSOptionsTransport(sr, pk)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -407,21 +435,24 @@ func TestAddRootsToCAs(t *testing.T) {
|
||||||
pool := x509.NewCertPool()
|
pool := x509.NewCertPool()
|
||||||
pool.AddCert(cert)
|
pool.AddCert(cert)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
client *Client
|
||||||
|
config *tls.Config
|
||||||
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tr http.RoundTripper
|
args args
|
||||||
want *tls.Config
|
want *tls.Config
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
|
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
|
||||||
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: client,
|
Client: tt.args.client,
|
||||||
Config: &tls.Config{},
|
Config: tt.args.config,
|
||||||
Transport: tt.tr,
|
|
||||||
}
|
}
|
||||||
if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -438,8 +469,12 @@ func TestAddFederationToCAs(t *testing.T) {
|
||||||
ca := startCATestServer()
|
ca := startCATestServer()
|
||||||
defer ca.Close()
|
defer ca.Close()
|
||||||
|
|
||||||
client, sr, pk := signDuration(ca, "127.0.0.1", 0)
|
client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt"))
|
||||||
tr, err := getTLSOptionsTransport(sr, pk)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -460,21 +495,24 @@ func TestAddFederationToCAs(t *testing.T) {
|
||||||
pool.AddCert(crt1)
|
pool.AddCert(crt1)
|
||||||
pool.AddCert(crt2)
|
pool.AddCert(crt2)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
client *Client
|
||||||
|
config *tls.Config
|
||||||
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tr http.RoundTripper
|
args args
|
||||||
want *tls.Config
|
want *tls.Config
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
|
{"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false},
|
||||||
{"fail", http.DefaultTransport, &tls.Config{}, true},
|
{"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := &TLSOptionCtx{
|
ctx := &TLSOptionCtx{
|
||||||
Client: client,
|
Client: tt.args.client,
|
||||||
Config: &tls.Config{},
|
Config: tt.args.config,
|
||||||
Transport: tt.tr,
|
|
||||||
}
|
}
|
||||||
if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr {
|
if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
Loading…
Reference in a new issue