Support for retry and identity files.

This commit is contained in:
Mariano Cano 2019-11-20 11:50:46 -08:00 committed by max furman
parent d555f310dc
commit bbaf8e106e
2 changed files with 244 additions and 6 deletions

View file

@ -31,6 +31,10 @@ import (
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
) )
// RetryFunc defines the method used to retry a request. If it returns true, the
// request will be retried once.
type RetryFunc func(code int) bool
// ClientOption is the type of options passed to the Client constructor. // ClientOption is the type of options passed to the Client constructor.
type ClientOption func(o *clientOptions) error type ClientOption func(o *clientOptions) error
@ -40,6 +44,7 @@ type clientOptions struct {
rootFilename string rootFilename string
rootBundle []byte rootBundle []byte
certificate tls.Certificate certificate tls.Certificate
retryFunc RetryFunc
} }
func (o *clientOptions) apply(opts []ClientOption) (err error) { func (o *clientOptions) apply(opts []ClientOption) (err error) {
@ -199,6 +204,14 @@ func WithCertificate(crt tls.Certificate) ClientOption {
} }
} }
// WithRetryFunc defines a method used to retry a request.
func WithRetryFunc(fn RetryFunc) ClientOption {
return func(o *clientOptions) error {
o.retryFunc = fn
return nil
}
}
func getTransportFromFile(filename string) (http.RoundTripper, error) { func getTransportFromFile(filename string) (http.RoundTripper, error) {
data, err := ioutil.ReadFile(filename) data, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
@ -330,8 +343,10 @@ func WithProvisionerLimit(limit int) ProvisionerOption {
// Client implements an HTTP client for the CA server. // Client implements an HTTP client for the CA server.
type Client struct { type Client struct {
client *http.Client client *http.Client
endpoint *url.URL endpoint *url.URL
retryFunc RetryFunc
opts []ClientOption
} }
// NewClient creates a new Client with the given endpoint and options. // NewClient creates a new Client with the given endpoint and options.
@ -354,10 +369,31 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
client: &http.Client{ client: &http.Client{
Transport: tr, Transport: tr,
}, },
endpoint: u, endpoint: u,
retryFunc: o.retryFunc,
opts: opts,
}, nil }, nil
} }
func (c *Client) retryOnError(r *http.Response) bool {
if c.retryFunc != nil {
if c.retryFunc(r.StatusCode) {
o := new(clientOptions)
if err := o.apply(c.opts); err != nil {
return false
}
tr, err := o.getTransport(c.endpoint.String())
if err != nil {
return false
}
r.Body.Close()
c.client.Transport = tr
return true
}
}
return false
}
// SetTransport updates the transport of the internal HTTP client. // SetTransport updates the transport of the internal HTTP client.
func (c *Client) SetTransport(tr http.RoundTripper) { func (c *Client) SetTransport(tr http.RoundTripper) {
c.client.Transport = tr c.client.Transport = tr
@ -366,12 +402,18 @@ func (c *Client) SetTransport(tr http.RoundTripper) {
// Health performs the health request to the CA and returns the // Health performs the health request to the CA and returns the
// api.HealthResponse struct. // api.HealthResponse struct.
func (c *Client) Health() (*api.HealthResponse, error) { func (c *Client) Health() (*api.HealthResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var health api.HealthResponse var health api.HealthResponse
@ -386,13 +428,19 @@ func (c *Client) Health() (*api.HealthResponse, error) {
// resulting root certificate with the given SHA256, returning an error if they // resulting root certificate with the given SHA256, returning an error if they
// do not match. // do not match.
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
var retried bool
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1)) sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum}) u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
retry:
resp, err := getInsecureClient().Get(u.String()) resp, err := getInsecureClient().Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var root api.RootResponse var root api.RootResponse
@ -410,16 +458,22 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
// Sign performs the sign request to the CA and returns the api.SignResponse // Sign performs the sign request to the CA and returns the api.SignResponse
// struct. // struct.
func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var sign api.SignResponse var sign api.SignResponse
@ -435,13 +489,19 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
// Renew performs the renew request to the CA and returns the api.SignResponse // Renew performs the renew request to the CA and returns the api.SignResponse
// struct. // struct.
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
client := &http.Client{Transport: tr} client := &http.Client{Transport: tr}
retry:
resp, err := client.Post(u.String(), "application/json", http.NoBody) resp, err := client.Post(u.String(), "application/json", http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var sign api.SignResponse var sign api.SignResponse
@ -454,12 +514,13 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
// Revoke performs the revoke request to the CA and returns the api.RevokeResponse // Revoke performs the revoke request to the CA and returns the api.RevokeResponse
// struct. // struct.
func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) { func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {
var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
var client *http.Client var client *http.Client
retry:
if tr != nil { if tr != nil {
client = &http.Client{Transport: tr} client = &http.Client{Transport: tr}
} else { } else {
@ -472,6 +533,10 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var revoke api.RevokeResponse var revoke api.RevokeResponse
@ -487,6 +552,7 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo
// ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to // ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to
// paginate the provisioners. // paginate the provisioners.
func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) { func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) {
var retried bool
o := new(provisionerOptions) o := new(provisionerOptions)
if err := o.apply(opts); err != nil { if err := o.apply(opts); err != nil {
return nil, err return nil, err
@ -495,11 +561,16 @@ func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersRespo
Path: "/provisioners", Path: "/provisioners",
RawQuery: o.rawQuery(), RawQuery: o.rawQuery(),
}) })
retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var provisioners api.ProvisionersResponse var provisioners api.ProvisionersResponse
@ -513,12 +584,18 @@ func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersRespo
// the given provisioner kid and returns the api.ProvisionerKeyResponse struct // the given provisioner kid and returns the api.ProvisionerKeyResponse struct
// with the encrypted key. // with the encrypted key.
func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) { func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"})
retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var key api.ProvisionerKeyResponse var key api.ProvisionerKeyResponse
@ -531,12 +608,18 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error)
// Roots performs the get roots request to the CA and returns the // Roots performs the get roots request to the CA and returns the
// api.RootsResponse struct. // api.RootsResponse struct.
func (c *Client) Roots() (*api.RootsResponse, error) { func (c *Client) Roots() (*api.RootsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var roots api.RootsResponse var roots api.RootsResponse
@ -549,12 +632,18 @@ func (c *Client) Roots() (*api.RootsResponse, error) {
// Federation performs the get federation request to the CA and returns the // Federation performs the get federation request to the CA and returns the
// api.FederationResponse struct. // api.FederationResponse struct.
func (c *Client) Federation() (*api.FederationResponse, error) { func (c *Client) Federation() (*api.FederationResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var federation api.FederationResponse var federation api.FederationResponse
@ -567,16 +656,22 @@ func (c *Client) Federation() (*api.FederationResponse, error) {
// SSHSign performs the POST /ssh/sign request to the CA and returns the // SSHSign performs the POST /ssh/sign request to the CA and returns the
// api.SSHSignResponse struct. // api.SSHSignResponse struct.
func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) { func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) {
var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var sign api.SSHSignResponse var sign api.SSHSignResponse
@ -589,16 +684,22 @@ func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error)
// SSHRenew performs the POST /ssh/renew request to the CA and returns the // SSHRenew performs the POST /ssh/renew request to the CA and returns the
// api.SSHRenewResponse struct. // api.SSHRenewResponse struct.
func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) { func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) {
var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/renew"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/renew"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var renew api.SSHRenewResponse var renew api.SSHRenewResponse
@ -611,16 +712,22 @@ func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, erro
// SSHRekey performs the POST /ssh/rekey request to the CA and returns the // SSHRekey performs the POST /ssh/rekey request to the CA and returns the
// api.SSHRekeyResponse struct. // api.SSHRekeyResponse struct.
func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) { func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) {
var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/rekey"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/rekey"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var rekey api.SSHRekeyResponse var rekey api.SSHRekeyResponse
@ -633,16 +740,22 @@ func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, erro
// SSHRevoke performs the POST /ssh/revoke request to the CA and returns the // SSHRevoke performs the POST /ssh/revoke request to the CA and returns the
// api.SSHRevokeResponse struct. // api.SSHRevokeResponse struct.
func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) { func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) {
var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/revoke"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/revoke"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var revoke api.SSHRevokeResponse var revoke api.SSHRevokeResponse
@ -655,12 +768,18 @@ func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, e
// SSHRoots performs the GET /ssh/roots request to the CA and returns the // SSHRoots performs the GET /ssh/roots request to the CA and returns the
// api.SSHRootsResponse struct. // api.SSHRootsResponse struct.
func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) { func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"})
retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var keys api.SSHRootsResponse var keys api.SSHRootsResponse
@ -673,12 +792,18 @@ func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) {
// SSHFederation performs the get /ssh/federation request to the CA and returns // SSHFederation performs the get /ssh/federation request to the CA and returns
// the api.SSHRootsResponse struct. // the api.SSHRootsResponse struct.
func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) { func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"})
retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var keys api.SSHRootsResponse var keys api.SSHRootsResponse
@ -691,16 +816,22 @@ func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) {
// SSHConfig performs the POST /ssh/config request to the CA to get the ssh // SSHConfig performs the POST /ssh/config request to the CA to get the ssh
// configuration templates. // configuration templates.
func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) { func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) {
var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/config"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/config"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var config api.SSHConfigResponse var config api.SSHConfigResponse
@ -713,6 +844,7 @@ func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, e
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the // SSHCheckHost performs the POST /ssh/check-host request to the CA with the
// given principal. // given principal.
func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, error) { func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, error) {
var retried bool
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
Type: provisioner.SSHHostCert, Type: provisioner.SSHHostCert,
Principal: principal, Principal: principal,
@ -721,11 +853,16 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse,
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var check api.SSHCheckPrincipalResponse var check api.SSHCheckPrincipalResponse
@ -737,12 +874,18 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse,
// SSHGetHosts performs the GET /ssh/get-hosts request to the CA. // SSHGetHosts performs the GET /ssh/get-hosts request to the CA.
func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) { func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/get-hosts"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/get-hosts"})
retry:
resp, err := c.client.Get(u.String()) resp, err := c.client.Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var hosts api.SSHGetHostsResponse var hosts api.SSHGetHostsResponse
@ -754,16 +897,22 @@ func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) {
// SSHBastion performs the POST /ssh/bastion request to the CA. // SSHBastion performs the POST /ssh/bastion request to the CA.
func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) { func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) {
var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"}) u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u) return nil, errors.Wrapf(err, "client POST %s failed", u)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var bastion api.SSHBastionResponse var bastion api.SSHBastionResponse

View file

@ -1,17 +1,30 @@
package ca package ca
import ( import (
"bytes"
"crypto"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"io/ioutil"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"github.com/smallstep/cli/config" "github.com/smallstep/cli/config"
"github.com/smallstep/cli/crypto/pemutil"
) )
// IdentityType represents the different types of identity files. // IdentityType represents the different types of identity files.
type IdentityType string type IdentityType string
// Disabled represents a disabled identity type
const Disabled IdentityType = ""
// MutualTLS represents the identity using mTLS // MutualTLS represents the identity using mTLS
const MutualTLS IdentityType = "mTLS" const MutualTLS IdentityType = "mTLS"
@ -26,9 +39,73 @@ type Identity struct {
Key string `json:"key"` Key string `json:"key"`
} }
// WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
base := filepath.Join(config.StepPath(), "config")
if err := os.MkdirAll(base, 0600); err != nil {
return errors.Wrap(err, "error creating config directory")
}
base = filepath.Join(config.StepPath(), "identity")
if err := os.MkdirAll(base, 0600); err != nil {
return errors.Wrap(err, "error creating identity directory")
}
certFilename := filepath.Join(base, "identity.crt")
keyFilename := filepath.Join(base, "identity_key")
// Write certificate
buf := new(bytes.Buffer)
for _, crt := range certChain {
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: crt.Raw,
}
if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity certificate")
}
}
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}
// Write key
buf.Reset()
block, err := pemutil.Serialize(key)
if err != nil {
return err
}
if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity key")
}
if err := ioutil.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}
// Write identity.json
buf.Reset()
enc := json.NewEncoder(buf)
enc.SetIndent("", " ")
if err := enc.Encode(Identity{
Type: string(MutualTLS),
Certificate: certFilename,
Key: keyFilename,
}); err != nil {
return errors.Wrap(err, "error writing identity json")
}
if err := ioutil.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}
return nil
}
// Kind returns the type for the given identity. // Kind returns the type for the given identity.
func (i *Identity) Kind() IdentityType { func (i *Identity) Kind() IdentityType {
switch strings.ToLower(i.Type) { switch strings.ToLower(i.Type) {
case "":
return Disabled
case "mtls": case "mtls":
return MutualTLS return MutualTLS
default: default:
@ -39,6 +116,8 @@ func (i *Identity) Kind() IdentityType {
// Validate validates the identity object. // Validate validates the identity object.
func (i *Identity) Validate() error { func (i *Identity) Validate() error {
switch i.Kind() { switch i.Kind() {
case Disabled:
return nil
case MutualTLS: case MutualTLS:
if i.Certificate == "" { if i.Certificate == "" {
return errors.New("identity.crt cannot be empty") return errors.New("identity.crt cannot be empty")
@ -47,8 +126,6 @@ func (i *Identity) Validate() error {
return errors.New("identity.key cannot be empty") return errors.New("identity.key cannot be empty")
} }
return nil return nil
case "":
return errors.New("identity.type cannot be empty")
default: default:
return errors.Errorf("unsupported identity type %s", i.Type) return errors.Errorf("unsupported identity type %s", i.Type)
} }
@ -57,11 +134,23 @@ func (i *Identity) Validate() error {
// Options returns the ClientOptions used for the given identity. // Options returns the ClientOptions used for the given identity.
func (i *Identity) Options() ([]ClientOption, error) { func (i *Identity) Options() ([]ClientOption, error) {
switch i.Kind() { switch i.Kind() {
case Disabled:
return nil, nil
case MutualTLS: case MutualTLS:
crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error creating identity certificate") return nil, errors.Wrap(err, "error creating identity certificate")
} }
// Check if certificate is expired.
// Do not return any options if expired.
x509Cert, err := x509.ParseCertificate(crt.Certificate[0])
if err != nil {
return nil, errors.Wrap(err, "error creating identity certificate")
}
now := time.Now()
if now.Before(x509Cert.NotBefore) || now.After(x509Cert.NotAfter) {
return nil, nil
}
return []ClientOption{WithCertificate(crt)}, nil return []ClientOption{WithCertificate(crt)}, nil
default: default:
return nil, errors.Errorf("unsupported identity type %s", i.Type) return nil, errors.Errorf("unsupported identity type %s", i.Type)