From 319333f936bdc7a4723f0bd4f329a83c7fea132a Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 21 Dec 2022 12:56:56 +0100 Subject: [PATCH 1/2] Add `WithContext` methods to the CA client --- ca/client.go | 304 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 232 insertions(+), 72 deletions(-) diff --git a/ca/client.go b/ca/client.go index bbafcfee..c6a7def2 100644 --- a/ca/client.go +++ b/ca/client.go @@ -2,6 +2,7 @@ package ca import ( "bytes" + "context" "crypto" "crypto/ecdsa" "crypto/elliptic" @@ -75,7 +76,11 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) { } func (c *uaClient) Get(u string) (*http.Response, error) { - req, err := http.NewRequest("GET", u, http.NoBody) + return c.GetWithContext(context.Background(), u) +} + +func (c *uaClient) GetWithContext(ctx context.Context, u string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "GET", u, http.NoBody) if err != nil { return nil, errors.Wrapf(err, "create GET %s request failed", u) } @@ -84,7 +89,11 @@ func (c *uaClient) Get(u string) (*http.Response, error) { } func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) { - req, err := http.NewRequest("POST", u, body) + return c.PostWithContext(context.Background(), u, contentType, body) +} + +func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "POST", u, body) if err != nil { return nil, errors.Wrapf(err, "create POST %s request failed", u) } @@ -581,13 +590,19 @@ func (c *Client) SetTransport(tr http.RoundTripper) { c.client.SetTransport(tr) } -// Version performs the version request to the CA and returns the +// Version performs the version request to the CA with an empty context and returns the // api.VersionResponse struct. func (c *Client) Version() (*api.VersionResponse, error) { + return c.VersionWithContext(context.Background()) +} + +// VersionWithContext performs the version request to the CA with the provided context +// and returns the api.VersionResponse struct. +func (c *Client) VersionWithContext(ctx context.Context) (*api.VersionResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/version"}) retry: - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -605,13 +620,19 @@ retry: return &version, nil } -// Health performs the health request to the CA and returns the -// api.HealthResponse struct. +// Health performs the health request to the CA with an empty context +// and returns the api.HealthResponse struct. func (c *Client) Health() (*api.HealthResponse, error) { + return c.HealthWithContext(context.Background()) +} + +// HealthWithContext performs the health request to the CA with the provided context +// and returns the api.HealthResponse struct. +func (c *Client) HealthWithContext(ctx context.Context) (*api.HealthResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/health"}) retry: - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -629,16 +650,24 @@ retry: return &health, nil } -// Root performs the root request to the CA with the given SHA256 and returns -// the api.RootResponse struct. It uses an insecure client, but it checks the -// resulting root certificate with the given SHA256, returning an error if they -// do not match. +// Root performs the root request to the CA with an empty context and the provided +// SHA256 and returns the api.RootResponse struct. It uses an insecure client, but +// it checks the resulting root certificate with the given SHA256, returning an error +// if they do not match. func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { + return c.RootWithContext(context.Background(), sha256Sum) +} + +// RootWithContext performs the root request to the CA with an empty context and the provided +// SHA256 and returns the api.RootResponse struct. It uses an insecure client, but +// it checks the resulting root certificate with the given SHA256, returning an error +// if they do not match. +func (c *Client) RootWithContext(ctx context.Context, sha256Sum string) (*api.RootResponse, error) { var retried bool sha256Sum = strings.ToLower(strings.ReplaceAll(sha256Sum, "-", "")) u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum}) retry: - resp, err := newInsecureClient().Get(u.String()) + resp, err := newInsecureClient().GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -661,9 +690,15 @@ retry: return &root, nil } -// Sign performs the sign request to the CA and returns the api.SignResponse -// struct. +// Sign performs the sign request to the CA with an empty context and returns +// the api.SignResponse struct. func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { + return c.SignWithContext(context.Background(), req) +} + +// SignWithContext performs the sign request to the CA with the provided context +// and returns the api.SignResponse struct. +func (c *Client) SignWithContext(ctx context.Context, req *api.SignRequest) (*api.SignResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { @@ -671,7 +706,7 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { } u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"}) retry: - resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -692,14 +727,25 @@ retry: return &sign, nil } -// Renew performs the renew request to the CA and returns the api.SignResponse -// struct. +// Renew performs the renew request to the CA with an empty context and +// returns the api.SignResponse struct. func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { + return c.RenewWithContext(context.Background(), tr) +} + +// RenewWithContext performs the renew request to the CA with the provided context +// and returns the api.SignResponse struct. +func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) client := &http.Client{Transport: tr} retry: - resp, err := client.Post(u.String(), "application/json", http.NoBody) + req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) if err != nil { return nil, clientError(err) } @@ -718,12 +764,19 @@ retry: } // RenewWithToken performs the renew request to the CA with the given -// authorization token and returns the api.SignResponse struct. This method is -// generally used to renew an expired certificate. +// authorization token and and empty context and returns the api.SignResponse struct. +// This method is generally used to renew an expired certificate. func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) { + return c.RenewWithTokenAndContext(context.Background(), token) +} + +// RenewWithTokenAndContext performs the renew request to the CA with the given +// authorization token and context and returns the api.SignResponse struct. +// This method is generally used to renew an expired certificate. +func (c *Client) RenewWithTokenAndContext(ctx context.Context, token string) (*api.SignResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) - req, err := http.NewRequest("POST", u.String(), http.NoBody) + req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) if err != nil { return nil, errors.Wrapf(err, "create POST %s request failed", u) } @@ -747,19 +800,29 @@ retry: return &sign, nil } -// Rekey performs the rekey request to the CA and returns the api.SignResponse -// struct. +// Rekey performs the rekey request to the CA with an empty context and +// returns the api.SignResponse struct. func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { + return c.RekeyWithContext(context.Background(), req, tr) +} + +// RekeyWithContext performs the rekey request to the CA with the provided context +// and returns the api.SignResponse struct. +func (c *Client) RekeyWithContext(ctx context.Context, req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } - u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) client := &http.Client{Transport: tr} retry: - resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Content-Type", "application/json") + resp, err := client.Do(httpReq) if err != nil { return nil, clientError(err) } @@ -777,9 +840,15 @@ retry: return &sign, nil } -// Revoke performs the revoke request to the CA and returns the api.RevokeResponse -// struct. +// Revoke performs the revoke request to the CA with an empty context and returns +// the api.RevokeResponse struct. func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) { + return c.RevokeWithContext(context.Background(), req, tr) +} + +// RevokeWithContext performs the revoke request to the CA with the provided context and +// returns the api.RevokeResponse struct. +func (c *Client) RevokeWithContext(ctx context.Context, req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { @@ -794,7 +863,7 @@ retry: } u := c.endpoint.ResolveReference(&url.URL{Path: "/revoke"}) - resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body)) + resp, err := client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -812,12 +881,21 @@ retry: return &revoke, nil } -// Provisioners performs the provisioners request to the CA and returns the -// api.ProvisionersResponse struct with a map of provisioners. +// Provisioners performs the provisioners request to the CA with an empty context +// and returns the api.ProvisionersResponse struct with a map of provisioners. // // ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to // paginate the provisioners. func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) { + return c.ProvisionersWithContext(context.Background(), opts...) +} + +// ProvisionersWithContext performs the provisioners request to the CA with the provided context +// and returns the api.ProvisionersResponse struct with a map of provisioners. +// +// ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to +// paginate the provisioners. +func (c *Client) ProvisionersWithContext(ctx context.Context, opts ...ProvisionerOption) (*api.ProvisionersResponse, error) { var retried bool o := new(ProvisionerOptions) if err := o.Apply(opts); err != nil { @@ -828,7 +906,7 @@ func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersRespo RawQuery: o.rawQuery(), }) retry: - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -846,14 +924,21 @@ retry: return &provisioners, nil } -// ProvisionerKey performs the request to the CA to get the encrypted key for -// the given provisioner kid and returns the api.ProvisionerKeyResponse struct -// with the encrypted key. +// ProvisionerKey performs the request to the CA with an empty context to get +// the encrypted key for the given provisioner kid and returns the api.ProvisionerKeyResponse +// struct with the encrypted key. func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) { + return c.ProvisionerKeyWithContext(context.Background(), kid) +} + +// ProvisionerKeyWithContext performs the request to the CA with the provided context to get +// the encrypted key for the given provisioner kid and returns the api.ProvisionerKeyResponse +// struct with the encrypted key. +func (c *Client) ProvisionerKeyWithContext(ctx context.Context, kid string) (*api.ProvisionerKeyResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"}) retry: - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -871,13 +956,19 @@ retry: return &key, nil } -// Roots performs the get roots request to the CA and returns the -// api.RootsResponse struct. +// Roots performs the get roots request to the CA with an empty context +// and returns the api.RootsResponse struct. func (c *Client) Roots() (*api.RootsResponse, error) { + return c.RootsWithContext(context.Background()) +} + +// RootsWithContext performs the get roots request to the CA with the provided context +// and returns the api.RootsResponse struct. +func (c *Client) RootsWithContext(ctx context.Context) (*api.RootsResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"}) retry: - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -895,13 +986,19 @@ retry: return &roots, nil } -// Federation performs the get federation request to the CA and returns the -// api.FederationResponse struct. +// Federation performs the get federation request to the CA with an empty context +// and returns the api.FederationResponse struct. func (c *Client) Federation() (*api.FederationResponse, error) { + return c.FederationWithContext(context.Background()) +} + +// FederationWithContext performs the get federation request to the CA with the provided context +// and returns the api.FederationResponse struct. +func (c *Client) FederationWithContext(ctx context.Context) (*api.FederationResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"}) retry: - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -919,9 +1016,15 @@ retry: return &federation, nil } -// SSHSign performs the POST /ssh/sign request to the CA and returns the -// api.SSHSignResponse struct. +// SSHSign performs the POST /ssh/sign request to the CA with an empty context +// and returns the api.SSHSignResponse struct. func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) { + return c.SSHSignWithContext(context.Background(), req) +} + +// SSHSignWithContext performs the POST /ssh/sign request to the CA with the provided context +// and returns the api.SSHSignResponse struct. +func (c *Client) SSHSignWithContext(ctx context.Context, req *api.SSHSignRequest) (*api.SSHSignResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { @@ -929,7 +1032,7 @@ func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"}) retry: - resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -947,9 +1050,15 @@ retry: return &sign, nil } -// SSHRenew performs the POST /ssh/renew request to the CA and returns the -// api.SSHRenewResponse struct. +// SSHRenew performs the POST /ssh/renew request to the CA with an empty context +// and returns the api.SSHRenewResponse struct. func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) { + return c.SSHRenewWithContext(context.Background(), req) +} + +// SSHRenewWithContext performs the POST /ssh/renew request to the CA with the provided context +// and returns the api.SSHRenewResponse struct. +func (c *Client) SSHRenewWithContext(ctx context.Context, req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { @@ -957,7 +1066,7 @@ func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, erro } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/renew"}) retry: - resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -975,9 +1084,15 @@ retry: return &renew, nil } -// SSHRekey performs the POST /ssh/rekey request to the CA and returns the -// api.SSHRekeyResponse struct. +// SSHRekey performs the POST /ssh/rekey request to the CA with an empty context +// and returns the api.SSHRekeyResponse struct. func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) { + return c.SSHRekeyWithContext(context.Background(), req) +} + +// SSHRekeyWithContext performs the POST /ssh/rekey request to the CA with the provided context +// and returns the api.SSHRekeyResponse struct. +func (c *Client) SSHRekeyWithContext(ctx context.Context, req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { @@ -985,7 +1100,7 @@ func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, erro } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/rekey"}) retry: - resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -1003,9 +1118,15 @@ retry: return &rekey, nil } -// SSHRevoke performs the POST /ssh/revoke request to the CA and returns the -// api.SSHRevokeResponse struct. +// SSHRevoke performs the POST /ssh/revoke request to the CA with an empty context +// and returns the api.SSHRevokeResponse struct. func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) { + return c.SSHRevokeWithContext(context.Background(), req) +} + +// SSHRevokeWithContext performs the POST /ssh/revoke request to the CA with the provided context +// and returns the api.SSHRevokeResponse struct. +func (c *Client) SSHRevokeWithContext(ctx context.Context, req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { @@ -1013,7 +1134,7 @@ func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, e } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/revoke"}) retry: - resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -1031,13 +1152,19 @@ retry: return &revoke, nil } -// SSHRoots performs the GET /ssh/roots request to the CA and returns the -// api.SSHRootsResponse struct. +// SSHRoots performs the GET /ssh/roots request to the CA with an empty context +// and returns the api.SSHRootsResponse struct. func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) { + return c.SSHRootsWithContext(context.Background()) +} + +// SSHRootsWithContext performs the GET /ssh/roots request to the CA with the provided context +// and returns the api.SSHRootsResponse struct. +func (c *Client) SSHRootsWithContext(ctx context.Context) (*api.SSHRootsResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"}) retry: - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -1055,13 +1182,19 @@ retry: return &keys, nil } -// SSHFederation performs the get /ssh/federation request to the CA and returns -// the api.SSHRootsResponse struct. +// SSHFederation performs the get /ssh/federation request to the CA with an empty context +// and returns the api.SSHRootsResponse struct. func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) { + return c.SSHFederationWithContext(context.Background()) +} + +// SSHFederationWithContext performs the get /ssh/federation request to the CA with the provided context +// and returns the api.SSHRootsResponse struct. +func (c *Client) SSHFederationWithContext(ctx context.Context) (*api.SSHRootsResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"}) retry: - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -1079,9 +1212,15 @@ retry: return &keys, nil } -// SSHConfig performs the POST /ssh/config request to the CA to get the ssh -// configuration templates. +// SSHConfig performs the POST /ssh/config request to the CA with an empty context +// to get the ssh configuration templates. func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) { + return c.SSHConfigWithContext(context.Background(), req) +} + +// SSHConfigWithContext performs the POST /ssh/config request to the CA with the provided context +// to get the ssh configuration templates. +func (c *Client) SSHConfigWithContext(ctx context.Context, req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { @@ -1089,7 +1228,7 @@ func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, e } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/config"}) retry: - resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -1107,9 +1246,15 @@ retry: return &cfg, nil } -// SSHCheckHost performs the POST /ssh/check-host request to the CA with the -// given principal. +// SSHCheckHost performs the POST /ssh/check-host request to the CA with an empty context, +// the principal and a token and returns the api.SSHCheckPrincipalResponse. func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalResponse, error) { + return c.SSHCheckHostWithContext(context.Background(), principal, token) +} + +// SSHCheckHostWithContext performs the POST /ssh/check-host request to the CA with the provided context, +// principal and token and returns the api.SSHCheckPrincipalResponse. +func (c *Client) SSHCheckHostWithContext(ctx context.Context, principal, token string) (*api.SSHCheckPrincipalResponse, error) { var retried bool body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ Type: provisioner.SSHHostCert, @@ -1122,7 +1267,7 @@ func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalRe } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"}) retry: - resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -1141,12 +1286,17 @@ retry: return &check, nil } -// SSHGetHosts performs the GET /ssh/get-hosts request to the CA. +// SSHGetHosts performs the GET /ssh/get-hosts request to the CA with an empty context. func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) { + return c.SSHGetHostsWithContext(context.Background()) +} + +// SSHGetHostsWithContext performs the GET /ssh/get-hosts request to the CA with the provided context. +func (c *Client) SSHGetHostsWithContext(ctx context.Context) (*api.SSHGetHostsResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/hosts"}) retry: - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return nil, clientError(err) } @@ -1164,8 +1314,13 @@ retry: return &hosts, nil } -// SSHBastion performs the POST /ssh/bastion request to the CA. +// SSHBastion performs the POST /ssh/bastion request to the CA with an empty context. func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) { + return c.SSHBastionWithContext(context.Background(), req) +} + +// SSHBastionWithContext performs the POST /ssh/bastion request to the CA with the provided context. +func (c *Client) SSHBastionWithContext(ctx context.Context, req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) { var retried bool body, err := json.Marshal(req) if err != nil { @@ -1173,7 +1328,7 @@ func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"}) retry: - resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -1192,11 +1347,16 @@ retry: } // RootFingerprint is a helper method that returns the current root fingerprint. -// It does an health connection and gets the fingerprint from the TLS verified -// chains. +// It does an health connection and gets the fingerprint from the TLS verified chains. func (c *Client) RootFingerprint() (string, error) { + return c.RootFingerprintWithContext(context.Background()) +} + +// RootFingerprintWithContext is a helper method that returns the current root fingerprint. +// It does an health connection and gets the fingerprint from the TLS verified chains. +func (c *Client) RootFingerprintWithContext(ctx context.Context) (string, error) { u := c.endpoint.ResolveReference(&url.URL{Path: "/health"}) - resp, err := c.client.Get(u.String()) + resp, err := c.client.GetWithContext(ctx, u.String()) if err != nil { return "", clientError(err) } From b5961beba9f2004a351e58a646f755c8acf8a166 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 21 Dec 2022 16:02:26 +0100 Subject: [PATCH 2/2] Fix and/or ignore linting issues --- ca/bootstrap.go | 6 ++--- ca/client.go | 44 ++++++++++++++++++------------------- ca/tls.go | 4 ++-- cas/stepcas/issuer.go | 5 +++-- cas/stepcas/issuer_test.go | 3 ++- cas/stepcas/jwk_issuer.go | 9 ++++---- cas/stepcas/stepcas.go | 4 ++-- cas/stepcas/stepcas_test.go | 2 +- 8 files changed, 40 insertions(+), 37 deletions(-) diff --git a/ca/bootstrap.go b/ca/bootstrap.go index 430f2e31..78b94ec9 100644 --- a/ca/bootstrap.go +++ b/ca/bootstrap.go @@ -61,7 +61,7 @@ func Bootstrap(token string) (*Client, error) { // } // resp, err := client.Get("https://internal.smallstep.com") func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) { - b, err := createBootstrap(token) + b, err := createBootstrap(token) //nolint:contextcheck // deeply nested context; temporary if err != nil { return nil, err } @@ -120,7 +120,7 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio return nil, errors.New("server TLSConfig is already set") } - b, err := createBootstrap(token) + b, err := createBootstrap(token) //nolint:contextcheck // deeply nested context; temporary if err != nil { return nil, err } @@ -169,7 +169,7 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio // ... // register services // srv.Serve(lis) func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) { - b, err := createBootstrap(token) + b, err := createBootstrap(token) //nolint:contextcheck // deeply nested context; temporary if err != nil { return nil, err } diff --git a/ca/client.go b/ca/client.go index c6a7def2..7321f82f 100644 --- a/ca/client.go +++ b/ca/client.go @@ -607,7 +607,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -637,7 +637,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -672,7 +672,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -711,7 +711,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -750,7 +750,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -787,7 +787,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -827,7 +827,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -868,7 +868,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -911,7 +911,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -943,7 +943,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -973,7 +973,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1003,7 +1003,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1037,7 +1037,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1071,7 +1071,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1105,7 +1105,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1139,7 +1139,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1169,7 +1169,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1199,7 +1199,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1233,7 +1233,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1272,7 +1272,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1301,7 +1301,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } @@ -1333,7 +1333,7 @@ retry: return nil, clientError(err) } if resp.StatusCode >= 400 { - if !retried && c.retryOnError(resp) { + if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context retried = true goto retry } diff --git a/ca/tls.go b/ca/tls.go index 7644b11f..d5d479f3 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -135,7 +135,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, //nolint:staticcheck // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) - renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) + renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context // Update client transport c.SetTransport(tr) @@ -183,7 +183,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, //nolint:staticcheck // Use mutable tls.Config on renew tr.DialTLS = c.buildDialTLS(tlsCtx) // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) - renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) + renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context // Update client transport c.SetTransport(tr) diff --git a/cas/stepcas/issuer.go b/cas/stepcas/issuer.go index 07607caa..cf985974 100644 --- a/cas/stepcas/issuer.go +++ b/cas/stepcas/issuer.go @@ -1,6 +1,7 @@ package stepcas import ( + "context" "net/url" "strings" "time" @@ -37,7 +38,7 @@ type stepIssuer interface { } // newStepIssuer returns the configured step issuer. -func newStepIssuer(caURL *url.URL, client *ca.Client, iss *apiv1.CertificateIssuer) (stepIssuer, error) { +func newStepIssuer(ctx context.Context, caURL *url.URL, client *ca.Client, iss *apiv1.CertificateIssuer) (stepIssuer, error) { if err := validateCertificateIssuer(iss); err != nil { return nil, err } @@ -46,7 +47,7 @@ func newStepIssuer(caURL *url.URL, client *ca.Client, iss *apiv1.CertificateIssu case "x5c": return newX5CIssuer(caURL, iss) case "jwk": - return newJWKIssuer(caURL, client, iss) + return newJWKIssuer(ctx, caURL, client, iss) default: return nil, errors.Errorf("stepCAS `certificateIssuer.type` %s is not supported", iss.Type) } diff --git a/cas/stepcas/issuer_test.go b/cas/stepcas/issuer_test.go index 7d468e38..ff4f45f5 100644 --- a/cas/stepcas/issuer_test.go +++ b/cas/stepcas/issuer_test.go @@ -1,6 +1,7 @@ package stepcas import ( + "context" "net/url" "reflect" "testing" @@ -118,7 +119,7 @@ func Test_newStepIssuer(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := newStepIssuer(tt.args.caURL, tt.args.client, tt.args.iss) + got, err := newStepIssuer(context.TODO(), tt.args.caURL, tt.args.client, tt.args.iss) if (err != nil) != tt.wantErr { t.Errorf("newStepIssuer() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/cas/stepcas/jwk_issuer.go b/cas/stepcas/jwk_issuer.go index 4ef4f541..5ef017a2 100644 --- a/cas/stepcas/jwk_issuer.go +++ b/cas/stepcas/jwk_issuer.go @@ -1,6 +1,7 @@ package stepcas import ( + "context" "crypto" "encoding/json" "net/url" @@ -21,13 +22,13 @@ type jwkIssuer struct { signer jose.Signer } -func newJWKIssuer(caURL *url.URL, client *ca.Client, cfg *apiv1.CertificateIssuer) (*jwkIssuer, error) { +func newJWKIssuer(ctx context.Context, caURL *url.URL, client *ca.Client, cfg *apiv1.CertificateIssuer) (*jwkIssuer, error) { var err error var signer jose.Signer // Read the key from the CA if not provided. // Or read it from a PEM file. if cfg.Key == "" { - p, err := findProvisioner(client, provisioner.TypeJWK, cfg.Provisioner) + p, err := findProvisioner(ctx, client, provisioner.TypeJWK, cfg.Provisioner) if err != nil { return nil, err } @@ -144,10 +145,10 @@ func newJWKSignerFromEncryptedKey(kid, key, password string) (jose.Signer, error return newJoseSigner(signer, so) } -func findProvisioner(client *ca.Client, typ provisioner.Type, name string) (provisioner.Interface, error) { +func findProvisioner(ctx context.Context, client *ca.Client, typ provisioner.Type, name string) (provisioner.Interface, error) { cursor := "" for { - ps, err := client.Provisioners(ca.WithProvisionerCursor(cursor)) + ps, err := client.ProvisionersWithContext(ctx, ca.WithProvisionerCursor(cursor)) if err != nil { return nil, err } diff --git a/cas/stepcas/stepcas.go b/cas/stepcas/stepcas.go index c64963e6..7c0dc86f 100644 --- a/cas/stepcas/stepcas.go +++ b/cas/stepcas/stepcas.go @@ -43,7 +43,7 @@ func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) { } // Create client. - client, err := ca.NewClient(opts.CertificateAuthority, ca.WithRootSHA256(opts.CertificateAuthorityFingerprint)) + client, err := ca.NewClient(opts.CertificateAuthority, ca.WithRootSHA256(opts.CertificateAuthorityFingerprint)) //nolint:contextcheck // deeply nested context if err != nil { return nil, err } @@ -52,7 +52,7 @@ func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) { // Create configured issuer unless we only want to use GetCertificateAuthority. // This avoid the request for the password if not provided. if !opts.IsCAGetter { - if iss, err = newStepIssuer(caURL, client, opts.CertificateIssuer); err != nil { + if iss, err = newStepIssuer(ctx, caURL, client, opts.CertificateIssuer); err != nil { return nil, err } } diff --git a/cas/stepcas/stepcas_test.go b/cas/stepcas/stepcas_test.go index 6691a4b4..b9dd9abd 100644 --- a/cas/stepcas/stepcas_test.go +++ b/cas/stepcas/stepcas_test.go @@ -245,7 +245,7 @@ func testJWKIssuer(t *testing.T, caURL *url.URL, password string) *jwkIssuer { key = testEncryptedKeyPath password = testPassword } - jwk, err := newJWKIssuer(caURL, client, &apiv1.CertificateIssuer{ + jwk, err := newJWKIssuer(context.TODO(), caURL, client, &apiv1.CertificateIssuer{ Type: "jwk", Provisioner: "ra@doe.org", Key: key,