Add TOSError and change ObtainCertificates to return errors by domain.

This commit is contained in:
xenolf 2015-11-02 01:01:00 +01:00
parent ee58d205a5
commit a2867a0c18
3 changed files with 80 additions and 41 deletions

View file

@ -183,17 +183,30 @@ func (c *Client) AgreeToTOS() error {
// PEM encoded byte slices. // PEM encoded byte slices.
// If bundle is true, the []byte contains both the issuer certificate and // If bundle is true, the []byte contains both the issuer certificate and
// your issued certificate as a bundle. // your issued certificate as a bundle.
func (c *Client) ObtainCertificates(domains []string, bundle bool) ([]CertificateResource, error) { func (c *Client) ObtainCertificates(domains []string, bundle bool) ([]CertificateResource, map[string]error) {
logger().Print("Obtaining certificates...") logger().Print("Obtaining certificates...")
challenges := c.getChallenges(domains) challenges, failures := c.getChallenges(domains)
if len(challenges) == 0 {
return nil, failures
}
err := c.solveChallenges(challenges) err := c.solveChallenges(challenges)
if err != nil { for k, v := range err {
return nil, err failures[k] = v
}
if len(failures) == len(domains) {
return nil, failures
} }
logger().Print("Validations succeeded. Getting certificates") logger().Print("Validations succeeded. Getting certificates")
return c.requestCertificates(challenges, bundle) certs, err := c.requestCertificates(challenges, bundle)
for k, v := range err {
failures[k] = v
}
return certs, failures
} }
// RevokeCertificate takes a PEM encoded certificate or bundle and tries to revoke it at the CA. // RevokeCertificate takes a PEM encoded certificate or bundle and tries to revoke it at the CA.
@ -299,9 +312,9 @@ func (c *Client) RenewCertificate(cert CertificateResource, revokeOld bool, bund
return cert, nil return cert, nil
} }
newCerts, err := c.ObtainCertificates([]string{cert.Domain}, bundle) newCerts, failures := c.ObtainCertificates([]string{cert.Domain}, bundle)
if err != nil { if len(failures) > 0 {
return CertificateResource{}, err return CertificateResource{}, failures[cert.Domain]
} }
if revokeOld { if revokeOld {
@ -313,8 +326,9 @@ func (c *Client) RenewCertificate(cert CertificateResource, revokeOld bool, bund
// Looks through the challenge combinations to find a solvable match. // Looks through the challenge combinations to find a solvable match.
// Then solves the challenges in series and returns. // Then solves the challenges in series and returns.
func (c *Client) solveChallenges(challenges []*authorizationResource) error { func (c *Client) solveChallenges(challenges []*authorizationResource) map[string]error {
// loop through the resources, basically through the domains. // loop through the resources, basically through the domains.
failures := make(map[string]error)
for _, authz := range challenges { for _, authz := range challenges {
// no solvers - no solving // no solvers - no solving
if solvers := c.chooseSolvers(authz.Body, authz.Domain); solvers != nil { if solvers := c.chooseSolvers(authz.Body, authz.Domain); solvers != nil {
@ -322,15 +336,15 @@ func (c *Client) solveChallenges(challenges []*authorizationResource) error {
// TODO: do not immediately fail if one domain fails to validate. // TODO: do not immediately fail if one domain fails to validate.
err := solver.Solve(authz.Body.Challenges[i], authz.Domain) err := solver.Solve(authz.Body.Challenges[i], authz.Domain)
if err != nil { if err != nil {
return err failures[authz.Domain] = err
} }
} }
} else { } else {
return fmt.Errorf("Could not determine solvers for %s", authz.Domain) failures[authz.Domain] = fmt.Errorf("Could not determine solvers for %s", authz.Domain)
} }
} }
return nil return failures
} }
// Checks all combinations from the server and returns an array of // Checks all combinations from the server and returns an array of
@ -355,25 +369,25 @@ func (c *Client) chooseSolvers(auth authorization, domain string) map[int]solver
} }
// Get the challenges needed to proof our identifier to the ACME server. // Get the challenges needed to proof our identifier to the ACME server.
func (c *Client) getChallenges(domains []string) []*authorizationResource { func (c *Client) getChallenges(domains []string) ([]*authorizationResource, map[string]error) {
resc, errc := make(chan *authorizationResource), make(chan error) resc, errc := make(chan *authorizationResource), make(chan domainError)
for _, domain := range domains { for _, domain := range domains {
go func(domain string) { go func(domain string) {
jsonBytes, err := json.Marshal(authorization{Resource: "new-authz", Identifier: identifier{Type: "dns", Value: domain}}) jsonBytes, err := json.Marshal(authorization{Resource: "new-authz", Identifier: identifier{Type: "dns", Value: domain}})
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: domain, Error: err}
return return
} }
resp, err := c.jws.post(c.user.GetRegistration().NewAuthzURL, jsonBytes) resp, err := c.jws.post(c.user.GetRegistration().NewAuthzURL, jsonBytes)
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: domain, Error: err}
return return
} }
if resp.StatusCode != http.StatusCreated { if resp.StatusCode != http.StatusCreated {
errc <- handleHTTPError(resp) errc <- domainError{Domain: domain, Error: handleHTTPError(resp)}
} }
links := parseLinks(resp.Header["Link"]) links := parseLinks(resp.Header["Link"])
@ -386,7 +400,7 @@ func (c *Client) getChallenges(domains []string) []*authorizationResource {
decoder := json.NewDecoder(resp.Body) decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(&authz) err = decoder.Decode(&authz)
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: domain, Error: err}
} }
resp.Body.Close() resp.Body.Close()
@ -395,37 +409,39 @@ func (c *Client) getChallenges(domains []string) []*authorizationResource {
} }
var responses []*authorizationResource var responses []*authorizationResource
failures := make(map[string]error)
for i := 0; i < len(domains); i++ { for i := 0; i < len(domains); i++ {
select { select {
case res := <-resc: case res := <-resc:
responses = append(responses, res) responses = append(responses, res)
case err := <-errc: case err := <-errc:
logger().Printf("%v", err) failures[err.Domain] = err.Error
} }
} }
close(resc) close(resc)
close(errc) close(errc)
return responses return responses, failures
} }
// requestCertificates iterates all granted authorizations, creates RSA private keys and CSRs. // requestCertificates iterates all granted authorizations, creates RSA private keys and CSRs.
// It then uses these to request a certificate from the CA and returns the list of successfully // It then uses these to request a certificate from the CA and returns the list of successfully
// granted certificates. // granted certificates.
func (c *Client) requestCertificates(challenges []*authorizationResource, bundle bool) ([]CertificateResource, error) { func (c *Client) requestCertificates(challenges []*authorizationResource, bundle bool) ([]CertificateResource, map[string]error) {
resc, errc := make(chan CertificateResource), make(chan error) resc, errc := make(chan CertificateResource), make(chan domainError)
for _, authz := range challenges { for _, authz := range challenges {
go c.requestCertificate(authz, resc, errc, bundle) go c.requestCertificate(authz, resc, errc, bundle)
} }
var certs []CertificateResource var certs []CertificateResource
failures := make(map[string]error)
for i := 0; i < len(challenges); i++ { for i := 0; i < len(challenges); i++ {
select { select {
case res := <-resc: case res := <-resc:
certs = append(certs, res) certs = append(certs, res)
case err := <-errc: case err := <-errc:
logger().Printf("%v", err) failures[err.Domain] = err.Error
} }
} }
@ -435,30 +451,30 @@ func (c *Client) requestCertificates(challenges []*authorizationResource, bundle
return certs, nil return certs, nil
} }
func (c *Client) requestCertificate(authz *authorizationResource, result chan CertificateResource, errc chan error, bundle bool) { func (c *Client) requestCertificate(authz *authorizationResource, result chan CertificateResource, errc chan domainError, bundle bool) {
privKey, err := generatePrivateKey(rsakey, c.keyBits) privKey, err := generatePrivateKey(rsakey, c.keyBits)
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: authz.Domain, Error: err}
return return
} }
// TODO: should the CSR be customizable? // TODO: should the CSR be customizable?
csr, err := generateCsr(privKey.(*rsa.PrivateKey), authz.Domain) csr, err := generateCsr(privKey.(*rsa.PrivateKey), authz.Domain)
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: authz.Domain, Error: err}
return return
} }
csrString := base64.URLEncoding.EncodeToString(csr) csrString := base64.URLEncoding.EncodeToString(csr)
jsonBytes, err := json.Marshal(csrMessage{Resource: "new-cert", Csr: csrString, Authorizations: []string{authz.AuthURL}}) jsonBytes, err := json.Marshal(csrMessage{Resource: "new-cert", Csr: csrString, Authorizations: []string{authz.AuthURL}})
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: authz.Domain, Error: err}
return return
} }
resp, err := c.jws.post(authz.NewCertURL, jsonBytes) resp, err := c.jws.post(authz.NewCertURL, jsonBytes)
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: authz.Domain, Error: err}
return return
} }
@ -477,7 +493,7 @@ func (c *Client) requestCertificate(authz *authorizationResource, result chan Ce
cert, err := ioutil.ReadAll(resp.Body) cert, err := ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: authz.Domain, Error: err}
return return
} }
@ -517,7 +533,7 @@ func (c *Client) requestCertificate(authz *authorizationResource, result chan Ce
ra := resp.Header.Get("Retry-After") ra := resp.Header.Get("Retry-After")
retryAfter, err := strconv.Atoi(ra) retryAfter, err := strconv.Atoi(ra)
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: authz.Domain, Error: err}
return return
} }
@ -526,13 +542,13 @@ func (c *Client) requestCertificate(authz *authorizationResource, result chan Ce
break break
default: default:
errc <- handleHTTPError(resp) errc <- domainError{Domain: authz.Domain, Error: handleHTTPError(resp)}
return return
} }
resp, err = http.Get(cerRes.CertURL) resp, err = http.Get(cerRes.CertURL)
if err != nil { if err != nil {
errc <- err errc <- domainError{Domain: authz.Domain, Error: err}
return return
} }
} }

View file

@ -6,19 +6,30 @@ import (
"net/http" "net/http"
) )
// Error is the base type for all errors specific to the ACME protocol. const (
type Error struct { tosAgreementError = "Must agree to subscriber agreement before any further actions"
)
// RemoteError is the base type for all errors specific to the ACME protocol.
type RemoteError struct {
StatusCode int `json:"status,omitempty"` StatusCode int `json:"status,omitempty"`
Type string `json:"type"` Type string `json:"type"`
Detail string `json:"detail"` Detail string `json:"detail"`
} }
func (e Error) Error() string { func (e RemoteError) Error() string {
return fmt.Sprintf("[%d] Type: %s Detail: %s", e.StatusCode, e.Type, e.Detail) return fmt.Sprintf("[%d] Type: %s Detail: %s", e.StatusCode, e.Type, e.Detail)
} }
// TOSError represents the error which is returned if the user needs to
// accept the TOS.
// TODO: include the new TOS url if we can somehow obtain it.
type TOSError struct {
RemoteError
}
func handleHTTPError(resp *http.Response) error { func handleHTTPError(resp *http.Response) error {
var errorDetail Error var errorDetail RemoteError
decoder := json.NewDecoder(resp.Body) decoder := json.NewDecoder(resp.Body)
err := decoder.Decode(&errorDetail) err := decoder.Decode(&errorDetail)
if err != nil { if err != nil {
@ -26,5 +37,16 @@ func handleHTTPError(resp *http.Response) error {
} }
errorDetail.StatusCode = resp.StatusCode errorDetail.StatusCode = resp.StatusCode
// Check for errors we handle specifically
if errorDetail.StatusCode == http.StatusForbidden && errorDetail.Detail == tosAgreementError {
return TOSError{errorDetail}
}
return errorDetail return errorDetail
} }
type domainError struct {
Domain string
Error error
}

View file

@ -70,7 +70,6 @@ func saveCertRes(certRes acme.CertificateResource, conf *Configuration) {
} }
func run(c *cli.Context) { func run(c *cli.Context) {
conf, acc, client := setup(c) conf, acc, client := setup(c)
if acc.Registration == nil { if acc.Registration == nil {
reg, err := client.Register() reg, err := client.Register()
@ -126,12 +125,14 @@ func run(c *cli.Context) {
logger().Fatal("Please specify --domains") logger().Fatal("Please specify --domains")
} }
certs, err := client.ObtainCertificates(c.GlobalStringSlice("domains"), true) certs, failures := client.ObtainCertificates(c.GlobalStringSlice("domains"), true)
if err != nil { if len(failures) > 0 {
logger().Fatalf("Could not obtain certificates\n\t%v", err) for k, v := range failures {
logger().Fatalf("[%s] Could not obtain certificates\n\t%v", k, v)
}
} }
err = checkFolder(conf.CertPath()) err := checkFolder(conf.CertPath())
if err != nil { if err != nil {
logger().Fatalf("Cound not check/create path: %v", err) logger().Fatalf("Cound not check/create path: %v", err)
} }