Make step handler backward compatible

This commit is contained in:
Mariano Cano 2022-05-04 19:20:34 -07:00
parent 43ddcf2efe
commit d51c6b7d83
2 changed files with 27 additions and 8 deletions

View file

@ -48,28 +48,48 @@ type Response struct {
} }
// handler is the SCEP request handler. // handler is the SCEP request handler.
type handler struct{} type handler struct {
auth *scep.Authority
}
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
// //
// Deprecated: use scep.Route(r api.Router) // Deprecated: use scep.Route(r api.Router)
func (h *handler) Route(r api.Router) { func (h *handler) Route(r api.Router) {
Route(r) route(r, func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := scep.NewContext(r.Context(), h.auth)
next(w, r.WithContext(ctx))
}
})
} }
// New returns a new SCEP API router. // New returns a new SCEP API router.
// //
// Deprecated: use scep.Route(r api.Router) // Deprecated: use scep.Route(r api.Router)
func New(auth *scep.Authority) api.RouterHandler { func New(auth *scep.Authority) api.RouterHandler {
return &handler{} return &handler{auth: auth}
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func Route(r api.Router) { func Route(r api.Router) {
r.MethodFunc(http.MethodGet, "/{provisionerName}/*", lookupProvisioner(Get)) route(r, nil)
r.MethodFunc(http.MethodGet, "/{provisionerName}", lookupProvisioner(Get)) }
r.MethodFunc(http.MethodPost, "/{provisionerName}/*", lookupProvisioner(Post))
r.MethodFunc(http.MethodPost, "/{provisionerName}", lookupProvisioner(Post)) func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc) {
getHandler := lookupProvisioner(Get)
postHandler := lookupProvisioner(Post)
// For backward compatibility.
if middleware != nil {
getHandler = middleware(getHandler)
postHandler = middleware(postHandler)
}
r.MethodFunc(http.MethodGet, "/{provisionerName}/*", getHandler)
r.MethodFunc(http.MethodGet, "/{provisionerName}", getHandler)
r.MethodFunc(http.MethodPost, "/{provisionerName}/*", postHandler)
r.MethodFunc(http.MethodPost, "/{provisionerName}", postHandler)
} }
// Get handles all SCEP GET requests // Get handles all SCEP GET requests

View file

@ -453,7 +453,6 @@ func (a *Authority) CreateFailureResponse(ctx context.Context, csr *x509.Certifi
// MatchChallengePassword verifies a SCEP challenge password // MatchChallengePassword verifies a SCEP challenge password
func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) { func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) {
p, err := provisionerFromContext(ctx) p, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
return false, err return false, err