diff --git a/api/api.go b/api/api.go index 563e65ea..4bdf1b09 100644 --- a/api/api.go +++ b/api/api.go @@ -25,8 +25,8 @@ type Authority interface { Renew(peer *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) GetProvisioners(cursor string, limit int) ([]*authority.Provisioner, string, error) GetEncryptedKey(kid string) (string, error) - GetRoots(peer *x509.Certificate) (federation []*x509.Certificate, err error) - GetFederation(peer *x509.Certificate) ([]*x509.Certificate, error) + GetRoots() (federation []*x509.Certificate, err error) + GetFederation() ([]*x509.Certificate, error) } // Certificate wraps a *x509.Certificate and adds the json.Marshaler interface. @@ -334,15 +334,9 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { JSON(w, &ProvisionerKeyResponse{key}) } -// Roots returns all the root certificates for the CA. It requires a valid TLS -// client. +// Roots returns all the root certificates for the CA. func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, BadRequest(errors.New("missing peer certificate"))) - return - } - - roots, err := h.Authority.GetRoots(r.TLS.PeerCertificates[0]) + roots, err := h.Authority.GetRoots() if err != nil { WriteError(w, Forbidden(err)) return @@ -359,15 +353,9 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { }) } -// Federation returns all the public certificates in the federation. It requires -// a valid TLS client. +// Federation returns all the public certificates in the federation. func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, BadRequest(errors.New("missing peer certificate"))) - return - } - - federated, err := h.Authority.GetFederation(r.TLS.PeerCertificates[0]) + federated, err := h.Authority.GetFederation() if err != nil { WriteError(w, Forbidden(err)) return diff --git a/api/api_test.go b/api/api_test.go index 0a9eb120..e60e6ba2 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -392,8 +392,8 @@ type mockAuthority struct { renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error) getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error) getEncryptedKey func(kid string) (string, error) - getRoots func(cert *x509.Certificate) ([]*x509.Certificate, error) - getFederation func(cert *x509.Certificate) ([]*x509.Certificate, error) + getRoots func() ([]*x509.Certificate, error) + getFederation func() ([]*x509.Certificate, error) } func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) { @@ -445,16 +445,16 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { return m.ret1.(string), m.err } -func (m *mockAuthority) GetRoots(cert *x509.Certificate) ([]*x509.Certificate, error) { +func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { if m.getFederation != nil { - return m.getRoots(cert) + return m.getRoots() } return m.ret1.([]*x509.Certificate), m.err } -func (m *mockAuthority) GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) { +func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { if m.getFederation != nil { - return m.getFederation(cert) + return m.getFederation() } return m.ret1.([]*x509.Certificate), m.err } @@ -842,9 +842,8 @@ func Test_caHandler_Roots(t *testing.T) { statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, - {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, - {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, - {"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) @@ -889,9 +888,8 @@ func Test_caHandler_Federation(t *testing.T) { statusCode int }{ {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, - {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, - {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, - {"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, + {"no peer certificates", &tls.ConnectionState{}, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, } expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) diff --git a/authority/root.go b/authority/root.go index cfd63595..01b3508d 100644 --- a/authority/root.go +++ b/authority/root.go @@ -34,21 +34,12 @@ func (a *Authority) GetRootCertificates() []*x509.Certificate { } // GetRoots returns all the root certificates for this CA. -func (a *Authority) GetRoots(peer *x509.Certificate) ([]*x509.Certificate, error) { - // Check step provisioner extensions - if err := a.authorizeRenewal(peer); err != nil { - return nil, err - } +func (a *Authority) GetRoots() ([]*x509.Certificate, error) { return a.rootX509Certs, nil } // GetFederation returns all the root certificates in the federation. -func (a *Authority) GetFederation(peer *x509.Certificate) (federation []*x509.Certificate, err error) { - // Check step provisioner extensions - if err := a.authorizeRenewal(peer); err != nil { - return nil, err - } - +func (a *Authority) GetFederation() (federation []*x509.Certificate, err error) { a.certificates.Range(func(k, v interface{}) bool { crt, ok := v.(*x509.Certificate) if !ok { diff --git a/authority/root_test.go b/authority/root_test.go index 9b80cad6..17f25755 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -8,9 +8,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/cli/crypto/keys" "github.com/smallstep/cli/crypto/pemutil" - "github.com/smallstep/cli/crypto/x509util" ) func TestRoot(t *testing.T) { @@ -99,42 +97,17 @@ func TestAuthority_GetRoots(t *testing.T) { t.Fatal(err) } - a := testAuthority(t) - pub, _, err := keys.GenerateDefaultKeyPair() - assert.FatalError(t, err) - leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test")) - assert.FatalError(t, err) - crtBytes, err := leaf.CreateCertificate() - assert.FatalError(t, err) - crt, err := x509.ParseCertificate(crtBytes) - assert.FatalError(t, err) - - leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"), - withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), - ) - assert.FatalError(t, err) - crtFailBytes, err := leafFail.CreateCertificate() - assert.FatalError(t, err) - crtFail, err := x509.ParseCertificate(crtFailBytes) - assert.FatalError(t, err) - - type args struct { - peer *x509.Certificate - } tests := []struct { name string - args args want []*x509.Certificate wantErr bool }{ - {"ok", args{crt}, []*x509.Certificate{cert}, false}, - {"fail", args{crtFail}, nil, true}, + {"ok", []*x509.Certificate{cert}, false}, } for _, tt := range tests { + a := testAuthority(t) t.Run(tt.name, func(t *testing.T) { - got, err := a.GetRoots(tt.args.peer) + got, err := a.GetRoots() if (err != nil) != tt.wantErr { t.Errorf("Authority.GetRoots() error = %v, wantErr %v", err, tt.wantErr) return @@ -152,49 +125,24 @@ func TestAuthority_GetFederation(t *testing.T) { t.Fatal(err) } - a := testAuthority(t) - pub, _, err := keys.GenerateDefaultKeyPair() - assert.FatalError(t, err) - leaf, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test")) - assert.FatalError(t, err) - crtBytes, err := leaf.CreateCertificate() - assert.FatalError(t, err) - crt, err := x509.ParseCertificate(crtBytes) - assert.FatalError(t, err) - - leafFail, err := x509util.NewLeafProfile("test", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test"), - withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].Key.KeyID), - ) - assert.FatalError(t, err) - crtFailBytes, err := leafFail.CreateCertificate() - assert.FatalError(t, err) - crtFail, err := x509.ParseCertificate(crtFailBytes) - assert.FatalError(t, err) - - type args struct { - peer *x509.Certificate - } tests := []struct { name string - args args wantFederation []*x509.Certificate wantErr bool - fn func() + fn func(a *Authority) }{ - {"ok", args{crt}, []*x509.Certificate{cert}, false, nil}, - {"fail", args{crtFail}, nil, true, nil}, - {"fail not a certificate", args{crt}, nil, true, func() { + {"ok", []*x509.Certificate{cert}, false, nil}, + {"fail", nil, true, func(a *Authority) { a.certificates.Store("foo", "bar") }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + a := testAuthority(t) if tt.fn != nil { - tt.fn() + tt.fn(a) } - gotFederation, err := a.GetFederation(tt.args.peer) + gotFederation, err := a.GetFederation() if (err != nil) != tt.wantErr { t.Errorf("Authority.GetFederation() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 452f5878..a046fde2 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -3,7 +3,6 @@ package ca import ( "context" "crypto/tls" - "fmt" "io/ioutil" "net" "net/http" @@ -26,7 +25,7 @@ func newLocalListener() net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { - panic(fmt.Sprintf("failed to listen on a port: %v", err)) + panic(errors.Wrap(err, "failed to listen on a port")) } } return l @@ -345,16 +344,16 @@ func TestBootstrapClientServerRotation(t *testing.T) { // doTest does a request that requires mTLS doTest := func(client *http.Client) error { // test with ca - resp, err := client.Get(caURL + "/roots") + resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) if err != nil { - return errors.Wrapf(err, "client.Get(%s) failed", caURL+"/roots") + return errors.Wrap(err, "client.Post() failed") } - var roots api.RootsResponse - if err := readJSON(resp.Body, &roots); err != nil { - return errors.Wrap(err, "client.Get() error reading response") + var renew api.SignResponse + if err := readJSON(resp.Body, &renew); err != nil { + return errors.Wrap(err, "client.Post() error reading response") } - if len(roots.Certificates) == 0 { - return errors.New("client.Get() error not certificates found") + if renew.ServerPEM.Certificate == nil || renew.CaPEM.Certificate == nil { + return errors.New("client.Post() unexpected response found") } // test with bootstrap server resp, err = client.Get(srvURL) diff --git a/ca/client.go b/ca/client.go index b8aab67f..627cd450 100644 --- a/ca/client.go +++ b/ca/client.go @@ -416,10 +416,9 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) // Roots performs the get roots request to the CA and returns the // api.RootsResponse struct. -func (c *Client) Roots(tr http.RoundTripper) (*api.RootsResponse, error) { +func (c *Client) Roots() (*api.RootsResponse, error) { u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"}) - client := &http.Client{Transport: tr} - resp, err := client.Get(u.String()) + resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } @@ -435,10 +434,9 @@ func (c *Client) Roots(tr http.RoundTripper) (*api.RootsResponse, error) { // Federation performs the get federation request to the CA and returns the // api.FederationResponse struct. -func (c *Client) Federation(tr http.RoundTripper) (*api.FederationResponse, error) { +func (c *Client) Federation() (*api.FederationResponse, error) { u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"}) - client := &http.Client{Transport: tr} - resp, err := client.Get(u.String()) + resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } diff --git a/ca/client_test.go b/ca/client_test.go index 0ec6324b..d82afa31 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -549,7 +549,7 @@ func TestClient_Roots(t *testing.T) { api.JSON(w, tt.response) }) - got, err := c.Roots(nil) + got, err := c.Roots() if (err != nil) != tt.wantErr { fmt.Printf("%+v", err) t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) @@ -610,7 +610,7 @@ func TestClient_Federation(t *testing.T) { api.JSON(w, tt.response) }) - got, err := c.Federation(nil) + got, err := c.Federation() if (err != nil) != tt.wantErr { fmt.Printf("%+v", err) t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) diff --git a/ca/tls.go b/ca/tls.go index e8ff0a9e..31d1632b 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -41,10 +41,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig) - if err != nil { - return nil, err - } + tlsCtx := newTLSOptionCtx(c, tlsConfig) if err := tlsCtx.apply(options); err != nil { return nil, err } @@ -56,6 +53,9 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, } renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) + // Update client transport + c.client.Transport = tr + // Start renewer renewer.RunContext(ctx) return tlsConfig, nil @@ -91,10 +91,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } // Apply options if given - tlsCtx, err := newTLSOptionCtx(c, sign, pk, tlsConfig) - if err != nil { - return nil, err - } + tlsCtx := newTLSOptionCtx(c, tlsConfig) if err := tlsCtx.apply(options); err != nil { return nil, err } @@ -106,6 +103,9 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, } renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) + // Update client transport + c.client.Transport = tr + // Start renewer renewer.RunContext(ctx) return tlsConfig, nil @@ -249,7 +249,7 @@ func getPEM(i interface{}) ([]byte, error) { func getRenewFunc(ctx *TLSOptionCtx, client *Client, tr *http.Transport, pk crypto.PrivateKey) RenewFunc { return func() (*tls.Certificate, error) { // Get updated list of roots - if err := ctx.applyRenew(tr); err != nil { + if err := ctx.applyRenew(); err != nil { return nil, err } // Get new certificate diff --git a/ca/tls_options.go b/ca/tls_options.go index dc15ab18..47e2c627 100644 --- a/ca/tls_options.go +++ b/ca/tls_options.go @@ -1,12 +1,8 @@ package ca import ( - "crypto" "crypto/tls" "crypto/x509" - "net/http" - - "github.com/smallstep/certificates/api" ) // TLSOption defines the type of a function that modifies a tls.Config. @@ -15,22 +11,16 @@ type TLSOption func(ctx *TLSOptionCtx) error // TLSOptionCtx is the context modified on TLSOption methods. type TLSOptionCtx struct { Client *Client - Transport http.RoundTripper Config *tls.Config OnRenewFunc []TLSOption } // newTLSOptionCtx creates the TLSOption context. -func newTLSOptionCtx(c *Client, sign *api.SignResponse, pk crypto.PrivateKey, config *tls.Config) (*TLSOptionCtx, error) { - tr, err := getTLSOptionsTransport(sign, pk) - if err != nil { - return nil, err - } +func newTLSOptionCtx(c *Client, config *tls.Config) *TLSOptionCtx { return &TLSOptionCtx{ - Client: c, - Transport: tr, - Config: config, - }, nil + Client: c, + Config: config, + } } func (ctx *TLSOptionCtx) apply(options []TLSOption) error { @@ -42,8 +32,7 @@ func (ctx *TLSOptionCtx) apply(options []TLSOption) error { return nil } -func (ctx *TLSOptionCtx) applyRenew(tr http.RoundTripper) error { - ctx.Transport = tr +func (ctx *TLSOptionCtx) applyRenew() error { for _, fn := range ctx.OnRenewFunc { if err := fn(ctx); err != nil { return err @@ -52,26 +41,6 @@ func (ctx *TLSOptionCtx) applyRenew(tr http.RoundTripper) error { return nil } -// getTLSOptionsTransport is the transport used by TLSOptions. It is used to get -// root certificates using a mTLS connection with the CA. -func getTLSOptionsTransport(sign *api.SignResponse, pk crypto.PrivateKey) (http.RoundTripper, error) { - cert, err := TLSCertificate(sign, pk) - if err != nil { - return nil, err - } - - // Build default transport with fixed certificate - tlsConfig := getDefaultTLSConfig(sign) - tlsConfig.Certificates = []tls.Certificate{*cert} - tlsConfig.PreferServerCipherSuites = true - // Build RootCAs with given root certificate - if pool := getCertPool(sign); pool != nil { - tlsConfig.RootCAs = pool - } - - return getDefaultTransport(tlsConfig) -} - // RequireAndVerifyClientCert is a tls.Config option used on servers to enforce // a valid TLS client certificate. This is the default option for mTLS servers. func RequireAndVerifyClientCert() TLSOption { @@ -123,7 +92,7 @@ func AddClientCA(cert *x509.Certificate) TLSOption { // BootstrapServer and BootstrapClient methods include this option by default. func AddRootsToRootCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Roots(ctx.Transport) + certs, err := ctx.Client.Roots() if err != nil { return err } @@ -149,7 +118,7 @@ func AddRootsToRootCAs() TLSOption { // BootstrapServer method includes this option by default. func AddRootsToClientCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Roots(ctx.Transport) + certs, err := ctx.Client.Roots() if err != nil { return err } @@ -172,7 +141,7 @@ func AddRootsToClientCAs() TLSOption { // certificate authorities that clients use when verifying server certificates. func AddFederationToRootCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Federation(ctx.Transport) + certs, err := ctx.Client.Federation() if err != nil { return err } @@ -196,7 +165,7 @@ func AddFederationToRootCAs() TLSOption { // certificate by the policy in ClientAuth. func AddFederationToClientCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Federation(ctx.Transport) + certs, err := ctx.Client.Federation() if err != nil { return err } @@ -219,7 +188,7 @@ func AddFederationToClientCAs() TLSOption { // AddRootsToRootCAs and AddRootsToClientCAs. func AddRootsToCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Roots(ctx.Transport) + certs, err := ctx.Client.Roots() if err != nil { return err } @@ -246,7 +215,7 @@ func AddRootsToCAs() TLSOption { // AddFederationToRootCAs and AddFederationToClientCAs. func AddFederationToCAs() TLSOption { fn := func(ctx *TLSOptionCtx) error { - certs, err := ctx.Client.Federation(ctx.Transport) + certs, err := ctx.Client.Federation() if err != nil { return err } diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index b52d1c89..ceeea7dc 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -1,7 +1,6 @@ package ca import ( - "crypto" "crypto/tls" "crypto/x509" "fmt" @@ -10,32 +9,29 @@ import ( "reflect" "sort" "testing" - - "github.com/smallstep/certificates/api" ) func Test_newTLSOptionCtx(t *testing.T) { - client, sign, pk := sign("test.smallstep.com") + client, err := NewClient("https://ca.smallstep.com", WithTransport(http.DefaultTransport)) + if err != nil { + t.Fatalf("NewClient() error = %v", err) + } + type args struct { c *Client - sign *api.SignResponse - pk crypto.PrivateKey config *tls.Config } tests := []struct { - name string - args args - wantErr bool + name string + args args + want *TLSOptionCtx }{ - {"ok", args{client, sign, pk, &tls.Config{}}, false}, - {"fail", args{client, sign, "foo", &tls.Config{}}, true}, + {"ok", args{client, &tls.Config{}}, &TLSOptionCtx{Client: client, Config: &tls.Config{}}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := newTLSOptionCtx(tt.args.c, tt.args.sign, tt.args.pk, tt.args.config) - if (err != nil) != tt.wantErr { - t.Errorf("newTLSOptionCtx() error = %v, wantErr %v", err, tt.wantErr) - return + if got := newTLSOptionCtx(tt.args.c, tt.args.config); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newTLSOptionCtx() = %v, want %v", got, tt.want) } }) } @@ -188,8 +184,12 @@ func TestAddRootsToRootCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -203,21 +203,24 @@ func TestAddRootsToRootCAs(t *testing.T) { pool := x509.NewCertPool() pool.AddCert(cert) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{RootCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddRootsToRootCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddRootsToRootCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -234,8 +237,12 @@ func TestAddRootsToClientCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -249,21 +256,24 @@ func TestAddRootsToClientCAs(t *testing.T) { pool := x509.NewCertPool() pool.AddCert(cert) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{ClientCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddRootsToClientCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddRootsToClientCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -280,8 +290,12 @@ func TestAddFederationToRootCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -302,21 +316,24 @@ func TestAddFederationToRootCAs(t *testing.T) { pool.AddCert(crt1) pool.AddCert(crt2) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{RootCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddFederationToRootCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddFederationToRootCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -336,8 +353,12 @@ func TestAddFederationToClientCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -358,21 +379,24 @@ func TestAddFederationToClientCAs(t *testing.T) { pool.AddCert(crt1) pool.AddCert(crt2) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{ClientCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddFederationToClientCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddFederationToClientCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -392,8 +416,12 @@ func TestAddRootsToCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -407,21 +435,24 @@ func TestAddRootsToCAs(t *testing.T) { pool := x509.NewCertPool() pool.AddCert(cert) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddRootsToCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddRootsToCAs() error = %v, wantErr %v", err, tt.wantErr) @@ -438,8 +469,12 @@ func TestAddFederationToCAs(t *testing.T) { ca := startCATestServer() defer ca.Close() - client, sr, pk := signDuration(ca, "127.0.0.1", 0) - tr, err := getTLSOptionsTransport(sr, pk) + client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) + if err != nil { + t.Fatal(err) + } + + clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) if err != nil { t.Fatal(err) } @@ -460,21 +495,24 @@ func TestAddFederationToCAs(t *testing.T) { pool.AddCert(crt1) pool.AddCert(crt2) + type args struct { + client *Client + config *tls.Config + } tests := []struct { name string - tr http.RoundTripper + args args want *tls.Config wantErr bool }{ - {"ok", tr, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, - {"fail", http.DefaultTransport, &tls.Config{}, true}, + {"ok", args{client, &tls.Config{}}, &tls.Config{ClientCAs: pool, RootCAs: pool}, false}, + {"fail", args{clientFail, &tls.Config{}}, &tls.Config{}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &TLSOptionCtx{ - Client: client, - Config: &tls.Config{}, - Transport: tt.tr, + Client: tt.args.client, + Config: tt.args.config, } if err := AddFederationToCAs()(ctx); (err != nil) != tt.wantErr { t.Errorf("AddFederationToCAs() error = %v, wantErr %v", err, tt.wantErr)