forked from TrueCloudLab/certificates
Refactor ACME api.
This commit is contained in:
parent
fddd6f7d95
commit
d1f75f1720
13 changed files with 510 additions and 438 deletions
|
@ -69,6 +69,9 @@ func (u *UpdateAccountRequest) Validate() error {
|
|||
// NewAccount is the handler resource for creating new ACME accounts.
|
||||
func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -120,7 +123,6 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := acme.MustFromContext(ctx)
|
||||
acc = &acme.Account{
|
||||
Key: jwk,
|
||||
Contact: nar.Contact,
|
||||
|
@ -148,16 +150,18 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
|||
httpStatus = http.StatusOK
|
||||
}
|
||||
|
||||
o := optionsFromContext(ctx)
|
||||
o.linker.LinkAccount(ctx, acc)
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", o.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
|
||||
w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
|
||||
render.JSONStatus(w, acc, httpStatus)
|
||||
}
|
||||
|
||||
// GetOrUpdateAccount is the api for updating an ACME account.
|
||||
func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -189,7 +193,6 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
|||
acc.Contact = uar.Contact
|
||||
}
|
||||
|
||||
db := acme.MustFromContext(ctx)
|
||||
if err := db.UpdateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
|
||||
return
|
||||
|
@ -197,10 +200,9 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
o := optionsFromContext(ctx)
|
||||
o.linker.LinkAccount(ctx, acc)
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", o.linker.GetLink(ctx, AccountLinkType, acc.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
|
||||
render.JSON(w, acc)
|
||||
}
|
||||
|
||||
|
@ -216,6 +218,9 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
|||
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
|
||||
func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -227,15 +232,13 @@ func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := acme.MustFromContext(ctx)
|
||||
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
o := optionsFromContext(ctx)
|
||||
o.linker.LinkOrdersByAccountID(ctx, orders)
|
||||
linker.LinkOrdersByAccountID(ctx, orders)
|
||||
|
||||
render.JSON(w, orders)
|
||||
logOrdersByAccount(w, orders)
|
||||
|
|
|
@ -47,7 +47,7 @@ func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest)
|
|||
return nil, acmeErr
|
||||
}
|
||||
|
||||
db := acme.MustFromContext(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
|
||||
if err != nil {
|
||||
if _, ok := err.(*acme.Error); ok {
|
||||
|
@ -103,7 +103,6 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
|
|||
// o The "nonce" field MUST NOT be present
|
||||
// o The "url" field MUST be set to the same value as the outer JWS
|
||||
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
|
||||
|
||||
if jws == nil {
|
||||
return "", acme.NewErrorISE("no JWS provided")
|
||||
}
|
||||
|
|
|
@ -2,12 +2,10 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -70,144 +68,117 @@ type HandlerOptions struct {
|
|||
// PrerequisitesChecker checks if all prerequisites for serving ACME are
|
||||
// met by the CA configuration.
|
||||
PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
|
||||
linker Linker
|
||||
validateChallengeOptions *acme.ValidateChallengeOptions
|
||||
}
|
||||
|
||||
type optionsKey struct{}
|
||||
|
||||
func newOptionsContext(ctx context.Context, o *HandlerOptions) context.Context {
|
||||
return context.WithValue(ctx, optionsKey{}, o)
|
||||
}
|
||||
|
||||
func optionsFromContext(ctx context.Context) *HandlerOptions {
|
||||
o, ok := ctx.Value(optionsKey{}).(*HandlerOptions)
|
||||
if !ok {
|
||||
panic("acme options are not in the context")
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
|
||||
return authority.MustFromContext(ctx)
|
||||
}
|
||||
|
||||
// Handler is the ACME API request handler.
|
||||
type Handler struct {
|
||||
// handler is the ACME API request handler.
|
||||
type handler struct {
|
||||
opts *HandlerOptions
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
//
|
||||
// Deprecated: Use api.Route(r Router, opts *HandlerOptions)
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
Route(r, h.opts)
|
||||
func (h *handler) Route(r api.Router) {
|
||||
route(r, h.opts)
|
||||
}
|
||||
|
||||
// NewHandler returns a new ACME API handler.
|
||||
//
|
||||
// Deprecated: Use api.Route(r Router, opts *HandlerOptions)
|
||||
func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||
return &Handler{
|
||||
opts: &ops,
|
||||
func NewHandler(opts HandlerOptions) api.RouterHandler {
|
||||
return &handler{
|
||||
opts: &opts,
|
||||
}
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func Route(r api.Router, opts *HandlerOptions) {
|
||||
// by default all prerequisites are met
|
||||
if opts.PrerequisitesChecker == nil {
|
||||
opts.PrerequisitesChecker = func(ctx context.Context) (bool, error) {
|
||||
return true, nil
|
||||
// Route traffic and implement the Router interface. This method requires that
|
||||
// all the acme components, authority, db, client, linker, and prerequisite
|
||||
// checker to be present in the context.
|
||||
func Route(r api.Router) {
|
||||
route(r, nil)
|
||||
}
|
||||
|
||||
func route(r api.Router, opts *HandlerOptions) {
|
||||
var withContext func(next nextHTTP) nextHTTP
|
||||
|
||||
// For backward compatibility this block adds will add a new middleware that
|
||||
// will set the ACME components to the context.
|
||||
if opts != nil {
|
||||
client := acme.NewClient()
|
||||
linker := acme.NewLinker(opts.DNS, opts.Prefix)
|
||||
|
||||
withContext = func(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil {
|
||||
ctx = authority.NewContext(ctx, ca)
|
||||
}
|
||||
ctx = acme.NewContext(ctx, opts.DB, client, linker, opts.PrerequisitesChecker)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
withContext = func(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
commonMiddleware := func(next nextHTTP) nextHTTP {
|
||||
return withContext(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Linker middleware gets the provisioner and current url from the
|
||||
// request and sets them in the context.
|
||||
linker := acme.MustLinkerFromContext(r.Context())
|
||||
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
opts.linker = NewLinker(opts.DNS, opts.Prefix)
|
||||
opts.validateChallengeOptions = &acme.ValidateChallengeOptions{
|
||||
HTTPGet: client.Get,
|
||||
LookupTxt: net.LookupTXT,
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(dialer, network, addr, config)
|
||||
},
|
||||
}
|
||||
|
||||
withOptions := func(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// For backward compatibility with NewHandler.
|
||||
if ca, ok := opts.CA.(*authority.Authority); ok && ca != nil {
|
||||
ctx = authority.NewContext(ctx, ca)
|
||||
}
|
||||
if opts.DB != nil {
|
||||
ctx = acme.NewContext(ctx, opts.DB)
|
||||
}
|
||||
|
||||
ctx = newOptionsContext(ctx, opts)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
||||
return withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next)))))))))
|
||||
return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
|
||||
}
|
||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||
return withOptions(validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next))))
|
||||
return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
extractPayloadByKid := func(next nextHTTP) nextHTTP {
|
||||
return withOptions(validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next))))
|
||||
return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
|
||||
return withOptions(validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next))))
|
||||
return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
|
||||
}
|
||||
|
||||
getPath := opts.linker.GetUnescapedPathSuffix
|
||||
getPath := acme.GetUnescapedPathSuffix
|
||||
|
||||
// Standard ACME API
|
||||
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"),
|
||||
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce)))))))
|
||||
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"),
|
||||
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(addNonce(addDirLink(GetNonce)))))))
|
||||
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"),
|
||||
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory)))))
|
||||
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"),
|
||||
withOptions(baseURLFromRequest(lookupProvisioner(checkPrerequisites(GetDirectory)))))
|
||||
r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
|
||||
commonMiddleware(addNonce(addDirLink(GetNonce))))
|
||||
r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
|
||||
commonMiddleware(GetDirectory))
|
||||
|
||||
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"),
|
||||
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
|
||||
extractPayloadByJWK(NewAccount))
|
||||
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"),
|
||||
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(GetOrUpdateAccount))
|
||||
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"),
|
||||
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(NotImplemented))
|
||||
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"),
|
||||
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
|
||||
extractPayloadByKid(NewOrder))
|
||||
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"),
|
||||
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrder)))
|
||||
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
|
||||
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
|
||||
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"),
|
||||
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
|
||||
extractPayloadByKid(FinalizeOrder))
|
||||
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"),
|
||||
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetAuthorization)))
|
||||
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
|
||||
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
|
||||
extractPayloadByKid(GetChallenge))
|
||||
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"),
|
||||
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
|
||||
extractPayloadByKid(isPostAsGet(GetCertificate)))
|
||||
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"),
|
||||
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
|
||||
extractPayloadByKidOrJWK(RevokeCert))
|
||||
}
|
||||
|
||||
|
@ -251,20 +222,20 @@ func (d *Directory) ToLog() (interface{}, error) {
|
|||
// for client configuration.
|
||||
func GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
o := optionsFromContext(ctx)
|
||||
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
fmt.Println(acmeProv, err)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
render.JSON(w, &Directory{
|
||||
NewNonce: o.linker.GetLink(ctx, NewNonceLinkType),
|
||||
NewAccount: o.linker.GetLink(ctx, NewAccountLinkType),
|
||||
NewOrder: o.linker.GetLink(ctx, NewOrderLinkType),
|
||||
RevokeCert: o.linker.GetLink(ctx, RevokeCertLinkType),
|
||||
KeyChange: o.linker.GetLink(ctx, KeyChangeLinkType),
|
||||
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
|
||||
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
|
||||
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
|
||||
RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
|
||||
KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
|
||||
Meta: Meta{
|
||||
ExternalAccountRequired: acmeProv.RequireEAB,
|
||||
},
|
||||
|
@ -280,8 +251,8 @@ func NotImplemented(w http.ResponseWriter, r *http.Request) {
|
|||
// GetAuthorization ACME api for retrieving an Authz.
|
||||
func GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
o := optionsFromContext(ctx)
|
||||
db := acme.MustFromContext(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -303,17 +274,17 @@ func GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
o.linker.LinkAuthorization(ctx, az)
|
||||
linker.LinkAuthorization(ctx, az)
|
||||
|
||||
w.Header().Set("Location", o.linker.GetLink(ctx, AuthzLinkType, az.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
|
||||
render.JSON(w, az)
|
||||
}
|
||||
|
||||
// GetChallenge ACME api for retrieving a Challenge.
|
||||
func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
o := optionsFromContext(ctx)
|
||||
db := acme.MustFromContext(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -351,22 +322,22 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
|||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
if err = ch.Validate(ctx, db, jwk, o.validateChallengeOptions); err != nil {
|
||||
if err = ch.Validate(ctx, db, jwk); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
|
||||
return
|
||||
}
|
||||
|
||||
o.linker.LinkChallenge(ctx, ch, azID)
|
||||
linker.LinkChallenge(ctx, ch, azID)
|
||||
|
||||
w.Header().Add("Link", link(o.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", o.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
|
||||
render.JSON(w, ch)
|
||||
}
|
||||
|
||||
// GetCertificate ACME api for retrieving a Certificate.
|
||||
func GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustFromContext(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
|
|
|
@ -31,39 +31,10 @@ func logNonce(w http.ResponseWriter, nonce string) {
|
|||
}
|
||||
}
|
||||
|
||||
// getBaseURLFromRequest determines the base URL which should be used for
|
||||
// constructing link URLs in e.g. the ACME directory result by taking the
|
||||
// request Host into consideration.
|
||||
//
|
||||
// If the Request.Host is an empty string, we return an empty string, to
|
||||
// indicate that the configured URL values should be used instead. If this
|
||||
// function returns a non-empty result, then this should be used in constructing
|
||||
// ACME link URLs.
|
||||
func getBaseURLFromRequest(r *http.Request) *url.URL {
|
||||
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
|
||||
// for an implementation that allows HTTP requests using the x-forwarded-proto
|
||||
// header.
|
||||
|
||||
if r.Host == "" {
|
||||
return nil
|
||||
}
|
||||
return &url.URL{Scheme: "https", Host: r.Host}
|
||||
}
|
||||
|
||||
// baseURLFromRequest is a middleware that extracts and caches the baseURL
|
||||
// from the request.
|
||||
// E.g. https://ca.smallstep.com/
|
||||
func baseURLFromRequest(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), baseURLContextKey, getBaseURLFromRequest(r))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// addNonce is a middleware that adds a nonce to the response header.
|
||||
func addNonce(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
db := acme.MustFromContext(r.Context())
|
||||
db := acme.MustDatabaseFromContext(r.Context())
|
||||
nonce, err := db.CreateNonce(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -81,9 +52,9 @@ func addNonce(next nextHTTP) nextHTTP {
|
|||
func addDirLink(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
opts := optionsFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
w.Header().Add("Link", link(opts.linker.GetLink(ctx, DirectoryLinkType), "index"))
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
@ -92,17 +63,12 @@ func addDirLink(next nextHTTP) nextHTTP {
|
|||
// application/jose+json.
|
||||
func verifyContentType(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var expected []string
|
||||
ctx := r.Context()
|
||||
opts := optionsFromContext(ctx)
|
||||
|
||||
p, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
p := acme.MustProvisionerFromContext(r.Context())
|
||||
u := &url.URL{
|
||||
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
|
||||
}
|
||||
|
||||
u := url.URL{Path: opts.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
|
||||
var expected []string
|
||||
if strings.Contains(r.URL.String(), u.EscapedPath()) {
|
||||
// GET /certificate requests allow a greater range of content types.
|
||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||
|
@ -159,7 +125,7 @@ func parseJWS(next nextHTTP) nextHTTP {
|
|||
func validateJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustFromContext(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -247,7 +213,7 @@ func validateJWS(next nextHTTP) nextHTTP {
|
|||
func extractJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustFromContext(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -325,18 +291,20 @@ func lookupProvisioner(next nextHTTP) nextHTTP {
|
|||
func checkPrerequisites(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
opts := optionsFromContext(ctx)
|
||||
|
||||
ok, err := opts.PrerequisitesChecker(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
return
|
||||
// If the function is not set assume that all prerequisites are met.
|
||||
checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
|
||||
if ok {
|
||||
ok, err := checkFunc(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
next(w, r.WithContext(ctx))
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -346,8 +314,8 @@ func checkPrerequisites(next nextHTTP) nextHTTP {
|
|||
func lookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
opts := optionsFromContext(ctx)
|
||||
db := acme.MustFromContext(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
|
@ -355,7 +323,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
|||
return
|
||||
}
|
||||
|
||||
kidPrefix := opts.linker.GetLink(ctx, AccountLinkType, "")
|
||||
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
|
||||
kid := jws.Signatures[0].Protected.KeyID
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
|
@ -527,32 +495,14 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
|
|||
return val, nil
|
||||
}
|
||||
|
||||
// provisionerFromContext searches the context for a provisioner. Returns the
|
||||
// provisioner or an error.
|
||||
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
|
||||
val := ctx.Value(provisionerContextKey)
|
||||
if val == nil {
|
||||
return nil, acme.NewErrorISE("provisioner expected in request context")
|
||||
}
|
||||
pval, ok := val.(acme.Provisioner)
|
||||
if !ok || pval == nil {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
return pval, nil
|
||||
}
|
||||
|
||||
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
|
||||
// pointer to an ACME provisioner or an error.
|
||||
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
acmeProv, ok := prov.(*provisioner.ACME)
|
||||
if !ok || acmeProv == nil {
|
||||
p, ok := acme.MustProvisionerFromContext(ctx).(*provisioner.ACME)
|
||||
if !ok {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
return acmeProv, nil
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// payloadFromContext searches the context for a payload. Returns the payload
|
||||
|
|
|
@ -70,16 +70,15 @@ var defaultOrderBackdate = time.Minute
|
|||
// NewOrder ACME api for creating a new order.
|
||||
func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
prov := acme.MustProvisionerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -136,16 +135,14 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
|
||||
}
|
||||
|
||||
db := acme.MustFromContext(ctx)
|
||||
if err := db.CreateOrder(ctx, o); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
|
||||
return
|
||||
}
|
||||
|
||||
opts := optionsFromContext(ctx)
|
||||
opts.linker.LinkOrder(ctx, o)
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSONStatus(w, o, http.StatusCreated)
|
||||
}
|
||||
|
||||
|
@ -166,7 +163,7 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
|||
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
|
||||
}
|
||||
|
||||
db := acme.MustFromContext(ctx)
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
az.Challenges = make([]*acme.Challenge, len(chTypes))
|
||||
for i, typ := range chTypes {
|
||||
ch := &acme.Challenge{
|
||||
|
@ -190,18 +187,16 @@ func newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
|||
// GetOrder ACME api for retrieving an order.
|
||||
func GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
prov := acme.MustProvisionerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
db := acme.MustFromContext(ctx)
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
|
@ -222,26 +217,24 @@ func GetOrder(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
opts := optionsFromContext(ctx)
|
||||
opts.linker.LinkOrder(ctx, o)
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
}
|
||||
|
||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||
func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
prov := acme.MustProvisionerFromContext(ctx)
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
|
@ -258,7 +251,6 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := acme.MustFromContext(ctx)
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
|
@ -281,10 +273,9 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
opts := optionsFromContext(ctx)
|
||||
opts.linker.LinkOrder(ctx, o)
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", opts.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
}
|
||||
|
||||
|
|
|
@ -28,13 +28,11 @@ type revokePayload struct {
|
|||
// RevokeCert attempts to revoke a certificate.
|
||||
func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
db := acme.MustDatabaseFromContext(ctx)
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
prov := acme.MustProvisionerFromContext(ctx)
|
||||
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
|
@ -67,7 +65,6 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
db := acme.MustFromContext(ctx)
|
||||
serial := certToBeRevoked.SerialNumber.String()
|
||||
dbCert, err := db.GetCertificateBySerial(ctx, serial)
|
||||
if err != nil {
|
||||
|
@ -138,8 +135,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
logRevoke(w, options)
|
||||
o := optionsFromContext(ctx)
|
||||
w.Header().Add("Link", link(o.linker.GetLink(ctx, DirectoryLinkType), "index"))
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
|
||||
w.Write(nil)
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) {
|
|||
// type using the DB interface.
|
||||
// satisfactorily validated, the 'status' and 'validated' attributes are
|
||||
// updated.
|
||||
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error {
|
||||
// If already valid or invalid then return without performing validation.
|
||||
if ch.Status != StatusPending {
|
||||
return nil
|
||||
}
|
||||
switch ch.Type {
|
||||
case HTTP01:
|
||||
return http01Validate(ctx, ch, db, jwk, vo)
|
||||
return http01Validate(ctx, ch, db, jwk)
|
||||
case DNS01:
|
||||
return dns01Validate(ctx, ch, db, jwk, vo)
|
||||
return dns01Validate(ctx, ch, db, jwk)
|
||||
case TLSALPN01:
|
||||
return tlsalpn01Validate(ctx, ch, db, jwk, vo)
|
||||
return tlsalpn01Validate(ctx, ch, db, jwk)
|
||||
default:
|
||||
return NewErrorISE("unexpected challenge type '%s'", ch.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||
|
||||
resp, err := vo.HTTPGet(u.String())
|
||||
vc := MustClientFromContext(ctx)
|
||||
resp, err := vc.Get(u.String())
|
||||
if err != nil {
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||
"error doing http GET for url %s", u))
|
||||
|
@ -141,7 +141,7 @@ func tlsAlert(err error) uint8 {
|
|||
return 0
|
||||
}
|
||||
|
||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
config := &tls.Config{
|
||||
NextProtos: []string{"acme-tls/1"},
|
||||
// https://tools.ietf.org/html/rfc8737#section-4
|
||||
|
@ -154,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
|
||||
hostPort := net.JoinHostPort(ch.Value, "443")
|
||||
|
||||
conn, err := vo.TLSDial("tcp", hostPort, config)
|
||||
vc := MustClientFromContext(ctx)
|
||||
conn, err := vc.TLSDial("tcp", hostPort, config)
|
||||
if err != nil {
|
||||
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
|
||||
// client and server protocols. When this happens the connection is
|
||||
|
@ -253,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
|||
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
|
||||
}
|
||||
|
||||
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
|
||||
// Normalize domain for wildcard DNS names
|
||||
// This is done to avoid making TXT lookups for domains like
|
||||
// _acme-challenge.*.example.com
|
||||
// Instead perform txt lookup for _acme-challenge.example.com
|
||||
domain := strings.TrimPrefix(ch.Value, "*.")
|
||||
|
||||
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
|
||||
vc := MustClientFromContext(ctx)
|
||||
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
|
||||
if err != nil {
|
||||
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
|
||||
"error looking up TXT records for domain %s", domain))
|
||||
|
@ -376,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type httpGetter func(string) (*http.Response, error)
|
||||
type lookupTxt func(string) ([]string, error)
|
||||
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
|
||||
// ValidateChallengeOptions are ACME challenge validator functions.
|
||||
type ValidateChallengeOptions struct {
|
||||
HTTPGet httpGetter
|
||||
LookupTxt lookupTxt
|
||||
TLSDial tlsDialer
|
||||
}
|
||||
|
|
79
acme/client.go
Normal file
79
acme/client.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client is the interface used to verify ACME challenges.
|
||||
type Client interface {
|
||||
// Get issues an HTTP GET to the specified URL.
|
||||
Get(url string) (*http.Response, error)
|
||||
|
||||
// LookupTXT returns the DNS TXT records for the given domain name.
|
||||
LookupTxt(name string) ([]string, error)
|
||||
|
||||
// TLSDial connects to the given network address using net.Dialer and then
|
||||
// initiates a TLS handshake, returning the resulting TLS connection.
|
||||
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
type clientKey struct{}
|
||||
|
||||
// NewClientContext adds the given client to the context.
|
||||
func NewClientContext(ctx context.Context, c Client) context.Context {
|
||||
return context.WithValue(ctx, clientKey{}, c)
|
||||
}
|
||||
|
||||
// ClientFromContext returns the current client from the given context.
|
||||
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
|
||||
c, ok = ctx.Value(clientKey{}).(Client)
|
||||
return
|
||||
}
|
||||
|
||||
// MustClientFromContext returns the current client from the given context. It will
|
||||
// return a new instance of the client if it does not exist.
|
||||
func MustClientFromContext(ctx context.Context) Client {
|
||||
if c, ok := ClientFromContext(ctx); !ok {
|
||||
return NewClient()
|
||||
} else {
|
||||
return c
|
||||
}
|
||||
}
|
||||
|
||||
type client struct {
|
||||
http *http.Client
|
||||
dialer *net.Dialer
|
||||
}
|
||||
|
||||
// NewClient returns an implementation of Client for verifying ACME challenges.
|
||||
func NewClient() Client {
|
||||
return &client{
|
||||
http: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
dialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Get(url string) (*http.Response, error) {
|
||||
return c.http.Get(url)
|
||||
}
|
||||
|
||||
func (c *client) LookupTxt(name string) ([]string, error) {
|
||||
return net.LookupTXT(name)
|
||||
}
|
||||
|
||||
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(c.dialer, network, addr, config)
|
||||
}
|
|
@ -9,14 +9,6 @@ import (
|
|||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
IsRevoked(sn string) (bool, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// Clock that returns time in UTC rounded to seconds.
|
||||
type Clock struct{}
|
||||
|
||||
|
@ -27,6 +19,51 @@ func (c *Clock) Now() time.Time {
|
|||
|
||||
var clock Clock
|
||||
|
||||
// CertificateAuthority is the interface implemented by a CA authority.
|
||||
type CertificateAuthority interface {
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
IsRevoked(sn string) (bool, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
LoadProvisionerByName(string) (provisioner.Interface, error)
|
||||
}
|
||||
|
||||
// NewContext adds the given acme components to the context.
|
||||
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
|
||||
ctx = NewDatabaseContext(ctx, db)
|
||||
ctx = NewClientContext(ctx, client)
|
||||
ctx = NewLinkerContext(ctx, linker)
|
||||
// Prerequisite checker is optional.
|
||||
if fn != nil {
|
||||
ctx = NewPrerequisitesCheckerContext(ctx, fn)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// PrerequisitesChecker is a function that checks if all prerequisites for
|
||||
// serving ACME are met by the CA configuration.
|
||||
type PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||
|
||||
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
|
||||
// always true.
|
||||
func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type prerequisitesKey struct{}
|
||||
|
||||
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
|
||||
// context.
|
||||
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
|
||||
return context.WithValue(ctx, prerequisitesKey{}, fn)
|
||||
}
|
||||
|
||||
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
|
||||
// context.
|
||||
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
|
||||
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
|
||||
return fn, ok && fn != nil
|
||||
}
|
||||
|
||||
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
||||
// only those methods required by the ACME api/authority.
|
||||
type Provisioner interface {
|
||||
|
@ -38,6 +75,29 @@ type Provisioner interface {
|
|||
GetOptions() *provisioner.Options
|
||||
}
|
||||
|
||||
type provisionerKey struct{}
|
||||
|
||||
// NewProvisionerContext adds the given provisioner to the context.
|
||||
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
|
||||
return context.WithValue(ctx, provisionerKey{}, v)
|
||||
}
|
||||
|
||||
// ProvisionerFromContext returns the current provisioner from the given context.
|
||||
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
|
||||
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current provisioner from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustProvisionerFromContext(ctx context.Context) Provisioner {
|
||||
if v, ok := ProvisionerFromContext(ctx); !ok {
|
||||
panic("acme provisioner is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// MockProvisioner for testing
|
||||
type MockProvisioner struct {
|
||||
Mret1 interface{}
|
||||
|
|
16
acme/db.go
16
acme/db.go
|
@ -50,21 +50,21 @@ type DB interface {
|
|||
|
||||
type dbKey struct{}
|
||||
|
||||
// NewContext adds the given acme database to the context.
|
||||
func NewContext(ctx context.Context, db DB) context.Context {
|
||||
// NewDatabaseContext adds the given acme database to the context.
|
||||
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
|
||||
return context.WithValue(ctx, dbKey{}, db)
|
||||
}
|
||||
|
||||
// FromContext returns the current acme database from the given context.
|
||||
func FromContext(ctx context.Context) (db DB, ok bool) {
|
||||
// DatabaseFromContext returns the current acme database from the given context.
|
||||
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
|
||||
db, ok = ctx.Value(dbKey{}).(DB)
|
||||
return
|
||||
}
|
||||
|
||||
// MustFromContext returns the current database from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustFromContext(ctx context.Context) DB {
|
||||
if db, ok := FromContext(ctx); !ok {
|
||||
// MustDatabaseFromContext returns the current database from the given context.
|
||||
// It will panic if it's not in the context.
|
||||
func MustDatabaseFromContext(ctx context.Context) DB {
|
||||
if db, ok := DatabaseFromContext(ctx); !ok {
|
||||
panic("acme database is not in the context")
|
||||
} else {
|
||||
return db
|
||||
|
|
265
acme/linker.go
265
acme/linker.go
|
@ -4,120 +4,16 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
_, _, err := net.SplitHostPort(dns)
|
||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||
// these cases, then the input dns is not changed.
|
||||
lastIndex := strings.LastIndex(dns, ":")
|
||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||
if ip := net.ParseIP(hostPart); ip != nil {
|
||||
dns = "[" + hostPart + "]:" + portPart
|
||||
} else if ip := net.ParseIP(dns); ip != nil {
|
||||
dns = "[" + dns + "]"
|
||||
}
|
||||
}
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
|
||||
|
||||
LinkOrder(ctx context.Context, o *acme.Order)
|
||||
LinkAccount(ctx context.Context, o *acme.Account)
|
||||
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *acme.Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
type linkerKey struct{}
|
||||
|
||||
// NewLinkerContext adds the given linker to the context.
|
||||
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
|
||||
return context.WithValue(ctx, linkerKey{}, v)
|
||||
}
|
||||
|
||||
// LinkerFromContext returns the current linker from the given context.
|
||||
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
|
||||
v, ok = ctx.Value(linkerKey{}).(Linker)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current linker from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustLinkerFromContext(ctx context.Context) Linker {
|
||||
if v, ok := LinkerFromContext(ctx); !ok {
|
||||
panic("acme linker is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var (
|
||||
provName string
|
||||
baseURL = baseURLFromContext(ctx)
|
||||
u = url.URL{}
|
||||
)
|
||||
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
|
||||
provName = p.GetName()
|
||||
}
|
||||
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
|
||||
if baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
|
||||
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
|
||||
|
||||
// If no Scheme is set, then default to https.
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
|
||||
// If no Host is set, then use the default (first DNS attr in the ca.json).
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
u.Path = l.prefix + u.Path
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// LinkType captures the link type.
|
||||
type LinkType int
|
||||
|
||||
|
@ -183,8 +79,151 @@ func (l LinkType) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
|
||||
switch typ {
|
||||
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
|
||||
return fmt.Sprintf("/%s/%s", provisionerName, typ)
|
||||
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
|
||||
case ChallengeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
|
||||
case OrdersByAccountLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
|
||||
case FinalizeLinkType:
|
||||
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// NewLinker returns a new Directory type.
|
||||
func NewLinker(dns, prefix string) Linker {
|
||||
_, _, err := net.SplitHostPort(dns)
|
||||
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
|
||||
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
|
||||
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
|
||||
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
|
||||
// these cases, then the input dns is not changed.
|
||||
lastIndex := strings.LastIndex(dns, ":")
|
||||
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
|
||||
if ip := net.ParseIP(hostPart); ip != nil {
|
||||
dns = "[" + hostPart + "]:" + portPart
|
||||
} else if ip := net.ParseIP(dns); ip != nil {
|
||||
dns = "[" + dns + "]"
|
||||
}
|
||||
}
|
||||
return &linker{prefix: prefix, dns: dns}
|
||||
}
|
||||
|
||||
// Linker interface for generating links for ACME resources.
|
||||
type Linker interface {
|
||||
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
|
||||
Middleware(http.Handler) http.Handler
|
||||
LinkOrder(ctx context.Context, o *Order)
|
||||
LinkAccount(ctx context.Context, o *Account)
|
||||
LinkChallenge(ctx context.Context, o *Challenge, azID string)
|
||||
LinkAuthorization(ctx context.Context, o *Authorization)
|
||||
LinkOrdersByAccountID(ctx context.Context, orders []string)
|
||||
}
|
||||
|
||||
type linkerKey struct{}
|
||||
|
||||
// NewLinkerContext adds the given linker to the context.
|
||||
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
|
||||
return context.WithValue(ctx, linkerKey{}, v)
|
||||
}
|
||||
|
||||
// LinkerFromContext returns the current linker from the given context.
|
||||
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
|
||||
v, ok = ctx.Value(linkerKey{}).(Linker)
|
||||
return
|
||||
}
|
||||
|
||||
// MustLinkerFromContext returns the current linker from the given context. It
|
||||
// will panic if it's not in the context.
|
||||
func MustLinkerFromContext(ctx context.Context) Linker {
|
||||
if v, ok := LinkerFromContext(ctx); !ok {
|
||||
panic("acme linker is not the context")
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
type baseURLKey struct{}
|
||||
|
||||
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
|
||||
var u *url.URL
|
||||
if r.Host != "" {
|
||||
u = &url.URL{Scheme: "https", Host: r.Host}
|
||||
}
|
||||
return context.WithValue(ctx, baseURLKey{}, u)
|
||||
}
|
||||
|
||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
|
||||
return u
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// linker generates ACME links.
|
||||
type linker struct {
|
||||
prefix string
|
||||
dns string
|
||||
}
|
||||
|
||||
// Middleware gets the provisioner and current url from the request and sets
|
||||
// them in the context so we can use the linker to create ACME links.
|
||||
func (l *linker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add base url to the context.
|
||||
ctx := newBaseURLContext(r.Context(), r)
|
||||
|
||||
// Add provisioner to the context.
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
|
||||
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// GetLink is a helper for GetLinkExplicit.
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var u url.URL
|
||||
if baseURL := baseURLFromContext(ctx); baseURL != nil {
|
||||
u = *baseURL
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
if u.Host == "" {
|
||||
u.Host = l.dns
|
||||
}
|
||||
|
||||
p := MustProvisionerFromContext(ctx)
|
||||
u.Path = l.prefix + GetUnescapedPathSuffix(typ, p.GetName(), inputs...)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// LinkOrder sets the ACME links required by an ACME order.
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
||||
func (l *linker) LinkOrder(ctx context.Context, o *Order) {
|
||||
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
|
||||
for i, azID := range o.AuthorizationIDs {
|
||||
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
|
||||
|
@ -196,17 +235,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
|
|||
}
|
||||
|
||||
// LinkAccount sets the ACME links required by an ACME account.
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
|
||||
func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
|
||||
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
|
||||
}
|
||||
|
||||
// LinkChallenge sets the ACME links required by an ACME challenge.
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
|
||||
func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
|
||||
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
|
||||
}
|
||||
|
||||
// LinkAuthorization sets the ACME links required by an ACME authorization.
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
|
||||
func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
|
||||
for _, ch := range az.Challenges {
|
||||
l.LinkChallenge(ctx, ch, az.ID)
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
)
|
||||
|
||||
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
||||
|
@ -173,27 +172,27 @@ func TestLinker_LinkOrder(t *testing.T) {
|
|||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
o *acme.Order
|
||||
validate func(o *acme.Order)
|
||||
o *Order
|
||||
validate func(o *Order)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"no-authz-and-no-cert": {
|
||||
o: &acme.Order{
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{})
|
||||
assert.Equals(t, o.CertificateURL, "")
|
||||
},
|
||||
},
|
||||
"one-authz-and-cert": {
|
||||
o: &acme.Order{
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo"},
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
|
@ -202,12 +201,12 @@ func TestLinker_LinkOrder(t *testing.T) {
|
|||
},
|
||||
},
|
||||
"many-authz": {
|
||||
o: &acme.Order{
|
||||
o: &Order{
|
||||
ID: oid,
|
||||
CertificateID: certID,
|
||||
AuthorizationIDs: []string{"foo", "bar", "zap"},
|
||||
},
|
||||
validate: func(o *acme.Order) {
|
||||
validate: func(o *Order) {
|
||||
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
|
||||
assert.Equals(t, o.AuthorizationURLs, []string{
|
||||
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
|
||||
|
@ -237,15 +236,15 @@ func TestLinker_LinkAccount(t *testing.T) {
|
|||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
a *acme.Account
|
||||
validate func(o *acme.Account)
|
||||
a *Account
|
||||
validate func(o *Account)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
a: &acme.Account{
|
||||
a: &Account{
|
||||
ID: accID,
|
||||
},
|
||||
validate: func(a *acme.Account) {
|
||||
validate: func(a *Account) {
|
||||
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
|
||||
},
|
||||
},
|
||||
|
@ -270,15 +269,15 @@ func TestLinker_LinkChallenge(t *testing.T) {
|
|||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
ch *acme.Challenge
|
||||
validate func(o *acme.Challenge)
|
||||
ch *Challenge
|
||||
validate func(o *Challenge)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
ch: &acme.Challenge{
|
||||
ch: &Challenge{
|
||||
ID: chID,
|
||||
},
|
||||
validate: func(ch *acme.Challenge) {
|
||||
validate: func(ch *Challenge) {
|
||||
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
|
||||
},
|
||||
},
|
||||
|
@ -305,20 +304,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
|||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
type test struct {
|
||||
az *acme.Authorization
|
||||
validate func(o *acme.Authorization)
|
||||
az *Authorization
|
||||
validate func(o *Authorization)
|
||||
}
|
||||
var tests = map[string]test{
|
||||
"ok": {
|
||||
az: &acme.Authorization{
|
||||
az: &Authorization{
|
||||
ID: azID,
|
||||
Challenges: []*acme.Challenge{
|
||||
Challenges: []*Challenge{
|
||||
{ID: chID0},
|
||||
{ID: chID1},
|
||||
{ID: chID2},
|
||||
},
|
||||
},
|
||||
validate: func(az *acme.Authorization) {
|
||||
validate: func(az *Authorization) {
|
||||
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
|
||||
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
|
||||
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
|
||||
|
|
36
ca/ca.go
36
ca/ca.go
|
@ -189,30 +189,24 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
dns = fmt.Sprintf("%s:%s", dns, port)
|
||||
}
|
||||
|
||||
// ACME Router
|
||||
prefix := "acme"
|
||||
// ACME Router is only available if we have a database.
|
||||
var acmeDB acme.DB
|
||||
if cfg.DB == nil {
|
||||
acmeDB = nil
|
||||
} else {
|
||||
var acmeLinker acme.Linker
|
||||
if cfg.DB != nil {
|
||||
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error configuring ACME DB interface")
|
||||
}
|
||||
acmeLinker = acme.NewLinker(dns, "acme")
|
||||
mux.Route("/acme", func(r chi.Router) {
|
||||
acmeAPI.Route(r)
|
||||
})
|
||||
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
|
||||
// of the ACME spec.
|
||||
mux.Route("/2.0/acme", func(r chi.Router) {
|
||||
acmeAPI.Route(r)
|
||||
})
|
||||
}
|
||||
acmeOptions := &acmeAPI.HandlerOptions{
|
||||
Backdate: *cfg.AuthorityConfig.Backdate,
|
||||
DNS: dns,
|
||||
Prefix: prefix,
|
||||
}
|
||||
mux.Route("/"+prefix, func(r chi.Router) {
|
||||
acmeAPI.Route(r, acmeOptions)
|
||||
})
|
||||
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
|
||||
// of the ACME spec.
|
||||
mux.Route("/2.0/"+prefix, func(r chi.Router) {
|
||||
acmeAPI.Route(r, acmeOptions)
|
||||
})
|
||||
|
||||
// Admin API Router
|
||||
if cfg.AuthorityConfig.EnableAdmin {
|
||||
|
@ -280,7 +274,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
}
|
||||
|
||||
// Create context with all the necessary values.
|
||||
baseContext := buildContext(auth, scepAuthority, acmeDB)
|
||||
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
|
||||
|
||||
ca.srv = server.New(cfg.Address, handler, tlsConfig)
|
||||
ca.srv.BaseContext = func(net.Listener) context.Context {
|
||||
|
@ -304,7 +298,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
}
|
||||
|
||||
// buildContext builds the server base context.
|
||||
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB) context.Context {
|
||||
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context {
|
||||
ctx := authority.NewContext(context.Background(), a)
|
||||
if authDB := a.GetDatabase(); authDB != nil {
|
||||
ctx = db.NewContext(ctx, authDB)
|
||||
|
@ -316,7 +310,7 @@ func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB
|
|||
ctx = scep.NewContext(ctx, scepAuthority)
|
||||
}
|
||||
if acmeDB != nil {
|
||||
ctx = acme.NewContext(ctx, acmeDB)
|
||||
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue