[acme db interface] wip
This commit is contained in:
parent
491c188a5e
commit
80a6640103
21 changed files with 787 additions and 975 deletions
|
@ -9,7 +9,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
|
@ -54,7 +53,7 @@ func baseURLFromRequest(r *http.Request) *url.URL {
|
|||
// E.g. https://ca.smallstep.com/
|
||||
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.WithValue(r.Context(), acme.BaseURLContextKey, baseURLFromRequest(r))
|
||||
ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
@ -62,14 +61,14 @@ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
|
|||
// addNonce is a middleware that adds a nonce to the response header.
|
||||
func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
nonce, err := h.Auth.NewNonce()
|
||||
nonce, err := h.db.CreateNonce(r.Context())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Replay-Nonce", nonce)
|
||||
w.Header().Set("Replay-Nonce", string(nonce))
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
logNonce(w, nonce)
|
||||
logNonce(w, string(nonce))
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
@ -78,8 +77,8 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
|||
// directory index url.
|
||||
func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Link", link(h.Auth.GetLink(r.Context(),
|
||||
acme.DirectoryLink, true), "index"))
|
||||
w.Header().Add("Link", link(h.linker.GetLink(r.Context(),
|
||||
DirectoryLinkType, true), "index"))
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
@ -90,7 +89,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ct := r.Header.Get("Content-Type")
|
||||
var expected []string
|
||||
if strings.Contains(r.URL.Path, h.Auth.GetLink(r.Context(), acme.CertificateLink, false, "")) {
|
||||
if strings.Contains(r.URL.Path, h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) {
|
||||
// GET /certificate requests allow a greater range of content types.
|
||||
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
|
||||
} else {
|
||||
|
@ -103,8 +102,8 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
|||
return
|
||||
}
|
||||
}
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf(
|
||||
"expected content-type to be in %s, but got %s", expected, ct)))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"expected content-type to be in %s, but got %s", expected, ct))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -113,15 +112,15 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body")))
|
||||
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
|
||||
return
|
||||
}
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body")))
|
||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), acme.JwsContextKey, jws)
|
||||
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
@ -143,17 +142,18 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
|||
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
|
||||
func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
jws, err := acme.JwsFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(r.Context())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) > 1 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -164,7 +164,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
|||
len(uh.Algorithm) > 0 ||
|
||||
len(uh.Nonce) > 0 ||
|
||||
len(uh.ExtraHeaders) > 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
|
||||
return
|
||||
}
|
||||
hdr := sig.Protected
|
||||
|
@ -174,25 +174,26 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
|||
switch k := hdr.JSONWebKey.Key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if k.Size() < keyutil.MinRSAKeyBytes {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+
|
||||
"keys must be at least %d bits (%d bytes) in size",
|
||||
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"rsa keys must be at least %d bits (%d bytes) in size",
|
||||
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
|
||||
return
|
||||
}
|
||||
default:
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"jws key type and algorithm do not match"))
|
||||
return
|
||||
}
|
||||
}
|
||||
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
||||
// we good
|
||||
default:
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm)))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: %s", hdr.Algorithm))
|
||||
return
|
||||
}
|
||||
|
||||
// Check the validity/freshness of the Nonce.
|
||||
if err := h.Auth.UseNonce(hdr.Nonce); err != nil {
|
||||
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
@ -200,21 +201,22 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
|||
// Check that the JWS url matches the requested url.
|
||||
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
||||
if !ok {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
|
||||
return
|
||||
}
|
||||
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
||||
if jwsURL != reqURL.String() {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
||||
return
|
||||
}
|
||||
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
|
@ -227,22 +229,27 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
|||
func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := acme.JwsFromContext(r.Context())
|
||||
jws, err := jwsFromContext(r.Context())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jwk := jws.Signatures[0].Protected.JSONWebKey
|
||||
if jwk == nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
|
||||
return
|
||||
}
|
||||
if !jwk.Valid() {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
|
||||
acc, err := h.Auth.GetAccountByKey(ctx, jwk)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||
kid, err := acme.KeyToID(jwk)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
|
||||
return
|
||||
}
|
||||
acc, err := h.db.GetAccountByKeyID(ctx, kid)
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
// For NewAccount requests ...
|
||||
|
@ -252,10 +259,10 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
|||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
}
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
|
@ -270,20 +277,20 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
|||
name := chi.URLParam(r, "provisionerID")
|
||||
provID, err := url.PathUnescape(name)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name)))
|
||||
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner id '%s'", name))
|
||||
return
|
||||
}
|
||||
p, err := h.Auth.LoadProvisionerByID("acme/" + provID)
|
||||
p, err := h.ca.LoadProvisionerByID("acme/" + provID)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, acme.ProvisionerContextKey, acme.Provisioner(acmeProv))
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
@ -294,36 +301,37 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
|||
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
jws, err := acme.JwsFromContext(ctx)
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
kidPrefix := h.Auth.GetLink(ctx, acme.AccountLink, true, "")
|
||||
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, true, "")
|
||||
kid := jws.Signatures[0].Protected.KeyID
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+
|
||||
"required prefix; expected %s, but got %s", kidPrefix, kid)))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
||||
"kid does not have required prefix; expected %s, but got %s",
|
||||
kidPrefix, kid))
|
||||
return
|
||||
}
|
||||
|
||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
||||
acc, err := h.Auth.GetAccount(r.Context(), accID)
|
||||
acc, err := h.db.GetAccount(ctx, accID)
|
||||
switch {
|
||||
case nosql.IsErrNotFound(err):
|
||||
api.WriteError(w, acme.AccountDoesNotExistErr(nil))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||
return
|
||||
case err != nil:
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, acme.AccContextKey, acc)
|
||||
ctx = context.WithValue(ctx, acme.JwkContextKey, acc.Key)
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
@ -334,26 +342,27 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
|||
// Make sure to parse and validate the JWS before running this middleware.
|
||||
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
jws, err := acme.JwsFromContext(r.Context())
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
jwk, err := acme.JwkFromContext(r.Context())
|
||||
jwk, err := jwkFromContext(ctx)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||
api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
||||
return
|
||||
}
|
||||
payload, err := jws.Verify(jwk)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws")))
|
||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), acme.PayloadContextKey, &payloadInfo{
|
||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
|
||||
value: payload,
|
||||
isPostAsGet: string(payload) == "",
|
||||
isEmptyJSON: string(payload) == "{}",
|
||||
|
@ -371,9 +380,89 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
|||
return
|
||||
}
|
||||
if !payload.isPostAsGet {
|
||||
api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET")))
|
||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// ContextKey is the key type for storing and searching for ACME request
|
||||
// essentials in the context of a request.
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// accContextKey account key
|
||||
accContextKey = ContextKey("acc")
|
||||
// baseURLContextKey baseURL key
|
||||
baseURLContextKey = ContextKey("baseURL")
|
||||
// jwsContextKey jws key
|
||||
jwsContextKey = ContextKey("jws")
|
||||
// jwkContextKey jwk key
|
||||
jwkContextKey = ContextKey("jwk")
|
||||
// payloadContextKey payload key
|
||||
payloadContextKey = ContextKey("payload")
|
||||
// provisionerContextKey provisioner key
|
||||
provisionerContextKey = ContextKey("provisioner")
|
||||
)
|
||||
|
||||
// accountFromContext searches the context for an ACME account. Returns the
|
||||
// account or an error.
|
||||
func accountFromContext(ctx context.Context) (*acme.Account, error) {
|
||||
val, ok := ctx.Value(accContextKey).(*acme.Account)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.NewErrorISE("account not in context")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// baseURLFromContext returns the baseURL if one is stored in the context.
|
||||
func baseURLFromContext(ctx context.Context) *url.URL {
|
||||
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
|
||||
if !ok || val == nil {
|
||||
return nil
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
|
||||
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
|
||||
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.NewErrorISE("jwk expected in request context")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// jwsFromContext searches the context for a JWS. Returns the JWS or an error.
|
||||
func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
|
||||
val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.NewErrorISE("jws expected in request context")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// provisionerFromContext searches the context for a provisioner. Returns the
|
||||
// provisioner or an error.
|
||||
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
|
||||
val := ctx.Value(provisionerContextKey)
|
||||
if val == nil {
|
||||
return nil, acme.NewErrorISE("provisioner expected in request context")
|
||||
}
|
||||
pval, ok := val.(acme.Provisioner)
|
||||
if !ok || pval == nil {
|
||||
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
|
||||
}
|
||||
return pval, nil
|
||||
}
|
||||
|
||||
// payloadFromContext searches the context for a payload. Returns the payload
|
||||
// or an error.
|
||||
func payloadFromContext(ctx context.Context) (*payloadInfo, error) {
|
||||
val, ok := ctx.Value(payloadContextKey).(*payloadInfo)
|
||||
if !ok || val == nil {
|
||||
return nil, acme.NewErrorISE("payload expected in request context")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue