From d1f75f172078370776d74edde2086d94122ccbd9 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 28 Apr 2022 19:15:18 -0700 Subject: [PATCH] Refactor ACME api. --- acme/api/account.go | 25 ++-- acme/api/eab.go | 3 +- acme/api/handler.go | 203 ++++++++++++++----------------- acme/api/middleware.go | 106 +++++------------ acme/api/order.go | 47 +++----- acme/api/revoke.go | 14 +-- acme/challenge.go | 35 ++---- acme/client.go | 79 ++++++++++++ acme/common.go | 76 ++++++++++-- acme/db.go | 16 +-- acme/linker.go | 265 +++++++++++++++++++++++------------------ acme/linker_test.go | 43 ++++--- ca/ca.go | 36 +++--- 13 files changed, 510 insertions(+), 438 deletions(-) create mode 100644 acme/client.go diff --git a/acme/api/account.go b/acme/api/account.go index 8c8c4d97..d88c7066 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -69,6 +69,9 @@ func (u *UpdateAccountRequest) Validate() error { // NewAccount is the handler resource for creating new ACME accounts. func NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -120,7 +123,6 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { return } - db := acme.MustFromContext(ctx) acc = &acme.Account{ Key: jwk, Contact: nar.Contact, @@ -148,16 +150,18 @@ func NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - o := optionsFromContext(ctx) - o.linker.LinkAccount(ctx, acc) + linker.LinkAccount(ctx, acc) - w.Header().Set("Location", o.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) + w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID)) render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -189,7 +193,6 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { acc.Contact = uar.Contact } - db := acme.MustFromContext(ctx) if err := db.UpdateAccount(ctx, acc); err != nil { render.Error(w, acme.WrapErrorISE(err, "error updating account")) return @@ -197,10 +200,9 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { } } - o := optionsFromContext(ctx) - o.linker.LinkAccount(ctx, acc) + linker.LinkAccount(ctx, acc) - w.Header().Set("Location", o.linker.GetLink(ctx, AccountLinkType, acc.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID)) render.JSON(w, acc) } @@ -216,6 +218,9 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) @@ -227,15 +232,13 @@ func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { return } - db := acme.MustFromContext(ctx) orders, err := db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { render.Error(w, err) return } - o := optionsFromContext(ctx) - o.linker.LinkOrdersByAccountID(ctx, orders) + linker.LinkOrdersByAccountID(ctx, orders) render.JSON(w, orders) logOrdersByAccount(w, orders) diff --git a/acme/api/eab.go b/acme/api/eab.go index 2c94a4ed..13928ac4 100644 --- a/acme/api/eab.go +++ b/acme/api/eab.go @@ -47,7 +47,7 @@ func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) return nil, acmeErr } - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) if err != nil { if _, ok := err.(*acme.Error); ok { @@ -103,7 +103,6 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool { // o The "nonce" field MUST NOT be present // o The "url" field MUST be set to the same value as the outer JWS func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) { - if jws == nil { return "", acme.NewErrorISE("no JWS provided") } diff --git a/acme/api/handler.go b/acme/api/handler.go index 4b916404..efe2b780 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -2,12 +2,10 @@ package api import ( "context" - "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "fmt" - "net" "net/http" "time" @@ -70,144 +68,117 @@ type HandlerOptions struct { // PrerequisitesChecker checks if all prerequisites for serving ACME are // met by the CA configuration. PrerequisitesChecker func(ctx context.Context) (bool, error) - - linker Linker - validateChallengeOptions *acme.ValidateChallengeOptions -} - -type optionsKey struct{} - -func newOptionsContext(ctx context.Context, o *HandlerOptions) context.Context { - return context.WithValue(ctx, optionsKey{}, o) -} - -func optionsFromContext(ctx context.Context) *HandlerOptions { - o, ok := ctx.Value(optionsKey{}).(*HandlerOptions) - if !ok { - panic("acme options are not in the context") - } - return o } var mustAuthority = func(ctx context.Context) acme.CertificateAuthority { return authority.MustFromContext(ctx) } -// Handler is the ACME API request handler. -type Handler struct { +// handler is the ACME API request handler. +type handler struct { opts *HandlerOptions } // Route traffic and implement the Router interface. -// -// Deprecated: Use api.Route(r Router, opts *HandlerOptions) -func (h *Handler) Route(r api.Router) { - Route(r, h.opts) +func (h *handler) Route(r api.Router) { + route(r, h.opts) } // NewHandler returns a new ACME API handler. -// -// Deprecated: Use api.Route(r Router, opts *HandlerOptions) -func NewHandler(ops HandlerOptions) api.RouterHandler { - return &Handler{ - opts: &ops, +func NewHandler(opts HandlerOptions) api.RouterHandler { + return &handler{ + opts: &opts, } } -// Route traffic and implement the Router interface. -func Route(r api.Router, opts *HandlerOptions) { - // by default all prerequisites are met - if opts.PrerequisitesChecker == nil { - opts.PrerequisitesChecker = func(ctx context.Context) (bool, error) { - return true, nil +// Route traffic and implement the Router interface. This method requires that +// all the acme components, authority, db, client, linker, and prerequisite +// checker to be present in the context. +func Route(r api.Router) { + route(r, nil) +} + +func route(r api.Router, opts *HandlerOptions) { + var withContext func(next nextHTTP) nextHTTP + + // For backward compatibility this block adds will add a new middleware that + // will set the ACME components to the context. + if opts != nil { + client := acme.NewClient() + linker := acme.NewLinker(opts.DNS, opts.Prefix) + + withContext = func(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil { + ctx = authority.NewContext(ctx, ca) + } + ctx = acme.NewContext(ctx, opts.DB, client, linker, opts.PrerequisitesChecker) + next(w, r.WithContext(ctx)) + } + } + } else { + withContext = func(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + next(w, r) + } } } - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, + commonMiddleware := func(next nextHTTP) nextHTTP { + return withContext(func(w http.ResponseWriter, r *http.Request) { + // Linker middleware gets the provisioner and current url from the + // request and sets them in the context. + linker := acme.MustLinkerFromContext(r.Context()) + linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r) + }) } - client := http.Client{ - Timeout: 30 * time.Second, - Transport: transport, - } - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - } - - opts.linker = NewLinker(opts.DNS, opts.Prefix) - opts.validateChallengeOptions = &acme.ValidateChallengeOptions{ - HTTPGet: client.Get, - LookupTxt: net.LookupTXT, - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.DialWithDialer(dialer, network, addr, config) - }, - } - - withOptions := func(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // For backward compatibility with NewHandler. - if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil { - ctx = authority.NewContext(ctx, ca) - } - if opts.DB != nil { - ctx = acme.NewContext(ctx, opts.DB) - } - - ctx = newOptionsContext(ctx, opts) - next(w, r.WithContext(ctx)) - } - } - validatingMiddleware := func(next nextHTTP) nextHTTP { - return withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))))) + return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next)))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { - return withOptions(validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))) + return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKid := func(next nextHTTP) nextHTTP { - return withOptions(validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))) + return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next))) } extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { - return withOptions(validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))) + return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next))) } - getPath := opts.linker.GetUnescapedPathSuffix + getPath := acme.GetUnescapedPathSuffix // Standard ACME API - r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), - withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) - r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), - withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), - withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), - withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) + r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"), + commonMiddleware(addNonce(addDirLink(GetNonce)))) + r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"), + commonMiddleware(addNonce(addDirLink(GetNonce)))) + r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"), + commonMiddleware(GetDirectory)) + r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"), + commonMiddleware(GetDirectory)) - r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), + r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(NewAccount)) - r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), + r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(GetOrUpdateAccount)) - r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), + r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(NotImplemented)) - r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), + r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(NewOrder)) - r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), + r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(isPostAsGet(GetOrder))) - r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), + r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(isPostAsGet(GetOrdersByAccountID))) - r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), + r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(FinalizeOrder)) - r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), + r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(isPostAsGet(GetAuthorization))) - r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), + r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(GetChallenge)) - r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), + r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(isPostAsGet(GetCertificate))) - r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), + r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(RevokeCert)) } @@ -251,20 +222,20 @@ func (d *Directory) ToLog() (interface{}, error) { // for client configuration. func GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - o := optionsFromContext(ctx) - acmeProv, err := acmeProvisionerFromContext(ctx) + fmt.Println(acmeProv, err) if err != nil { render.Error(w, err) return } + linker := acme.MustLinkerFromContext(ctx) render.JSON(w, &Directory{ - NewNonce: o.linker.GetLink(ctx, NewNonceLinkType), - NewAccount: o.linker.GetLink(ctx, NewAccountLinkType), - NewOrder: o.linker.GetLink(ctx, NewOrderLinkType), - RevokeCert: o.linker.GetLink(ctx, RevokeCertLinkType), - KeyChange: o.linker.GetLink(ctx, KeyChangeLinkType), + NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType), + NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType), + NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType), + RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType), + KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType), Meta: Meta{ ExternalAccountRequired: acmeProv.RequireEAB, }, @@ -280,8 +251,8 @@ func NotImplemented(w http.ResponseWriter, r *http.Request) { // GetAuthorization ACME api for retrieving an Authz. func GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - o := optionsFromContext(ctx) - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { @@ -303,17 +274,17 @@ func GetAuthorization(w http.ResponseWriter, r *http.Request) { return } - o.linker.LinkAuthorization(ctx, az) + linker.LinkAuthorization(ctx, az) - w.Header().Set("Location", o.linker.GetLink(ctx, AuthzLinkType, az.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID)) render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. func GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - o := optionsFromContext(ctx) - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { @@ -351,22 +322,22 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) { render.Error(w, err) return } - if err = ch.Validate(ctx, db, jwk, o.validateChallengeOptions); err != nil { + if err = ch.Validate(ctx, db, jwk); err != nil { render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } - o.linker.LinkChallenge(ctx, ch, azID) + linker.LinkChallenge(ctx, ch, azID) - w.Header().Add("Link", link(o.linker.GetLink(ctx, AuthzLinkType, azID), "up")) - w.Header().Set("Location", o.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up")) + w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID)) render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. func GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) acc, err := accountFromContext(ctx) if err != nil { diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 564a16f5..09e88b8d 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -31,39 +31,10 @@ func logNonce(w http.ResponseWriter, nonce string) { } } -// getBaseURLFromRequest determines the base URL which should be used for -// constructing link URLs in e.g. the ACME directory result by taking the -// request Host into consideration. -// -// If the Request.Host is an empty string, we return an empty string, to -// indicate that the configured URL values should be used instead. If this -// function returns a non-empty result, then this should be used in constructing -// ACME link URLs. -func getBaseURLFromRequest(r *http.Request) *url.URL { - // NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go - // for an implementation that allows HTTP requests using the x-forwarded-proto - // header. - - if r.Host == "" { - return nil - } - return &url.URL{Scheme: "https", Host: r.Host} -} - -// baseURLFromRequest is a middleware that extracts and caches the baseURL -// from the request. -// E.g. https://ca.smallstep.com/ -func baseURLFromRequest(next nextHTTP) nextHTTP { - return func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), baseURLContextKey, getBaseURLFromRequest(r)) - next(w, r.WithContext(ctx)) - } -} - // addNonce is a middleware that adds a nonce to the response header. func addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - db := acme.MustFromContext(r.Context()) + db := acme.MustDatabaseFromContext(r.Context()) nonce, err := db.CreateNonce(r.Context()) if err != nil { render.Error(w, err) @@ -81,9 +52,9 @@ func addNonce(next nextHTTP) nextHTTP { func addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - opts := optionsFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) - w.Header().Add("Link", link(opts.linker.GetLink(ctx, DirectoryLinkType), "index")) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) next(w, r) } } @@ -92,17 +63,12 @@ func addDirLink(next nextHTTP) nextHTTP { // application/jose+json. func verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - var expected []string - ctx := r.Context() - opts := optionsFromContext(ctx) - - p, err := provisionerFromContext(ctx) - if err != nil { - render.Error(w, err) - return + p := acme.MustProvisionerFromContext(r.Context()) + u := &url.URL{ + Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""), } - u := url.URL{Path: opts.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} + var expected []string if strings.Contains(r.URL.String(), u.EscapedPath()) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} @@ -159,7 +125,7 @@ func parseJWS(next nextHTTP) nextHTTP { func validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { @@ -247,7 +213,7 @@ func validateJWS(next nextHTTP) nextHTTP { func extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { @@ -325,18 +291,20 @@ func lookupProvisioner(next nextHTTP) nextHTTP { func checkPrerequisites(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - opts := optionsFromContext(ctx) - - ok, err := opts.PrerequisitesChecker(ctx) - if err != nil { - render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) - return + // If the function is not set assume that all prerequisites are met. + checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx) + if ok { + ok, err := checkFunc(ctx) + if err != nil { + render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + return + } + if !ok { + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + return + } } - if !ok { - render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) - return - } - next(w, r.WithContext(ctx)) + next(w, r) } } @@ -346,8 +314,8 @@ func checkPrerequisites(next nextHTTP) nextHTTP { func lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - opts := optionsFromContext(ctx) - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) jws, err := jwsFromContext(ctx) if err != nil { @@ -355,7 +323,7 @@ func lookupJWK(next nextHTTP) nextHTTP { return } - kidPrefix := opts.linker.GetLink(ctx, AccountLinkType, "") + kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { render.Error(w, acme.NewError(acme.ErrorMalformedType, @@ -527,32 +495,14 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { return val, nil } -// provisionerFromContext searches the context for a provisioner. Returns the -// provisioner or an error. -func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { - val := ctx.Value(provisionerContextKey) - if val == nil { - return nil, acme.NewErrorISE("provisioner expected in request context") - } - pval, ok := val.(acme.Provisioner) - if !ok || pval == nil { - return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") - } - return pval, nil -} - // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // pointer to an ACME provisioner or an error. func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { - prov, err := provisionerFromContext(ctx) - if err != nil { - return nil, err - } - acmeProv, ok := prov.(*provisioner.ACME) - if !ok || acmeProv == nil { + p, ok := acme.MustProvisionerFromContext(ctx).(*provisioner.ACME) + if !ok { return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") } - return acmeProv, nil + return p, nil } // payloadFromContext searches the context for a payload. Returns the payload diff --git a/acme/api/order.go b/acme/api/order.go index ebd0c7f5..2b9f912e 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -70,16 +70,15 @@ var defaultOrderBackdate = time.Minute // NewOrder ACME api for creating a new order. func NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + prov := acme.MustProvisionerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - prov, err := provisionerFromContext(ctx) - if err != nil { - render.Error(w, err) - return - } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -136,16 +135,14 @@ func NewOrder(w http.ResponseWriter, r *http.Request) { o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) } - db := acme.MustFromContext(ctx) if err := db.CreateOrder(ctx, o); err != nil { render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } - opts := optionsFromContext(ctx) - opts.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSONStatus(w, o, http.StatusCreated) } @@ -166,7 +163,7 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error { return acme.WrapErrorISE(err, "error generating random alphanumeric ID") } - db := acme.MustFromContext(ctx) + db := acme.MustDatabaseFromContext(ctx) az.Challenges = make([]*acme.Challenge, len(chTypes)) for i, typ := range chTypes { ch := &acme.Challenge{ @@ -190,18 +187,16 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error { // GetOrder ACME api for retrieving an order. func GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + prov := acme.MustProvisionerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - prov, err := provisionerFromContext(ctx) - if err != nil { - render.Error(w, err) - return - } - db := acme.MustFromContext(ctx) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) @@ -222,26 +217,24 @@ func GetOrder(w http.ResponseWriter, r *http.Request) { return } - opts := optionsFromContext(ctx) - opts.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. func FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + prov := acme.MustProvisionerFromContext(ctx) + acc, err := accountFromContext(ctx) if err != nil { render.Error(w, err) return } - prov, err := provisionerFromContext(ctx) - if err != nil { - render.Error(w, err) - return - } payload, err := payloadFromContext(ctx) if err != nil { render.Error(w, err) @@ -258,7 +251,6 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - db := acme.MustFromContext(ctx) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) @@ -281,10 +273,9 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) { return } - opts := optionsFromContext(ctx) - opts.linker.LinkOrder(ctx, o) + linker.LinkOrder(ctx, o) - w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) + w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID)) render.JSON(w, o) } diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 55774aea..584ed27e 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -28,13 +28,11 @@ type revokePayload struct { // RevokeCert attempts to revoke a certificate. func RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - jws, err := jwsFromContext(ctx) - if err != nil { - render.Error(w, err) - return - } + db := acme.MustDatabaseFromContext(ctx) + linker := acme.MustLinkerFromContext(ctx) + prov := acme.MustProvisionerFromContext(ctx) - prov, err := provisionerFromContext(ctx) + jws, err := jwsFromContext(ctx) if err != nil { render.Error(w, err) return @@ -67,7 +65,6 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { return } - db := acme.MustFromContext(ctx) serial := certToBeRevoked.SerialNumber.String() dbCert, err := db.GetCertificateBySerial(ctx, serial) if err != nil { @@ -138,8 +135,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) { } logRevoke(w, options) - o := optionsFromContext(ctx) - w.Header().Add("Link", link(o.linker.GetLink(ctx, DirectoryLinkType), "index")) + w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index")) w.Write(nil) } diff --git a/acme/challenge.go b/acme/challenge.go index 9f08bae5..8d8466bd 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -14,7 +14,6 @@ import ( "fmt" "io" "net" - "net/http" "net/url" "reflect" "strings" @@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) { // type using the DB interface. // satisfactorily validated, the 'status' and 'validated' attributes are // updated. -func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error { // If already valid or invalid then return without performing validation. if ch.Status != StatusPending { return nil } switch ch.Type { case HTTP01: - return http01Validate(ctx, ch, db, jwk, vo) + return http01Validate(ctx, ch, db, jwk) case DNS01: - return dns01Validate(ctx, ch, db, jwk, vo) + return dns01Validate(ctx, ch, db, jwk) case TLSALPN01: - return tlsalpn01Validate(ctx, ch, db, jwk, vo) + return tlsalpn01Validate(ctx, ch, db, jwk) default: return NewErrorISE("unexpected challenge type '%s'", ch.Type) } } -func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} - resp, err := vo.HTTPGet(u.String()) + vc := MustClientFromContext(ctx) + resp, err := vc.Get(u.String()) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, "error doing http GET for url %s", u)) @@ -141,7 +141,7 @@ func tlsAlert(err error) uint8 { return 0 } -func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { config := &tls.Config{ NextProtos: []string{"acme-tls/1"}, // https://tools.ietf.org/html/rfc8737#section-4 @@ -154,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON hostPort := net.JoinHostPort(ch.Value, "443") - conn, err := vo.TLSDial("tcp", hostPort, config) + vc := MustClientFromContext(ctx) + conn, err := vc.TLSDial("tcp", hostPort, config) if err != nil { // With Go 1.17+ tls.Dial fails if there's no overlap between configured // client and server protocols. When this happens the connection is @@ -253,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) } -func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { +func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { // Normalize domain for wildcard DNS names // This is done to avoid making TXT lookups for domains like // _acme-challenge.*.example.com // Instead perform txt lookup for _acme-challenge.example.com domain := strings.TrimPrefix(ch.Value, "*.") - txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) + vc := MustClientFromContext(ctx) + txtRecords, err := vc.LookupTxt("_acme-challenge." + domain) if err != nil { return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, "error looking up TXT records for domain %s", domain)) @@ -376,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err } return nil } - -type httpGetter func(string) (*http.Response, error) -type lookupTxt func(string) ([]string, error) -type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) - -// ValidateChallengeOptions are ACME challenge validator functions. -type ValidateChallengeOptions struct { - HTTPGet httpGetter - LookupTxt lookupTxt - TLSDial tlsDialer -} diff --git a/acme/client.go b/acme/client.go new file mode 100644 index 00000000..2b200e45 --- /dev/null +++ b/acme/client.go @@ -0,0 +1,79 @@ +package acme + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "time" +) + +// Client is the interface used to verify ACME challenges. +type Client interface { + // Get issues an HTTP GET to the specified URL. + Get(url string) (*http.Response, error) + + // LookupTXT returns the DNS TXT records for the given domain name. + LookupTxt(name string) ([]string, error) + + // TLSDial connects to the given network address using net.Dialer and then + // initiates a TLS handshake, returning the resulting TLS connection. + TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +type clientKey struct{} + +// NewClientContext adds the given client to the context. +func NewClientContext(ctx context.Context, c Client) context.Context { + return context.WithValue(ctx, clientKey{}, c) +} + +// ClientFromContext returns the current client from the given context. +func ClientFromContext(ctx context.Context) (c Client, ok bool) { + c, ok = ctx.Value(clientKey{}).(Client) + return +} + +// MustClientFromContext returns the current client from the given context. It will +// return a new instance of the client if it does not exist. +func MustClientFromContext(ctx context.Context) Client { + if c, ok := ClientFromContext(ctx); !ok { + return NewClient() + } else { + return c + } +} + +type client struct { + http *http.Client + dialer *net.Dialer +} + +// NewClient returns an implementation of Client for verifying ACME challenges. +func NewClient() Client { + return &client{ + http: &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + dialer: &net.Dialer{ + Timeout: 30 * time.Second, + }, + } +} + +func (c *client) Get(url string) (*http.Response, error) { + return c.http.Get(url) +} + +func (c *client) LookupTxt(name string) ([]string, error) { + return net.LookupTXT(name) +} + +func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.DialWithDialer(c.dialer, network, addr, config) +} diff --git a/acme/common.go b/acme/common.go index 0c9e83dc..5290c06d 100644 --- a/acme/common.go +++ b/acme/common.go @@ -9,14 +9,6 @@ import ( "github.com/smallstep/certificates/authority/provisioner" ) -// CertificateAuthority is the interface implemented by a CA authority. -type CertificateAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - IsRevoked(sn string) (bool, error) - Revoke(context.Context, *authority.RevokeOptions) error - LoadProvisionerByName(string) (provisioner.Interface, error) -} - // Clock that returns time in UTC rounded to seconds. type Clock struct{} @@ -27,6 +19,51 @@ func (c *Clock) Now() time.Time { var clock Clock +// CertificateAuthority is the interface implemented by a CA authority. +type CertificateAuthority interface { + Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + IsRevoked(sn string) (bool, error) + Revoke(context.Context, *authority.RevokeOptions) error + LoadProvisionerByName(string) (provisioner.Interface, error) +} + +// NewContext adds the given acme components to the context. +func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context { + ctx = NewDatabaseContext(ctx, db) + ctx = NewClientContext(ctx, client) + ctx = NewLinkerContext(ctx, linker) + // Prerequisite checker is optional. + if fn != nil { + ctx = NewPrerequisitesCheckerContext(ctx, fn) + } + return ctx +} + +// PrerequisitesChecker is a function that checks if all prerequisites for +// serving ACME are met by the CA configuration. +type PrerequisitesChecker func(ctx context.Context) (bool, error) + +// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns +// always true. +func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) { + return true, nil +} + +type prerequisitesKey struct{} + +// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the +// context. +func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context { + return context.WithValue(ctx, prerequisitesKey{}, fn) +} + +// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the +// context. +func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) { + fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker) + return fn, ok && fn != nil +} + // Provisioner is an interface that implements a subset of the provisioner.Interface -- // only those methods required by the ACME api/authority. type Provisioner interface { @@ -38,6 +75,29 @@ type Provisioner interface { GetOptions() *provisioner.Options } +type provisionerKey struct{} + +// NewProvisionerContext adds the given provisioner to the context. +func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context { + return context.WithValue(ctx, provisionerKey{}, v) +} + +// ProvisionerFromContext returns the current provisioner from the given context. +func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) { + v, ok = ctx.Value(provisionerKey{}).(Provisioner) + return +} + +// MustLinkerFromContext returns the current provisioner from the given context. +// It will panic if it's not in the context. +func MustProvisionerFromContext(ctx context.Context) Provisioner { + if v, ok := ProvisionerFromContext(ctx); !ok { + panic("acme provisioner is not the context") + } else { + return v + } +} + // MockProvisioner for testing type MockProvisioner struct { Mret1 interface{} diff --git a/acme/db.go b/acme/db.go index a8637f57..3d781156 100644 --- a/acme/db.go +++ b/acme/db.go @@ -50,21 +50,21 @@ type DB interface { type dbKey struct{} -// NewContext adds the given acme database to the context. -func NewContext(ctx context.Context, db DB) context.Context { +// NewDatabaseContext adds the given acme database to the context. +func NewDatabaseContext(ctx context.Context, db DB) context.Context { return context.WithValue(ctx, dbKey{}, db) } -// FromContext returns the current acme database from the given context. -func FromContext(ctx context.Context) (db DB, ok bool) { +// DatabaseFromContext returns the current acme database from the given context. +func DatabaseFromContext(ctx context.Context) (db DB, ok bool) { db, ok = ctx.Value(dbKey{}).(DB) return } -// MustFromContext returns the current database from the given context. It -// will panic if it's not in the context. -func MustFromContext(ctx context.Context) DB { - if db, ok := FromContext(ctx); !ok { +// MustDatabaseFromContext returns the current database from the given context. +// It will panic if it's not in the context. +func MustDatabaseFromContext(ctx context.Context) DB { + if db, ok := DatabaseFromContext(ctx); !ok { panic("acme database is not in the context") } else { return db diff --git a/acme/linker.go b/acme/linker.go index 8dc87b14..6e9110c2 100644 --- a/acme/linker.go +++ b/acme/linker.go @@ -4,120 +4,16 @@ import ( "context" "fmt" "net" + "net/http" "net/url" "strings" - "github.com/smallstep/certificates/acme" + "github.com/go-chi/chi" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" ) -// NewLinker returns a new Directory type. -func NewLinker(dns, prefix string) Linker { - _, _, err := net.SplitHostPort(dns) - if err != nil && strings.Contains(err.Error(), "too many colons in address") { - // this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334 - // in case a port was appended to this wrong format, we try to extract the port, then check if it's - // still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of - // these cases, then the input dns is not changed. - lastIndex := strings.LastIndex(dns, ":") - hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:] - if ip := net.ParseIP(hostPart); ip != nil { - dns = "[" + hostPart + "]:" + portPart - } else if ip := net.ParseIP(dns); ip != nil { - dns = "[" + dns + "]" - } - } - return &linker{prefix: prefix, dns: dns} -} - -// Linker interface for generating links for ACME resources. -type Linker interface { - GetLink(ctx context.Context, typ LinkType, inputs ...string) string - GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string - - LinkOrder(ctx context.Context, o *acme.Order) - LinkAccount(ctx context.Context, o *acme.Account) - LinkChallenge(ctx context.Context, o *acme.Challenge, azID string) - LinkAuthorization(ctx context.Context, o *acme.Authorization) - LinkOrdersByAccountID(ctx context.Context, orders []string) -} - -type linkerKey struct{} - -// NewLinkerContext adds the given linker to the context. -func NewLinkerContext(ctx context.Context, v Linker) context.Context { - return context.WithValue(ctx, linkerKey{}, v) -} - -// LinkerFromContext returns the current linker from the given context. -func LinkerFromContext(ctx context.Context) (v Linker, ok bool) { - v, ok = ctx.Value(linkerKey{}).(Linker) - return -} - -// MustLinkerFromContext returns the current linker from the given context. It -// will panic if it's not in the context. -func MustLinkerFromContext(ctx context.Context) Linker { - if v, ok := LinkerFromContext(ctx); !ok { - panic("acme linker is not the context") - } else { - return v - } -} - -// linker generates ACME links. -type linker struct { - prefix string - dns string -} - -func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { - switch typ { - case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: - return fmt.Sprintf("/%s/%s", provisionerName, typ) - case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: - return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) - case ChallengeLinkType: - return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) - case OrdersByAccountLinkType: - return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) - case FinalizeLinkType: - return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) - default: - return "" - } -} - -// GetLink is a helper for GetLinkExplicit -func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { - var ( - provName string - baseURL = baseURLFromContext(ctx) - u = url.URL{} - ) - if p, err := provisionerFromContext(ctx); err == nil && p != nil { - provName = p.GetName() - } - // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 - if baseURL != nil { - u = *baseURL - } - - u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...) - - // If no Scheme is set, then default to https. - if u.Scheme == "" { - u.Scheme = "https" - } - - // If no Host is set, then use the default (first DNS attr in the ca.json). - if u.Host == "" { - u.Host = l.dns - } - - u.Path = l.prefix + u.Path - return u.String() -} - // LinkType captures the link type. type LinkType int @@ -183,8 +79,151 @@ func (l LinkType) String() string { } } +func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string { + switch typ { + case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: + return fmt.Sprintf("/%s/%s", provisionerName, typ) + case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: + return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) + case ChallengeLinkType: + return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) + case OrdersByAccountLinkType: + return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) + case FinalizeLinkType: + return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) + default: + return "" + } +} + +// NewLinker returns a new Directory type. +func NewLinker(dns, prefix string) Linker { + _, _, err := net.SplitHostPort(dns) + if err != nil && strings.Contains(err.Error(), "too many colons in address") { + // this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + // in case a port was appended to this wrong format, we try to extract the port, then check if it's + // still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of + // these cases, then the input dns is not changed. + lastIndex := strings.LastIndex(dns, ":") + hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:] + if ip := net.ParseIP(hostPart); ip != nil { + dns = "[" + hostPart + "]:" + portPart + } else if ip := net.ParseIP(dns); ip != nil { + dns = "[" + dns + "]" + } + } + return &linker{prefix: prefix, dns: dns} +} + +// Linker interface for generating links for ACME resources. +type Linker interface { + GetLink(ctx context.Context, typ LinkType, inputs ...string) string + Middleware(http.Handler) http.Handler + LinkOrder(ctx context.Context, o *Order) + LinkAccount(ctx context.Context, o *Account) + LinkChallenge(ctx context.Context, o *Challenge, azID string) + LinkAuthorization(ctx context.Context, o *Authorization) + LinkOrdersByAccountID(ctx context.Context, orders []string) +} + +type linkerKey struct{} + +// NewLinkerContext adds the given linker to the context. +func NewLinkerContext(ctx context.Context, v Linker) context.Context { + return context.WithValue(ctx, linkerKey{}, v) +} + +// LinkerFromContext returns the current linker from the given context. +func LinkerFromContext(ctx context.Context) (v Linker, ok bool) { + v, ok = ctx.Value(linkerKey{}).(Linker) + return +} + +// MustLinkerFromContext returns the current linker from the given context. It +// will panic if it's not in the context. +func MustLinkerFromContext(ctx context.Context) Linker { + if v, ok := LinkerFromContext(ctx); !ok { + panic("acme linker is not the context") + } else { + return v + } +} + +type baseURLKey struct{} + +func newBaseURLContext(ctx context.Context, r *http.Request) context.Context { + var u *url.URL + if r.Host != "" { + u = &url.URL{Scheme: "https", Host: r.Host} + } + return context.WithValue(ctx, baseURLKey{}, u) +} + +func baseURLFromContext(ctx context.Context) *url.URL { + if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok { + return u + } + return nil +} + +// linker generates ACME links. +type linker struct { + prefix string + dns string +} + +// Middleware gets the provisioner and current url from the request and sets +// them in the context so we can use the linker to create ACME links. +func (l *linker) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Add base url to the context. + ctx := newBaseURLContext(r.Context(), r) + + // Add provisioner to the context. + nameEscaped := chi.URLParam(r, "provisionerID") + name, err := url.PathUnescape(nameEscaped) + if err != nil { + render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) + return + } + + p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name) + if err != nil { + render.Error(w, err) + return + } + + acmeProv, ok := p.(*provisioner.ACME) + if !ok { + render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) + return + } + + ctx = NewProvisionerContext(ctx, Provisioner(acmeProv)) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetLink is a helper for GetLinkExplicit. +func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { + var u url.URL + if baseURL := baseURLFromContext(ctx); baseURL != nil { + u = *baseURL + } + if u.Scheme == "" { + u.Scheme = "https" + } + if u.Host == "" { + u.Host = l.dns + } + + p := MustProvisionerFromContext(ctx) + u.Path = l.prefix + GetUnescapedPathSuffix(typ, p.GetName(), inputs...) + return u.String() +} + // LinkOrder sets the ACME links required by an ACME order. -func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { +func (l *linker) LinkOrder(ctx context.Context, o *Order) { o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) for i, azID := range o.AuthorizationIDs { o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID) @@ -196,17 +235,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { } // LinkAccount sets the ACME links required by an ACME account. -func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { +func (l *linker) LinkAccount(ctx context.Context, acc *Account) { acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID) } // LinkChallenge sets the ACME links required by an ACME challenge. -func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { +func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) { ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID) } // LinkAuthorization sets the ACME links required by an ACME authorization. -func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { +func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) { for _, ch := range az.Challenges { l.LinkChallenge(ctx, ch, az.ID) } diff --git a/acme/linker_test.go b/acme/linker_test.go index a8612e6b..1946dd88 100644 --- a/acme/linker_test.go +++ b/acme/linker_test.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/smallstep/assert" - "github.com/smallstep/certificates/acme" ) func TestLinker_GetUnescapedPathSuffix(t *testing.T) { @@ -173,27 +172,27 @@ func TestLinker_LinkOrder(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - o *acme.Order - validate func(o *acme.Order) + o *Order + validate func(o *Order) } var tests = map[string]test{ "no-authz-and-no-cert": { - o: &acme.Order{ + o: &Order{ ID: oid, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{}) assert.Equals(t, o.CertificateURL, "") }, }, "one-authz-and-cert": { - o: &acme.Order{ + o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo"}, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), @@ -202,12 +201,12 @@ func TestLinker_LinkOrder(t *testing.T) { }, }, "many-authz": { - o: &acme.Order{ + o: &Order{ ID: oid, CertificateID: certID, AuthorizationIDs: []string{"foo", "bar", "zap"}, }, - validate: func(o *acme.Order) { + validate: func(o *Order) { assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.AuthorizationURLs, []string{ fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), @@ -237,15 +236,15 @@ func TestLinker_LinkAccount(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - a *acme.Account - validate func(o *acme.Account) + a *Account + validate func(o *Account) } var tests = map[string]test{ "ok": { - a: &acme.Account{ + a: &Account{ ID: accID, }, - validate: func(a *acme.Account) { + validate: func(a *Account) { assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) }, }, @@ -270,15 +269,15 @@ func TestLinker_LinkChallenge(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - ch *acme.Challenge - validate func(o *acme.Challenge) + ch *Challenge + validate func(o *Challenge) } var tests = map[string]test{ "ok": { - ch: &acme.Challenge{ + ch: &Challenge{ ID: chID, }, - validate: func(ch *acme.Challenge) { + validate: func(ch *Challenge) { assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) }, }, @@ -305,20 +304,20 @@ func TestLinker_LinkAuthorization(t *testing.T) { linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix) type test struct { - az *acme.Authorization - validate func(o *acme.Authorization) + az *Authorization + validate func(o *Authorization) } var tests = map[string]test{ "ok": { - az: &acme.Authorization{ + az: &Authorization{ ID: azID, - Challenges: []*acme.Challenge{ + Challenges: []*Challenge{ {ID: chID0}, {ID: chID1}, {ID: chID2}, }, }, - validate: func(az *acme.Authorization) { + validate: func(az *Authorization) { assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) diff --git a/ca/ca.go b/ca/ca.go index a8ecbb05..e910da74 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -189,30 +189,24 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { dns = fmt.Sprintf("%s:%s", dns, port) } - // ACME Router - prefix := "acme" + // ACME Router is only available if we have a database. var acmeDB acme.DB - if cfg.DB == nil { - acmeDB = nil - } else { + var acmeLinker acme.Linker + if cfg.DB != nil { acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) if err != nil { return nil, errors.Wrap(err, "error configuring ACME DB interface") } + acmeLinker = acme.NewLinker(dns, "acme") + mux.Route("/acme", func(r chi.Router) { + acmeAPI.Route(r) + }) + // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 + // of the ACME spec. + mux.Route("/2.0/acme", func(r chi.Router) { + acmeAPI.Route(r) + }) } - acmeOptions := &acmeAPI.HandlerOptions{ - Backdate: *cfg.AuthorityConfig.Backdate, - DNS: dns, - Prefix: prefix, - } - mux.Route("/"+prefix, func(r chi.Router) { - acmeAPI.Route(r, acmeOptions) - }) - // Use 2.0 because, at the moment, our ACME api is only compatible with v2.0 - // of the ACME spec. - mux.Route("/2.0/"+prefix, func(r chi.Router) { - acmeAPI.Route(r, acmeOptions) - }) // Admin API Router if cfg.AuthorityConfig.EnableAdmin { @@ -280,7 +274,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // Create context with all the necessary values. - baseContext := buildContext(auth, scepAuthority, acmeDB) + baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv.BaseContext = func(net.Listener) context.Context { @@ -304,7 +298,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // buildContext builds the server base context. -func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB) context.Context { +func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context { ctx := authority.NewContext(context.Background(), a) if authDB := a.GetDatabase(); authDB != nil { ctx = db.NewContext(ctx, authDB) @@ -316,7 +310,7 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB ctx = scep.NewContext(ctx, scepAuthority) } if acmeDB != nil { - ctx = acme.NewContext(ctx, acmeDB) + ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil) } return ctx }