diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 21a7229d..2c189624 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -40,11 +40,11 @@ type GetExternalAccountKeysResponse struct { // requireEABEnabled is a middleware that ensures ACME EAB is enabled // before serving requests that act on ACME EAB credentials. -func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { +func requireEABEnabled(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() provName := chi.URLParam(r, "provisionerName") - eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) + eabEnabled, prov, err := provisionerHasEABEnabled(ctx, provName) if err != nil { render.Error(w, err) return @@ -60,16 +60,20 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { // provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME // provisioner is set to true and thus has EAB enabled. -func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) { +func provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) { var ( p provisioner.Interface err error ) - if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil { + + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + + if p, err = auth.LoadProvisionerByName(provisionerName); err != nil { return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName) } - prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) + prov, err := db.GetProvisioner(ctx, p.GetID()) if err != nil { return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID()) } diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 5e4b9c30..6ef6f0eb 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -81,10 +81,10 @@ type DeleteResponse struct { } // GetAdmin returns the requested admin, or an error. -func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { +func GetAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") - adm, ok := h.auth.LoadAdminByID(id) + adm, ok := mustAuthority(r.Context()).LoadAdminByID(id) if !ok { render.Error(w, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) @@ -94,7 +94,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { } // GetAdmins returns a segment of admins associated with the authority. -func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { +func GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, @@ -102,7 +102,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { return } - admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) + admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) return @@ -114,7 +114,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { } // CreateAdmin creates a new admin. -func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { +func CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) @@ -126,7 +126,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { return } - p, err := h.auth.LoadProvisionerByName(body.Provisioner) + auth := mustAuthority(r.Context()) + p, err := auth.LoadProvisionerByName(body.Provisioner) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return @@ -137,7 +138,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { Type: body.Type, } // Store to authority collection. - if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { + if err := auth.StoreAdmin(r.Context(), adm, p); err != nil { render.Error(w, admin.WrapErrorISE(err, "error storing admin")) return } @@ -146,10 +147,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { } // DeleteAdmin deletes admin. -func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { +func DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") - if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { + if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } @@ -158,7 +159,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { } // UpdateAdmin updates an existing admin. -func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { +func UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) @@ -171,8 +172,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { } id := chi.URLParam(r, "id") - - adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) + auth := mustAuthority(r.Context()) + adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) return diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index 99e74c88..0acd3ca9 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -1,56 +1,66 @@ package api import ( + "context" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" ) // Handler is the Admin API request handler. type Handler struct { - adminDB admin.DB - auth adminAuthority - acmeDB acme.DB acmeResponder acmeAdminResponderInterface } +// Route traffic and implement the Router interface. +// +// Deprecated: use Route(r api.Router, acmeResponder acmeAdminResponderInterface) +func (h *Handler) Route(r api.Router) { + Route(r, h.acmeResponder) +} + // NewHandler returns a new Authority Config Handler. +// +// Deprecated: use Route(r api.Router, acmeResponder acmeAdminResponderInterface) func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler { return &Handler{ - auth: auth, - adminDB: adminDB, - acmeDB: acmeDB, acmeResponder: acmeResponder, } } +var mustAuthority = func(ctx context.Context) adminAuthority { + return authority.MustFromContext(ctx) +} + // Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { +func Route(r api.Router, acmeResponder acmeAdminResponderInterface) { authnz := func(next nextHTTP) nextHTTP { - return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) + return extractAuthorizeTokenAdmin(requireAPIEnabled(next)) } requireEABEnabled := func(next nextHTTP) nextHTTP { - return h.requireEABEnabled(next) + return requireEABEnabled(next) } // Provisioners - r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) - r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) - r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner)) - r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner)) - r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner)) + r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner)) + r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners)) + r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner)) + r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner)) + r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner)) // Admins - r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin)) - r.MethodFunc("GET", "/admins", authnz(h.GetAdmins)) - r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin)) - r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin)) - r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin)) + r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin)) + r.MethodFunc("GET", "/admins", authnz(GetAdmins)) + r.MethodFunc("POST", "/admins", authnz(CreateAdmin)) + r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin)) + r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin)) // ACME External Account Binding Keys - r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) - r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) - r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey))) - r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey))) + r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(acmeResponder.GetExternalAccountKeys))) + r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(acmeResponder.GetExternalAccountKeys))) + r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(acmeResponder.CreateExternalAccountKey))) + r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(acmeResponder.DeleteExternalAccountKey))) } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index b57dd6eb..9bd6c698 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -12,11 +12,10 @@ type nextHTTP = func(http.ResponseWriter, *http.Request) // requireAPIEnabled is a middleware that ensures the Administration API // is enabled before servicing requests. -func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { +func requireAPIEnabled(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - if !h.auth.IsAdminAPIEnabled() { - render.Error(w, admin.NewError(admin.ErrorNotImplementedType, - "administration API not enabled")) + if !mustAuthority(r.Context()).IsAdminAPIEnabled() { + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } next(w, r) @@ -24,7 +23,7 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { } // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. -func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { +func extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { tok := r.Header.Get("Authorization") if tok == "" { @@ -33,13 +32,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return } - adm, err := h.auth.AuthorizeAdminToken(r, tok) + ctx := r.Context() + adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok) if err != nil { render.Error(w, err) return } - ctx := context.WithValue(r.Context(), adminContextKey, adm) + ctx = context.WithValue(ctx, adminContextKey, adm) next(w, r.WithContext(ctx)) } } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 1cad62dd..149f2c6a 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -23,29 +23,31 @@ type GetProvisionersResponse struct { } // GetProvisioner returns the requested provisioner, or an error. -func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - id := r.URL.Query().Get("id") - name := chi.URLParam(r, "name") - +func GetProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) + + ctx := r.Context() + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + if len(id) > 0 { - if p, err = h.auth.LoadProvisionerByID(id); err != nil { + if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { - if p, err = h.auth.LoadProvisionerByName(name); err != nil { + if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } - prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) + prov, err := db.GetProvisioner(ctx, p.GetID()) if err != nil { render.Error(w, err) return @@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { } // GetProvisioners returns the given segment of provisioners associated with the authority. -func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { +func GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, @@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { return } - p, next, err := h.auth.GetProvisioners(cursor, limit) + p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { } // CreateProvisioner creates a new prov. -func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { +func CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, prov); err != nil { render.Error(w, err) @@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { return } - if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { + if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil { render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } @@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { } // DeleteProvisioner deletes a provisioner. -func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { - id := r.URL.Query().Get("id") - name := chi.URLParam(r, "name") - +func DeleteProvisioner(w http.ResponseWriter, r *http.Request) { var ( p provisioner.Interface err error ) + + id := r.URL.Query().Get("id") + name := chi.URLParam(r, "name") + auth := mustAuthority(r.Context()) + if len(id) > 0 { - if p, err = h.auth.LoadProvisionerByID(id); err != nil { + if p, err = auth.LoadProvisionerByID(id); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { - if p, err = h.auth.LoadProvisionerByName(name); err != nil { + if p, err = auth.LoadProvisionerByName(name); err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } - if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { + if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } @@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { } // UpdateProvisioner updates an existing prov. -func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { +func UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, nu); err != nil { render.Error(w, err) return } + ctx := r.Context() name := chi.URLParam(r, "name") - _old, err := h.auth.LoadProvisionerByName(name) + auth := mustAuthority(ctx) + db := admin.MustFromContext(ctx) + + p, err := auth.LoadProvisionerByName(name) if err != nil { render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } - old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) + old, err := db.GetProvisioner(r.Context(), p.GetID()) if err != nil { - render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID())) return } @@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { return } - if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { + if err := auth.UpdateProvisioner(r.Context(), nu); err != nil { render.Error(w, err) return }