diff --git a/acme/api/handler.go b/acme/api/handler.go index bd226e73..c3a481f9 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -43,6 +44,7 @@ type Handler struct { ca acme.CertificateAuthority linker Linker validateChallengeOptions *acme.ValidateChallengeOptions + prerequisitesChecker func(ctx context.Context) (bool, error) } // HandlerOptions required to create a new ACME API request handler. @@ -60,6 +62,9 @@ type HandlerOptions struct { // "acme" is the prefix from which the ACME api is accessed. Prefix string CA acme.CertificateAuthority + // PrerequisitesChecker checks if all prerequisites for serving ACME are + // met by the CA configuration. + PrerequisitesChecker func(ctx context.Context) (bool, error) } // NewHandler returns a new ACME API handler. @@ -76,6 +81,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { dialer := &net.Dialer{ Timeout: 30 * time.Second, } + prerequisitesChecker := func(ctx context.Context) (bool, error) { + // by default all prerequisites are met + return true, nil + } + if ops.PrerequisitesChecker != nil { + prerequisitesChecker = ops.PrerequisitesChecker + } return &Handler{ ca: ops.CA, db: ops.DB, @@ -88,6 +100,7 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { return tls.DialWithDialer(dialer, network, addr, config) }, }, + prerequisitesChecker: prerequisitesChecker, } } @@ -95,13 +108,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { func (h *Handler) Route(r api.Router) { getPath := h.linker.GetUnescapedPathSuffix // Standard ACME API - r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) - r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) + r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) + r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) + r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) + r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) validatingMiddleware := func(next nextHTTP) nextHTTP { - return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))) + return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) diff --git a/acme/api/middleware.go b/acme/api/middleware.go index de8614ee..0cdeaabb 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -283,11 +283,10 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { } // lookupProvisioner loads the provisioner associated with the request. -// Responsds 404 if the provisioner does not exist. +// Responds 404 if the provisioner does not exist. func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - nameEscaped := chi.URLParam(r, "provisionerID") name, err := url.PathUnescape(nameEscaped) if err != nil { @@ -309,6 +308,24 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { } } +// checkPrerequisites checks if all prerequisites for serving ACME +// are met by the CA configuration. +func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ok, err := h.prerequisitesChecker(ctx) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + return + } + if !ok { + api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + return + } + next(w, r.WithContext(ctx)) + } +} + // lookupJWK loads the JWK associated with the acme account referenced by the // kid parameter of the signed payload. // Make sure to parse and validate the JWS before running this middleware. diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 050b46a5..8003fa16 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -1656,3 +1656,91 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { }) } } + +func TestHandler_checkPrerequisites(t *testing.T) { + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + u := fmt.Sprintf("%s/acme/%s/account/1234", + baseURL, provName) + type test struct { + linker Linker + ctx context.Context + prerequisitesChecker func(context.Context) (bool, error) + next func(http.ResponseWriter, *http.Request) + err *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/error": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + linker: NewLinker("dns", "acme"), + ctx: ctx, + prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + err: acme.WrapErrorISE(errors.New("force"), "error checking acme provisioner prerequisites"), + statusCode: 500, + } + }, + "fail/prerequisites-nok": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + linker: NewLinker("dns", "acme"), + ctx: ctx, + prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + err: acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"), + statusCode: 501, + } + }, + "ok": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + linker: NewLinker("dns", "acme"), + ctx: ctx, + prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} + req := httptest.NewRequest("GET", u, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.checkPrerequisites(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +}