Compare commits

...

19 commits

Author SHA1 Message Date
Panagiotis Siatras
845e41967d
api/render: use default error messages if none specified 2022-03-22 19:15:23 +02:00
Panagiotis Siatras
c3cc60e211
go.mod: added dependency to stretchr/testify 2022-03-22 18:39:42 +02:00
Panagiotis Siatras
2e729ebb26
authority/admin/api: refactored to support read.AdminJSON 2022-03-22 18:39:13 +02:00
Panagiotis Siatras
a5c171e750
api: refactored to support the new signatures for read.JSON, read.AdminJSON & read.ProtoJSON 2022-03-22 18:38:46 +02:00
Panagiotis Siatras
a715e57d04
api/read: reworked JSON & ProtoJSON to use buffers & added AdminJSON 2022-03-22 18:38:03 +02:00
Panagiotis Siatras
2fd84227f0
internal/buffer: initial implementation of the package 2022-03-22 18:36:57 +02:00
Panagiotis Siatras
e82b21c1cb
api/render: implemented BadRequest 2022-03-22 15:36:05 +02:00
Panagiotis Siatras
4cdb38b2e8
ca: fixed broken tests 2022-03-22 14:35:22 +02:00
Panagiotis Siatras
23c81db95a
scep: refactored to support api/render.Error 2022-03-22 14:35:22 +02:00
Panagiotis Siatras
d49c00b0d7
ca: refactored to support api/render.Error 2022-03-22 14:35:21 +02:00
Panagiotis Siatras
098c2e1134
authority/admin: refactored to support api/render.Error 2022-03-22 14:35:21 +02:00
Panagiotis Siatras
6636e87fc7
acme/api: refactored to support api/render.Error 2022-03-22 14:35:20 +02:00
Panagiotis Siatras
9b6c1f608e
api: refactored to support api/render.Error 2022-03-22 14:35:20 +02:00
Panagiotis Siatras
3389e57c48
api/render: implemented Error 2022-03-22 14:35:19 +02:00
Panagiotis Siatras
9aa480a09a
api: refactored to support api/render 2022-03-22 14:35:19 +02:00
Panagiotis Siatras
833ea1e695
ca: refactored to support api/render 2022-03-22 14:35:18 +02:00
Panagiotis Siatras
b79af0456c
authority/admin: refactored to support api/render 2022-03-22 14:35:18 +02:00
Panagiotis Siatras
eae0211a3e
acme/api: refactored to support api/render 2022-03-22 14:35:18 +02:00
Panagiotis Siatras
a31feae6d4
api/render: initial implementation of the package 2022-03-22 14:35:17 +02:00
31 changed files with 679 additions and 500 deletions

View file

@ -5,8 +5,9 @@ import (
"net/http"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/logging"
)
@ -70,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
var nar NewAccountRequest
if err := json.Unmarshal(payload.value, &nar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload"))
return
}
if err := nar.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := acmeProvisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -96,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
acmeErr, ok := err.(*acme.Error)
if !ok || acmeErr.Status != http.StatusBadRequest {
// Something went wrong ...
api.WriteError(w, err)
render.Error(w, err)
return
}
// Account does not exist //
if nar.OnlyReturnExisting {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
"account does not exist"))
return
}
jwk, err := jwkFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
eak, err := h.validateExternalAccountBinding(ctx, &nar)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -125,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusValid,
}
if err := h.db.CreateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating account"))
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
return
}
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
err := eak.BindTo(acc)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key"))
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
return
}
acc.ExternalAccountBinding = nar.ExternalAccountBinding
@ -149,7 +150,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
api.JSONStatus(w, acc, httpStatus)
render.JSONStatus(w, acc, httpStatus)
}
// GetOrUpdateAccount is the api for updating an ACME account.
@ -157,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -171,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
if !payload.isPostAsGet {
var uar UpdateAccountRequest
if err := json.Unmarshal(payload.value, &uar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload"))
return
}
if err := uar.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if len(uar.Status) > 0 || len(uar.Contact) > 0 {
@ -187,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
}
if err := h.db.UpdateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating account"))
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return
}
}
@ -196,7 +197,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
api.JSON(w, acc)
render.JSON(w, acc)
}
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
@ -213,22 +214,22 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
accID := chi.URLParam(r, "accID")
if acc.ID != accID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return
}
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
h.linker.LinkOrdersByAccountID(ctx, orders)
api.JSON(w, orders)
render.JSON(w, orders)
logOrdersByAccount(w, orders)
}

View file

@ -12,8 +12,10 @@ import (
"time"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
)
@ -181,11 +183,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
api.JSON(w, &Directory{
render.JSON(w, &Directory{
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
@ -200,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
// NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
}
// GetAuthorization ACME api for retrieving an Authz.
@ -208,28 +210,28 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
return
}
if acc.ID != az.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
return
}
if err = az.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status"))
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
return
}
h.linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
api.JSON(w, az)
render.JSON(w, az)
}
// GetChallenge ACME api for retrieving a Challenge.
@ -237,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
// Just verify that the payload was set, since we're not strictly adhering
// to ACME V2 spec for reasons specified below.
_, err = payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -257,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
azID := chi.URLParam(r, "authzID")
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
return
}
ch.AuthorizationID = azID
if acc.ID != ch.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own challenge '%s'", acc.ID, ch.ID))
return
}
jwk, err := jwkFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge"))
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return
}
@ -280,7 +282,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
api.JSON(w, ch)
render.JSON(w, ch)
}
// GetCertificate ACME api for retrieving a Certificate.
@ -288,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
certID := chi.URLParam(r, "certID")
cert, err := h.db.GetCertificate(ctx, certID)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
return
}
if cert.AccountID != acc.ID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own certificate '%s'", acc.ID, certID))
return
}

View file

@ -10,13 +10,14 @@ import (
"strings"
"github.com/go-chi/chi"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/nosql"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
)
type nextHTTP = func(http.ResponseWriter, *http.Request)
@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.db.CreateNonce(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
w.Header().Set("Replay-Nonce", string(nonce))
@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
var expected []string
p, err := provisionerFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return
}
}
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"expected content-type to be in %s, but got %s", expected, ct))
}
}
@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
render.Error(w, acme.WrapErrorISE(err, "failed to read request body"))
return
}
jws, err := jose.ParseJWS(string(body))
if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
return
}
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if len(jws.Signatures) == 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
return
}
if len(jws.Signatures) > 1 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
return
}
@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
len(uh.Algorithm) > 0 ||
len(uh.Nonce) > 0 ||
len(uh.ExtraHeaders) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return
}
hdr := sig.Protected
@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey:
if k.Size() < keyutil.MinRSAKeyBytes {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"rsa keys must be at least %d bits (%d bytes) in size",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
return
}
default:
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"jws key type and algorithm do not match"))
return
}
@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good
default:
api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
return
}
// Check the validity/freshness of the Nonce.
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
// Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
return
}
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
return
}
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return
}
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return
}
next(w, r)
@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
return
}
if !jwk.Valid() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
return
}
// Overwrite KeyID with the JWK thumbprint.
jwk.KeyID, err = acme.KeyToID(jwk)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
return
}
@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
// For NewAccount and Revoke requests ...
break
case err != nil:
api.WriteError(w, err)
render.Error(w, err)
return
default:
if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return
}
ctx = context.WithValue(ctx, accContextKey, acc)
@ -290,17 +291,17 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := h.ca.LoadProvisionerByName(name)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
@ -315,11 +316,11 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
ctx := r.Context()
ok, err := h.prerequisitesChecker(ctx)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
next(w, r.WithContext(ctx))
@ -334,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid))
return
@ -351,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
acc, err := h.db.GetAccount(ctx, accID)
switch {
case nosql.IsErrNotFound(err):
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return
case err != nil:
api.WriteError(w, err)
render.Error(w, err)
return
default:
if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return
}
ctx = context.WithValue(ctx, accContextKey, acc)
@ -376,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -412,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
jwk, err := jwkFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return
}
payload, err := jws.Verify(jwk)
if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
return
}
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
@ -443,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if !payload.isPostAsGet {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
return
}
next(w, r)

View file

@ -11,9 +11,11 @@ import (
"time"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
)
// NewOrderRequest represents the body for a NewOrder request.
@ -70,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-order request payload"))
return
}
if err := nor.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -116,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusPending,
}
if err := h.newAuthorization(ctx, az); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
o.AuthorizationIDs[i] = az.ID
@ -135,14 +137,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
}
if err := h.db.CreateOrder(ctx, o); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating order"))
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return
}
h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSONStatus(w, o, http.StatusCreated)
render.JSONStatus(w, o, http.StatusCreated)
}
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
@ -186,38 +188,38 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return
}
if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID))
return
}
if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return
}
if err = o.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating order status"))
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
return
}
h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o)
render.JSON(w, o)
}
// FinalizeOrder attemptst to finalize an order and create a certificate.
@ -225,54 +227,54 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
var fr FinalizeRequest
if err := json.Unmarshal(payload.value, &fr); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal finalize-order request payload"))
return
}
if err := fr.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return
}
if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID))
return
}
if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return
}
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order"))
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
return
}
h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o)
render.JSON(w, o)
}
// challengeTypes determines the types of challenges that should be used

View file

@ -10,13 +10,14 @@ import (
"net/http"
"strings"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
)
type revokePayload struct {
@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
var p revokePayload
err = json.Unmarshal(payload.value, &p)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
return
}
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
if err != nil {
// in this case the most likely cause is a client that didn't properly encode the certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
return
}
certToBeRevoked, err := x509.ParseCertificate(certBytes)
if err != nil {
// in this case a client may have encoded something different than a certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
return
}
serial := certToBeRevoked.SerialNumber.String()
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return
}
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
// this should never happen
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal"))
render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal"))
return
}
if shouldCheckAccountFrom(jws) {
account, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil {
api.WriteError(w, acmeErr)
render.Error(w, acmeErr)
return
}
} else {
@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
_, err := jws.Verify(certToBeRevoked.PublicKey)
if err != nil {
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
return
}
}
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return
}
if hasBeenRevokedBefore {
api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
return
}
reasonCode := p.ReasonCode
acmeErr := validateReasonCode(reasonCode)
if acmeErr != nil {
api.WriteError(w, acmeErr)
render.Error(w, acmeErr)
return
}
@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
err = prov.AuthorizeRevoke(ctx, "")
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
return
}
options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = h.ca.Revoke(ctx, options)
if err != nil {
api.WriteError(w, wrapRevokeErr(err))
render.Error(w, wrapRevokeErr(err))
return
}

View file

@ -21,6 +21,7 @@ import (
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
@ -282,7 +283,7 @@ func (h *caHandler) Route(r Router) {
// Version is an HTTP handler that returns the version of the server.
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
v := h.Authority.Version()
JSON(w, VersionResponse{
render.JSON(w, VersionResponse{
Version: v.Version,
RequireClientAuthentication: v.RequireClientAuthentication,
})
@ -290,7 +291,7 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
// Health is an HTTP handler that returns the status of the server.
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
JSON(w, HealthResponse{Status: "ok"})
render.JSON(w, HealthResponse{Status: "ok"})
}
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
@ -301,11 +302,11 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the
cert, err := h.Authority.Root(sum)
if err != nil {
WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return
}
JSON(w, &RootResponse{RootPEM: Certificate{cert}})
render.JSON(w, &RootResponse{RootPEM: Certificate{cert}})
}
func certChainToPEM(certChain []*x509.Certificate) []Certificate {
@ -320,16 +321,16 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r)
if err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
p, next, err := h.Authority.GetProvisioners(cursor, limit)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
JSON(w, &ProvisionersResponse{
render.JSON(w, &ProvisionersResponse{
Provisioners: p,
NextCursor: next,
})
@ -340,17 +341,17 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid)
if err != nil {
WriteError(w, errs.NotFoundErr(err))
render.Error(w, errs.NotFoundErr(err))
return
}
JSON(w, &ProvisionerKeyResponse{key})
render.JSON(w, &ProvisionerKeyResponse{key})
}
// Roots returns all the root certificates for the CA.
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error getting roots"))
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
return
}
@ -359,7 +360,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{roots[i]}
}
JSONStatus(w, &RootsResponse{
render.JSONStatus(w, &RootsResponse{
Certificates: certs,
}, http.StatusCreated)
}
@ -368,7 +369,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation()
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error getting federated roots"))
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
return
}
@ -377,7 +378,7 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{federated[i]}
}
JSONStatus(w, &FederationResponse{
render.JSONStatus(w, &FederationResponse{
Certificates: certs,
}, http.StatusCreated)
}

View file

@ -1,66 +0,0 @@
package api
import (
"encoding/json"
"fmt"
"net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/scep"
)
// WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err error) {
switch k := err.(type) {
case *acme.Error:
acme.WriteError(w, k)
return
case *admin.Error:
admin.WriteError(w, k)
return
case *scep.Error:
w.Header().Set("Content-Type", "text/plain")
default:
w.Header().Set("Content-Type", "application/json")
}
cause := errors.Cause(err)
if sc, ok := err.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
if sc, ok := cause.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
w.WriteHeader(http.StatusInternalServerError)
}
}
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
} else if e, ok := cause.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
log.Error(w, err)
}
}

View file

@ -2,30 +2,91 @@
package read
import (
"bytes"
"encoding/json"
"io"
"fmt"
"net/http"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/internal/buffer"
)
// JSON reads JSON from the request body and stores it in the value
// pointed by v.
func JSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
// JSON unmarshals from the given request's JSON body into v. In case of an
// error a HTTP Bad Request error will be written to w.
func JSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
b := read(w, r)
if b == nil {
return false
}
return nil
defer buffer.Put(b)
if err := json.NewDecoder(b).Decode(v); err != nil {
err = fmt.Errorf("error decoding json: %w", err)
render.BadRequest(w, err)
return false
}
return true
}
// AdminJSON is obsolete; it's here for backwards compatibility.
//
// Please don't use.
func AdminJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
b := read(w, r)
if b == nil {
return false
}
defer buffer.Put(b)
if err := json.NewDecoder(b).Decode(v); err != nil {
e := admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")
admin.WriteError(w, e)
return false
}
return true
}
// ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
func ProtoJSON(w http.ResponseWriter, r *http.Request, m proto.Message) bool {
b := read(w, r)
if b == nil {
return false
}
return protojson.Unmarshal(data, m)
defer buffer.Put(b)
if err := protojson.Unmarshal(b.Bytes(), m); err != nil {
err = fmt.Errorf("error decoding proto json: %w", err)
render.BadRequest(w, err)
return false
}
return true
}
func read(w http.ResponseWriter, r *http.Request) *bytes.Buffer {
b := buffer.Get()
if _, err := b.ReadFrom(r.Body); err != nil {
buffer.Put(b)
err = fmt.Errorf("error reading request body: %w", err)
render.BadRequest(w, err)
return nil
}
return b
}

View file

@ -2,45 +2,56 @@ package read
import (
"io"
"reflect"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"testing/iotest"
"github.com/smallstep/certificates/errs"
"github.com/stretchr/testify/assert"
)
func TestJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
cases := []struct {
src io.Reader
exp interface{}
ok bool
code int
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := JSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
0: {
src: strings.NewReader(`{"foo":"bar"}`),
exp: map[string]interface{}{"foo": "bar"},
ok: true,
code: http.StatusOK,
},
1: {
src: strings.NewReader(`{"foo"}`),
code: http.StatusBadRequest,
},
2: {
src: io.MultiReader(
strings.NewReader(`{`),
iotest.ErrReader(assert.AnError),
strings.NewReader(`"foo":"bar"}`),
),
code: http.StatusBadRequest,
},
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
for caseIndex := range cases {
kase := cases[caseIndex]
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", kase.src)
rec := httptest.NewRecorder()
var body interface{}
got := JSON(rec, req, &body)
assert.Equal(t, kase.ok, got)
assert.Equal(t, kase.code, rec.Result().StatusCode)
assert.Equal(t, kase.exp, body)
})
}
}

View file

@ -4,6 +4,7 @@ import (
"net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs"
)
@ -28,24 +29,23 @@ func (s *RekeyRequest) Validate() error {
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing client certificate"))
render.Error(w, errs.BadRequest("missing client certificate"))
return
}
var body RekeyRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if !read.JSON(w, r, &body) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
return
}
certChainPEM := certChainToPEM(certChain)
@ -55,7 +55,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
}
LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,

148
api/render/render.go Normal file
View file

@ -0,0 +1,148 @@
// Package render implements functionality related to response rendering.
package render
import (
"encoding/json"
"fmt"
"net/http"
"os"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/scep"
)
// JSON writes the passed value into the http.ResponseWriter.
func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK)
}
// JSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil {
log.Error(w, err)
return
}
log.EnabledResponse(w, v)
}
// ProtoJSON writes the passed value into the http.ResponseWriter.
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
ProtoJSONStatus(w, m, http.StatusOK)
}
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
b, err := protojson.Marshal(m)
if err != nil {
log.Error(w, err)
return
}
if _, err := w.Write(b); err != nil {
log.Error(w, err)
return
}
// log.EnabledResponse(w, v)
}
// Error encodes the JSON representation of err to w.
func Error(w http.ResponseWriter, err error) {
switch k := err.(type) {
case *acme.Error:
acme.WriteError(w, k)
return
case *admin.Error:
admin.WriteError(w, k)
return
case *scep.Error:
w.Header().Set("Content-Type", "text/plain")
default:
w.Header().Set("Content-Type", "application/json")
}
cause := errors.Cause(err)
if sc, ok := err.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
if sc, ok := cause.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
w.WriteHeader(http.StatusInternalServerError)
}
}
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
} else if e, ok := cause.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
log.Error(w, err)
}
}
// BadRequest renders the JSON representation of err into w and sets its
// status code to http.StatusBadRequest.
//
// In case err is nil, a default error message will be used in its place.
func BadRequest(w http.ResponseWriter, err error) {
codedError(w, http.StatusBadRequest, err)
}
func codedError(w http.ResponseWriter, code int, err error) {
if err == nil {
err = errors.New(http.StatusText(code))
}
var wrapper = struct {
Status int `json:"status"`
Message string `json:"message"`
}{
Status: code,
Message: err.Error(),
}
data, err := json.Marshal(wrapper)
if err != nil {
log.Error(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
w.Write(data)
}

View file

@ -1,10 +1,13 @@
package api
package render
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/smallstep/certificates/logging"
)
@ -51,3 +54,40 @@ func TestJSON(t *testing.T) {
})
}
}
func TestErrors(t *testing.T) {
cases := []struct {
fn func(http.ResponseWriter, error) // helper
err error // error being render
code int // expected status code
body string // expected body
}{
// --- BadRequest
0: {
fn: BadRequest,
err: assert.AnError,
code: http.StatusBadRequest,
body: `{"status":400,"message":"assert.AnError general error for testing"}`,
},
1: {
fn: BadRequest,
code: http.StatusBadRequest,
body: `{"status":400,"message":"Bad Request"}`,
},
}
for caseIndex := range cases {
kase := cases[caseIndex]
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
rec := httptest.NewRecorder()
kase.fn(rec, kase.err)
ret := rec.Result()
assert.Equal(t, "application/json", ret.Header.Get("Content-Type"))
assert.Equal(t, kase.code, ret.StatusCode)
assert.Equal(t, kase.body, rec.Body.String())
})
}
}

View file

@ -5,6 +5,7 @@ import (
"net/http"
"strings"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs"
)
@ -18,13 +19,13 @@ const (
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
cert, err := h.getPeerCertificate(r)
if err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
certChain, err := h.Authority.Renew(cert)
if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return
}
certChainPEM := certChainToPEM(certChain)
@ -34,7 +35,7 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
}
LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,

View file

@ -7,6 +7,7 @@ import (
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
@ -50,13 +51,12 @@ func (r *RevokeRequest) Validate() (err error) {
// TODO: Add CRL and OCSP support.
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if !read.JSON(w, r, &body) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
@ -73,7 +73,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
if len(body.OTT) > 0 {
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
@ -82,12 +82,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number
// being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing ott or client certificate"))
render.Error(w, errs.BadRequest("missing ott or client certificate"))
return
}
opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest("serial number in client certificate different than body"))
render.Error(w, errs.BadRequest("serial number in client certificate different than body"))
return
}
// TODO: should probably be checking if the certificate was revoked here.
@ -98,12 +98,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
}
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err, "error revoking certificate"))
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
return
}
logRevoke(w, opts)
JSON(w, &RevokeResponse{Status: "ok"})
render.JSON(w, &RevokeResponse{Status: "ok"})
}
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

View file

@ -6,6 +6,7 @@ import (
"net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
@ -50,14 +51,13 @@ type SignResponse struct {
// information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if !read.JSON(w, r, &body) {
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
@ -69,13 +69,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing certificate"))
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return
}
certChainPEM := certChainToPEM(certChain)
@ -84,7 +84,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
caPEM = certChainPEM[1]
}
LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{
render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0],
CaPEM: caPEM,
CertChainPEM: certChainPEM,

View file

@ -12,6 +12,7 @@ import (
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
@ -251,20 +252,19 @@ type SSHBastionResponse struct {
// the request.
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if !read.JSON(w, r, &body) {
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
return
}
@ -272,7 +272,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
return
}
}
@ -289,13 +289,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return
}
@ -303,7 +303,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return
}
addUserCertificate = &SSHCertificate{addUserCert}
@ -316,7 +316,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
@ -328,13 +328,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate"))
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return
}
identityCertificate = certChainToPEM(certChain)
}
JSONStatus(w, &SSHSignResponse{
render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{cert},
AddUserCertificate: addUserCertificate,
IdentityCertificate: identityCertificate,
@ -346,12 +346,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots(r.Context())
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound("no keys found"))
render.Error(w, errs.NotFound("no keys found"))
return
}
@ -363,7 +363,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
}
JSON(w, resp)
render.JSON(w, resp)
}
// SSHFederation is an HTTP handler that returns the federated SSH public keys
@ -371,12 +371,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation(r.Context())
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound("no keys found"))
render.Error(w, errs.NotFound("no keys found"))
return
}
@ -388,25 +388,24 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
}
JSON(w, resp)
render.JSON(w, resp)
}
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
// and servers.
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if !read.JSON(w, r, &body) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
@ -417,31 +416,30 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert:
cfg.HostTemplates = ts
default:
WriteError(w, errs.InternalServer("it should hot get here"))
render.Error(w, errs.InternalServer("it should hot get here"))
return
}
JSON(w, cfg)
render.JSON(w, cfg)
}
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if !read.JSON(w, r, &body) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHCheckPrincipalResponse{
render.JSON(w, &SSHCheckPrincipalResponse{
Exists: exists,
})
}
@ -455,10 +453,10 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHGetHostsResponse{
render.JSON(w, &SSHGetHostsResponse{
Hosts: hosts,
})
}
@ -466,22 +464,21 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
// SSHBastion provides returns the bastion configured if any.
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if !read.JSON(w, r, &body) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
JSON(w, &SSHBastionResponse{
render.JSON(w, &SSHBastionResponse{
Hostname: body.Hostname,
Bastion: bastion,
})

View file

@ -7,6 +7,7 @@ import (
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
@ -40,37 +41,37 @@ type SSHRekeyResponse struct {
// the request.
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if !read.JSON(w, r, &body) {
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
return
}
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
return
}
@ -80,11 +81,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return
}
JSONStatus(w, &SSHRekeyResponse{
render.JSONStatus(w, &SSHRekeyResponse{
Certificate: SSHCertificate{newCert},
IdentityCertificate: identity,
}, http.StatusCreated)

View file

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
@ -38,31 +39,31 @@ type SSHRenewResponse struct {
// the request.
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if read.JSON(w, r, &body) {
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
return
}
@ -72,11 +73,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return
}
JSONStatus(w, &SSHSignResponse{
render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{newCert},
IdentityCertificate: identity,
}, http.StatusCreated)

View file

@ -6,6 +6,7 @@ import (
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
@ -49,13 +50,12 @@ func (r *SSHRevokeRequest) Validate() (err error) {
// NOTE: currently only Passive revocation is supported.
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest
if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
if !read.JSON(w, r, &body) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, err)
render.Error(w, err)
return
}
@ -71,18 +71,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
// otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.UnauthorizedErr(err))
render.Error(w, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
return
}
logSSHRevoke(w, opts)
JSON(w, &SSHRevokeResponse{Status: "ok"})
render.JSON(w, &SSHRevokeResponse{Status: "ok"})
}
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

View file

@ -1,57 +0,0 @@
package api
import (
"encoding/json"
"net/http"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/log"
)
// JSON writes the passed value into the http.ResponseWriter.
func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK)
}
// JSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil {
log.Error(w, err)
return
}
log.EnabledResponse(w, v)
}
// ProtoJSON writes the passed value into the http.ResponseWriter.
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
ProtoJSONStatus(w, m, http.StatusOK)
}
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
b, err := protojson.Marshal(m)
if err != nil {
log.Error(w, err)
return
}
if _, err := w.Write(b); err != nil {
log.Error(w, err)
return
}
// log.EnabledResponse(w, v)
}

View file

@ -6,10 +6,12 @@ import (
"net/http"
"github.com/go-chi/chi"
"github.com/smallstep/certificates/api"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
)
const (
@ -44,11 +46,11 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
provName := chi.URLParam(r, "provisionerName")
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if !eabEnabled {
api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, prov)
@ -101,15 +103,15 @@ func NewACMEAdminResponder() *ACMEAdminResponder {
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
}

View file

@ -10,6 +10,7 @@ import (
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
)
@ -85,28 +86,28 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
adm, ok := h.auth.LoadAdminByID(id)
if !ok {
api.WriteError(w, admin.NewError(admin.ErrorNotFoundType,
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
"admin %s not found", id))
return
}
api.ProtoJSON(w, adm)
render.ProtoJSON(w, adm)
}
// GetAdmins returns a segment of admins associated with the authority.
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r)
if err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params"))
return
}
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
return
}
api.JSON(w, &GetAdminsResponse{
render.JSON(w, &GetAdminsResponse{
Admins: admins,
NextCursor: nextCursor,
})
@ -115,19 +116,18 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
// CreateAdmin creates a new admin.
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest
if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
if !read.AdminJSON(w, r, &body) {
return
}
if err := body.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
return
}
adm := &linkedca.Admin{
@ -137,11 +137,11 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
}
// Store to authority collection.
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error storing admin"))
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
return
}
api.ProtoJSONStatus(w, adm, http.StatusCreated)
render.ProtoJSONStatus(w, adm, http.StatusCreated)
}
// DeleteAdmin deletes admin.
@ -149,23 +149,22 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
return
}
api.JSON(w, &DeleteResponse{Status: "ok"})
render.JSON(w, &DeleteResponse{Status: "ok"})
}
// UpdateAdmin updates an existing admin.
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest
if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
if !read.AdminJSON(w, r, &body) {
return
}
if err := body.Validate(); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
@ -173,9 +172,9 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id))
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
return
}
api.ProtoJSON(w, adm)
render.ProtoJSON(w, adm)
}

View file

@ -4,7 +4,7 @@ import (
"context"
"net/http"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
)
@ -15,7 +15,7 @@ type nextHTTP = func(http.ResponseWriter, *http.Request)
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
if !h.auth.IsAdminAPIEnabled() {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType,
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
"administration API not enabled"))
return
}
@ -28,14 +28,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization")
if tok == "" {
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
"missing authorization header token"))
return
}
adm, err := h.auth.AuthorizeAdminToken(r, tok)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}

View file

@ -4,10 +4,12 @@ import (
"net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
@ -33,39 +35,39 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
)
if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return
}
} else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
}
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
api.ProtoJSON(w, prov)
render.ProtoJSON(w, prov)
}
// GetProvisioners returns the given segment of provisioners associated with the authority.
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r)
if err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params"))
return
}
p, next, err := h.auth.GetProvisioners(cursor, limit)
if err != nil {
api.WriteError(w, errs.InternalServerErr(err))
render.Error(w, errs.InternalServerErr(err))
return
}
api.JSON(w, &GetProvisionersResponse{
render.JSON(w, &GetProvisionersResponse{
Provisioners: p,
NextCursor: next,
})
@ -74,22 +76,21 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
// CreateProvisioner creates a new prov.
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, prov); err != nil {
api.WriteError(w, err)
if !read.ProtoJSON(w, r, prov) {
return
}
// TODO: Validate inputs
if err := authority.ValidateClaims(prov.Claims); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
return
}
api.ProtoJSONStatus(w, prov, http.StatusCreated)
render.ProtoJSONStatus(w, prov, http.StatusCreated)
}
// DeleteProvisioner deletes a provisioner.
@ -103,75 +104,74 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
)
if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return
}
} else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
}
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
return
}
api.JSON(w, &DeleteResponse{Status: "ok"})
render.JSON(w, &DeleteResponse{Status: "ok"})
}
// UpdateProvisioner updates an existing prov.
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, nu); err != nil {
api.WriteError(w, err)
if !read.ProtoJSON(w, r, nu) {
return
}
name := chi.URLParam(r, "name")
_old, err := h.auth.LoadProvisionerByName(name)
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
return
}
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
return
}
if nu.Id != old.Id {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID"))
render.Error(w, admin.NewErrorISE("cannot change provisioner ID"))
return
}
if nu.Type != old.Type {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner type"))
render.Error(w, admin.NewErrorISE("cannot change provisioner type"))
return
}
if nu.AuthorityId != old.AuthorityId {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID"))
render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID"))
return
}
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt"))
render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt"))
return
}
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
return
}
// TODO: Validate inputs
if err := authority.ValidateClaims(nu.Claims); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
api.ProtoJSON(w, nu)
render.ProtoJSON(w, nu)
}

View file

@ -12,12 +12,13 @@ import (
"time"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api"
"github.com/smallstep/certificates/api"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/certificates/api/render"
)
func TestNewACMEClient(t *testing.T) {
@ -112,15 +113,15 @@ func TestNewACMEClient(t *testing.T) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
switch {
case i == 0:
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
case i == 1:
w.Header().Set("Replay-Nonce", "abc123")
api.JSONStatus(w, []byte{}, 200)
render.JSONStatus(w, []byte{}, 200)
i++
default:
w.Header().Set("Location", accLocation)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
}
})
@ -206,7 +207,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce)
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
})
if nonce, err := ac.GetNonce(); err != nil {
@ -315,7 +316,7 @@ func TestACMEClient_post(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -338,7 +339,7 @@ func TestACMEClient_post(t *testing.T) {
assert.Equals(t, hdr.KeyID, ac.kid)
}
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil {
@ -455,7 +456,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -477,7 +478,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, payload, norb)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := ac.NewOrder(norb); err != nil {
@ -577,7 +578,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -599,7 +600,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := ac.GetOrder(url); err != nil {
@ -699,7 +700,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -721,7 +722,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := ac.GetAuthz(url); err != nil {
@ -821,7 +822,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -844,7 +845,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := ac.GetChallenge(url); err != nil {
@ -944,7 +945,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -967,7 +968,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
assert.Equals(t, payload, []byte("{}"))
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if err := ac.ValidateChallenge(url); err != nil {
@ -1071,7 +1072,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -1093,7 +1094,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, payload, frb)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if err := ac.FinalizeOrder(url, csr); err != nil {
@ -1200,7 +1201,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -1222,7 +1223,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
})
if res, err := tc.client.GetAccountOrders(); err != nil {
@ -1331,7 +1332,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1)
render.JSONStatus(w, tc.r1, tc.rc1)
i++
return
}
@ -1356,7 +1357,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
if tc.certBytes != nil {
w.Write(tc.certBytes)
} else {
api.JSONStatus(w, tc.r2, tc.rc2)
render.JSONStatus(w, tc.r2, tc.rc2)
}
})

View file

@ -14,11 +14,14 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
)
func newLocalListener() net.Listener {
@ -79,7 +82,7 @@ func startCAServer(configFile string) (*CA, string, error) {
func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/version" {
api.JSON(w, api.VersionResponse{
render.JSON(w, api.VersionResponse{
Version: "test",
RequireClientAuthentication: true,
})
@ -93,7 +96,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han
}
isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0
if !isMTLS {
api.WriteError(w, errs.Unauthorized("missing peer certificate"))
render.Error(w, errs.Unauthorized("missing peer certificate"))
} else {
next.ServeHTTP(w, r)
}

View file

@ -24,6 +24,7 @@ import (
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
@ -182,7 +183,7 @@ func TestClient_Version(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Version()
@ -232,7 +233,7 @@ func TestClient_Health(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Health()
@ -290,7 +291,7 @@ func TestClient_Root(t *testing.T) {
if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Root(tt.shasum)
@ -360,7 +361,7 @@ func TestClient_Sign(t *testing.T) {
if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
render.Error(w, e)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
@ -371,7 +372,7 @@ func TestClient_Sign(t *testing.T) {
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
}
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Sign(tt.request)
@ -432,7 +433,7 @@ func TestClient_Revoke(t *testing.T) {
if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e)
render.Error(w, e)
return
} else if !equalJSON(t, body, tt.request) {
if tt.request == nil {
@ -443,7 +444,7 @@ func TestClient_Revoke(t *testing.T) {
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
}
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Revoke(tt.request, nil)
@ -503,7 +504,7 @@ func TestClient_Renew(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Renew(nil)
@ -568,9 +569,9 @@ func TestClient_RenewWithToken(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer token" {
api.JSONStatus(w, errs.InternalServer("force"), 500)
render.JSONStatus(w, errs.InternalServer("force"), 500)
} else {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
}
})
@ -640,7 +641,7 @@ func TestClient_Rekey(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Rekey(tt.request, nil)
@ -705,7 +706,7 @@ func TestClient_Provisioners(t *testing.T) {
if req.RequestURI != tt.expectedURI {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Provisioners(tt.args...)
@ -762,7 +763,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
}
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.ProvisionerKey(tt.kid)
@ -821,7 +822,7 @@ func TestClient_Roots(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Roots()
@ -879,7 +880,7 @@ func TestClient_Federation(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.Federation()
@ -941,7 +942,7 @@ func TestClient_SSHRoots(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.SSHRoots()
@ -1041,7 +1042,7 @@ func TestClient_RootFingerprint(t *testing.T) {
}
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.RootFingerprint()
@ -1102,7 +1103,7 @@ func TestClient_SSHBastion(t *testing.T) {
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode)
render.JSONStatus(w, tt.response, tt.responseCode)
})
got, err := c.SSHBastion(tt.request)

1
go.mod
View file

@ -31,6 +31,7 @@ require (
github.com/slackhq/nebula v1.5.2
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262
github.com/smallstep/nosql v0.4.0
github.com/stretchr/testify v1.7.0
github.com/urfave/cli v1.22.4
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.7.0

3
go.sum
View file

@ -445,9 +445,11 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
@ -1120,6 +1122,7 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=

23
internal/buffer/buffer.go Normal file
View file

@ -0,0 +1,23 @@
// Package buffer implements a reusable buffer pool.
package buffer
import (
"bytes"
"sync"
)
func Get() *bytes.Buffer {
return pool.Get().(*bytes.Buffer)
}
func Put(b *bytes.Buffer) {
b.Reset()
pool.Put(b)
}
var pool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}

View file

@ -16,6 +16,7 @@ import (
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/scep"
)
@ -189,19 +190,19 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
name := chi.URLParam(r, "provisionerName")
provisionerName, err := url.PathUnescape(name)
if err != nil {
api.WriteError(w, errors.Errorf("error url unescaping provisioner name '%s'", name))
render.Error(w, errors.Errorf("error url unescaping provisioner name '%s'", name))
return
}
p, err := h.Auth.LoadProvisionerByName(provisionerName)
if err != nil {
api.WriteError(w, err)
render.Error(w, err)
return
}
prov, ok := p.(*provisioner.SCEP)
if !ok {
api.WriteError(w, errors.New("provisioner must be of type SCEP"))
render.Error(w, errors.New("provisioner must be of type SCEP"))
return
}
@ -356,7 +357,7 @@ func writeError(w http.ResponseWriter, err error) {
Message: err.Error(),
Status: http.StatusInternalServerError, // TODO: make this a param?
}
api.WriteError(w, scepError)
render.Error(w, scepError)
}
func (h *Handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (SCEPResponse, error) {