forked from TrueCloudLab/certificates
Fix PR comment and add tests for ACME prerequisites checker
This commit is contained in:
parent
e47dd0a666
commit
b6f6bd879c
2 changed files with 90 additions and 2 deletions
|
@ -283,7 +283,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// lookupProvisioner loads the provisioner associated with the request.
|
||||
// Responsds 404 if the provisioner does not exist.
|
||||
// Responds 404 if the provisioner does not exist.
|
||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
@ -319,7 +319,7 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
|
|||
return
|
||||
}
|
||||
if !ok {
|
||||
api.WriteError(w, acme.NewErrorISE("acme provisioner configuration lacks prerequisites"))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
next(w, r.WithContext(ctx))
|
||||
|
|
|
@ -1656,3 +1656,91 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_checkPrerequisites(t *testing.T) {
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||
baseURL, provName)
|
||||
type test struct {
|
||||
linker Linker
|
||||
ctx context.Context
|
||||
prerequisitesChecker func(context.Context) (bool, error)
|
||||
next func(http.ResponseWriter, *http.Request)
|
||||
err *acme.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/error": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(testBody)
|
||||
},
|
||||
err: acme.WrapErrorISE(errors.New("force"), "error checking acme provisioner prerequisites"),
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"fail/prerequisites-nok": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(testBody)
|
||||
},
|
||||
err: acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"),
|
||||
statusCode: 501,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
ctx: ctx,
|
||||
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(testBody)
|
||||
},
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.checkPrerequisites(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||
var ae acme.Error
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
assert.Equals(t, bytes.TrimSpace(body), testBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue