diff --git a/acme/client.go b/acme/client.go index a9438cbd..f5fa8cef 100644 --- a/acme/client.go +++ b/acme/client.go @@ -255,7 +255,7 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle bool) (Certif // The first step of renewal is to check if we get a renewed cert // directly from the cert URL. - resp, err := http.Get(cert.CertURL) + resp, err := httpGet(cert.CertURL) if err != nil { return CertificateResource{}, err } @@ -439,7 +439,6 @@ func (c *Client) requestCertificate(authz []authorizationResource, bundle bool) switch resp.StatusCode { case 202: case 201: - cert, err := ioutil.ReadAll(limitReader(resp.Body, 1024*1024)) resp.Body.Close() if err != nil { @@ -492,7 +491,7 @@ func (c *Client) requestCertificate(authz []authorizationResource, bundle bool) return CertificateResource{}, handleHTTPError(resp) } - resp, err = http.Get(cerRes.CertURL) + resp, err = httpGet(cerRes.CertURL) if err != nil { return CertificateResource{}, err } @@ -507,7 +506,7 @@ func (c *Client) getIssuerCertificate(url string) ([]byte, error) { return c.issuerCert, nil } - resp, err := http.Get(url) + resp, err := httpGet(url) if err != nil { return nil, err } @@ -589,7 +588,7 @@ func validate(j *jws, domain, uri string, chlng challenge) error { // getJSON performs an HTTP GET request and parses the response body // as JSON, into the provided respBody object. func getJSON(uri string, respBody interface{}) (http.Header, error) { - resp, err := http.Get(uri) + resp, err := httpGet(uri) if err != nil { return nil, fmt.Errorf("failed to get %q: %v", uri, err) } diff --git a/acme/crypto.go b/acme/crypto.go index b0b568da..6e9c721a 100644 --- a/acme/crypto.go +++ b/acme/crypto.go @@ -63,7 +63,7 @@ func GetOCSPForCert(bundle []byte) ([]byte, *ocsp.Response, error) { return nil, nil, errors.New("no issuing certificate URL") } - resp, err := http.Get(certificates[0].IssuingCertificateURL[0]) + resp, err := httpGet(certificates[0].IssuingCertificateURL[0]) if err != nil { return nil, nil, err } @@ -97,7 +97,7 @@ func GetOCSPForCert(bundle []byte) ([]byte, *ocsp.Response, error) { } reader := bytes.NewReader(ocspReq) - req, err := http.Post(issuedCert.OCSPServer[0], "application/ocsp-request", reader) + req, err := httpPost(issuedCert.OCSPServer[0], "application/ocsp-request", reader) if err != nil { return nil, nil, err } diff --git a/acme/http.go b/acme/http.go new file mode 100644 index 00000000..8907f892 --- /dev/null +++ b/acme/http.go @@ -0,0 +1,72 @@ +package acme + +import ( + "fmt" + "io" + "net/http" + "runtime" + "strings" +) + +// UserAgent, if non-empty, will be tacked onto the User-Agent string in requests. +var UserAgent string + +const ( + // defaultGoUserAgent is the Go HTTP package user agent string. Too + // bad it isn't exported. If it changes, we should update it here, too. + defaultGoUserAgent = "Go-http-client/1.1" + + // ourUserAgent is the User-Agent of this underlying library package. + ourUserAgent = "xenolf-acme" +) + +// httpHead performs a HEAD request with a proper User-Agent string. +// The response body (resp.Body) is already closed when this function returns. +func httpHead(url string) (resp *http.Response, err error) { + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("User-Agent", userAgent()) + + client := http.Client{} + resp, err = client.Do(req) + if resp.Body != nil { + resp.Body.Close() + } + return resp, err +} + +// httpPost performs a POST request with a proper User-Agent string. +// Callers should close resp.Body when done reading from it. +func httpPost(url string, bodyType string, body io.Reader) (resp *http.Response, err error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + req.Header.Set("User-Agent", userAgent()) + + client := http.Client{} + return client.Do(req) +} + +// httpGet performs a GET request with a proper User-Agent string. +// Callers should close resp.Body when done reading from it. +func httpGet(url string) (resp *http.Response, err error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", userAgent()) + + client := http.Client{} + return client.Do(req) +} + +// userAgent builds and returns the User-Agent string to use in requests. +func userAgent() string { + ua := fmt.Sprintf("%s (%s; %s) %s %s", defaultGoUserAgent, runtime.GOOS, runtime.GOARCH, ourUserAgent, UserAgent) + return strings.TrimSpace(ua) +} diff --git a/acme/http_challenge_test.go b/acme/http_challenge_test.go index 97a9d979..cb1aec34 100644 --- a/acme/http_challenge_test.go +++ b/acme/http_challenge_test.go @@ -3,7 +3,6 @@ package acme import ( "crypto/rsa" "io/ioutil" - "net/http" "strings" "testing" ) @@ -14,7 +13,7 @@ func TestHTTPChallenge(t *testing.T) { clientChallenge := challenge{Type: "http-01", Token: "http1"} mockValidate := func(_ *jws, _, _ string, chlng challenge) error { uri := "http://localhost:23457/.well-known/acme-challenge/" + chlng.Token - resp, err := http.Get(uri) + resp, err := httpGet(uri) if err != nil { return err } @@ -50,7 +49,7 @@ func TestHTTPChallengeInvalidPort(t *testing.T) { solver := &httpChallenge{jws: j, validate: stubValidate, optPort: "123456"} if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { - t.Error("Solve error: got %v, want error", err) + t.Errorf("Solve error: got %v, want error", err) } else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) { t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) } diff --git a/acme/http_test.go b/acme/http_test.go new file mode 100644 index 00000000..3e04e950 --- /dev/null +++ b/acme/http_test.go @@ -0,0 +1,88 @@ +package acme + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestHTTPHeadUserAgent(t *testing.T) { + var ua string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ua = r.Header.Get("User-Agent") + })) + defer ts.Close() + + _, err := httpHead(ts.URL) + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua) + } +} + +func TestHTTPGetUserAgent(t *testing.T) { + var ua string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ua = r.Header.Get("User-Agent") + })) + defer ts.Close() + + res, err := httpGet(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua) + } +} + +func TestHTTPPostUserAgent(t *testing.T) { + var ua string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ua = r.Header.Get("User-Agent") + })) + defer ts.Close() + + res, err := httpPost(ts.URL, "text/plain", strings.NewReader("falalalala")) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua) + } +} + +func TestUserAgent(t *testing.T) { + ua := userAgent() + + if !strings.Contains(ua, defaultGoUserAgent) { + t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua) + } + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua) + } + if strings.HasSuffix(ua, " ") { + t.Errorf("UA should not have trailing spaces; got '%s'", ua) + } + + // customize the UA by appending a value + UserAgent = "MyApp/1.2.3" + ua = userAgent() + if !strings.Contains(ua, defaultGoUserAgent) { + t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua) + } + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua) + } + if !strings.Contains(ua, UserAgent) { + t.Errorf("Expected custom UA to contain %s, got '%s'", UserAgent, ua) + } +} diff --git a/acme/jws.go b/acme/jws.go index ede8eff4..b676fe39 100644 --- a/acme/jws.go +++ b/acme/jws.go @@ -35,7 +35,7 @@ func (j *jws) post(url string, content []byte) (*http.Response, error) { return nil, err } - resp, err := http.Post(url, "application/jose+json", bytes.NewBuffer([]byte(signedContent.FullSerialize()))) + resp, err := httpPost(url, "application/jose+json", bytes.NewBuffer([]byte(signedContent.FullSerialize()))) if err != nil { return nil, err } @@ -71,7 +71,7 @@ func (j *jws) getNonceFromResponse(resp *http.Response) error { } func (j *jws) getNonce() error { - resp, err := http.Head(j.directoryURL) + resp, err := httpHead(j.directoryURL) if err != nil { return err } diff --git a/acme/tls_sni_challenge_test.go b/acme/tls_sni_challenge_test.go index 8f3ccbe1..562cb073 100644 --- a/acme/tls_sni_challenge_test.go +++ b/acme/tls_sni_challenge_test.go @@ -57,7 +57,7 @@ func TestTLSSNIChallengeInvalidPort(t *testing.T) { solver := &tlsSNIChallenge{jws: j, validate: stubValidate, optPort: "123456"} if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { - t.Error("Solve error: got %v, want error", err) + t.Errorf("Solve error: got %v, want error", err) } else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) { t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) } diff --git a/cli.go b/cli.go index da902ee9..7a8026d6 100644 --- a/cli.go +++ b/cli.go @@ -6,6 +6,7 @@ import ( "path" "github.com/codegangsta/cli" + "github.com/xenolf/lego/acme" ) // Logger is used to log errors; if nil, the default log.Logger is used. @@ -20,12 +21,13 @@ func logger() *log.Logger { } func main() { - app := cli.NewApp() app.Name = "lego" app.Usage = "Let's encrypt client to go!" app.Version = "0.1.0" + acme.UserAgent = "lego/" + app.Version + cwd, err := os.Getwd() if err != nil { logger().Fatal("Could not determine current working directory. Please pass --path.")