Refactor ACME api.

This commit is contained in:
Mariano Cano 2022-04-28 19:15:18 -07:00
parent fddd6f7d95
commit d1f75f1720
13 changed files with 510 additions and 438 deletions

View file

@ -69,6 +69,9 @@ func (u *UpdateAccountRequest) Validate() error {
// NewAccount is the handler resource for creating new ACME accounts. // NewAccount is the handler resource for creating new ACME accounts.
func NewAccount(w http.ResponseWriter, r *http.Request) { func NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -120,7 +123,6 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
return return
} }
db := acme.MustFromContext(ctx)
acc = &acme.Account{ acc = &acme.Account{
Key: jwk, Key: jwk,
Contact: nar.Contact, Contact: nar.Contact,
@ -148,16 +150,18 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
httpStatus = http.StatusOK httpStatus = http.StatusOK
} }
o := optionsFromContext(ctx) linker.LinkAccount(ctx, acc)
o.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", o.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
render.JSONStatus(w, acc, httpStatus) render.JSONStatus(w, acc, httpStatus)
} }
// GetOrUpdateAccount is the api for updating an ACME account. // GetOrUpdateAccount is the api for updating an ACME account.
func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -189,7 +193,6 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
acc.Contact = uar.Contact acc.Contact = uar.Contact
} }
db := acme.MustFromContext(ctx)
if err := db.UpdateAccount(ctx, acc); err != nil { if err := db.UpdateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating account")) render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return return
@ -197,10 +200,9 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
} }
} }
o := optionsFromContext(ctx) linker.LinkAccount(ctx, acc)
o.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", o.linker.GetLink(ctx, AccountLinkType, acc.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
render.JSON(w, acc) render.JSON(w, acc)
} }
@ -216,6 +218,9 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -227,15 +232,13 @@ func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
return return
} }
db := acme.MustFromContext(ctx)
orders, err := db.GetOrdersByAccountID(ctx, acc.ID) orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
o := optionsFromContext(ctx) linker.LinkOrdersByAccountID(ctx, orders)
o.linker.LinkOrdersByAccountID(ctx, orders)
render.JSON(w, orders) render.JSON(w, orders)
logOrdersByAccount(w, orders) logOrdersByAccount(w, orders)

View file

@ -47,7 +47,7 @@ func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest)
return nil, acmeErr return nil, acmeErr
} }
db := acme.MustFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
if err != nil { if err != nil {
if _, ok := err.(*acme.Error); ok { if _, ok := err.(*acme.Error); ok {
@ -103,7 +103,6 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
// o The "nonce" field MUST NOT be present // o The "nonce" field MUST NOT be present
// o The "url" field MUST be set to the same value as the outer JWS // o The "url" field MUST be set to the same value as the outer JWS
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) { func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
if jws == nil { if jws == nil {
return "", acme.NewErrorISE("no JWS provided") return "", acme.NewErrorISE("no JWS provided")
} }

View file

@ -2,12 +2,10 @@ package api
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net"
"net/http" "net/http"
"time" "time"
@ -70,144 +68,117 @@ type HandlerOptions struct {
// PrerequisitesChecker checks if all prerequisites for serving ACME are // PrerequisitesChecker checks if all prerequisites for serving ACME are
// met by the CA configuration. // met by the CA configuration.
PrerequisitesChecker func(ctx context.Context) (bool, error) PrerequisitesChecker func(ctx context.Context) (bool, error)
linker Linker
validateChallengeOptions *acme.ValidateChallengeOptions
}
type optionsKey struct{}
func newOptionsContext(ctx context.Context, o *HandlerOptions) context.Context {
return context.WithValue(ctx, optionsKey{}, o)
}
func optionsFromContext(ctx context.Context) *HandlerOptions {
o, ok := ctx.Value(optionsKey{}).(*HandlerOptions)
if !ok {
panic("acme options are not in the context")
}
return o
} }
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority { var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
return authority.MustFromContext(ctx) return authority.MustFromContext(ctx)
} }
// Handler is the ACME API request handler. // handler is the ACME API request handler.
type Handler struct { type handler struct {
opts *HandlerOptions opts *HandlerOptions
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
// func (h *handler) Route(r api.Router) {
// Deprecated: Use api.Route(r Router, opts *HandlerOptions) route(r, h.opts)
func (h *Handler) Route(r api.Router) {
Route(r, h.opts)
} }
// NewHandler returns a new ACME API handler. // NewHandler returns a new ACME API handler.
// func NewHandler(opts HandlerOptions) api.RouterHandler {
// Deprecated: Use api.Route(r Router, opts *HandlerOptions) return &handler{
func NewHandler(ops HandlerOptions) api.RouterHandler { opts: &opts,
return &Handler{
opts: &ops,
} }
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface. This method requires that
func Route(r api.Router, opts *HandlerOptions) { // all the acme components, authority, db, client, linker, and prerequisite
// by default all prerequisites are met // checker to be present in the context.
if opts.PrerequisitesChecker == nil { func Route(r api.Router) {
opts.PrerequisitesChecker = func(ctx context.Context) (bool, error) { route(r, nil)
return true, nil }
func route(r api.Router, opts *HandlerOptions) {
var withContext func(next nextHTTP) nextHTTP
// For backward compatibility this block adds will add a new middleware that
// will set the ACME components to the context.
if opts != nil {
client := acme.NewClient()
linker := acme.NewLinker(opts.DNS, opts.Prefix)
withContext = func(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil {
ctx = authority.NewContext(ctx, ca)
}
ctx = acme.NewContext(ctx, opts.DB, client, linker, opts.PrerequisitesChecker)
next(w, r.WithContext(ctx))
}
}
} else {
withContext = func(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
next(w, r)
}
} }
} }
transport := &http.Transport{ commonMiddleware := func(next nextHTTP) nextHTTP {
TLSClientConfig: &tls.Config{ return withContext(func(w http.ResponseWriter, r *http.Request) {
InsecureSkipVerify: true, // Linker middleware gets the provisioner and current url from the
}, // request and sets them in the context.
linker := acme.MustLinkerFromContext(r.Context())
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
})
} }
client := http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
}
opts.linker = NewLinker(opts.DNS, opts.Prefix)
opts.validateChallengeOptions = &acme.ValidateChallengeOptions{
HTTPGet: client.Get,
LookupTxt: net.LookupTXT,
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(dialer, network, addr, config)
},
}
withOptions := func(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// For backward compatibility with NewHandler.
if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil {
ctx = authority.NewContext(ctx, ca)
}
if opts.DB != nil {
ctx = acme.NewContext(ctx, opts.DB)
}
ctx = newOptionsContext(ctx, opts)
next(w, r.WithContext(ctx))
}
}
validatingMiddleware := func(next nextHTTP) nextHTTP { validatingMiddleware := func(next nextHTTP) nextHTTP {
return withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))))) return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
} }
extractPayloadByJWK := func(next nextHTTP) nextHTTP { extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return withOptions(validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))) return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
} }
extractPayloadByKid := func(next nextHTTP) nextHTTP { extractPayloadByKid := func(next nextHTTP) nextHTTP {
return withOptions(validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))) return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
} }
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
return withOptions(validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))) return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
} }
getPath := opts.linker.GetUnescapedPathSuffix getPath := acme.GetUnescapedPathSuffix
// Standard ACME API // Standard ACME API
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce))))))) commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) commonMiddleware(GetDirectory))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory))))) commonMiddleware(GetDirectory))
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
extractPayloadByJWK(NewAccount)) extractPayloadByJWK(NewAccount))
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(GetOrUpdateAccount)) extractPayloadByKid(GetOrUpdateAccount))
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(NotImplemented)) extractPayloadByKid(NotImplemented))
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
extractPayloadByKid(NewOrder)) extractPayloadByKid(NewOrder))
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(isPostAsGet(GetOrder))) extractPayloadByKid(isPostAsGet(GetOrder)))
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID))) extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(FinalizeOrder)) extractPayloadByKid(FinalizeOrder))
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
extractPayloadByKid(isPostAsGet(GetAuthorization))) extractPayloadByKid(isPostAsGet(GetAuthorization)))
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
extractPayloadByKid(GetChallenge)) extractPayloadByKid(GetChallenge))
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
extractPayloadByKid(isPostAsGet(GetCertificate))) extractPayloadByKid(isPostAsGet(GetCertificate)))
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
extractPayloadByKidOrJWK(RevokeCert)) extractPayloadByKidOrJWK(RevokeCert))
} }
@ -251,20 +222,20 @@ func (d *Directory) ToLog() (interface{}, error) {
// for client configuration. // for client configuration.
func GetDirectory(w http.ResponseWriter, r *http.Request) { func GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
o := optionsFromContext(ctx)
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
fmt.Println(acmeProv, err)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
linker := acme.MustLinkerFromContext(ctx)
render.JSON(w, &Directory{ render.JSON(w, &Directory{
NewNonce: o.linker.GetLink(ctx, NewNonceLinkType), NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
NewAccount: o.linker.GetLink(ctx, NewAccountLinkType), NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
NewOrder: o.linker.GetLink(ctx, NewOrderLinkType), NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
RevokeCert: o.linker.GetLink(ctx, RevokeCertLinkType), RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
KeyChange: o.linker.GetLink(ctx, KeyChangeLinkType), KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
Meta: Meta{ Meta: Meta{
ExternalAccountRequired: acmeProv.RequireEAB, ExternalAccountRequired: acmeProv.RequireEAB,
}, },
@ -280,8 +251,8 @@ func NotImplemented(w http.ResponseWriter, r *http.Request) {
// GetAuthorization ACME api for retrieving an Authz. // GetAuthorization ACME api for retrieving an Authz.
func GetAuthorization(w http.ResponseWriter, r *http.Request) { func GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
o := optionsFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
db := acme.MustFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
@ -303,17 +274,17 @@ func GetAuthorization(w http.ResponseWriter, r *http.Request) {
return return
} }
o.linker.LinkAuthorization(ctx, az) linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", o.linker.GetLink(ctx, AuthzLinkType, az.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
render.JSON(w, az) render.JSON(w, az)
} }
// GetChallenge ACME api for retrieving a Challenge. // GetChallenge ACME api for retrieving a Challenge.
func GetChallenge(w http.ResponseWriter, r *http.Request) { func GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
o := optionsFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
db := acme.MustFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
@ -351,22 +322,22 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
render.Error(w, err) render.Error(w, err)
return return
} }
if err = ch.Validate(ctx, db, jwk, o.validateChallengeOptions); err != nil { if err = ch.Validate(ctx, db, jwk); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return return
} }
o.linker.LinkChallenge(ctx, ch, azID) linker.LinkChallenge(ctx, ch, azID)
w.Header().Add("Link", link(o.linker.GetLink(ctx, AuthzLinkType, azID), "up")) w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
w.Header().Set("Location", o.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
render.JSON(w, ch) render.JSON(w, ch)
} }
// GetCertificate ACME api for retrieving a Certificate. // GetCertificate ACME api for retrieving a Certificate.
func GetCertificate(w http.ResponseWriter, r *http.Request) { func GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {

View file

@ -31,39 +31,10 @@ func logNonce(w http.ResponseWriter, nonce string) {
} }
} }
// getBaseURLFromRequest determines the base URL which should be used for
// constructing link URLs in e.g. the ACME directory result by taking the
// request Host into consideration.
//
// If the Request.Host is an empty string, we return an empty string, to
// indicate that the configured URL values should be used instead. If this
// function returns a non-empty result, then this should be used in constructing
// ACME link URLs.
func getBaseURLFromRequest(r *http.Request) *url.URL {
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
// for an implementation that allows HTTP requests using the x-forwarded-proto
// header.
if r.Host == "" {
return nil
}
return &url.URL{Scheme: "https", Host: r.Host}
}
// baseURLFromRequest is a middleware that extracts and caches the baseURL
// from the request.
// E.g. https://ca.smallstep.com/
func baseURLFromRequest(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), baseURLContextKey, getBaseURLFromRequest(r))
next(w, r.WithContext(ctx))
}
}
// addNonce is a middleware that adds a nonce to the response header. // addNonce is a middleware that adds a nonce to the response header.
func addNonce(next nextHTTP) nextHTTP { func addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
db := acme.MustFromContext(r.Context()) db := acme.MustDatabaseFromContext(r.Context())
nonce, err := db.CreateNonce(r.Context()) nonce, err := db.CreateNonce(r.Context())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -81,9 +52,9 @@ func addNonce(next nextHTTP) nextHTTP {
func addDirLink(next nextHTTP) nextHTTP { func addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
opts := optionsFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
w.Header().Add("Link", link(opts.linker.GetLink(ctx, DirectoryLinkType), "index")) w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
next(w, r) next(w, r)
} }
} }
@ -92,17 +63,12 @@ func addDirLink(next nextHTTP) nextHTTP {
// application/jose+json. // application/jose+json.
func verifyContentType(next nextHTTP) nextHTTP { func verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var expected []string p := acme.MustProvisionerFromContext(r.Context())
ctx := r.Context() u := &url.URL{
opts := optionsFromContext(ctx) Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
p, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
} }
u := url.URL{Path: opts.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} var expected []string
if strings.Contains(r.URL.String(), u.EscapedPath()) { if strings.Contains(r.URL.String(), u.EscapedPath()) {
// GET /certificate requests allow a greater range of content types. // GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
@ -159,7 +125,7 @@ func parseJWS(next nextHTTP) nextHTTP {
func validateJWS(next nextHTTP) nextHTTP { func validateJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
@ -247,7 +213,7 @@ func validateJWS(next nextHTTP) nextHTTP {
func extractJWK(next nextHTTP) nextHTTP { func extractJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
@ -325,18 +291,20 @@ func lookupProvisioner(next nextHTTP) nextHTTP {
func checkPrerequisites(next nextHTTP) nextHTTP { func checkPrerequisites(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
opts := optionsFromContext(ctx) // If the function is not set assume that all prerequisites are met.
checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
ok, err := opts.PrerequisitesChecker(ctx) if ok {
if err != nil { ok, err := checkFunc(ctx)
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) if err != nil {
return render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
} }
if !ok { next(w, r)
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
next(w, r.WithContext(ctx))
} }
} }
@ -346,8 +314,8 @@ func checkPrerequisites(next nextHTTP) nextHTTP {
func lookupJWK(next nextHTTP) nextHTTP { func lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
opts := optionsFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
db := acme.MustFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
@ -355,7 +323,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
return return
} }
kidPrefix := opts.linker.GetLink(ctx, AccountLinkType, "") kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
render.Error(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
@ -527,32 +495,14 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
return val, nil return val, nil
} }
// provisionerFromContext searches the context for a provisioner. Returns the
// provisioner or an error.
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
val := ctx.Value(provisionerContextKey)
if val == nil {
return nil, acme.NewErrorISE("provisioner expected in request context")
}
pval, ok := val.(acme.Provisioner)
if !ok || pval == nil {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
}
return pval, nil
}
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
// pointer to an ACME provisioner or an error. // pointer to an ACME provisioner or an error.
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
prov, err := provisionerFromContext(ctx) p, ok := acme.MustProvisionerFromContext(ctx).(*provisioner.ACME)
if err != nil { if !ok {
return nil, err
}
acmeProv, ok := prov.(*provisioner.ACME)
if !ok || acmeProv == nil {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
} }
return acmeProv, nil return p, nil
} }
// payloadFromContext searches the context for a payload. Returns the payload // payloadFromContext searches the context for a payload. Returns the payload

View file

@ -70,16 +70,15 @@ var defaultOrderBackdate = time.Minute
// NewOrder ACME api for creating a new order. // NewOrder ACME api for creating a new order.
func NewOrder(w http.ResponseWriter, r *http.Request) { func NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -136,16 +135,14 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
} }
db := acme.MustFromContext(ctx)
if err := db.CreateOrder(ctx, o); err != nil { if err := db.CreateOrder(ctx, o); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating order")) render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return return
} }
opts := optionsFromContext(ctx) linker.LinkOrder(ctx, o)
opts.linker.LinkOrder(ctx, o)
w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSONStatus(w, o, http.StatusCreated) render.JSONStatus(w, o, http.StatusCreated)
} }
@ -166,7 +163,7 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error {
return acme.WrapErrorISE(err, "error generating random alphanumeric ID") return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
} }
db := acme.MustFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
az.Challenges = make([]*acme.Challenge, len(chTypes)) az.Challenges = make([]*acme.Challenge, len(chTypes))
for i, typ := range chTypes { for i, typ := range chTypes {
ch := &acme.Challenge{ ch := &acme.Challenge{
@ -190,18 +187,16 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error {
// GetOrder ACME api for retrieving an order. // GetOrder ACME api for retrieving an order.
func GetOrder(w http.ResponseWriter, r *http.Request) { func GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
db := acme.MustFromContext(ctx)
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
@ -222,26 +217,24 @@ func GetOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
opts := optionsFromContext(ctx) linker.LinkOrder(ctx, o)
opts.linker.LinkOrder(ctx, o)
w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o) render.JSON(w, o)
} }
// FinalizeOrder attemptst to finalize an order and create a certificate. // FinalizeOrder attemptst to finalize an order and create a certificate.
func FinalizeOrder(w http.ResponseWriter, r *http.Request) { func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -258,7 +251,6 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
db := acme.MustFromContext(ctx)
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
@ -281,10 +273,9 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
opts := optionsFromContext(ctx) linker.LinkOrder(ctx, o)
opts.linker.LinkOrder(ctx, o)
w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o) render.JSON(w, o)
} }

View file

@ -28,13 +28,11 @@ type revokePayload struct {
// RevokeCert attempts to revoke a certificate. // RevokeCert attempts to revoke a certificate.
func RevokeCert(w http.ResponseWriter, r *http.Request) { func RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
if err != nil { linker := acme.MustLinkerFromContext(ctx)
render.Error(w, err) prov := acme.MustProvisionerFromContext(ctx)
return
}
prov, err := provisionerFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -67,7 +65,6 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
return return
} }
db := acme.MustFromContext(ctx)
serial := certToBeRevoked.SerialNumber.String() serial := certToBeRevoked.SerialNumber.String()
dbCert, err := db.GetCertificateBySerial(ctx, serial) dbCert, err := db.GetCertificateBySerial(ctx, serial)
if err != nil { if err != nil {
@ -138,8 +135,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
} }
logRevoke(w, options) logRevoke(w, options)
o := optionsFromContext(ctx) w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
w.Header().Add("Link", link(o.linker.GetLink(ctx, DirectoryLinkType), "index"))
w.Write(nil) w.Write(nil)
} }

View file

@ -14,7 +14,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) {
// type using the DB interface. // type using the DB interface.
// satisfactorily validated, the 'status' and 'validated' attributes are // satisfactorily validated, the 'status' and 'validated' attributes are
// updated. // updated.
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error {
// If already valid or invalid then return without performing validation. // If already valid or invalid then return without performing validation.
if ch.Status != StatusPending { if ch.Status != StatusPending {
return nil return nil
} }
switch ch.Type { switch ch.Type {
case HTTP01: case HTTP01:
return http01Validate(ctx, ch, db, jwk, vo) return http01Validate(ctx, ch, db, jwk)
case DNS01: case DNS01:
return dns01Validate(ctx, ch, db, jwk, vo) return dns01Validate(ctx, ch, db, jwk)
case TLSALPN01: case TLSALPN01:
return tlsalpn01Validate(ctx, ch, db, jwk, vo) return tlsalpn01Validate(ctx, ch, db, jwk)
default: default:
return NewErrorISE("unexpected challenge type '%s'", ch.Type) return NewErrorISE("unexpected challenge type '%s'", ch.Type)
} }
} }
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(u.String()) vc := MustClientFromContext(ctx)
resp, err := vc.Get(u.String())
if err != nil { if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
"error doing http GET for url %s", u)) "error doing http GET for url %s", u))
@ -141,7 +141,7 @@ func tlsAlert(err error) uint8 {
return 0 return 0
} }
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
config := &tls.Config{ config := &tls.Config{
NextProtos: []string{"acme-tls/1"}, NextProtos: []string{"acme-tls/1"},
// https://tools.ietf.org/html/rfc8737#section-4 // https://tools.ietf.org/html/rfc8737#section-4
@ -154,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
hostPort := net.JoinHostPort(ch.Value, "443") hostPort := net.JoinHostPort(ch.Value, "443")
conn, err := vo.TLSDial("tcp", hostPort, config) vc := MustClientFromContext(ctx)
conn, err := vc.TLSDial("tcp", hostPort, config)
if err != nil { if err != nil {
// With Go 1.17+ tls.Dial fails if there's no overlap between configured // With Go 1.17+ tls.Dial fails if there's no overlap between configured
// client and server protocols. When this happens the connection is // client and server protocols. When this happens the connection is
@ -253,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
} }
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
// Normalize domain for wildcard DNS names // Normalize domain for wildcard DNS names
// This is done to avoid making TXT lookups for domains like // This is done to avoid making TXT lookups for domains like
// _acme-challenge.*.example.com // _acme-challenge.*.example.com
// Instead perform txt lookup for _acme-challenge.example.com // Instead perform txt lookup for _acme-challenge.example.com
domain := strings.TrimPrefix(ch.Value, "*.") domain := strings.TrimPrefix(ch.Value, "*.")
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) vc := MustClientFromContext(ctx)
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
if err != nil { if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
"error looking up TXT records for domain %s", domain)) "error looking up TXT records for domain %s", domain))
@ -376,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
} }
return nil return nil
} }
type httpGetter func(string) (*http.Response, error)
type lookupTxt func(string) ([]string, error)
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
// ValidateChallengeOptions are ACME challenge validator functions.
type ValidateChallengeOptions struct {
HTTPGet httpGetter
LookupTxt lookupTxt
TLSDial tlsDialer
}

79
acme/client.go Normal file
View file

@ -0,0 +1,79 @@
package acme
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// Client is the interface used to verify ACME challenges.
type Client interface {
// Get issues an HTTP GET to the specified URL.
Get(url string) (*http.Response, error)
// LookupTXT returns the DNS TXT records for the given domain name.
LookupTxt(name string) ([]string, error)
// TLSDial connects to the given network address using net.Dialer and then
// initiates a TLS handshake, returning the resulting TLS connection.
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
}
type clientKey struct{}
// NewClientContext adds the given client to the context.
func NewClientContext(ctx context.Context, c Client) context.Context {
return context.WithValue(ctx, clientKey{}, c)
}
// ClientFromContext returns the current client from the given context.
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
c, ok = ctx.Value(clientKey{}).(Client)
return
}
// MustClientFromContext returns the current client from the given context. It will
// return a new instance of the client if it does not exist.
func MustClientFromContext(ctx context.Context) Client {
if c, ok := ClientFromContext(ctx); !ok {
return NewClient()
} else {
return c
}
}
type client struct {
http *http.Client
dialer *net.Dialer
}
// NewClient returns an implementation of Client for verifying ACME challenges.
func NewClient() Client {
return &client{
http: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
},
dialer: &net.Dialer{
Timeout: 30 * time.Second,
},
}
}
func (c *client) Get(url string) (*http.Response, error) {
return c.http.Get(url)
}
func (c *client) LookupTxt(name string) ([]string, error) {
return net.LookupTXT(name)
}
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(c.dialer, network, addr, config)
}

View file

@ -9,14 +9,6 @@ import (
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
LoadProvisionerByName(string) (provisioner.Interface, error)
}
// Clock that returns time in UTC rounded to seconds. // Clock that returns time in UTC rounded to seconds.
type Clock struct{} type Clock struct{}
@ -27,6 +19,51 @@ func (c *Clock) Now() time.Time {
var clock Clock var clock Clock
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
LoadProvisionerByName(string) (provisioner.Interface, error)
}
// NewContext adds the given acme components to the context.
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
ctx = NewDatabaseContext(ctx, db)
ctx = NewClientContext(ctx, client)
ctx = NewLinkerContext(ctx, linker)
// Prerequisite checker is optional.
if fn != nil {
ctx = NewPrerequisitesCheckerContext(ctx, fn)
}
return ctx
}
// PrerequisitesChecker is a function that checks if all prerequisites for
// serving ACME are met by the CA configuration.
type PrerequisitesChecker func(ctx context.Context) (bool, error)
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
// always true.
func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
return true, nil
}
type prerequisitesKey struct{}
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
// context.
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
return context.WithValue(ctx, prerequisitesKey{}, fn)
}
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
// context.
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
return fn, ok && fn != nil
}
// Provisioner is an interface that implements a subset of the provisioner.Interface -- // Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority. // only those methods required by the ACME api/authority.
type Provisioner interface { type Provisioner interface {
@ -38,6 +75,29 @@ type Provisioner interface {
GetOptions() *provisioner.Options GetOptions() *provisioner.Options
} }
type provisionerKey struct{}
// NewProvisionerContext adds the given provisioner to the context.
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
return context.WithValue(ctx, provisionerKey{}, v)
}
// ProvisionerFromContext returns the current provisioner from the given context.
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
return
}
// MustLinkerFromContext returns the current provisioner from the given context.
// It will panic if it's not in the context.
func MustProvisionerFromContext(ctx context.Context) Provisioner {
if v, ok := ProvisionerFromContext(ctx); !ok {
panic("acme provisioner is not the context")
} else {
return v
}
}
// MockProvisioner for testing // MockProvisioner for testing
type MockProvisioner struct { type MockProvisioner struct {
Mret1 interface{} Mret1 interface{}

View file

@ -50,21 +50,21 @@ type DB interface {
type dbKey struct{} type dbKey struct{}
// NewContext adds the given acme database to the context. // NewDatabaseContext adds the given acme database to the context.
func NewContext(ctx context.Context, db DB) context.Context { func NewDatabaseContext(ctx context.Context, db DB) context.Context {
return context.WithValue(ctx, dbKey{}, db) return context.WithValue(ctx, dbKey{}, db)
} }
// FromContext returns the current acme database from the given context. // DatabaseFromContext returns the current acme database from the given context.
func FromContext(ctx context.Context) (db DB, ok bool) { func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
db, ok = ctx.Value(dbKey{}).(DB) db, ok = ctx.Value(dbKey{}).(DB)
return return
} }
// MustFromContext returns the current database from the given context. It // MustDatabaseFromContext returns the current database from the given context.
// will panic if it's not in the context. // It will panic if it's not in the context.
func MustFromContext(ctx context.Context) DB { func MustDatabaseFromContext(ctx context.Context) DB {
if db, ok := FromContext(ctx); !ok { if db, ok := DatabaseFromContext(ctx); !ok {
panic("acme database is not in the context") panic("acme database is not in the context")
} else { } else {
return db return db

View file

@ -4,120 +4,16 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/smallstep/certificates/acme" "github.com/go-chi/chi"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
) )
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) Linker {
_, _, err := net.SplitHostPort(dns)
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
// these cases, then the input dns is not changed.
lastIndex := strings.LastIndex(dns, ":")
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
if ip := net.ParseIP(hostPart); ip != nil {
dns = "[" + hostPart + "]:" + portPart
} else if ip := net.ParseIP(dns); ip != nil {
dns = "[" + dns + "]"
}
}
return &linker{prefix: prefix, dns: dns}
}
// Linker interface for generating links for ACME resources.
type Linker interface {
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
LinkOrder(ctx context.Context, o *acme.Order)
LinkAccount(ctx context.Context, o *acme.Account)
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
LinkAuthorization(ctx context.Context, o *acme.Authorization)
LinkOrdersByAccountID(ctx context.Context, orders []string)
}
type linkerKey struct{}
// NewLinkerContext adds the given linker to the context.
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
return context.WithValue(ctx, linkerKey{}, v)
}
// LinkerFromContext returns the current linker from the given context.
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
v, ok = ctx.Value(linkerKey{}).(Linker)
return
}
// MustLinkerFromContext returns the current linker from the given context. It
// will panic if it's not in the context.
func MustLinkerFromContext(ctx context.Context) Linker {
if v, ok := LinkerFromContext(ctx); !ok {
panic("acme linker is not the context")
} else {
return v
}
}
// linker generates ACME links.
type linker struct {
prefix string
dns string
}
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// GetLink is a helper for GetLinkExplicit
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var (
provName string
baseURL = baseURLFromContext(ctx)
u = url.URL{}
)
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
provName = p.GetName()
}
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
if baseURL != nil {
u = *baseURL
}
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
// If no Scheme is set, then default to https.
if u.Scheme == "" {
u.Scheme = "https"
}
// If no Host is set, then use the default (first DNS attr in the ca.json).
if u.Host == "" {
u.Host = l.dns
}
u.Path = l.prefix + u.Path
return u.String()
}
// LinkType captures the link type. // LinkType captures the link type.
type LinkType int type LinkType int
@ -183,8 +79,151 @@ func (l LinkType) String() string {
} }
} }
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) Linker {
_, _, err := net.SplitHostPort(dns)
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
// these cases, then the input dns is not changed.
lastIndex := strings.LastIndex(dns, ":")
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
if ip := net.ParseIP(hostPart); ip != nil {
dns = "[" + hostPart + "]:" + portPart
} else if ip := net.ParseIP(dns); ip != nil {
dns = "[" + dns + "]"
}
}
return &linker{prefix: prefix, dns: dns}
}
// Linker interface for generating links for ACME resources.
type Linker interface {
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
Middleware(http.Handler) http.Handler
LinkOrder(ctx context.Context, o *Order)
LinkAccount(ctx context.Context, o *Account)
LinkChallenge(ctx context.Context, o *Challenge, azID string)
LinkAuthorization(ctx context.Context, o *Authorization)
LinkOrdersByAccountID(ctx context.Context, orders []string)
}
type linkerKey struct{}
// NewLinkerContext adds the given linker to the context.
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
return context.WithValue(ctx, linkerKey{}, v)
}
// LinkerFromContext returns the current linker from the given context.
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
v, ok = ctx.Value(linkerKey{}).(Linker)
return
}
// MustLinkerFromContext returns the current linker from the given context. It
// will panic if it's not in the context.
func MustLinkerFromContext(ctx context.Context) Linker {
if v, ok := LinkerFromContext(ctx); !ok {
panic("acme linker is not the context")
} else {
return v
}
}
type baseURLKey struct{}
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
var u *url.URL
if r.Host != "" {
u = &url.URL{Scheme: "https", Host: r.Host}
}
return context.WithValue(ctx, baseURLKey{}, u)
}
func baseURLFromContext(ctx context.Context) *url.URL {
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
return u
}
return nil
}
// linker generates ACME links.
type linker struct {
prefix string
dns string
}
// Middleware gets the provisioner and current url from the request and sets
// them in the context so we can use the linker to create ACME links.
func (l *linker) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add base url to the context.
ctx := newBaseURLContext(r.Context(), r)
// Add provisioner to the context.
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
if err != nil {
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetLink is a helper for GetLinkExplicit.
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var u url.URL
if baseURL := baseURLFromContext(ctx); baseURL != nil {
u = *baseURL
}
if u.Scheme == "" {
u.Scheme = "https"
}
if u.Host == "" {
u.Host = l.dns
}
p := MustProvisionerFromContext(ctx)
u.Path = l.prefix + GetUnescapedPathSuffix(typ, p.GetName(), inputs...)
return u.String()
}
// LinkOrder sets the ACME links required by an ACME order. // LinkOrder sets the ACME links required by an ACME order.
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { func (l *linker) LinkOrder(ctx context.Context, o *Order) {
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
for i, azID := range o.AuthorizationIDs { for i, azID := range o.AuthorizationIDs {
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID) o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
@ -196,17 +235,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
} }
// LinkAccount sets the ACME links required by an ACME account. // LinkAccount sets the ACME links required by an ACME account.
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID) acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
} }
// LinkChallenge sets the ACME links required by an ACME challenge. // LinkChallenge sets the ACME links required by an ACME challenge.
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID) ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
} }
// LinkAuthorization sets the ACME links required by an ACME authorization. // LinkAuthorization sets the ACME links required by an ACME authorization.
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
for _, ch := range az.Challenges { for _, ch := range az.Challenges {
l.LinkChallenge(ctx, ch, az.ID) l.LinkChallenge(ctx, ch, az.ID)
} }

View file

@ -7,7 +7,6 @@ import (
"testing" "testing"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
) )
func TestLinker_GetUnescapedPathSuffix(t *testing.T) { func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
@ -173,27 +172,27 @@ func TestLinker_LinkOrder(t *testing.T) {
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
o *acme.Order o *Order
validate func(o *acme.Order) validate func(o *Order)
} }
var tests = map[string]test{ var tests = map[string]test{
"no-authz-and-no-cert": { "no-authz-and-no-cert": {
o: &acme.Order{ o: &Order{
ID: oid, ID: oid,
}, },
validate: func(o *acme.Order) { validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{}) assert.Equals(t, o.AuthorizationURLs, []string{})
assert.Equals(t, o.CertificateURL, "") assert.Equals(t, o.CertificateURL, "")
}, },
}, },
"one-authz-and-cert": { "one-authz-and-cert": {
o: &acme.Order{ o: &Order{
ID: oid, ID: oid,
CertificateID: certID, CertificateID: certID,
AuthorizationIDs: []string{"foo"}, AuthorizationIDs: []string{"foo"},
}, },
validate: func(o *acme.Order) { validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{ assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -202,12 +201,12 @@ func TestLinker_LinkOrder(t *testing.T) {
}, },
}, },
"many-authz": { "many-authz": {
o: &acme.Order{ o: &Order{
ID: oid, ID: oid,
CertificateID: certID, CertificateID: certID,
AuthorizationIDs: []string{"foo", "bar", "zap"}, AuthorizationIDs: []string{"foo", "bar", "zap"},
}, },
validate: func(o *acme.Order) { validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{ assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -237,15 +236,15 @@ func TestLinker_LinkAccount(t *testing.T) {
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
a *acme.Account a *Account
validate func(o *acme.Account) validate func(o *Account)
} }
var tests = map[string]test{ var tests = map[string]test{
"ok": { "ok": {
a: &acme.Account{ a: &Account{
ID: accID, ID: accID,
}, },
validate: func(a *acme.Account) { validate: func(a *Account) {
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
}, },
}, },
@ -270,15 +269,15 @@ func TestLinker_LinkChallenge(t *testing.T) {
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
ch *acme.Challenge ch *Challenge
validate func(o *acme.Challenge) validate func(o *Challenge)
} }
var tests = map[string]test{ var tests = map[string]test{
"ok": { "ok": {
ch: &acme.Challenge{ ch: &Challenge{
ID: chID, ID: chID,
}, },
validate: func(ch *acme.Challenge) { validate: func(ch *Challenge) {
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
}, },
}, },
@ -305,20 +304,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
az *acme.Authorization az *Authorization
validate func(o *acme.Authorization) validate func(o *Authorization)
} }
var tests = map[string]test{ var tests = map[string]test{
"ok": { "ok": {
az: &acme.Authorization{ az: &Authorization{
ID: azID, ID: azID,
Challenges: []*acme.Challenge{ Challenges: []*Challenge{
{ID: chID0}, {ID: chID0},
{ID: chID1}, {ID: chID1},
{ID: chID2}, {ID: chID2},
}, },
}, },
validate: func(az *acme.Authorization) { validate: func(az *Authorization) {
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))

View file

@ -189,30 +189,24 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
dns = fmt.Sprintf("%s:%s", dns, port) dns = fmt.Sprintf("%s:%s", dns, port)
} }
// ACME Router // ACME Router is only available if we have a database.
prefix := "acme"
var acmeDB acme.DB var acmeDB acme.DB
if cfg.DB == nil { var acmeLinker acme.Linker
acmeDB = nil if cfg.DB != nil {
} else {
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error configuring ACME DB interface") return nil, errors.Wrap(err, "error configuring ACME DB interface")
} }
acmeLinker = acme.NewLinker(dns, "acme")
mux.Route("/acme", func(r chi.Router) {
acmeAPI.Route(r)
})
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
// of the ACME spec.
mux.Route("/2.0/acme", func(r chi.Router) {
acmeAPI.Route(r)
})
} }
acmeOptions := &acmeAPI.HandlerOptions{
Backdate: *cfg.AuthorityConfig.Backdate,
DNS: dns,
Prefix: prefix,
}
mux.Route("/"+prefix, func(r chi.Router) {
acmeAPI.Route(r, acmeOptions)
})
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
// of the ACME spec.
mux.Route("/2.0/"+prefix, func(r chi.Router) {
acmeAPI.Route(r, acmeOptions)
})
// Admin API Router // Admin API Router
if cfg.AuthorityConfig.EnableAdmin { if cfg.AuthorityConfig.EnableAdmin {
@ -280,7 +274,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
} }
// Create context with all the necessary values. // Create context with all the necessary values.
baseContext := buildContext(auth, scepAuthority, acmeDB) baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv = server.New(cfg.Address, handler, tlsConfig)
ca.srv.BaseContext = func(net.Listener) context.Context { ca.srv.BaseContext = func(net.Listener) context.Context {
@ -304,7 +298,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
} }
// buildContext builds the server base context. // buildContext builds the server base context.
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB) context.Context { func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context {
ctx := authority.NewContext(context.Background(), a) ctx := authority.NewContext(context.Background(), a)
if authDB := a.GetDatabase(); authDB != nil { if authDB := a.GetDatabase(); authDB != nil {
ctx = db.NewContext(ctx, authDB) ctx = db.NewContext(ctx, authDB)
@ -316,7 +310,7 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB
ctx = scep.NewContext(ctx, scepAuthority) ctx = scep.NewContext(ctx, scepAuthority)
} }
if acmeDB != nil { if acmeDB != nil {
ctx = acme.NewContext(ctx, acmeDB) ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil)
} }
return ctx return ctx
} }