acme/api: refactored to support api/render.Error

This commit is contained in:
Panagiotis Siatras 2022-03-18 17:35:23 +02:00
parent 9b6c1f608e
commit 6636e87fc7
No known key found for this signature in database
GPG key ID: 529695F03A572804
5 changed files with 120 additions and 120 deletions

View file

@ -7,7 +7,6 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
@ -72,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var nar NewAccountRequest var nar NewAccountRequest
if err := json.Unmarshal(payload.value, &nar); err != nil { if err := json.Unmarshal(payload.value, &nar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
if err := nar.Validate(); err != nil { if err := nar.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := acmeProvisionerFromContext(ctx) prov, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -98,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
acmeErr, ok := err.(*acme.Error) acmeErr, ok := err.(*acme.Error)
if !ok || acmeErr.Status != http.StatusBadRequest { if !ok || acmeErr.Status != http.StatusBadRequest {
// Something went wrong ... // Something went wrong ...
api.WriteError(w, err) render.Error(w, err)
return return
} }
// Account does not exist // // Account does not exist //
if nar.OnlyReturnExisting { if nar.OnlyReturnExisting {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
"account does not exist")) "account does not exist"))
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
eak, err := h.validateExternalAccountBinding(ctx, &nar) eak, err := h.validateExternalAccountBinding(ctx, &nar)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -127,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusValid, Status: acme.StatusValid,
} }
if err := h.db.CreateAccount(ctx, acc); err != nil { if err := h.db.CreateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) render.Error(w, acme.WrapErrorISE(err, "error creating account"))
return return
} }
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
err := eak.BindTo(acc) err := eak.BindTo(acc)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key")) render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
return return
} }
acc.ExternalAccountBinding = nar.ExternalAccountBinding acc.ExternalAccountBinding = nar.ExternalAccountBinding
@ -159,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -173,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
if !payload.isPostAsGet { if !payload.isPostAsGet {
var uar UpdateAccountRequest var uar UpdateAccountRequest
if err := json.Unmarshal(payload.value, &uar); err != nil { if err := json.Unmarshal(payload.value, &uar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
if err := uar.Validate(); err != nil { if err := uar.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if len(uar.Status) > 0 || len(uar.Contact) > 0 { if len(uar.Status) > 0 || len(uar.Contact) > 0 {
@ -189,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
} }
if err := h.db.UpdateAccount(ctx, acc); err != nil { if err := h.db.UpdateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return return
} }
} }
@ -215,17 +214,17 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
accID := chi.URLParam(r, "accID") accID := chi.URLParam(r, "accID")
if acc.ID != accID { if acc.ID != accID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return return
} }
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }

View file

@ -183,7 +183,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -202,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
// NotImplemented returns a 501 and is generally a placeholder for functionality which // NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such. // MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
} }
// GetAuthorization ACME api for retrieving an Authz. // GetAuthorization ACME api for retrieving an Authz.
@ -210,21 +210,21 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
return return
} }
if acc.ID != az.AccountID { if acc.ID != az.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own authorization '%s'", acc.ID, az.ID)) "account '%s' does not own authorization '%s'", acc.ID, az.ID))
return return
} }
if err = az.UpdateStatus(ctx, h.db); err != nil { if err = az.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
return return
} }
@ -239,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
// Just verify that the payload was set, since we're not strictly adhering // Just verify that the payload was set, since we're not strictly adhering
// to ACME V2 spec for reasons specified below. // to ACME V2 spec for reasons specified below.
_, err = payloadFromContext(ctx) _, err = payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -259,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
azID := chi.URLParam(r, "authzID") azID := chi.URLParam(r, "authzID")
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
return return
} }
ch.AuthorizationID = azID ch.AuthorizationID = azID
if acc.ID != ch.AccountID { if acc.ID != ch.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own challenge '%s'", acc.ID, ch.ID)) "account '%s' does not own challenge '%s'", acc.ID, ch.ID))
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return return
} }
@ -290,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
certID := chi.URLParam(r, "certID") certID := chi.URLParam(r, "certID")
cert, err := h.db.GetCertificate(ctx, certID) cert, err := h.db.GetCertificate(ctx, certID)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
return return
} }
if cert.AccountID != acc.ID { if cert.AccountID != acc.ID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own certificate '%s'", acc.ID, certID)) "account '%s' does not own certificate '%s'", acc.ID, certID))
return return
} }

View file

@ -10,13 +10,14 @@ import (
"strings" "strings"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
) )
type nextHTTP = func(http.ResponseWriter, *http.Request) type nextHTTP = func(http.ResponseWriter, *http.Request)
@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.db.CreateNonce(r.Context()) nonce, err := h.db.CreateNonce(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
w.Header().Set("Replay-Nonce", string(nonce)) w.Header().Set("Replay-Nonce", string(nonce))
@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
var expected []string var expected []string
p, err := provisionerFromContext(r.Context()) p, err := provisionerFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return return
} }
} }
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"expected content-type to be in %s, but got %s", expected, ct)) "expected content-type to be in %s, but got %s", expected, ct))
} }
} }
@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) render.Error(w, acme.WrapErrorISE(err, "failed to read request body"))
return return
} }
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
return return
} }
ctx := context.WithValue(r.Context(), jwsContextKey, jws) ctx := context.WithValue(r.Context(), jwsContextKey, jws)
@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) jws, err := jwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if len(jws.Signatures) == 0 { if len(jws.Signatures) == 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
return return
} }
if len(jws.Signatures) > 1 { if len(jws.Signatures) > 1 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
return return
} }
@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
len(uh.Algorithm) > 0 || len(uh.Algorithm) > 0 ||
len(uh.Nonce) > 0 || len(uh.Nonce) > 0 ||
len(uh.ExtraHeaders) > 0 { len(uh.ExtraHeaders) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return return
} }
hdr := sig.Protected hdr := sig.Protected
@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
switch k := hdr.JSONWebKey.Key.(type) { switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
if k.Size() < keyutil.MinRSAKeyBytes { if k.Size() < keyutil.MinRSAKeyBytes {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"rsa keys must be at least %d bits (%d bytes) in size", "rsa keys must be at least %d bits (%d bytes) in size",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
return return
} }
default: default:
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"jws key type and algorithm do not match")) "jws key type and algorithm do not match"))
return return
} }
@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good // we good
default: default:
api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
return return
} }
// Check the validity/freshness of the Nonce. // Check the validity/freshness of the Nonce.
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
// Check that the JWS url matches the requested url. // Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string) jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok { if !ok {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
return return
} }
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() { if jwsURL != reqURL.String() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
return return
} }
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return return
} }
if hdr.JSONWebKey == nil && hdr.KeyID == "" { if hdr.JSONWebKey == nil && hdr.KeyID == "" {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return return
} }
next(w, r) next(w, r)
@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) jws, err := jwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
jwk := jws.Signatures[0].Protected.JSONWebKey jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil { if jwk == nil {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
return return
} }
if !jwk.Valid() { if !jwk.Valid() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
return return
} }
// Overwrite KeyID with the JWK thumbprint. // Overwrite KeyID with the JWK thumbprint.
jwk.KeyID, err = acme.KeyToID(jwk) jwk.KeyID, err = acme.KeyToID(jwk)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
return return
} }
@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
// For NewAccount and Revoke requests ... // For NewAccount and Revoke requests ...
break break
case err != nil: case err != nil:
api.WriteError(w, err) render.Error(w, err)
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
@ -290,17 +291,17 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
nameEscaped := chi.URLParam(r, "provisionerID") nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped) name, err := url.PathUnescape(nameEscaped)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return return
} }
p, err := h.ca.LoadProvisionerByName(name) p, err := h.ca.LoadProvisionerByName(name)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
acmeProv, ok := p.(*provisioner.ACME) acmeProv, ok := p.(*provisioner.ACME)
if !ok { if !ok {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return return
} }
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
@ -315,11 +316,11 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
ok, err := h.prerequisitesChecker(ctx) ok, err := h.prerequisitesChecker(ctx)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return return
} }
if !ok { if !ok {
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return return
} }
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
@ -334,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"kid does not have required prefix; expected %s, but got %s", "kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid)) kidPrefix, kid))
return return
@ -351,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
acc, err := h.db.GetAccount(ctx, accID) acc, err := h.db.GetAccount(ctx, accID)
switch { switch {
case nosql.IsErrNotFound(err): case nosql.IsErrNotFound(err):
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return return
case err != nil: case err != nil:
api.WriteError(w, err) render.Error(w, err)
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
@ -376,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -412,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return return
} }
payload, err := jws.Verify(jwk) payload, err := jws.Verify(jwk)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
return return
} }
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
@ -443,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context()) payload, err := payloadFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if !payload.isPostAsGet { if !payload.isPostAsGet {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
return return
} }
next(w, r) next(w, r)

View file

@ -15,7 +15,6 @@ import (
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
) )
@ -73,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var nor NewOrderRequest var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil { if err := json.Unmarshal(payload.value, &nor); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-order request payload")) "failed to unmarshal new-order request payload"))
return return
} }
if err := nor.Validate(); err != nil { if err := nor.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -119,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusPending, Status: acme.StatusPending,
} }
if err := h.newAuthorization(ctx, az); err != nil { if err := h.newAuthorization(ctx, az); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
o.AuthorizationIDs[i] = az.ID o.AuthorizationIDs[i] = az.ID
@ -138,7 +137,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
} }
if err := h.db.CreateOrder(ctx, o); err != nil { if err := h.db.CreateOrder(ctx, o); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return return
} }
@ -189,31 +188,31 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
} }
if acc.ID != o.AccountID { if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID)) "account '%s' does not own order '%s'", acc.ID, o.ID))
return return
} }
if prov.GetID() != o.ProvisionerID { if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.UpdateStatus(ctx, h.db); err != nil { if err = o.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
return return
} }
@ -228,47 +227,47 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var fr FinalizeRequest var fr FinalizeRequest
if err := json.Unmarshal(payload.value, &fr); err != nil { if err := json.Unmarshal(payload.value, &fr); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal finalize-order request payload")) "failed to unmarshal finalize-order request payload"))
return return
} }
if err := fr.Validate(); err != nil { if err := fr.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
} }
if acc.ID != o.AccountID { if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID)) "account '%s' does not own order '%s'", acc.ID, o.ID))
return return
} }
if prov.GetID() != o.ProvisionerID { if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
return return
} }

View file

@ -10,13 +10,14 @@ import (
"net/http" "net/http"
"strings" "strings"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
) )
type revokePayload struct { type revokePayload struct {
@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var p revokePayload var p revokePayload
err = json.Unmarshal(payload.value, &p) err = json.Unmarshal(payload.value, &p)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload")) render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
return return
} }
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
if err != nil { if err != nil {
// in this case the most likely cause is a client that didn't properly encode the certificate // in this case the most likely cause is a client that didn't properly encode the certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
return return
} }
certToBeRevoked, err := x509.ParseCertificate(certBytes) certToBeRevoked, err := x509.ParseCertificate(certBytes)
if err != nil { if err != nil {
// in this case a client may have encoded something different than a certificate // in this case a client may have encoded something different than a certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
return return
} }
serial := certToBeRevoked.SerialNumber.String() serial := certToBeRevoked.SerialNumber.String()
dbCert, err := h.db.GetCertificateBySerial(ctx, serial) dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return return
} }
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
// this should never happen // this should never happen
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal")) render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal"))
return return
} }
if shouldCheckAccountFrom(jws) { if shouldCheckAccountFrom(jws) {
account, err := accountFromContext(ctx) account, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil { if acmeErr != nil {
api.WriteError(w, acmeErr) render.Error(w, acmeErr)
return return
} }
} else { } else {
@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
_, err := jws.Verify(certToBeRevoked.PublicKey) _, err := jws.Verify(certToBeRevoked.PublicKey)
if err != nil { if err != nil {
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
return return
} }
} }
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return return
} }
if hasBeenRevokedBefore { if hasBeenRevokedBefore {
api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
return return
} }
reasonCode := p.ReasonCode reasonCode := p.ReasonCode
acmeErr := validateReasonCode(reasonCode) acmeErr := validateReasonCode(reasonCode)
if acmeErr != nil { if acmeErr != nil {
api.WriteError(w, acmeErr) render.Error(w, acmeErr)
return return
} }
@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
err = prov.AuthorizeRevoke(ctx, "") err = prov.AuthorizeRevoke(ctx, "")
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
return return
} }
options := revokeOptions(serial, certToBeRevoked, reasonCode) options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = h.ca.Revoke(ctx, options) err = h.ca.Revoke(ctx, options)
if err != nil { if err != nil {
api.WriteError(w, wrapRevokeErr(err)) render.Error(w, wrapRevokeErr(err))
return return
} }