forked from TrueCloudLab/certificates
Create a custom client that sends a custom User-Agent.
This commit is contained in:
parent
15a222d354
commit
b25cbbe6ca
2 changed files with 62 additions and 20 deletions
|
@ -156,8 +156,8 @@ func TestBootstrap(t *testing.T) {
|
|||
if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) {
|
||||
t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint)
|
||||
}
|
||||
gotTR := got.client.Transport.(*http.Transport)
|
||||
wantTR := tt.want.client.Transport.(*http.Transport)
|
||||
gotTR := got.client.GetTransport().(*http.Transport)
|
||||
wantTR := tt.want.client.GetTransport().(*http.Transport)
|
||||
if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) {
|
||||
t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs)
|
||||
}
|
||||
|
|
78
ca/client.go
78
ca/client.go
|
@ -32,6 +32,58 @@ import (
|
|||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
// UserAgent will set the User-Agent header in the client requests.
|
||||
var UserAgent = "step-http-client/1.0"
|
||||
|
||||
type uaClient struct {
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
func newClient(transport http.RoundTripper) *uaClient {
|
||||
return &uaClient{
|
||||
Client: &http.Client{
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newInsecureClient() *uaClient {
|
||||
return &uaClient{
|
||||
Client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *uaClient) GetTransport() http.RoundTripper {
|
||||
return c.Client.Transport
|
||||
}
|
||||
|
||||
func (c *uaClient) SetTransport(tr http.RoundTripper) {
|
||||
c.Client.Transport = tr
|
||||
}
|
||||
|
||||
func (c *uaClient) Get(url string) (*http.Response, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "new request GET %s failed", url)
|
||||
}
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
return c.Client.Do(req)
|
||||
}
|
||||
|
||||
func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response, error) {
|
||||
req, err := http.NewRequest("POST", url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
return c.Client.Do(req)
|
||||
}
|
||||
|
||||
// RetryFunc defines the method used to retry a request. If it returns true, the
|
||||
// request will be retried once.
|
||||
type RetryFunc func(code int) bool
|
||||
|
@ -354,7 +406,7 @@ func WithProvisionerLimit(limit int) ProvisionerOption {
|
|||
|
||||
// Client implements an HTTP client for the CA server.
|
||||
type Client struct {
|
||||
client *http.Client
|
||||
client *uaClient
|
||||
endpoint *url.URL
|
||||
retryFunc RetryFunc
|
||||
opts []ClientOption
|
||||
|
@ -377,9 +429,7 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
|
|||
}
|
||||
|
||||
return &Client{
|
||||
client: &http.Client{
|
||||
Transport: tr,
|
||||
},
|
||||
client: newClient(tr),
|
||||
endpoint: u,
|
||||
retryFunc: o.retryFunc,
|
||||
opts: opts,
|
||||
|
@ -398,7 +448,7 @@ func (c *Client) retryOnError(r *http.Response) bool {
|
|||
return false
|
||||
}
|
||||
r.Body.Close()
|
||||
c.client.Transport = tr
|
||||
c.client.SetTransport(tr)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -408,7 +458,7 @@ func (c *Client) retryOnError(r *http.Response) bool {
|
|||
// GetRootCAs returns the RootCAs certificate pool from the configured
|
||||
// transport.
|
||||
func (c *Client) GetRootCAs() *x509.CertPool {
|
||||
switch t := c.client.Transport.(type) {
|
||||
switch t := c.client.GetTransport().(type) {
|
||||
case *http.Transport:
|
||||
if t.TLSClientConfig != nil {
|
||||
return t.TLSClientConfig.RootCAs
|
||||
|
@ -426,7 +476,7 @@ func (c *Client) GetRootCAs() *x509.CertPool {
|
|||
|
||||
// SetTransport updates the transport of the internal HTTP client.
|
||||
func (c *Client) SetTransport(tr http.RoundTripper) {
|
||||
c.client.Transport = tr
|
||||
c.client.SetTransport(tr)
|
||||
}
|
||||
|
||||
// Version performs the version request to the CA and returns the
|
||||
|
@ -486,7 +536,7 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
|||
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
|
||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
|
||||
retry:
|
||||
resp, err := getInsecureClient().Get(u.String())
|
||||
resp, err := newInsecureClient().Get(u.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||
}
|
||||
|
@ -573,10 +623,10 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshaling request")
|
||||
}
|
||||
var client *http.Client
|
||||
var client *uaClient
|
||||
retry:
|
||||
if tr != nil {
|
||||
client = &http.Client{Transport: tr}
|
||||
client = newClient(tr)
|
||||
} else {
|
||||
client = c.client
|
||||
}
|
||||
|
@ -1082,14 +1132,6 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva
|
|||
return &api.CertificateRequest{CertificateRequest: cr}, key, nil
|
||||
}
|
||||
|
||||
func getInsecureClient() *http.Client {
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// getRootCAPath returns the path where the root CA is stored based on the
|
||||
// STEPPATH environment variable.
|
||||
func getRootCAPath() string {
|
||||
|
|
Loading…
Reference in a new issue