Create a custom client that sends a custom User-Agent.

This commit is contained in:
Mariano Cano 2019-11-27 17:30:06 -08:00
parent 15a222d354
commit b25cbbe6ca
2 changed files with 62 additions and 20 deletions

View file

@ -156,8 +156,8 @@ func TestBootstrap(t *testing.T) {
if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) { if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) {
t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint) t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint)
} }
gotTR := got.client.Transport.(*http.Transport) gotTR := got.client.GetTransport().(*http.Transport)
wantTR := tt.want.client.Transport.(*http.Transport) wantTR := tt.want.client.GetTransport().(*http.Transport)
if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) { if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) {
t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs)
} }

View file

@ -32,6 +32,58 @@ import (
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
) )
// UserAgent will set the User-Agent header in the client requests.
var UserAgent = "step-http-client/1.0"
type uaClient struct {
Client *http.Client
}
func newClient(transport http.RoundTripper) *uaClient {
return &uaClient{
Client: &http.Client{
Transport: transport,
},
}
}
func newInsecureClient() *uaClient {
return &uaClient{
Client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
},
}
}
func (c *uaClient) GetTransport() http.RoundTripper {
return c.Client.Transport
}
func (c *uaClient) SetTransport(tr http.RoundTripper) {
c.Client.Transport = tr
}
func (c *uaClient) Get(url string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, errors.Wrapf(err, "new request GET %s failed", url)
}
req.Header.Set("User-Agent", UserAgent)
return c.Client.Do(req)
}
func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
req.Header.Set("User-Agent", UserAgent)
return c.Client.Do(req)
}
// RetryFunc defines the method used to retry a request. If it returns true, the // RetryFunc defines the method used to retry a request. If it returns true, the
// request will be retried once. // request will be retried once.
type RetryFunc func(code int) bool type RetryFunc func(code int) bool
@ -354,7 +406,7 @@ func WithProvisionerLimit(limit int) ProvisionerOption {
// Client implements an HTTP client for the CA server. // Client implements an HTTP client for the CA server.
type Client struct { type Client struct {
client *http.Client client *uaClient
endpoint *url.URL endpoint *url.URL
retryFunc RetryFunc retryFunc RetryFunc
opts []ClientOption opts []ClientOption
@ -377,9 +429,7 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
} }
return &Client{ return &Client{
client: &http.Client{ client: newClient(tr),
Transport: tr,
},
endpoint: u, endpoint: u,
retryFunc: o.retryFunc, retryFunc: o.retryFunc,
opts: opts, opts: opts,
@ -398,7 +448,7 @@ func (c *Client) retryOnError(r *http.Response) bool {
return false return false
} }
r.Body.Close() r.Body.Close()
c.client.Transport = tr c.client.SetTransport(tr)
return true return true
} }
} }
@ -408,7 +458,7 @@ func (c *Client) retryOnError(r *http.Response) bool {
// GetRootCAs returns the RootCAs certificate pool from the configured // GetRootCAs returns the RootCAs certificate pool from the configured
// transport. // transport.
func (c *Client) GetRootCAs() *x509.CertPool { func (c *Client) GetRootCAs() *x509.CertPool {
switch t := c.client.Transport.(type) { switch t := c.client.GetTransport().(type) {
case *http.Transport: case *http.Transport:
if t.TLSClientConfig != nil { if t.TLSClientConfig != nil {
return t.TLSClientConfig.RootCAs return t.TLSClientConfig.RootCAs
@ -426,7 +476,7 @@ func (c *Client) GetRootCAs() *x509.CertPool {
// SetTransport updates the transport of the internal HTTP client. // SetTransport updates the transport of the internal HTTP client.
func (c *Client) SetTransport(tr http.RoundTripper) { func (c *Client) SetTransport(tr http.RoundTripper) {
c.client.Transport = tr c.client.SetTransport(tr)
} }
// Version performs the version request to the CA and returns the // Version performs the version request to the CA and returns the
@ -486,7 +536,7 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1)) sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum}) u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
retry: retry:
resp, err := getInsecureClient().Get(u.String()) resp, err := newInsecureClient().Get(u.String())
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u) return nil, errors.Wrapf(err, "client GET %s failed", u)
} }
@ -573,10 +623,10 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")
} }
var client *http.Client var client *uaClient
retry: retry:
if tr != nil { if tr != nil {
client = &http.Client{Transport: tr} client = newClient(tr)
} else { } else {
client = c.client client = c.client
} }
@ -1082,14 +1132,6 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva
return &api.CertificateRequest{CertificateRequest: cr}, key, nil return &api.CertificateRequest{CertificateRequest: cr}, key, nil
} }
func getInsecureClient() *http.Client {
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
}
// getRootCAPath returns the path where the root CA is stored based on the // getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // STEPPATH environment variable.
func getRootCAPath() string { func getRootCAPath() string {