Create a custom client that sends a custom User-Agent.
This commit is contained in:
parent
15a222d354
commit
b25cbbe6ca
2 changed files with 62 additions and 20 deletions
|
@ -156,8 +156,8 @@ func TestBootstrap(t *testing.T) {
|
||||||
if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) {
|
if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) {
|
||||||
t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint)
|
t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint)
|
||||||
}
|
}
|
||||||
gotTR := got.client.Transport.(*http.Transport)
|
gotTR := got.client.GetTransport().(*http.Transport)
|
||||||
wantTR := tt.want.client.Transport.(*http.Transport)
|
wantTR := tt.want.client.GetTransport().(*http.Transport)
|
||||||
if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) {
|
if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) {
|
||||||
t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs)
|
t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs)
|
||||||
}
|
}
|
||||||
|
|
78
ca/client.go
78
ca/client.go
|
@ -32,6 +32,58 @@ import (
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UserAgent will set the User-Agent header in the client requests.
|
||||||
|
var UserAgent = "step-http-client/1.0"
|
||||||
|
|
||||||
|
type uaClient struct {
|
||||||
|
Client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClient(transport http.RoundTripper) *uaClient {
|
||||||
|
return &uaClient{
|
||||||
|
Client: &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInsecureClient() *uaClient {
|
||||||
|
return &uaClient{
|
||||||
|
Client: &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *uaClient) GetTransport() http.RoundTripper {
|
||||||
|
return c.Client.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *uaClient) SetTransport(tr http.RoundTripper) {
|
||||||
|
c.Client.Transport = tr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *uaClient) Get(url string) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "new request GET %s failed", url)
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
return c.Client.Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest("POST", url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", contentType)
|
||||||
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
|
return c.Client.Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
// RetryFunc defines the method used to retry a request. If it returns true, the
|
// RetryFunc defines the method used to retry a request. If it returns true, the
|
||||||
// request will be retried once.
|
// request will be retried once.
|
||||||
type RetryFunc func(code int) bool
|
type RetryFunc func(code int) bool
|
||||||
|
@ -354,7 +406,7 @@ func WithProvisionerLimit(limit int) ProvisionerOption {
|
||||||
|
|
||||||
// Client implements an HTTP client for the CA server.
|
// Client implements an HTTP client for the CA server.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
client *http.Client
|
client *uaClient
|
||||||
endpoint *url.URL
|
endpoint *url.URL
|
||||||
retryFunc RetryFunc
|
retryFunc RetryFunc
|
||||||
opts []ClientOption
|
opts []ClientOption
|
||||||
|
@ -377,9 +429,7 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
client: &http.Client{
|
client: newClient(tr),
|
||||||
Transport: tr,
|
|
||||||
},
|
|
||||||
endpoint: u,
|
endpoint: u,
|
||||||
retryFunc: o.retryFunc,
|
retryFunc: o.retryFunc,
|
||||||
opts: opts,
|
opts: opts,
|
||||||
|
@ -398,7 +448,7 @@ func (c *Client) retryOnError(r *http.Response) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
r.Body.Close()
|
r.Body.Close()
|
||||||
c.client.Transport = tr
|
c.client.SetTransport(tr)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -408,7 +458,7 @@ func (c *Client) retryOnError(r *http.Response) bool {
|
||||||
// GetRootCAs returns the RootCAs certificate pool from the configured
|
// GetRootCAs returns the RootCAs certificate pool from the configured
|
||||||
// transport.
|
// transport.
|
||||||
func (c *Client) GetRootCAs() *x509.CertPool {
|
func (c *Client) GetRootCAs() *x509.CertPool {
|
||||||
switch t := c.client.Transport.(type) {
|
switch t := c.client.GetTransport().(type) {
|
||||||
case *http.Transport:
|
case *http.Transport:
|
||||||
if t.TLSClientConfig != nil {
|
if t.TLSClientConfig != nil {
|
||||||
return t.TLSClientConfig.RootCAs
|
return t.TLSClientConfig.RootCAs
|
||||||
|
@ -426,7 +476,7 @@ func (c *Client) GetRootCAs() *x509.CertPool {
|
||||||
|
|
||||||
// SetTransport updates the transport of the internal HTTP client.
|
// SetTransport updates the transport of the internal HTTP client.
|
||||||
func (c *Client) SetTransport(tr http.RoundTripper) {
|
func (c *Client) SetTransport(tr http.RoundTripper) {
|
||||||
c.client.Transport = tr
|
c.client.SetTransport(tr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Version performs the version request to the CA and returns the
|
// Version performs the version request to the CA and returns the
|
||||||
|
@ -486,7 +536,7 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
||||||
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
|
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
|
||||||
retry:
|
retry:
|
||||||
resp, err := getInsecureClient().Get(u.String())
|
resp, err := newInsecureClient().Get(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||||
}
|
}
|
||||||
|
@ -573,10 +623,10 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling request")
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
}
|
}
|
||||||
var client *http.Client
|
var client *uaClient
|
||||||
retry:
|
retry:
|
||||||
if tr != nil {
|
if tr != nil {
|
||||||
client = &http.Client{Transport: tr}
|
client = newClient(tr)
|
||||||
} else {
|
} else {
|
||||||
client = c.client
|
client = c.client
|
||||||
}
|
}
|
||||||
|
@ -1082,14 +1132,6 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva
|
||||||
return &api.CertificateRequest{CertificateRequest: cr}, key, nil
|
return &api.CertificateRequest{CertificateRequest: cr}, key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInsecureClient() *http.Client {
|
|
||||||
return &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRootCAPath returns the path where the root CA is stored based on the
|
// getRootCAPath returns the path where the root CA is stored based on the
|
||||||
// STEPPATH environment variable.
|
// STEPPATH environment variable.
|
||||||
func getRootCAPath() string {
|
func getRootCAPath() string {
|
||||||
|
|
Loading…
Reference in a new issue