diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 36ee1d0e..0cdeaabb 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -283,7 +283,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { } // lookupProvisioner loads the provisioner associated with the request. -// Responsds 404 if the provisioner does not exist. +// Responds 404 if the provisioner does not exist. func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -319,7 +319,7 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { return } if !ok { - api.WriteError(w, acme.NewErrorISE("acme provisioner configuration lacks prerequisites")) + api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) return } next(w, r.WithContext(ctx)) diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 050b46a5..8003fa16 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -1656,3 +1656,91 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { }) } } + +func TestHandler_checkPrerequisites(t *testing.T) { + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + u := fmt.Sprintf("%s/acme/%s/account/1234", + baseURL, provName) + type test struct { + linker Linker + ctx context.Context + prerequisitesChecker func(context.Context) (bool, error) + next func(http.ResponseWriter, *http.Request) + err *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/error": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + linker: NewLinker("dns", "acme"), + ctx: ctx, + prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + err: acme.WrapErrorISE(errors.New("force"), "error checking acme provisioner prerequisites"), + statusCode: 500, + } + }, + "fail/prerequisites-nok": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + linker: NewLinker("dns", "acme"), + ctx: ctx, + prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + err: acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"), + statusCode: 501, + } + }, + "ok": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + linker: NewLinker("dns", "acme"), + ctx: ctx, + prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} + req := httptest.NewRequest("GET", u, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.checkPrerequisites(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +}