From d51c6b7d83b566182f0584ea6d6e82332057ed39 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 4 May 2022 19:20:34 -0700 Subject: [PATCH] Make step handler backward compatible --- scep/api/api.go | 34 +++++++++++++++++++++++++++------- scep/authority.go | 1 - 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/scep/api/api.go b/scep/api/api.go index e513aa43..49a5267a 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -48,28 +48,48 @@ type Response struct { } // handler is the SCEP request handler. -type handler struct{} +type handler struct { + auth *scep.Authority +} // Route traffic and implement the Router interface. // // Deprecated: use scep.Route(r api.Router) func (h *handler) Route(r api.Router) { - Route(r) + route(r, func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := scep.NewContext(r.Context(), h.auth) + next(w, r.WithContext(ctx)) + } + }) } // New returns a new SCEP API router. // // Deprecated: use scep.Route(r api.Router) func New(auth *scep.Authority) api.RouterHandler { - return &handler{} + return &handler{auth: auth} } // Route traffic and implement the Router interface. func Route(r api.Router) { - r.MethodFunc(http.MethodGet, "/{provisionerName}/*", lookupProvisioner(Get)) - r.MethodFunc(http.MethodGet, "/{provisionerName}", lookupProvisioner(Get)) - r.MethodFunc(http.MethodPost, "/{provisionerName}/*", lookupProvisioner(Post)) - r.MethodFunc(http.MethodPost, "/{provisionerName}", lookupProvisioner(Post)) + route(r, nil) +} + +func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc) { + getHandler := lookupProvisioner(Get) + postHandler := lookupProvisioner(Post) + + // For backward compatibility. + if middleware != nil { + getHandler = middleware(getHandler) + postHandler = middleware(postHandler) + } + + r.MethodFunc(http.MethodGet, "/{provisionerName}/*", getHandler) + r.MethodFunc(http.MethodGet, "/{provisionerName}", getHandler) + r.MethodFunc(http.MethodPost, "/{provisionerName}/*", postHandler) + r.MethodFunc(http.MethodPost, "/{provisionerName}", postHandler) } // Get handles all SCEP GET requests diff --git a/scep/authority.go b/scep/authority.go index 7fe01c1d..7dbbb8c5 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -453,7 +453,6 @@ func (a *Authority) CreateFailureResponse(ctx context.Context, csr *x509.Certifi // MatchChallengePassword verifies a SCEP challenge password func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) { - p, err := provisionerFromContext(ctx) if err != nil { return false, err