From bbaf8e106edca8ca9c58364d7b06d60a5665c607 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 20 Nov 2019 11:50:46 -0800 Subject: [PATCH] Support for retry and identity files. --- ca/client.go | 157 +++++++++++++++++++++++++++++++++++++++++++++++-- ca/identity.go | 93 ++++++++++++++++++++++++++++- 2 files changed, 244 insertions(+), 6 deletions(-) diff --git a/ca/client.go b/ca/client.go index 6c043ca7..51d21199 100644 --- a/ca/client.go +++ b/ca/client.go @@ -31,6 +31,10 @@ import ( "gopkg.in/square/go-jose.v2/jwt" ) +// RetryFunc defines the method used to retry a request. If it returns true, the +// request will be retried once. +type RetryFunc func(code int) bool + // ClientOption is the type of options passed to the Client constructor. type ClientOption func(o *clientOptions) error @@ -40,6 +44,7 @@ type clientOptions struct { rootFilename string rootBundle []byte certificate tls.Certificate + retryFunc RetryFunc } func (o *clientOptions) apply(opts []ClientOption) (err error) { @@ -199,6 +204,14 @@ func WithCertificate(crt tls.Certificate) ClientOption { } } +// WithRetryFunc defines a method used to retry a request. +func WithRetryFunc(fn RetryFunc) ClientOption { + return func(o *clientOptions) error { + o.retryFunc = fn + return nil + } +} + func getTransportFromFile(filename string) (http.RoundTripper, error) { data, err := ioutil.ReadFile(filename) if err != nil { @@ -330,8 +343,10 @@ func WithProvisionerLimit(limit int) ProvisionerOption { // Client implements an HTTP client for the CA server. type Client struct { - client *http.Client - endpoint *url.URL + client *http.Client + endpoint *url.URL + retryFunc RetryFunc + opts []ClientOption } // NewClient creates a new Client with the given endpoint and options. @@ -354,10 +369,31 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) { client: &http.Client{ Transport: tr, }, - endpoint: u, + endpoint: u, + retryFunc: o.retryFunc, + opts: opts, }, nil } +func (c *Client) retryOnError(r *http.Response) bool { + if c.retryFunc != nil { + if c.retryFunc(r.StatusCode) { + o := new(clientOptions) + if err := o.apply(c.opts); err != nil { + return false + } + tr, err := o.getTransport(c.endpoint.String()) + if err != nil { + return false + } + r.Body.Close() + c.client.Transport = tr + return true + } + } + return false +} + // SetTransport updates the transport of the internal HTTP client. func (c *Client) SetTransport(tr http.RoundTripper) { c.client.Transport = tr @@ -366,12 +402,18 @@ func (c *Client) SetTransport(tr http.RoundTripper) { // Health performs the health request to the CA and returns the // api.HealthResponse struct. func (c *Client) Health() (*api.HealthResponse, error) { + var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/health"}) +retry: resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var health api.HealthResponse @@ -386,13 +428,19 @@ func (c *Client) Health() (*api.HealthResponse, error) { // resulting root certificate with the given SHA256, returning an error if they // do not match. func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { + var retried bool sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1)) u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum}) +retry: resp, err := getInsecureClient().Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var root api.RootResponse @@ -410,16 +458,22 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { // Sign performs the sign request to the CA and returns the api.SignResponse // struct. func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { + var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"}) +retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var sign api.SignResponse @@ -435,13 +489,19 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { // Renew performs the renew request to the CA and returns the api.SignResponse // struct. func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { + var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) client := &http.Client{Transport: tr} +retry: resp, err := client.Post(u.String(), "application/json", http.NoBody) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var sign api.SignResponse @@ -454,12 +514,13 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { // Revoke performs the revoke request to the CA and returns the api.RevokeResponse // struct. func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) { + var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } - var client *http.Client +retry: if tr != nil { client = &http.Client{Transport: tr} } else { @@ -472,6 +533,10 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var revoke api.RevokeResponse @@ -487,6 +552,7 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo // ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to // paginate the provisioners. func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) { + var retried bool o := new(provisionerOptions) if err := o.apply(opts); err != nil { return nil, err @@ -495,11 +561,16 @@ func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersRespo Path: "/provisioners", RawQuery: o.rawQuery(), }) +retry: resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var provisioners api.ProvisionersResponse @@ -513,12 +584,18 @@ func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersRespo // the given provisioner kid and returns the api.ProvisionerKeyResponse struct // with the encrypted key. func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) { + var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"}) +retry: resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var key api.ProvisionerKeyResponse @@ -531,12 +608,18 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) // Roots performs the get roots request to the CA and returns the // api.RootsResponse struct. func (c *Client) Roots() (*api.RootsResponse, error) { + var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"}) +retry: resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var roots api.RootsResponse @@ -549,12 +632,18 @@ func (c *Client) Roots() (*api.RootsResponse, error) { // Federation performs the get federation request to the CA and returns the // api.FederationResponse struct. func (c *Client) Federation() (*api.FederationResponse, error) { + var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"}) +retry: resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var federation api.FederationResponse @@ -567,16 +656,22 @@ func (c *Client) Federation() (*api.FederationResponse, error) { // SSHSign performs the POST /ssh/sign request to the CA and returns the // api.SSHSignResponse struct. func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) { + var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"}) +retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var sign api.SSHSignResponse @@ -589,16 +684,22 @@ func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) // SSHRenew performs the POST /ssh/renew request to the CA and returns the // api.SSHRenewResponse struct. func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) { + var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/renew"}) +retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var renew api.SSHRenewResponse @@ -611,16 +712,22 @@ func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, erro // SSHRekey performs the POST /ssh/rekey request to the CA and returns the // api.SSHRekeyResponse struct. func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) { + var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/rekey"}) +retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var rekey api.SSHRekeyResponse @@ -633,16 +740,22 @@ func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, erro // SSHRevoke performs the POST /ssh/revoke request to the CA and returns the // api.SSHRevokeResponse struct. func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) { + var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/revoke"}) +retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var revoke api.SSHRevokeResponse @@ -655,12 +768,18 @@ func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, e // SSHRoots performs the GET /ssh/roots request to the CA and returns the // api.SSHRootsResponse struct. func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) { + var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"}) +retry: resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var keys api.SSHRootsResponse @@ -673,12 +792,18 @@ func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) { // SSHFederation performs the get /ssh/federation request to the CA and returns // the api.SSHRootsResponse struct. func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) { + var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"}) +retry: resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var keys api.SSHRootsResponse @@ -691,16 +816,22 @@ func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) { // SSHConfig performs the POST /ssh/config request to the CA to get the ssh // configuration templates. func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) { + var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/config"}) +retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var config api.SSHConfigResponse @@ -713,6 +844,7 @@ func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, e // SSHCheckHost performs the POST /ssh/check-host request to the CA with the // given principal. func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, error) { + var retried bool body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ Type: provisioner.SSHHostCert, Principal: principal, @@ -721,11 +853,16 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"}) +retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var check api.SSHCheckPrincipalResponse @@ -737,12 +874,18 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, // SSHGetHosts performs the GET /ssh/get-hosts request to the CA. func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) { + var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/get-hosts"}) +retry: resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var hosts api.SSHGetHostsResponse @@ -754,16 +897,22 @@ func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) { // SSHBastion performs the POST /ssh/bastion request to the CA. func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) { + var retried bool body, err := json.Marshal(req) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"}) +retry: resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) } if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } return nil, readError(resp.Body) } var bastion api.SSHBastionResponse diff --git a/ca/identity.go b/ca/identity.go index 15f8358c..5576dde3 100644 --- a/ca/identity.go +++ b/ca/identity.go @@ -1,17 +1,30 @@ package ca import ( + "bytes" + "crypto" "crypto/tls" + "crypto/x509" + "encoding/json" + "encoding/pem" + "io/ioutil" + "os" "path/filepath" "strings" + "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/api" "github.com/smallstep/cli/config" + "github.com/smallstep/cli/crypto/pemutil" ) // IdentityType represents the different types of identity files. type IdentityType string +// Disabled represents a disabled identity type +const Disabled IdentityType = "" + // MutualTLS represents the identity using mTLS const MutualTLS IdentityType = "mTLS" @@ -26,9 +39,73 @@ type Identity struct { Key string `json:"key"` } +// WriteDefaultIdentity writes the given certificates and key and the +// identity.json pointing to the new files. +func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { + base := filepath.Join(config.StepPath(), "config") + if err := os.MkdirAll(base, 0600); err != nil { + return errors.Wrap(err, "error creating config directory") + } + + base = filepath.Join(config.StepPath(), "identity") + if err := os.MkdirAll(base, 0600); err != nil { + return errors.Wrap(err, "error creating identity directory") + } + + certFilename := filepath.Join(base, "identity.crt") + keyFilename := filepath.Join(base, "identity_key") + + // Write certificate + buf := new(bytes.Buffer) + for _, crt := range certChain { + block := &pem.Block{ + Type: "CERTIFICATE", + Bytes: crt.Raw, + } + if err := pem.Encode(buf, block); err != nil { + return errors.Wrap(err, "error encoding identity certificate") + } + } + if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { + return errors.Wrap(err, "error writing identity certificate") + } + + // Write key + buf.Reset() + block, err := pemutil.Serialize(key) + if err != nil { + return err + } + if err := pem.Encode(buf, block); err != nil { + return errors.Wrap(err, "error encoding identity key") + } + if err := ioutil.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil { + return errors.Wrap(err, "error writing identity certificate") + } + + // Write identity.json + buf.Reset() + enc := json.NewEncoder(buf) + enc.SetIndent("", " ") + if err := enc.Encode(Identity{ + Type: string(MutualTLS), + Certificate: certFilename, + Key: keyFilename, + }); err != nil { + return errors.Wrap(err, "error writing identity json") + } + if err := ioutil.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil { + return errors.Wrap(err, "error writing identity certificate") + } + + return nil +} + // Kind returns the type for the given identity. func (i *Identity) Kind() IdentityType { switch strings.ToLower(i.Type) { + case "": + return Disabled case "mtls": return MutualTLS default: @@ -39,6 +116,8 @@ func (i *Identity) Kind() IdentityType { // Validate validates the identity object. func (i *Identity) Validate() error { switch i.Kind() { + case Disabled: + return nil case MutualTLS: if i.Certificate == "" { return errors.New("identity.crt cannot be empty") @@ -47,8 +126,6 @@ func (i *Identity) Validate() error { return errors.New("identity.key cannot be empty") } return nil - case "": - return errors.New("identity.type cannot be empty") default: return errors.Errorf("unsupported identity type %s", i.Type) } @@ -57,11 +134,23 @@ func (i *Identity) Validate() error { // Options returns the ClientOptions used for the given identity. func (i *Identity) Options() ([]ClientOption, error) { switch i.Kind() { + case Disabled: + return nil, nil case MutualTLS: crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key) if err != nil { return nil, errors.Wrap(err, "error creating identity certificate") } + // Check if certificate is expired. + // Do not return any options if expired. + x509Cert, err := x509.ParseCertificate(crt.Certificate[0]) + if err != nil { + return nil, errors.Wrap(err, "error creating identity certificate") + } + now := time.Now() + if now.Before(x509Cert.NotBefore) || now.After(x509Cert.NotAfter) { + return nil, nil + } return []ClientOption{WithCertificate(crt)}, nil default: return nil, errors.Errorf("unsupported identity type %s", i.Type)