diff --git a/acme/api/account.go b/acme/api/account.go index ade51aef..8c8c4d97 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -67,7 +67,7 @@ func (u *UpdateAccountRequest) Validate() error { } // NewAccount is the handler resource for creating new ACME accounts. -func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { +func NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() payload, err := payloadFromContext(ctx) if err != nil { @@ -114,18 +114,19 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { return } - eak, err := h.validateExternalAccountBinding(ctx, &nar) + eak, err := validateExternalAccountBinding(ctx, &nar) if err != nil { render.Error(w, err) return } + db := acme.MustFromContext(ctx) acc = &acme.Account{ Key: jwk, Contact: nar.Contact, Status: acme.StatusValid, } - if err := h.db.CreateAccount(ctx, acc); err != nil { + if err := db.CreateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating account")) return } @@ -136,7 +137,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { + if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) return } @@ -147,14 +148,15 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - h.linker.LinkAccount(ctx, acc) + o := optionsFromContext(ctx) + o.linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) + w.Header().Set("Location", o.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. -func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { +func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -187,16 +189,18 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { acc.Contact = uar.Contact } - if err := h.db.UpdateAccount(ctx, acc); err != nil { + db := acme.MustFromContext(ctx) + if err := db.UpdateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating account")) return } } } - h.linker.LinkAccount(ctx, acc) + o := optionsFromContext(ctx) + o.linker.LinkAccount(ctx, acc) - w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) + w.Header().Set("Location", o.linker.GetLink(ctx, AccountLinkType, acc.ID)) render.JSON(w, acc) } @@ -210,7 +214,7 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { } // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. -func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { +func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -222,13 +226,16 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } - orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) + + db := acme.MustFromContext(ctx) + orders, err := db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { render.Error(w, err) return } - h.linker.LinkOrdersByAccountID(ctx, orders) + o := optionsFromContext(ctx) + o.linker.LinkOrdersByAccountID(ctx, orders) render.JSON(w, orders) logOrdersByAccount(w, orders) diff --git a/acme/api/eab.go b/acme/api/eab.go index 3660d066..2c94a4ed 100644 --- a/acme/api/eab.go +++ b/acme/api/eab.go @@ -16,7 +16,7 @@ type ExternalAccountBinding struct { } // validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account. -func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { +func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context") @@ -47,7 +47,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc return nil, acmeErr } - externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) + db := acme.MustFromContext(ctx) + externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) if err != nil { if _, ok := err.(*acme.Error); ok { return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key") diff --git a/acme/api/handler.go b/acme/api/handler.go index 10eb22cb..04680656 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -16,6 +16,7 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" ) @@ -39,38 +40,89 @@ type payloadInfo struct { isEmptyJSON bool } -// Handler is the ACME API request handler. -type Handler struct { - db acme.DB - backdate provisioner.Duration - ca acme.CertificateAuthority - linker Linker - validateChallengeOptions *acme.ValidateChallengeOptions - prerequisitesChecker func(ctx context.Context) (bool, error) -} - // HandlerOptions required to create a new ACME API request handler. type HandlerOptions struct { - Backdate provisioner.Duration // DB storage backend that impements the acme.DB interface. + // + // Deprecated: use acme.NewContex(context.Context, acme.DB) DB acme.DB + + // CA is the certificate authority interface. + // + // Deprecated: use authority.NewContext(context.Context, *authority.Authority) + CA acme.CertificateAuthority + + // Backdate is the duration that the CA will substract from the current time + // to set the NotBefore in the certificate. + Backdate provisioner.Duration + // DNS the host used to generate accurate ACME links. By default the authority // will use the Host from the request, so this value will only be used if // request.Host is empty. DNS string + // Prefix is a URL path prefix under which the ACME api is served. This // prefix is required to generate accurate ACME links. // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- // "acme" is the prefix from which the ACME api is accessed. Prefix string - CA acme.CertificateAuthority + // PrerequisitesChecker checks if all prerequisites for serving ACME are // met by the CA configuration. PrerequisitesChecker func(ctx context.Context) (bool, error) + + linker Linker + validateChallengeOptions *acme.ValidateChallengeOptions +} + +type optionsKey struct{} + +func newOptionsContext(ctx context.Context, o *HandlerOptions) context.Context { + return context.WithValue(ctx, optionsKey{}, o) +} + +func optionsFromContext(ctx context.Context) *HandlerOptions { + o, ok := ctx.Value(optionsKey{}).(*HandlerOptions) + if !ok { + panic("handler options are not in the context") + } + return o +} + +var mustAuthority = func(ctx context.Context) acme.CertificateAuthority { + return authority.MustFromContext(ctx) +} + +// Handler is the ACME API request handler. +type Handler struct { + opts *HandlerOptions +} + +// Route traffic and implement the Router interface. +// +// Deprecated: Use api.Route(r Router, opts *HandlerOptions) +func (h *Handler) Route(r api.Router) { + Route(r, h.opts) } // NewHandler returns a new ACME API handler. +// +// Deprecated: Use api.Route(r Router, opts *HandlerOptions) func NewHandler(ops HandlerOptions) api.RouterHandler { + return &Handler{ + opts: &ops, + } +} + +// Route traffic and implement the Router interface. +func Route(r api.Router, opts *HandlerOptions) { + // by default all prerequisites are met + if opts.PrerequisitesChecker == nil { + opts.PrerequisitesChecker = func(ctx context.Context) (bool, error) { + return true, nil + } + } + transport := &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, @@ -83,67 +135,85 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { dialer := &net.Dialer{ Timeout: 30 * time.Second, } - prerequisitesChecker := func(ctx context.Context) (bool, error) { - // by default all prerequisites are met - return true, nil - } - if ops.PrerequisitesChecker != nil { - prerequisitesChecker = ops.PrerequisitesChecker - } - return &Handler{ - ca: ops.CA, - db: ops.DB, - backdate: ops.Backdate, - linker: NewLinker(ops.DNS, ops.Prefix), - validateChallengeOptions: &acme.ValidateChallengeOptions{ - HTTPGet: client.Get, - LookupTxt: net.LookupTXT, - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, - }, - prerequisitesChecker: prerequisitesChecker, - } -} -// Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { - getPath := h.linker.GetUnescapedPathSuffix - // Standard ACME API - r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) - r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) + opts.linker = NewLinker(opts.DNS, opts.Prefix) + opts.validateChallengeOptions = &acme.ValidateChallengeOptions{ + HTTPGet: client.Get, + LookupTxt: net.LookupTXT, + TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(dialer, network, addr, config) + }, + } + + withOptions := func(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // For backward compatibility with NewHandler. + if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil { + ctx = authority.NewContext(ctx, ca) + } + if opts.DB != nil { + ctx = acme.NewContext(ctx, opts.DB) + } + + ctx = newOptionsContext(ctx, opts) + next(w, r.WithContext(ctx)) + } + } validatingMiddleware := func(next nextHTTP) nextHTTP { - return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))) + return withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) + return withOptions(validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))) } extractPayloadByKid := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next))) + return withOptions(validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))) } extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { - return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next))) + return withOptions(validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))) } - r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount)) - r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented)) - r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder)) - r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) - r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) - r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) - r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) - r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert)) + getPath := opts.linker.GetUnescapedPathSuffix + + // Standard ACME API + r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), + withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) + r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), + withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) + r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), + withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) + r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), + withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) + + r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), + extractPayloadByJWK(NewAccount)) + r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(GetOrUpdateAccount)) + r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(NotImplemented)) + r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), + extractPayloadByKid(NewOrder)) + r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), + extractPayloadByKid(isPostAsGet(GetOrder))) + r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), + extractPayloadByKid(isPostAsGet(GetOrdersByAccountID))) + r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), + extractPayloadByKid(FinalizeOrder)) + r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), + extractPayloadByKid(isPostAsGet(GetAuthorization))) + r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), + extractPayloadByKid(GetChallenge)) + r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), + extractPayloadByKid(isPostAsGet(GetCertificate))) + r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), + extractPayloadByKidOrJWK(RevokeCert)) } // GetNonce just sets the right header since a Nonce is added to each response // by middleware by default. -func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { +func GetNonce(w http.ResponseWriter, r *http.Request) { if r.Method == "HEAD" { w.WriteHeader(http.StatusOK) } else { @@ -179,8 +249,10 @@ func (d *Directory) ToLog() (interface{}, error) { // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. -func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { +func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + o := optionsFromContext(ctx) + acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { render.Error(w, err) @@ -188,11 +260,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { } render.JSON(w, &Directory{ - NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), - NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), - NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), - RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType), - KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType), + NewNonce: o.linker.GetLink(ctx, NewNonceLinkType), + NewAccount: o.linker.GetLink(ctx, NewAccountLinkType), + NewOrder: o.linker.GetLink(ctx, NewOrderLinkType), + RevokeCert: o.linker.GetLink(ctx, RevokeCertLinkType), + KeyChange: o.linker.GetLink(ctx, KeyChangeLinkType), Meta: Meta{ ExternalAccountRequired: acmeProv.RequireEAB, }, @@ -201,19 +273,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. -func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { +func NotImplemented(w http.ResponseWriter, r *http.Request) { render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. -func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { +func GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + o := optionsFromContext(ctx) + db := acme.MustFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) + az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) return @@ -223,20 +298,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } - if err = az.UpdateStatus(ctx, h.db); err != nil { + if err = az.UpdateStatus(ctx, db); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) return } - h.linker.LinkAuthorization(ctx, az) + o.linker.LinkAuthorization(ctx, az) - w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) + w.Header().Set("Location", o.linker.GetLink(ctx, AuthzLinkType, az.ID)) render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. -func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { +func GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + o := optionsFromContext(ctx) + db := acme.MustFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -257,7 +335,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { // we'll just ignore the body. azID := chi.URLParam(r, "authzID") - ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) + ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) return @@ -273,29 +351,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { + if err = ch.Validate(ctx, db, jwk, o.validateChallengeOptions); err != nil { render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } - h.linker.LinkChallenge(ctx, ch, azID) + o.linker.LinkChallenge(ctx, ch, azID) - w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) - w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) + w.Header().Add("Link", link(o.linker.GetLink(ctx, AuthzLinkType, azID), "up")) + w.Header().Set("Location", o.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. -func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { +func GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - certID := chi.URLParam(r, "certID") - cert, err := h.db.GetCertificate(ctx, certID) + certID := chi.URLParam(r, "certID") + cert, err := db.GetCertificate(ctx, certID) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) return diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 10f7841f..564a16f5 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -31,15 +31,15 @@ func logNonce(w http.ResponseWriter, nonce string) { } } -// baseURLFromRequest determines the base URL which should be used for +// getBaseURLFromRequest determines the base URL which should be used for // constructing link URLs in e.g. the ACME directory result by taking the // request Host into consideration. // // If the Request.Host is an empty string, we return an empty string, to // indicate that the configured URL values should be used instead. If this -// function returns a non-empty result, then this should be used in -// constructing ACME link URLs. -func baseURLFromRequest(r *http.Request) *url.URL { +// function returns a non-empty result, then this should be used in constructing +// ACME link URLs. +func getBaseURLFromRequest(r *http.Request) *url.URL { // NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go // for an implementation that allows HTTP requests using the x-forwarded-proto // header. @@ -53,17 +53,18 @@ func baseURLFromRequest(r *http.Request) *url.URL { // baseURLFromRequest is a middleware that extracts and caches the baseURL // from the request. // E.g. https://ca.smallstep.com/ -func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { +func baseURLFromRequest(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r)) + ctx := context.WithValue(r.Context(), baseURLContextKey, getBaseURLFromRequest(r)) next(w, r.WithContext(ctx)) } } // addNonce is a middleware that adds a nonce to the response header. -func (h *Handler) addNonce(next nextHTTP) nextHTTP { +func addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - nonce, err := h.db.CreateNonce(r.Context()) + db := acme.MustFromContext(r.Context()) + nonce, err := db.CreateNonce(r.Context()) if err != nil { render.Error(w, err) return @@ -77,25 +78,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { // addDirLink is a middleware that adds a 'Link' response reader with the // directory index url. -func (h *Handler) addDirLink(next nextHTTP) nextHTTP { +func addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index")) + ctx := r.Context() + opts := optionsFromContext(ctx) + + w.Header().Add("Link", link(opts.linker.GetLink(ctx, DirectoryLinkType), "index")) next(w, r) } } // verifyContentType is a middleware that verifies that content type is // application/jose+json. -func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { +func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { var expected []string - p, err := provisionerFromContext(r.Context()) + ctx := r.Context() + opts := optionsFromContext(ctx) + + p, err := provisionerFromContext(ctx) if err != nil { render.Error(w, err) return } - u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} + u := url.URL{Path: opts.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} @@ -117,7 +124,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { } // parseJWS is a middleware that parses a request body into a JSONWebSignature struct. -func (h *Handler) parseJWS(next nextHTTP) nextHTTP { +func parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { @@ -149,10 +156,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { // * “nonce” (defined in Section 6.5) // * “url” (defined in Section 6.4) // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below -func (h *Handler) validateJWS(next nextHTTP) nextHTTP { +func validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(r.Context()) + db := acme.MustFromContext(ctx) + + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -202,7 +211,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { } // Check the validity/freshness of the Nonce. - if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { + if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { render.Error(w, err) return } @@ -235,10 +244,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { // extractJWK is a middleware that extracts the JWK from the JWS and saves it // in the context. Make sure to parse and validate the JWS before running this // middleware. -func (h *Handler) extractJWK(next nextHTTP) nextHTTP { +func extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(r.Context()) + db := acme.MustFromContext(ctx) + + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -264,7 +275,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { ctx = context.WithValue(ctx, jwkContextKey, jwk) // Get Account OR continue to generate a new one OR continue Revoke with certificate private key - acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) + acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID) switch { case errors.Is(err, acme.ErrNotFound): // For NewAccount and Revoke requests ... @@ -285,7 +296,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. -func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { +func lookupProvisioner(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() nameEscaped := chi.URLParam(r, "provisionerID") @@ -294,7 +305,7 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } - p, err := h.ca.LoadProvisionerByName(name) + p, err := mustAuthority(r.Context()).LoadProvisionerByName(name) if err != nil { render.Error(w, err) return @@ -311,10 +322,12 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { // checkPrerequisites checks if all prerequisites for serving ACME // are met by the CA configuration. -func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { +func checkPrerequisites(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ok, err := h.prerequisitesChecker(ctx) + opts := optionsFromContext(ctx) + + ok, err := opts.PrerequisitesChecker(ctx) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) return @@ -330,16 +343,19 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { // lookupJWK loads the JWK associated with the acme account referenced by the // kid parameter of the signed payload. // Make sure to parse and validate the JWS before running this middleware. -func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { +func lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + opts := optionsFromContext(ctx) + db := acme.MustFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return } - kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") + kidPrefix := opts.linker.GetLink(ctx, AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { render.Error(w, acme.NewError(acme.ErrorMalformedType, @@ -349,7 +365,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { } accID := strings.TrimPrefix(kid, kidPrefix) - acc, err := h.db.GetAccount(ctx, accID) + acc, err := db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) @@ -372,7 +388,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { // extractOrLookupJWK forwards handling to either extractJWK or // lookupJWK based on the presence of a JWK or a KID, respectively. -func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { +func extractOrLookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) @@ -385,13 +401,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { // and it can be used to check if a JWK exists. This flow is used when the ACME client // signed the payload with a certificate private key. if canExtractJWKFrom(jws) { - h.extractJWK(next)(w, r) + extractJWK(next)(w, r) return } // default to looking up the JWK based on KeyID. This flow is used when the ACME client // signed the payload with an account private key. - h.lookupJWK(next)(w, r) + lookupJWK(next)(w, r) } } @@ -408,7 +424,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool { // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // Make sure to parse and validate the JWS before running this middleware. -func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { +func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) @@ -440,7 +456,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { } // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). -func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { +func isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { diff --git a/acme/api/order.go b/acme/api/order.go index 99eb0e95..ebd0c7f5 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -68,7 +68,7 @@ var defaultOrderExpiry = time.Hour * 24 var defaultOrderBackdate = time.Minute // NewOrder ACME api for creating a new order. -func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { +func NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -117,7 +117,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ExpiresAt: o.ExpiresAt, Status: acme.StatusPending, } - if err := h.newAuthorization(ctx, az); err != nil { + if err := newAuthorization(ctx, az); err != nil { render.Error(w, err) return } @@ -136,18 +136,20 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) } - if err := h.db.CreateOrder(ctx, o); err != nil { + db := acme.MustFromContext(ctx) + if err := db.CreateOrder(ctx, o); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } - h.linker.LinkOrder(ctx, o) + opts := optionsFromContext(ctx) + opts.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) render.JSONStatus(w, o, http.StatusCreated) } -func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { +func newAuthorization(ctx context.Context, az *acme.Authorization) error { if strings.HasPrefix(az.Identifier.Value, "*.") { az.Wildcard = true az.Identifier = acme.Identifier{ @@ -163,6 +165,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) if err != nil { return acme.WrapErrorISE(err, "error generating random alphanumeric ID") } + + db := acme.MustFromContext(ctx) az.Challenges = make([]*acme.Challenge, len(chTypes)) for i, typ := range chTypes { ch := &acme.Challenge{ @@ -172,19 +176,19 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) Token: az.Token, Status: acme.StatusPending, } - if err := h.db.CreateChallenge(ctx, ch); err != nil { + if err := db.CreateChallenge(ctx, ch); err != nil { return acme.WrapErrorISE(err, "error creating challenge") } az.Challenges[i] = ch } - if err = h.db.CreateAuthorization(ctx, az); err != nil { + if err = db.CreateAuthorization(ctx, az); err != nil { return acme.WrapErrorISE(err, "error creating authorization") } return nil } // GetOrder ACME api for retrieving an order. -func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { +func GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -196,7 +200,9 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + + db := acme.MustFromContext(ctx) + o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return @@ -211,19 +217,20 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } - if err = o.UpdateStatus(ctx, h.db); err != nil { + if err = o.UpdateStatus(ctx, db); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating order status")) return } - h.linker.LinkOrder(ctx, o) + opts := optionsFromContext(ctx) + opts.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) render.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. -func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { +func FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { @@ -251,7 +258,8 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) + db := acme.MustFromContext(ctx) + o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return @@ -266,14 +274,17 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } - if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { + + ca := mustAuthority(ctx) + if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil { render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) return } - h.linker.LinkOrder(ctx, o) + opts := optionsFromContext(ctx) + opts.linker.LinkOrder(ctx, o) - w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) render.JSON(w, o) } diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 4b71bc22..55774aea 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -26,8 +26,7 @@ type revokePayload struct { } // RevokeCert attempts to revoke a certificate. -func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { - +func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { @@ -68,8 +67,9 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { return } + db := acme.MustFromContext(ctx) serial := certToBeRevoked.SerialNumber.String() - dbCert, err := h.db.GetCertificateBySerial(ctx, serial) + dbCert, err := db.GetCertificateBySerial(ctx, serial) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return @@ -87,7 +87,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) + acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { render.Error(w, acmeErr) return @@ -103,7 +103,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } } - hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) + ca := mustAuthority(ctx) + hasBeenRevokedBefore, err := ca.IsRevoked(serial) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return @@ -130,14 +131,15 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } options := revokeOptions(serial, certToBeRevoked, reasonCode) - err = h.ca.Revoke(ctx, options) + err = ca.Revoke(ctx, options) if err != nil { render.Error(w, wrapRevokeErr(err)) return } logRevoke(w, options) - w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index")) + o := optionsFromContext(ctx) + w.Header().Add("Link", link(o.linker.GetLink(ctx, DirectoryLinkType), "index")) w.Write(nil) } @@ -148,7 +150,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { // the identifiers in the certificate are extracted and compared against the (valid) Authorizations // that are stored for the ACME Account. If these sets match, the Account is considered authorized // to revoke the certificate. If this check fails, the client will receive an unauthorized error. -func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { +func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { if !account.IsValid() { return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) }