forked from TrueCloudLab/certificates
Merge pull request #836 from smallstep/herman/acme-eab
Add ACME configuration prerequisites check
This commit is contained in:
commit
ea454f9dfc
3 changed files with 125 additions and 7 deletions
|
@ -1,6 +1,7 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -43,6 +44,7 @@ type Handler struct {
|
||||||
ca acme.CertificateAuthority
|
ca acme.CertificateAuthority
|
||||||
linker Linker
|
linker Linker
|
||||||
validateChallengeOptions *acme.ValidateChallengeOptions
|
validateChallengeOptions *acme.ValidateChallengeOptions
|
||||||
|
prerequisitesChecker func(ctx context.Context) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlerOptions required to create a new ACME API request handler.
|
// HandlerOptions required to create a new ACME API request handler.
|
||||||
|
@ -60,6 +62,9 @@ type HandlerOptions struct {
|
||||||
// "acme" is the prefix from which the ACME api is accessed.
|
// "acme" is the prefix from which the ACME api is accessed.
|
||||||
Prefix string
|
Prefix string
|
||||||
CA acme.CertificateAuthority
|
CA acme.CertificateAuthority
|
||||||
|
// PrerequisitesChecker checks if all prerequisites for serving ACME are
|
||||||
|
// met by the CA configuration.
|
||||||
|
PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler returns a new ACME API handler.
|
// NewHandler returns a new ACME API handler.
|
||||||
|
@ -76,6 +81,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
prerequisitesChecker := func(ctx context.Context) (bool, error) {
|
||||||
|
// by default all prerequisites are met
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if ops.PrerequisitesChecker != nil {
|
||||||
|
prerequisitesChecker = ops.PrerequisitesChecker
|
||||||
|
}
|
||||||
return &Handler{
|
return &Handler{
|
||||||
ca: ops.CA,
|
ca: ops.CA,
|
||||||
db: ops.DB,
|
db: ops.DB,
|
||||||
|
@ -88,6 +100,7 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||||
return tls.DialWithDialer(dialer, network, addr, config)
|
return tls.DialWithDialer(dialer, network, addr, config)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
prerequisitesChecker: prerequisitesChecker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,13 +108,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||||
func (h *Handler) Route(r api.Router) {
|
func (h *Handler) Route(r api.Router) {
|
||||||
getPath := h.linker.GetUnescapedPathSuffix
|
getPath := h.linker.GetUnescapedPathSuffix
|
||||||
// Standard ACME API
|
// Standard ACME API
|
||||||
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
|
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||||
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
|
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||||
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
|
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||||
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
|
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||||
|
|
||||||
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
||||||
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))
|
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))))
|
||||||
}
|
}
|
||||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||||
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
|
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
|
||||||
|
|
|
@ -283,11 +283,10 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupProvisioner loads the provisioner associated with the request.
|
// lookupProvisioner loads the provisioner associated with the request.
|
||||||
// Responsds 404 if the provisioner does not exist.
|
// Responds 404 if the provisioner does not exist.
|
||||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||||
name, err := url.PathUnescape(nameEscaped)
|
name, err := url.PathUnescape(nameEscaped)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -309,6 +308,24 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkPrerequisites checks if all prerequisites for serving ACME
|
||||||
|
// are met by the CA configuration.
|
||||||
|
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
ok, err := h.prerequisitesChecker(ctx)
|
||||||
|
if err != nil {
|
||||||
|
api.WriteError(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// lookupJWK loads the JWK associated with the acme account referenced by the
|
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||||
// kid parameter of the signed payload.
|
// kid parameter of the signed payload.
|
||||||
// Make sure to parse and validate the JWS before running this middleware.
|
// Make sure to parse and validate the JWS before running this middleware.
|
||||||
|
|
|
@ -1656,3 +1656,91 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandler_checkPrerequisites(t *testing.T) {
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||||
|
baseURL, provName)
|
||||||
|
type test struct {
|
||||||
|
linker Linker
|
||||||
|
ctx context.Context
|
||||||
|
prerequisitesChecker func(context.Context) (bool, error)
|
||||||
|
next func(http.ResponseWriter, *http.Request)
|
||||||
|
err *acme.Error
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/error": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("dns", "acme"),
|
||||||
|
ctx: ctx,
|
||||||
|
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
err: acme.WrapErrorISE(errors.New("force"), "error checking acme provisioner prerequisites"),
|
||||||
|
statusCode: 500,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/prerequisites-nok": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("dns", "acme"),
|
||||||
|
ctx: ctx,
|
||||||
|
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
err: acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"),
|
||||||
|
statusCode: 501,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("dns", "acme"),
|
||||||
|
ctx: ctx,
|
||||||
|
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
|
||||||
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.checkPrerequisites(tc.next)(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
|
var ae acme.Error
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), testBody)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue