diff --git a/acme/api/handler.go b/acme/api/handler.go index bd226e73..c3a481f9 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -43,6 +44,7 @@ type Handler struct { ca acme.CertificateAuthority linker Linker validateChallengeOptions *acme.ValidateChallengeOptions + prerequisitesChecker func(ctx context.Context) (bool, error) } // HandlerOptions required to create a new ACME API request handler. @@ -60,6 +62,9 @@ type HandlerOptions struct { // "acme" is the prefix from which the ACME api is accessed. Prefix string CA acme.CertificateAuthority + // PrerequisitesChecker checks if all prerequisites for serving ACME are + // met by the CA configuration. + PrerequisitesChecker func(ctx context.Context) (bool, error) } // NewHandler returns a new ACME API handler. @@ -76,6 +81,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { dialer := &net.Dialer{ Timeout: 30 * time.Second, } + prerequisitesChecker := func(ctx context.Context) (bool, error) { + // by default all prerequisites are met + return true, nil + } + if ops.PrerequisitesChecker != nil { + prerequisitesChecker = ops.PrerequisitesChecker + } return &Handler{ ca: ops.CA, db: ops.DB, @@ -88,6 +100,7 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { return tls.DialWithDialer(dialer, network, addr, config) }, }, + prerequisitesChecker: prerequisitesChecker, } } @@ -95,13 +108,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { func (h *Handler) Route(r api.Router) { getPath := h.linker.GetUnescapedPathSuffix // Standard ACME API - r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) - r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) + r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) + r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) + r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) + r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) validatingMiddleware := func(next nextHTTP) nextHTTP { - return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))) + return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) diff --git a/acme/api/middleware.go b/acme/api/middleware.go index de8614ee..36ee1d0e 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -287,7 +287,6 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - nameEscaped := chi.URLParam(r, "provisionerID") name, err := url.PathUnescape(nameEscaped) if err != nil { @@ -309,6 +308,24 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { } } +// checkPrerequisites checks if all prerequisites for serving ACME +// are met by the CA configuration. +func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ok, err := h.prerequisitesChecker(ctx) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + return + } + if !ok { + api.WriteError(w, acme.NewErrorISE("acme provisioner configuration lacks prerequisites")) + return + } + next(w, r.WithContext(ctx)) + } +} + // lookupJWK loads the JWK associated with the acme account referenced by the // kid parameter of the signed payload. // Make sure to parse and validate the JWS before running this middleware.