feat: change ObtainForCSR signature.

This commit is contained in:
Fernandez Ludovic 2020-09-02 02:31:53 +02:00 committed by Ludovic Fernandez
parent 30e4987f99
commit 6ad8c6c16f
4 changed files with 37 additions and 9 deletions

View file

@ -58,6 +58,15 @@ type ObtainRequest struct {
PreferredChain string PreferredChain string
} }
// ObtainForCSRRequest The request to obtain a certificate matching the CSR passed into it.
//
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
type ObtainForCSRRequest struct {
CSR *x509.CertificateRequest
Bundle bool
PreferredChain string
}
type resolver interface { type resolver interface {
Solve(authorizations []acme.Authorization) error Solve(authorizations []acme.Authorization) error
} }
@ -146,12 +155,16 @@ func (c *Certifier) Obtain(request ObtainRequest) (*Resource, error) {
// //
// This function will never return a partial certificate. // This function will never return a partial certificate.
// If one domain in the list fails, the whole certificate will fail. // If one domain in the list fails, the whole certificate will fail.
func (c *Certifier) ObtainForCSR(csr x509.CertificateRequest, bundle bool, preferredChain string) (*Resource, error) { func (c *Certifier) ObtainForCSR(request ObtainForCSRRequest) (*Resource, error) {
if request.CSR == nil {
return nil, errors.New("cannot obtain resource for CSR: CSR is missing")
}
// figure out what domains it concerns // figure out what domains it concerns
// start with the common name // start with the common name
domains := certcrypto.ExtractDomainsCSR(&csr) domains := certcrypto.ExtractDomainsCSR(request.CSR)
if bundle { if request.Bundle {
log.Infof("[%s] acme: Obtaining bundled SAN certificate given a CSR", strings.Join(domains, ", ")) log.Infof("[%s] acme: Obtaining bundled SAN certificate given a CSR", strings.Join(domains, ", "))
} else { } else {
log.Infof("[%s] acme: Obtaining SAN certificate given a CSR", strings.Join(domains, ", ")) log.Infof("[%s] acme: Obtaining SAN certificate given a CSR", strings.Join(domains, ", "))
@ -179,7 +192,7 @@ func (c *Certifier) ObtainForCSR(csr x509.CertificateRequest, bundle bool, prefe
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", ")) log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := make(obtainError) failures := make(obtainError)
cert, err := c.getForCSR(domains, order, bundle, csr.Raw, nil, preferredChain) cert, err := c.getForCSR(domains, order, request.Bundle, request.CSR.Raw, nil, request.PreferredChain)
if err != nil { if err != nil {
for _, auth := range authz { for _, auth := range authz {
failures[challenge.GetTargetedDomain(auth)] = err failures[challenge.GetTargetedDomain(auth)] = err
@ -188,7 +201,7 @@ func (c *Certifier) ObtainForCSR(csr x509.CertificateRequest, bundle bool, prefe
if cert != nil { if cert != nil {
// Add the CSR to the certificate so that it can be used for renewals. // Add the CSR to the certificate so that it can be used for renewals.
cert.CSR = certcrypto.PEMEncode(&csr) cert.CSR = certcrypto.PEMEncode(request.CSR)
} }
// Do not return an empty failures map, // Do not return an empty failures map,
@ -394,7 +407,11 @@ func (c *Certifier) Renew(certRes Resource, bundle, mustStaple bool, preferredCh
return nil, errP return nil, errP
} }
return c.ObtainForCSR(*csr, bundle, preferredChain) return c.ObtainForCSR(ObtainForCSRRequest{
CSR: csr,
Bundle: bundle,
PreferredChain: preferredChain,
})
} }
var privateKey crypto.PrivateKey var privateKey crypto.PrivateKey

View file

@ -173,7 +173,11 @@ func renewForCSR(ctx *cli.Context, client *lego.Client, certsStorage *Certificat
timeLeft := cert.NotAfter.Sub(time.Now().UTC()) timeLeft := cert.NotAfter.Sub(time.Now().UTC())
log.Infof("[%s] acme: Trying renewal with %d hours remaining", domain, int(timeLeft.Hours())) log.Infof("[%s] acme: Trying renewal with %d hours remaining", domain, int(timeLeft.Hours()))
certRes, err := client.Certificate.ObtainForCSR(*csr, bundle, ctx.String("preferred-chain")) certRes, err := client.Certificate.ObtainForCSR(certificate.ObtainForCSRRequest{
CSR: csr,
Bundle: bundle,
PreferredChain: ctx.String("preferred-chain"),
})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View file

@ -178,5 +178,9 @@ func obtainCertificate(ctx *cli.Context, client *lego.Client) (*certificate.Reso
} }
// obtain a certificate for this CSR // obtain a certificate for this CSR
return client.Certificate.ObtainForCSR(*csr, bundle, ctx.String("preferred-chain")) return client.Certificate.ObtainForCSR(certificate.ObtainForCSRRequest{
CSR: csr,
Bundle: bundle,
PreferredChain: ctx.String("preferred-chain"),
})
} }

View file

@ -307,7 +307,10 @@ func TestChallengeTLS_Client_ObtainForCSR(t *testing.T) {
csr, err := x509.ParseCertificateRequest(csrRaw) csr, err := x509.ParseCertificateRequest(csrRaw)
require.NoError(t, err) require.NoError(t, err)
resource, err := client.Certificate.ObtainForCSR(*csr, true, "") resource, err := client.Certificate.ObtainForCSR(certificate.ObtainForCSRRequest{
CSR: csr,
Bundle: true,
})
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, resource) require.NotNil(t, resource)