diff --git a/acme/api/middleware.go b/acme/api/middleware.go index c33aaeda..50f7146f 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -87,14 +87,14 @@ func (h *Handler) addDirLink(next nextHTTP) nextHTTP { // application/jose+json. func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - var ( - expected []string - provName string - ) - if p, err := provisionerFromContext(r.Context()); err == nil && p != nil { - provName = p.GetName() + var expected []string + p, err := provisionerFromContext(r.Context()) + if err != nil { + api.WriteError(w, err) + return } - u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, provName, "")} + + u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 4c316910..40090e83 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -240,6 +240,18 @@ func TestHandler_verifyContentType(t *testing.T) { url string } var tests = map[string]func(t *testing.T) test{ + "fail/provisioner-not-set": func(t *testing.T) test { + return test{ + h: Handler{ + linker: NewLinker("dns", "acme"), + }, + url: url, + ctx: context.Background(), + contentType: "foo", + statusCode: 500, + err: acme.NewErrorISE("provisioner expected in request context"), + } + }, "fail/general-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ diff --git a/ca/ca.go b/ca/ca.go index 356a1f6f..e23be140 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -168,13 +168,15 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { }) // helpful routine for logging all routes // - walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { - fmt.Printf("%s %s\n", method, route) - return nil - } - if err := chi.Walk(mux, walkFunc); err != nil { - fmt.Printf("Logging err: %s\n", err.Error()) - } + /* + walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { + fmt.Printf("%s %s\n", method, route) + return nil + } + if err := chi.Walk(mux, walkFunc); err != nil { + fmt.Printf("Logging err: %s\n", err.Error()) + } + */ // Add monitoring if configured if len(config.Monitoring) > 0 {