Merge pull request #66 from xenolf/user-agent-string

Implement custom User-Agent string
This commit is contained in:
xenolf 2016-01-07 04:51:31 +01:00
commit 1193ae895a
8 changed files with 174 additions and 14 deletions

View file

@ -255,7 +255,7 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle bool) (Certif
// The first step of renewal is to check if we get a renewed cert // The first step of renewal is to check if we get a renewed cert
// directly from the cert URL. // directly from the cert URL.
resp, err := http.Get(cert.CertURL) resp, err := httpGet(cert.CertURL)
if err != nil { if err != nil {
return CertificateResource{}, err return CertificateResource{}, err
} }
@ -439,7 +439,6 @@ func (c *Client) requestCertificate(authz []authorizationResource, bundle bool)
switch resp.StatusCode { switch resp.StatusCode {
case 202: case 202:
case 201: case 201:
cert, err := ioutil.ReadAll(limitReader(resp.Body, 1024*1024)) cert, err := ioutil.ReadAll(limitReader(resp.Body, 1024*1024))
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
@ -492,7 +491,7 @@ func (c *Client) requestCertificate(authz []authorizationResource, bundle bool)
return CertificateResource{}, handleHTTPError(resp) return CertificateResource{}, handleHTTPError(resp)
} }
resp, err = http.Get(cerRes.CertURL) resp, err = httpGet(cerRes.CertURL)
if err != nil { if err != nil {
return CertificateResource{}, err return CertificateResource{}, err
} }
@ -507,7 +506,7 @@ func (c *Client) getIssuerCertificate(url string) ([]byte, error) {
return c.issuerCert, nil return c.issuerCert, nil
} }
resp, err := http.Get(url) resp, err := httpGet(url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -589,7 +588,7 @@ func validate(j *jws, domain, uri string, chlng challenge) error {
// getJSON performs an HTTP GET request and parses the response body // getJSON performs an HTTP GET request and parses the response body
// as JSON, into the provided respBody object. // as JSON, into the provided respBody object.
func getJSON(uri string, respBody interface{}) (http.Header, error) { func getJSON(uri string, respBody interface{}) (http.Header, error) {
resp, err := http.Get(uri) resp, err := httpGet(uri)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get %q: %v", uri, err) return nil, fmt.Errorf("failed to get %q: %v", uri, err)
} }

View file

@ -63,7 +63,7 @@ func GetOCSPForCert(bundle []byte) ([]byte, *ocsp.Response, error) {
return nil, nil, errors.New("no issuing certificate URL") return nil, nil, errors.New("no issuing certificate URL")
} }
resp, err := http.Get(certificates[0].IssuingCertificateURL[0]) resp, err := httpGet(certificates[0].IssuingCertificateURL[0])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -97,7 +97,7 @@ func GetOCSPForCert(bundle []byte) ([]byte, *ocsp.Response, error) {
} }
reader := bytes.NewReader(ocspReq) reader := bytes.NewReader(ocspReq)
req, err := http.Post(issuedCert.OCSPServer[0], "application/ocsp-request", reader) req, err := httpPost(issuedCert.OCSPServer[0], "application/ocsp-request", reader)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

72
acme/http.go Normal file
View file

@ -0,0 +1,72 @@
package acme
import (
"fmt"
"io"
"net/http"
"runtime"
"strings"
)
// UserAgent, if non-empty, will be tacked onto the User-Agent string in requests.
var UserAgent string
const (
// defaultGoUserAgent is the Go HTTP package user agent string. Too
// bad it isn't exported. If it changes, we should update it here, too.
defaultGoUserAgent = "Go-http-client/1.1"
// ourUserAgent is the User-Agent of this underlying library package.
ourUserAgent = "xenolf-acme"
)
// httpHead performs a HEAD request with a proper User-Agent string.
// The response body (resp.Body) is already closed when this function returns.
func httpHead(url string) (resp *http.Response, err error) {
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", userAgent())
client := http.Client{}
resp, err = client.Do(req)
if resp.Body != nil {
resp.Body.Close()
}
return resp, err
}
// httpPost performs a POST request with a proper User-Agent string.
// Callers should close resp.Body when done reading from it.
func httpPost(url string, bodyType string, body io.Reader) (resp *http.Response, err error) {
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", bodyType)
req.Header.Set("User-Agent", userAgent())
client := http.Client{}
return client.Do(req)
}
// httpGet performs a GET request with a proper User-Agent string.
// Callers should close resp.Body when done reading from it.
func httpGet(url string) (resp *http.Response, err error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", userAgent())
client := http.Client{}
return client.Do(req)
}
// userAgent builds and returns the User-Agent string to use in requests.
func userAgent() string {
ua := fmt.Sprintf("%s (%s; %s) %s %s", defaultGoUserAgent, runtime.GOOS, runtime.GOARCH, ourUserAgent, UserAgent)
return strings.TrimSpace(ua)
}

View file

@ -3,7 +3,6 @@ package acme
import ( import (
"crypto/rsa" "crypto/rsa"
"io/ioutil" "io/ioutil"
"net/http"
"strings" "strings"
"testing" "testing"
) )
@ -14,7 +13,7 @@ func TestHTTPChallenge(t *testing.T) {
clientChallenge := challenge{Type: "http-01", Token: "http1"} clientChallenge := challenge{Type: "http-01", Token: "http1"}
mockValidate := func(_ *jws, _, _ string, chlng challenge) error { mockValidate := func(_ *jws, _, _ string, chlng challenge) error {
uri := "http://localhost:23457/.well-known/acme-challenge/" + chlng.Token uri := "http://localhost:23457/.well-known/acme-challenge/" + chlng.Token
resp, err := http.Get(uri) resp, err := httpGet(uri)
if err != nil { if err != nil {
return err return err
} }
@ -50,7 +49,7 @@ func TestHTTPChallengeInvalidPort(t *testing.T) {
solver := &httpChallenge{jws: j, validate: stubValidate, optPort: "123456"} solver := &httpChallenge{jws: j, validate: stubValidate, optPort: "123456"}
if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil {
t.Error("Solve error: got %v, want error", err) t.Errorf("Solve error: got %v, want error", err)
} else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) { } else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) {
t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want)
} }

88
acme/http_test.go Normal file
View file

@ -0,0 +1,88 @@
package acme
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestHTTPHeadUserAgent(t *testing.T) {
var ua string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ua = r.Header.Get("User-Agent")
}))
defer ts.Close()
_, err := httpHead(ts.URL)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(ua, ourUserAgent) {
t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua)
}
}
func TestHTTPGetUserAgent(t *testing.T) {
var ua string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ua = r.Header.Get("User-Agent")
}))
defer ts.Close()
res, err := httpGet(ts.URL)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
if !strings.Contains(ua, ourUserAgent) {
t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua)
}
}
func TestHTTPPostUserAgent(t *testing.T) {
var ua string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ua = r.Header.Get("User-Agent")
}))
defer ts.Close()
res, err := httpPost(ts.URL, "text/plain", strings.NewReader("falalalala"))
if err != nil {
t.Fatal(err)
}
res.Body.Close()
if !strings.Contains(ua, ourUserAgent) {
t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua)
}
}
func TestUserAgent(t *testing.T) {
ua := userAgent()
if !strings.Contains(ua, defaultGoUserAgent) {
t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua)
}
if !strings.Contains(ua, ourUserAgent) {
t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua)
}
if strings.HasSuffix(ua, " ") {
t.Errorf("UA should not have trailing spaces; got '%s'", ua)
}
// customize the UA by appending a value
UserAgent = "MyApp/1.2.3"
ua = userAgent()
if !strings.Contains(ua, defaultGoUserAgent) {
t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua)
}
if !strings.Contains(ua, ourUserAgent) {
t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua)
}
if !strings.Contains(ua, UserAgent) {
t.Errorf("Expected custom UA to contain %s, got '%s'", UserAgent, ua)
}
}

View file

@ -35,7 +35,7 @@ func (j *jws) post(url string, content []byte) (*http.Response, error) {
return nil, err return nil, err
} }
resp, err := http.Post(url, "application/jose+json", bytes.NewBuffer([]byte(signedContent.FullSerialize()))) resp, err := httpPost(url, "application/jose+json", bytes.NewBuffer([]byte(signedContent.FullSerialize())))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,7 +71,7 @@ func (j *jws) getNonceFromResponse(resp *http.Response) error {
} }
func (j *jws) getNonce() error { func (j *jws) getNonce() error {
resp, err := http.Head(j.directoryURL) resp, err := httpHead(j.directoryURL)
if err != nil { if err != nil {
return err return err
} }

View file

@ -57,7 +57,7 @@ func TestTLSSNIChallengeInvalidPort(t *testing.T) {
solver := &tlsSNIChallenge{jws: j, validate: stubValidate, optPort: "123456"} solver := &tlsSNIChallenge{jws: j, validate: stubValidate, optPort: "123456"}
if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil {
t.Error("Solve error: got %v, want error", err) t.Errorf("Solve error: got %v, want error", err)
} else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) { } else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) {
t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want)
} }

4
cli.go
View file

@ -6,6 +6,7 @@ import (
"path" "path"
"github.com/codegangsta/cli" "github.com/codegangsta/cli"
"github.com/xenolf/lego/acme"
) )
// Logger is used to log errors; if nil, the default log.Logger is used. // Logger is used to log errors; if nil, the default log.Logger is used.
@ -20,12 +21,13 @@ func logger() *log.Logger {
} }
func main() { func main() {
app := cli.NewApp() app := cli.NewApp()
app.Name = "lego" app.Name = "lego"
app.Usage = "Let's encrypt client to go!" app.Usage = "Let's encrypt client to go!"
app.Version = "0.1.0" app.Version = "0.1.0"
acme.UserAgent = "lego/" + app.Version
cwd, err := os.Getwd() cwd, err := os.Getwd()
if err != nil { if err != nil {
logger().Fatal("Could not determine current working directory. Please pass --path.") logger().Fatal("Could not determine current working directory. Please pass --path.")