forked from TrueCloudLab/certificates
Support for retry and identity files.
This commit is contained in:
parent
d555f310dc
commit
bbaf8e106e
2 changed files with 244 additions and 6 deletions
151
ca/client.go
151
ca/client.go
|
@ -31,6 +31,10 @@ import (
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RetryFunc defines the method used to retry a request. If it returns true, the
|
||||||
|
// request will be retried once.
|
||||||
|
type RetryFunc func(code int) bool
|
||||||
|
|
||||||
// ClientOption is the type of options passed to the Client constructor.
|
// ClientOption is the type of options passed to the Client constructor.
|
||||||
type ClientOption func(o *clientOptions) error
|
type ClientOption func(o *clientOptions) error
|
||||||
|
|
||||||
|
@ -40,6 +44,7 @@ type clientOptions struct {
|
||||||
rootFilename string
|
rootFilename string
|
||||||
rootBundle []byte
|
rootBundle []byte
|
||||||
certificate tls.Certificate
|
certificate tls.Certificate
|
||||||
|
retryFunc RetryFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *clientOptions) apply(opts []ClientOption) (err error) {
|
func (o *clientOptions) apply(opts []ClientOption) (err error) {
|
||||||
|
@ -199,6 +204,14 @@ func WithCertificate(crt tls.Certificate) ClientOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithRetryFunc defines a method used to retry a request.
|
||||||
|
func WithRetryFunc(fn RetryFunc) ClientOption {
|
||||||
|
return func(o *clientOptions) error {
|
||||||
|
o.retryFunc = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getTransportFromFile(filename string) (http.RoundTripper, error) {
|
func getTransportFromFile(filename string) (http.RoundTripper, error) {
|
||||||
data, err := ioutil.ReadFile(filename)
|
data, err := ioutil.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -332,6 +345,8 @@ func WithProvisionerLimit(limit int) ProvisionerOption {
|
||||||
type Client struct {
|
type Client struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
endpoint *url.URL
|
endpoint *url.URL
|
||||||
|
retryFunc RetryFunc
|
||||||
|
opts []ClientOption
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new Client with the given endpoint and options.
|
// NewClient creates a new Client with the given endpoint and options.
|
||||||
|
@ -355,9 +370,30 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
|
||||||
Transport: tr,
|
Transport: tr,
|
||||||
},
|
},
|
||||||
endpoint: u,
|
endpoint: u,
|
||||||
|
retryFunc: o.retryFunc,
|
||||||
|
opts: opts,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) retryOnError(r *http.Response) bool {
|
||||||
|
if c.retryFunc != nil {
|
||||||
|
if c.retryFunc(r.StatusCode) {
|
||||||
|
o := new(clientOptions)
|
||||||
|
if err := o.apply(c.opts); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tr, err := o.getTransport(c.endpoint.String())
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
r.Body.Close()
|
||||||
|
c.client.Transport = tr
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// SetTransport updates the transport of the internal HTTP client.
|
// SetTransport updates the transport of the internal HTTP client.
|
||||||
func (c *Client) SetTransport(tr http.RoundTripper) {
|
func (c *Client) SetTransport(tr http.RoundTripper) {
|
||||||
c.client.Transport = tr
|
c.client.Transport = tr
|
||||||
|
@ -366,12 +402,18 @@ func (c *Client) SetTransport(tr http.RoundTripper) {
|
||||||
// Health performs the health request to the CA and returns the
|
// Health performs the health request to the CA and returns the
|
||||||
// api.HealthResponse struct.
|
// api.HealthResponse struct.
|
||||||
func (c *Client) Health() (*api.HealthResponse, error) {
|
func (c *Client) Health() (*api.HealthResponse, error) {
|
||||||
|
var retried bool
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Get(u.String())
|
resp, err := c.client.Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var health api.HealthResponse
|
var health api.HealthResponse
|
||||||
|
@ -386,13 +428,19 @@ func (c *Client) Health() (*api.HealthResponse, error) {
|
||||||
// resulting root certificate with the given SHA256, returning an error if they
|
// resulting root certificate with the given SHA256, returning an error if they
|
||||||
// do not match.
|
// do not match.
|
||||||
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
||||||
|
var retried bool
|
||||||
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
|
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
|
||||||
|
retry:
|
||||||
resp, err := getInsecureClient().Get(u.String())
|
resp, err := getInsecureClient().Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var root api.RootResponse
|
var root api.RootResponse
|
||||||
|
@ -410,16 +458,22 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
||||||
// Sign performs the sign request to the CA and returns the api.SignResponse
|
// Sign performs the sign request to the CA and returns the api.SignResponse
|
||||||
// struct.
|
// struct.
|
||||||
func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
|
func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
|
||||||
|
var retried bool
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var sign api.SignResponse
|
var sign api.SignResponse
|
||||||
|
@ -435,13 +489,19 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
|
||||||
// Renew performs the renew request to the CA and returns the api.SignResponse
|
// Renew performs the renew request to the CA and returns the api.SignResponse
|
||||||
// struct.
|
// struct.
|
||||||
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
||||||
|
var retried bool
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
|
||||||
client := &http.Client{Transport: tr}
|
client := &http.Client{Transport: tr}
|
||||||
|
retry:
|
||||||
resp, err := client.Post(u.String(), "application/json", http.NoBody)
|
resp, err := client.Post(u.String(), "application/json", http.NoBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var sign api.SignResponse
|
var sign api.SignResponse
|
||||||
|
@ -454,12 +514,13 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
|
||||||
// Revoke performs the revoke request to the CA and returns the api.RevokeResponse
|
// Revoke performs the revoke request to the CA and returns the api.RevokeResponse
|
||||||
// struct.
|
// struct.
|
||||||
func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {
|
func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {
|
||||||
|
var retried bool
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
|
|
||||||
var client *http.Client
|
var client *http.Client
|
||||||
|
retry:
|
||||||
if tr != nil {
|
if tr != nil {
|
||||||
client = &http.Client{Transport: tr}
|
client = &http.Client{Transport: tr}
|
||||||
} else {
|
} else {
|
||||||
|
@ -472,6 +533,10 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var revoke api.RevokeResponse
|
var revoke api.RevokeResponse
|
||||||
|
@ -487,6 +552,7 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo
|
||||||
// ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to
|
// ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to
|
||||||
// paginate the provisioners.
|
// paginate the provisioners.
|
||||||
func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) {
|
func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) {
|
||||||
|
var retried bool
|
||||||
o := new(provisionerOptions)
|
o := new(provisionerOptions)
|
||||||
if err := o.apply(opts); err != nil {
|
if err := o.apply(opts); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -495,11 +561,16 @@ func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersRespo
|
||||||
Path: "/provisioners",
|
Path: "/provisioners",
|
||||||
RawQuery: o.rawQuery(),
|
RawQuery: o.rawQuery(),
|
||||||
})
|
})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Get(u.String())
|
resp, err := c.client.Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var provisioners api.ProvisionersResponse
|
var provisioners api.ProvisionersResponse
|
||||||
|
@ -513,12 +584,18 @@ func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersRespo
|
||||||
// the given provisioner kid and returns the api.ProvisionerKeyResponse struct
|
// the given provisioner kid and returns the api.ProvisionerKeyResponse struct
|
||||||
// with the encrypted key.
|
// with the encrypted key.
|
||||||
func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) {
|
func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) {
|
||||||
|
var retried bool
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Get(u.String())
|
resp, err := c.client.Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var key api.ProvisionerKeyResponse
|
var key api.ProvisionerKeyResponse
|
||||||
|
@ -531,12 +608,18 @@ func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error)
|
||||||
// Roots performs the get roots request to the CA and returns the
|
// Roots performs the get roots request to the CA and returns the
|
||||||
// api.RootsResponse struct.
|
// api.RootsResponse struct.
|
||||||
func (c *Client) Roots() (*api.RootsResponse, error) {
|
func (c *Client) Roots() (*api.RootsResponse, error) {
|
||||||
|
var retried bool
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Get(u.String())
|
resp, err := c.client.Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var roots api.RootsResponse
|
var roots api.RootsResponse
|
||||||
|
@ -549,12 +632,18 @@ func (c *Client) Roots() (*api.RootsResponse, error) {
|
||||||
// Federation performs the get federation request to the CA and returns the
|
// Federation performs the get federation request to the CA and returns the
|
||||||
// api.FederationResponse struct.
|
// api.FederationResponse struct.
|
||||||
func (c *Client) Federation() (*api.FederationResponse, error) {
|
func (c *Client) Federation() (*api.FederationResponse, error) {
|
||||||
|
var retried bool
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Get(u.String())
|
resp, err := c.client.Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var federation api.FederationResponse
|
var federation api.FederationResponse
|
||||||
|
@ -567,16 +656,22 @@ func (c *Client) Federation() (*api.FederationResponse, error) {
|
||||||
// SSHSign performs the POST /ssh/sign request to the CA and returns the
|
// SSHSign performs the POST /ssh/sign request to the CA and returns the
|
||||||
// api.SSHSignResponse struct.
|
// api.SSHSignResponse struct.
|
||||||
func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) {
|
func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) {
|
||||||
|
var retried bool
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var sign api.SSHSignResponse
|
var sign api.SSHSignResponse
|
||||||
|
@ -589,16 +684,22 @@ func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error)
|
||||||
// SSHRenew performs the POST /ssh/renew request to the CA and returns the
|
// SSHRenew performs the POST /ssh/renew request to the CA and returns the
|
||||||
// api.SSHRenewResponse struct.
|
// api.SSHRenewResponse struct.
|
||||||
func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) {
|
func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) {
|
||||||
|
var retried bool
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/renew"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/renew"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var renew api.SSHRenewResponse
|
var renew api.SSHRenewResponse
|
||||||
|
@ -611,16 +712,22 @@ func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, erro
|
||||||
// SSHRekey performs the POST /ssh/rekey request to the CA and returns the
|
// SSHRekey performs the POST /ssh/rekey request to the CA and returns the
|
||||||
// api.SSHRekeyResponse struct.
|
// api.SSHRekeyResponse struct.
|
||||||
func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) {
|
func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) {
|
||||||
|
var retried bool
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/rekey"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/rekey"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var rekey api.SSHRekeyResponse
|
var rekey api.SSHRekeyResponse
|
||||||
|
@ -633,16 +740,22 @@ func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, erro
|
||||||
// SSHRevoke performs the POST /ssh/revoke request to the CA and returns the
|
// SSHRevoke performs the POST /ssh/revoke request to the CA and returns the
|
||||||
// api.SSHRevokeResponse struct.
|
// api.SSHRevokeResponse struct.
|
||||||
func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) {
|
func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) {
|
||||||
|
var retried bool
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/revoke"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/revoke"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var revoke api.SSHRevokeResponse
|
var revoke api.SSHRevokeResponse
|
||||||
|
@ -655,12 +768,18 @@ func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, e
|
||||||
// SSHRoots performs the GET /ssh/roots request to the CA and returns the
|
// SSHRoots performs the GET /ssh/roots request to the CA and returns the
|
||||||
// api.SSHRootsResponse struct.
|
// api.SSHRootsResponse struct.
|
||||||
func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) {
|
func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) {
|
||||||
|
var retried bool
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Get(u.String())
|
resp, err := c.client.Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var keys api.SSHRootsResponse
|
var keys api.SSHRootsResponse
|
||||||
|
@ -673,12 +792,18 @@ func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) {
|
||||||
// SSHFederation performs the get /ssh/federation request to the CA and returns
|
// SSHFederation performs the get /ssh/federation request to the CA and returns
|
||||||
// the api.SSHRootsResponse struct.
|
// the api.SSHRootsResponse struct.
|
||||||
func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) {
|
func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) {
|
||||||
|
var retried bool
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Get(u.String())
|
resp, err := c.client.Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var keys api.SSHRootsResponse
|
var keys api.SSHRootsResponse
|
||||||
|
@ -691,16 +816,22 @@ func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) {
|
||||||
// SSHConfig performs the POST /ssh/config request to the CA to get the ssh
|
// SSHConfig performs the POST /ssh/config request to the CA to get the ssh
|
||||||
// configuration templates.
|
// configuration templates.
|
||||||
func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) {
|
func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) {
|
||||||
|
var retried bool
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/config"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/config"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var config api.SSHConfigResponse
|
var config api.SSHConfigResponse
|
||||||
|
@ -713,6 +844,7 @@ func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, e
|
||||||
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the
|
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the
|
||||||
// given principal.
|
// given principal.
|
||||||
func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, error) {
|
func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, error) {
|
||||||
|
var retried bool
|
||||||
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
|
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
|
||||||
Type: provisioner.SSHHostCert,
|
Type: provisioner.SSHHostCert,
|
||||||
Principal: principal,
|
Principal: principal,
|
||||||
|
@ -721,11 +853,16 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse,
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var check api.SSHCheckPrincipalResponse
|
var check api.SSHCheckPrincipalResponse
|
||||||
|
@ -737,12 +874,18 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse,
|
||||||
|
|
||||||
// SSHGetHosts performs the GET /ssh/get-hosts request to the CA.
|
// SSHGetHosts performs the GET /ssh/get-hosts request to the CA.
|
||||||
func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) {
|
func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) {
|
||||||
|
var retried bool
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/get-hosts"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/get-hosts"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Get(u.String())
|
resp, err := c.client.Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var hosts api.SSHGetHostsResponse
|
var hosts api.SSHGetHostsResponse
|
||||||
|
@ -754,16 +897,22 @@ func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) {
|
||||||
|
|
||||||
// SSHBastion performs the POST /ssh/bastion request to the CA.
|
// SSHBastion performs the POST /ssh/bastion request to the CA.
|
||||||
func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) {
|
func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) {
|
||||||
|
var retried bool
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"})
|
||||||
|
retry:
|
||||||
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
return nil, errors.Wrapf(err, "client POST %s failed", u)
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var bastion api.SSHBastionResponse
|
var bastion api.SSHBastionResponse
|
||||||
|
|
|
@ -1,17 +1,30 @@
|
||||||
package ca
|
package ca
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
"github.com/smallstep/cli/config"
|
"github.com/smallstep/cli/config"
|
||||||
|
"github.com/smallstep/cli/crypto/pemutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IdentityType represents the different types of identity files.
|
// IdentityType represents the different types of identity files.
|
||||||
type IdentityType string
|
type IdentityType string
|
||||||
|
|
||||||
|
// Disabled represents a disabled identity type
|
||||||
|
const Disabled IdentityType = ""
|
||||||
|
|
||||||
// MutualTLS represents the identity using mTLS
|
// MutualTLS represents the identity using mTLS
|
||||||
const MutualTLS IdentityType = "mTLS"
|
const MutualTLS IdentityType = "mTLS"
|
||||||
|
|
||||||
|
@ -26,9 +39,73 @@ type Identity struct {
|
||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteDefaultIdentity writes the given certificates and key and the
|
||||||
|
// identity.json pointing to the new files.
|
||||||
|
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
|
||||||
|
base := filepath.Join(config.StepPath(), "config")
|
||||||
|
if err := os.MkdirAll(base, 0600); err != nil {
|
||||||
|
return errors.Wrap(err, "error creating config directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
base = filepath.Join(config.StepPath(), "identity")
|
||||||
|
if err := os.MkdirAll(base, 0600); err != nil {
|
||||||
|
return errors.Wrap(err, "error creating identity directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
certFilename := filepath.Join(base, "identity.crt")
|
||||||
|
keyFilename := filepath.Join(base, "identity_key")
|
||||||
|
|
||||||
|
// Write certificate
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
for _, crt := range certChain {
|
||||||
|
block := &pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: crt.Raw,
|
||||||
|
}
|
||||||
|
if err := pem.Encode(buf, block); err != nil {
|
||||||
|
return errors.Wrap(err, "error encoding identity certificate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
|
||||||
|
return errors.Wrap(err, "error writing identity certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write key
|
||||||
|
buf.Reset()
|
||||||
|
block, err := pemutil.Serialize(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := pem.Encode(buf, block); err != nil {
|
||||||
|
return errors.Wrap(err, "error encoding identity key")
|
||||||
|
}
|
||||||
|
if err := ioutil.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil {
|
||||||
|
return errors.Wrap(err, "error writing identity certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write identity.json
|
||||||
|
buf.Reset()
|
||||||
|
enc := json.NewEncoder(buf)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(Identity{
|
||||||
|
Type: string(MutualTLS),
|
||||||
|
Certificate: certFilename,
|
||||||
|
Key: keyFilename,
|
||||||
|
}); err != nil {
|
||||||
|
return errors.Wrap(err, "error writing identity json")
|
||||||
|
}
|
||||||
|
if err := ioutil.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil {
|
||||||
|
return errors.Wrap(err, "error writing identity certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Kind returns the type for the given identity.
|
// Kind returns the type for the given identity.
|
||||||
func (i *Identity) Kind() IdentityType {
|
func (i *Identity) Kind() IdentityType {
|
||||||
switch strings.ToLower(i.Type) {
|
switch strings.ToLower(i.Type) {
|
||||||
|
case "":
|
||||||
|
return Disabled
|
||||||
case "mtls":
|
case "mtls":
|
||||||
return MutualTLS
|
return MutualTLS
|
||||||
default:
|
default:
|
||||||
|
@ -39,6 +116,8 @@ func (i *Identity) Kind() IdentityType {
|
||||||
// Validate validates the identity object.
|
// Validate validates the identity object.
|
||||||
func (i *Identity) Validate() error {
|
func (i *Identity) Validate() error {
|
||||||
switch i.Kind() {
|
switch i.Kind() {
|
||||||
|
case Disabled:
|
||||||
|
return nil
|
||||||
case MutualTLS:
|
case MutualTLS:
|
||||||
if i.Certificate == "" {
|
if i.Certificate == "" {
|
||||||
return errors.New("identity.crt cannot be empty")
|
return errors.New("identity.crt cannot be empty")
|
||||||
|
@ -47,8 +126,6 @@ func (i *Identity) Validate() error {
|
||||||
return errors.New("identity.key cannot be empty")
|
return errors.New("identity.key cannot be empty")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case "":
|
|
||||||
return errors.New("identity.type cannot be empty")
|
|
||||||
default:
|
default:
|
||||||
return errors.Errorf("unsupported identity type %s", i.Type)
|
return errors.Errorf("unsupported identity type %s", i.Type)
|
||||||
}
|
}
|
||||||
|
@ -57,11 +134,23 @@ func (i *Identity) Validate() error {
|
||||||
// Options returns the ClientOptions used for the given identity.
|
// Options returns the ClientOptions used for the given identity.
|
||||||
func (i *Identity) Options() ([]ClientOption, error) {
|
func (i *Identity) Options() ([]ClientOption, error) {
|
||||||
switch i.Kind() {
|
switch i.Kind() {
|
||||||
|
case Disabled:
|
||||||
|
return nil, nil
|
||||||
case MutualTLS:
|
case MutualTLS:
|
||||||
crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key)
|
crt, err := tls.LoadX509KeyPair(i.Certificate, i.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error creating identity certificate")
|
return nil, errors.Wrap(err, "error creating identity certificate")
|
||||||
}
|
}
|
||||||
|
// Check if certificate is expired.
|
||||||
|
// Do not return any options if expired.
|
||||||
|
x509Cert, err := x509.ParseCertificate(crt.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error creating identity certificate")
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if now.Before(x509Cert.NotBefore) || now.After(x509Cert.NotAfter) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return []ClientOption{WithCertificate(crt)}, nil
|
return []ClientOption{WithCertificate(crt)}, nil
|
||||||
default:
|
default:
|
||||||
return nil, errors.Errorf("unsupported identity type %s", i.Type)
|
return nil, errors.Errorf("unsupported identity type %s", i.Type)
|
||||||
|
|
Loading…
Reference in a new issue