Merge branch 'master' into fix/adminra

This commit is contained in:
Mariano Cano 2022-04-07 18:25:41 -07:00
commit 8abd568f03
57 changed files with 1188 additions and 823 deletions

View file

@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased - 0.18.3] - DATE ## [Unreleased - 0.18.3] - DATE
### Added ### Added
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`. - Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
- Added support for `extraNames` in X.509 templates.
### Changed ### Changed
- Made SCEP CA URL paths dynamic - Made SCEP CA URL paths dynamic
- Support two latest versions of Go (1.17, 1.18) - Support two latest versions of Go (1.17, 1.18)

View file

@ -5,8 +5,9 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
@ -70,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var nar NewAccountRequest var nar NewAccountRequest
if err := json.Unmarshal(payload.value, &nar); err != nil { if err := json.Unmarshal(payload.value, &nar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
if err := nar.Validate(); err != nil { if err := nar.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := acmeProvisionerFromContext(ctx) prov, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -96,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
acmeErr, ok := err.(*acme.Error) acmeErr, ok := err.(*acme.Error)
if !ok || acmeErr.Status != http.StatusBadRequest { if !ok || acmeErr.Status != http.StatusBadRequest {
// Something went wrong ... // Something went wrong ...
api.WriteError(w, err) render.Error(w, err)
return return
} }
// Account does not exist // // Account does not exist //
if nar.OnlyReturnExisting { if nar.OnlyReturnExisting {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
"account does not exist")) "account does not exist"))
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
eak, err := h.validateExternalAccountBinding(ctx, &nar) eak, err := h.validateExternalAccountBinding(ctx, &nar)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -125,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusValid, Status: acme.StatusValid,
} }
if err := h.db.CreateAccount(ctx, acc); err != nil { if err := h.db.CreateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) render.Error(w, acme.WrapErrorISE(err, "error creating account"))
return return
} }
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
err := eak.BindTo(acc) err := eak.BindTo(acc)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key")) render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
return return
} }
acc.ExternalAccountBinding = nar.ExternalAccountBinding acc.ExternalAccountBinding = nar.ExternalAccountBinding
@ -149,7 +150,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc) h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
api.JSONStatus(w, acc, httpStatus) render.JSONStatus(w, acc, httpStatus)
} }
// GetOrUpdateAccount is the api for updating an ACME account. // GetOrUpdateAccount is the api for updating an ACME account.
@ -157,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -171,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
if !payload.isPostAsGet { if !payload.isPostAsGet {
var uar UpdateAccountRequest var uar UpdateAccountRequest
if err := json.Unmarshal(payload.value, &uar); err != nil { if err := json.Unmarshal(payload.value, &uar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
if err := uar.Validate(); err != nil { if err := uar.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if len(uar.Status) > 0 || len(uar.Contact) > 0 { if len(uar.Status) > 0 || len(uar.Contact) > 0 {
@ -187,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
} }
if err := h.db.UpdateAccount(ctx, acc); err != nil { if err := h.db.UpdateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return return
} }
} }
@ -196,7 +197,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc) h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
api.JSON(w, acc) render.JSON(w, acc)
} }
func logOrdersByAccount(w http.ResponseWriter, oids []string) { func logOrdersByAccount(w http.ResponseWriter, oids []string) {
@ -213,22 +214,22 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
accID := chi.URLParam(r, "accID") accID := chi.URLParam(r, "accID")
if acc.ID != accID { if acc.ID != accID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return return
} }
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
h.linker.LinkOrdersByAccountID(ctx, orders) h.linker.LinkOrdersByAccountID(ctx, orders)
api.JSON(w, orders) render.JSON(w, orders)
logOrdersByAccount(w, orders) logOrdersByAccount(w, orders)
} }

View file

@ -12,8 +12,10 @@ import (
"time" "time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
@ -181,11 +183,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
api.JSON(w, &Directory{ render.JSON(w, &Directory{
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
@ -200,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
// NotImplemented returns a 501 and is generally a placeholder for functionality which // NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such. // MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
} }
// GetAuthorization ACME api for retrieving an Authz. // GetAuthorization ACME api for retrieving an Authz.
@ -208,28 +210,28 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
return return
} }
if acc.ID != az.AccountID { if acc.ID != az.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own authorization '%s'", acc.ID, az.ID)) "account '%s' does not own authorization '%s'", acc.ID, az.ID))
return return
} }
if err = az.UpdateStatus(ctx, h.db); err != nil { if err = az.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
return return
} }
h.linker.LinkAuthorization(ctx, az) h.linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
api.JSON(w, az) render.JSON(w, az)
} }
// GetChallenge ACME api for retrieving a Challenge. // GetChallenge ACME api for retrieving a Challenge.
@ -237,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
// Just verify that the payload was set, since we're not strictly adhering // Just verify that the payload was set, since we're not strictly adhering
// to ACME V2 spec for reasons specified below. // to ACME V2 spec for reasons specified below.
_, err = payloadFromContext(ctx) _, err = payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -257,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
azID := chi.URLParam(r, "authzID") azID := chi.URLParam(r, "authzID")
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
return return
} }
ch.AuthorizationID = azID ch.AuthorizationID = azID
if acc.ID != ch.AccountID { if acc.ID != ch.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own challenge '%s'", acc.ID, ch.ID)) "account '%s' does not own challenge '%s'", acc.ID, ch.ID))
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return return
} }
@ -280,7 +282,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
api.JSON(w, ch) render.JSON(w, ch)
} }
// GetCertificate ACME api for retrieving a Certificate. // GetCertificate ACME api for retrieving a Certificate.
@ -288,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
certID := chi.URLParam(r, "certID") certID := chi.URLParam(r, "certID")
cert, err := h.db.GetCertificate(ctx, certID) cert, err := h.db.GetCertificate(ctx, certID)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
return return
} }
if cert.AccountID != acc.ID { if cert.AccountID != acc.ID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own certificate '%s'", acc.ID, certID)) "account '%s' does not own certificate '%s'", acc.ID, certID))
return return
} }

View file

@ -10,13 +10,14 @@ import (
"strings" "strings"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
) )
type nextHTTP = func(http.ResponseWriter, *http.Request) type nextHTTP = func(http.ResponseWriter, *http.Request)
@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.db.CreateNonce(r.Context()) nonce, err := h.db.CreateNonce(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
w.Header().Set("Replay-Nonce", string(nonce)) w.Header().Set("Replay-Nonce", string(nonce))
@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
var expected []string var expected []string
p, err := provisionerFromContext(r.Context()) p, err := provisionerFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return return
} }
} }
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"expected content-type to be in %s, but got %s", expected, ct)) "expected content-type to be in %s, but got %s", expected, ct))
} }
} }
@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) render.Error(w, acme.WrapErrorISE(err, "failed to read request body"))
return return
} }
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
return return
} }
ctx := context.WithValue(r.Context(), jwsContextKey, jws) ctx := context.WithValue(r.Context(), jwsContextKey, jws)
@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) jws, err := jwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if len(jws.Signatures) == 0 { if len(jws.Signatures) == 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
return return
} }
if len(jws.Signatures) > 1 { if len(jws.Signatures) > 1 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
return return
} }
@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
len(uh.Algorithm) > 0 || len(uh.Algorithm) > 0 ||
len(uh.Nonce) > 0 || len(uh.Nonce) > 0 ||
len(uh.ExtraHeaders) > 0 { len(uh.ExtraHeaders) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return return
} }
hdr := sig.Protected hdr := sig.Protected
@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
switch k := hdr.JSONWebKey.Key.(type) { switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
if k.Size() < keyutil.MinRSAKeyBytes { if k.Size() < keyutil.MinRSAKeyBytes {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"rsa keys must be at least %d bits (%d bytes) in size", "rsa keys must be at least %d bits (%d bytes) in size",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
return return
} }
default: default:
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"jws key type and algorithm do not match")) "jws key type and algorithm do not match"))
return return
} }
@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good // we good
default: default:
api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
return return
} }
// Check the validity/freshness of the Nonce. // Check the validity/freshness of the Nonce.
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
// Check that the JWS url matches the requested url. // Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string) jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok { if !ok {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
return return
} }
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() { if jwsURL != reqURL.String() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
return return
} }
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return return
} }
if hdr.JSONWebKey == nil && hdr.KeyID == "" { if hdr.JSONWebKey == nil && hdr.KeyID == "" {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return return
} }
next(w, r) next(w, r)
@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) jws, err := jwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
jwk := jws.Signatures[0].Protected.JSONWebKey jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil { if jwk == nil {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
return return
} }
if !jwk.Valid() { if !jwk.Valid() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
return return
} }
// Overwrite KeyID with the JWK thumbprint. // Overwrite KeyID with the JWK thumbprint.
jwk.KeyID, err = acme.KeyToID(jwk) jwk.KeyID, err = acme.KeyToID(jwk)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
return return
} }
@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
// For NewAccount and Revoke requests ... // For NewAccount and Revoke requests ...
break break
case err != nil: case err != nil:
api.WriteError(w, err) render.Error(w, err)
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
@ -290,17 +291,17 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
nameEscaped := chi.URLParam(r, "provisionerID") nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped) name, err := url.PathUnescape(nameEscaped)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return return
} }
p, err := h.ca.LoadProvisionerByName(name) p, err := h.ca.LoadProvisionerByName(name)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
acmeProv, ok := p.(*provisioner.ACME) acmeProv, ok := p.(*provisioner.ACME)
if !ok { if !ok {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return return
} }
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
@ -315,11 +316,11 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
ok, err := h.prerequisitesChecker(ctx) ok, err := h.prerequisitesChecker(ctx)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return return
} }
if !ok { if !ok {
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return return
} }
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
@ -334,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"kid does not have required prefix; expected %s, but got %s", "kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid)) kidPrefix, kid))
return return
@ -351,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
acc, err := h.db.GetAccount(ctx, accID) acc, err := h.db.GetAccount(ctx, accID)
switch { switch {
case nosql.IsErrNotFound(err): case nosql.IsErrNotFound(err):
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return return
case err != nil: case err != nil:
api.WriteError(w, err) render.Error(w, err)
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
@ -376,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -412,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return return
} }
payload, err := jws.Verify(jwk) payload, err := jws.Verify(jwk)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
return return
} }
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
@ -443,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context()) payload, err := payloadFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if !payload.isPostAsGet { if !payload.isPostAsGet {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
return return
} }
next(w, r) next(w, r)

View file

@ -11,9 +11,11 @@ import (
"time" "time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
) )
// NewOrderRequest represents the body for a NewOrder request. // NewOrderRequest represents the body for a NewOrder request.
@ -70,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var nor NewOrderRequest var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil { if err := json.Unmarshal(payload.value, &nor); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-order request payload")) "failed to unmarshal new-order request payload"))
return return
} }
if err := nor.Validate(); err != nil { if err := nor.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -116,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusPending, Status: acme.StatusPending,
} }
if err := h.newAuthorization(ctx, az); err != nil { if err := h.newAuthorization(ctx, az); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
o.AuthorizationIDs[i] = az.ID o.AuthorizationIDs[i] = az.ID
@ -135,14 +137,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
} }
if err := h.db.CreateOrder(ctx, o); err != nil { if err := h.db.CreateOrder(ctx, o); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return return
} }
h.linker.LinkOrder(ctx, o) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSONStatus(w, o, http.StatusCreated) render.JSONStatus(w, o, http.StatusCreated)
} }
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
@ -186,38 +188,38 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
} }
if acc.ID != o.AccountID { if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID)) "account '%s' does not own order '%s'", acc.ID, o.ID))
return return
} }
if prov.GetID() != o.ProvisionerID { if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.UpdateStatus(ctx, h.db); err != nil { if err = o.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
return return
} }
h.linker.LinkOrder(ctx, o) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o) render.JSON(w, o)
} }
// FinalizeOrder attemptst to finalize an order and create a certificate. // FinalizeOrder attemptst to finalize an order and create a certificate.
@ -225,54 +227,54 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var fr FinalizeRequest var fr FinalizeRequest
if err := json.Unmarshal(payload.value, &fr); err != nil { if err := json.Unmarshal(payload.value, &fr); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal finalize-order request payload")) "failed to unmarshal finalize-order request payload"))
return return
} }
if err := fr.Validate(); err != nil { if err := fr.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
} }
if acc.ID != o.AccountID { if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID)) "account '%s' does not own order '%s'", acc.ID, o.ID))
return return
} }
if prov.GetID() != o.ProvisionerID { if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
return return
} }
h.linker.LinkOrder(ctx, o) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o) render.JSON(w, o)
} }
// challengeTypes determines the types of challenges that should be used // challengeTypes determines the types of challenges that should be used

View file

@ -10,13 +10,14 @@ import (
"net/http" "net/http"
"strings" "strings"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
) )
type revokePayload struct { type revokePayload struct {
@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var p revokePayload var p revokePayload
err = json.Unmarshal(payload.value, &p) err = json.Unmarshal(payload.value, &p)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload")) render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
return return
} }
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
if err != nil { if err != nil {
// in this case the most likely cause is a client that didn't properly encode the certificate // in this case the most likely cause is a client that didn't properly encode the certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
return return
} }
certToBeRevoked, err := x509.ParseCertificate(certBytes) certToBeRevoked, err := x509.ParseCertificate(certBytes)
if err != nil { if err != nil {
// in this case a client may have encoded something different than a certificate // in this case a client may have encoded something different than a certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
return return
} }
serial := certToBeRevoked.SerialNumber.String() serial := certToBeRevoked.SerialNumber.String()
dbCert, err := h.db.GetCertificateBySerial(ctx, serial) dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return return
} }
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
// this should never happen // this should never happen
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal")) render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal"))
return return
} }
if shouldCheckAccountFrom(jws) { if shouldCheckAccountFrom(jws) {
account, err := accountFromContext(ctx) account, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil { if acmeErr != nil {
api.WriteError(w, acmeErr) render.Error(w, acmeErr)
return return
} }
} else { } else {
@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
_, err := jws.Verify(certToBeRevoked.PublicKey) _, err := jws.Verify(certToBeRevoked.PublicKey)
if err != nil { if err != nil {
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
return return
} }
} }
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return return
} }
if hasBeenRevokedBefore { if hasBeenRevokedBefore {
api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
return return
} }
reasonCode := p.ReasonCode reasonCode := p.ReasonCode
acmeErr := validateReasonCode(reasonCode) acmeErr := validateReasonCode(reasonCode)
if acmeErr != nil { if acmeErr != nil {
api.WriteError(w, acmeErr) render.Error(w, acmeErr)
return return
} }
@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
err = prov.AuthorizeRevoke(ctx, "") err = prov.AuthorizeRevoke(ctx, "")
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
return return
} }
options := revokeOptions(serial, certToBeRevoked, reasonCode) options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = h.ca.Revoke(ctx, options) err = h.ca.Revoke(ctx, options)
if err != nil { if err != nil {
api.WriteError(w, wrapRevokeErr(err)) render.Error(w, wrapRevokeErr(err))
return return
} }

View file

@ -79,7 +79,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
} }
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(u.String()) resp, err := vo.HTTPGet(u.String())
if err != nil { if err != nil {
@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
return nil return nil
} }
// http01ChallengeHost checks if a Challenge value is an IPv6 address
// and adds square brackets if that's the case, so that it can be used
// as a hostname. Returns the original Challenge value as the host to
// use in other cases.
func http01ChallengeHost(value string) string {
if ip := net.ParseIP(value); ip != nil && ip.To4() == nil {
value = "[" + value + "]"
}
return value
}
func tlsAlert(err error) uint8 { func tlsAlert(err error) uint8 {
var opErr *net.OpError var opErr *net.OpError
if errors.As(err, &opErr) { if errors.As(err, &opErr) {

View file

@ -13,6 +13,7 @@ import (
"encoding/asn1" "encoding/asn1"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
@ -23,9 +24,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
) )
func Test_storeError(t *testing.T) { func Test_storeError(t *testing.T) {
@ -2350,3 +2351,34 @@ func Test_serverName(t *testing.T) {
}) })
} }
} }
func Test_http01ChallengeHost(t *testing.T) {
tests := []struct {
name string
value string
want string
}{
{
name: "dns",
value: "www.example.com",
want: "www.example.com",
},
{
name: "ipv4",
value: "127.0.0.1",
want: "127.0.0.1",
},
{
name: "ipv6",
value: "::1",
want: "[::1]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := http01ChallengeHost(tt.value); got != tt.want {
t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -3,13 +3,10 @@ package acme
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/logging"
) )
// ProblemType is the type of the ACME problem. // ProblemType is the type of the ACME problem.
@ -353,26 +350,8 @@ func (e *Error) ToLog() (interface{}, error) {
return string(b), nil return string(b), nil
} }
// WriteError writes to w a JSON representation of the given error. // Render implements render.RenderableError for Error.
func WriteError(w http.ResponseWriter, err *Error) { func (e *Error) Render(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/problem+json") w.Header().Set("Content-Type", "application/problem+json")
w.WriteHeader(err.StatusCode()) render.JSONStatus(w, e, e.StatusCode())
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err.Err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.Err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
log.Println(err)
}
} }

View file

@ -22,6 +22,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api/log" "github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -284,7 +285,7 @@ func (h *caHandler) Route(r Router) {
// Version is an HTTP handler that returns the version of the server. // Version is an HTTP handler that returns the version of the server.
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
v := h.Authority.Version() v := h.Authority.Version()
JSON(w, VersionResponse{ render.JSON(w, VersionResponse{
Version: v.Version, Version: v.Version,
RequireClientAuthentication: v.RequireClientAuthentication, RequireClientAuthentication: v.RequireClientAuthentication,
}) })
@ -292,7 +293,7 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
// Health is an HTTP handler that returns the status of the server. // Health is an HTTP handler that returns the status of the server.
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
JSON(w, HealthResponse{Status: "ok"}) render.JSON(w, HealthResponse{Status: "ok"})
} }
// Root is an HTTP handler that using the SHA256 from the URL, returns the root // Root is an HTTP handler that using the SHA256 from the URL, returns the root
@ -303,11 +304,11 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the // Load root certificate with the
cert, err := h.Authority.Root(sum) cert, err := h.Authority.Root(sum)
if err != nil { if err != nil {
WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return return
} }
JSON(w, &RootResponse{RootPEM: Certificate{cert}}) render.JSON(w, &RootResponse{RootPEM: Certificate{cert}})
} }
func certChainToPEM(certChain []*x509.Certificate) []Certificate { func certChainToPEM(certChain []*x509.Certificate) []Certificate {
@ -322,16 +323,16 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r) cursor, limit, err := ParseCursor(r)
if err != nil { if err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
p, next, err := h.Authority.GetProvisioners(cursor, limit) p, next, err := h.Authority.GetProvisioners(cursor, limit)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &ProvisionersResponse{ render.JSON(w, &ProvisionersResponse{
Provisioners: p, Provisioners: p,
NextCursor: next, NextCursor: next,
}) })
@ -342,17 +343,17 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid") kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid) key, err := h.Authority.GetEncryptedKey(kid)
if err != nil { if err != nil {
WriteError(w, errs.NotFoundErr(err)) render.Error(w, errs.NotFoundErr(err))
return return
} }
JSON(w, &ProvisionerKeyResponse{key}) render.JSON(w, &ProvisionerKeyResponse{key})
} }
// Roots returns all the root certificates for the CA. // Roots returns all the root certificates for the CA.
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := h.Authority.GetRoots()
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error getting roots")) render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
return return
} }
@ -361,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{roots[i]} certs[i] = Certificate{roots[i]}
} }
JSONStatus(w, &RootsResponse{ render.JSONStatus(w, &RootsResponse{
Certificates: certs, Certificates: certs,
}, http.StatusCreated) }, http.StatusCreated)
} }
@ -370,7 +371,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := h.Authority.GetRoots()
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
@ -393,7 +394,7 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation() federated, err := h.Authority.GetFederation()
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error getting federated roots")) render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
return return
} }
@ -402,7 +403,7 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{federated[i]} certs[i] = Certificate{federated[i]}
} }
JSONStatus(w, &FederationResponse{ render.JSONStatus(w, &FederationResponse{
Certificates: certs, Certificates: certs,
}, http.StatusCreated) }, http.StatusCreated)
} }

View file

@ -1,62 +0,0 @@
package api
import (
"encoding/json"
"fmt"
"net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
)
// WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err error) {
switch k := err.(type) {
case *acme.Error:
acme.WriteError(w, k)
return
case *admin.Error:
admin.WriteError(w, k)
return
}
cause := errors.Cause(err)
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
} else if e, ok := cause.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
code := http.StatusInternalServerError
if sc, ok := err.(errs.StatusCoder); ok {
code = sc.StatusCode()
} else if sc, ok := cause.(errs.StatusCoder); ok {
code = sc.StatusCode()
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(err); err != nil {
log.Error(w, err)
}
}

View file

@ -2,22 +2,54 @@
package log package log
import ( import (
"fmt"
"log" "log"
"net/http" "net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
// StackTracedError is the set of errors implementing the StackTrace function.
//
// Errors implementing this interface have their stack traces logged when passed
// to the Error function of this package.
type StackTracedError interface {
error
StackTrace() errors.StackTrace
}
// Error adds to the response writer the given error if it implements // Error adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error // logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package. // using the log package.
func Error(rw http.ResponseWriter, err error) { func Error(rw http.ResponseWriter, err error) {
if rl, ok := rw.(logging.ResponseLogger); ok { rl, ok := rw.(logging.ResponseLogger)
rl.WithFields(map[string]interface{}{ if !ok {
"error": err,
})
} else {
log.Println(err) log.Println(err)
return
}
rl.WithFields(map[string]interface{}{
"error": err,
})
if os.Getenv("STEPDEBUG") != "1" {
return
}
e, ok := err.(StackTracedError)
if !ok {
e, ok = errors.Cause(err).(StackTracedError)
}
if ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e.StackTrace()),
})
} }
} }

View file

@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -28,24 +29,24 @@ func (s *RekeyRequest) Validate() error {
// Rekey is similar to renew except that the certificate will be renewed with new key from csr. // Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing client certificate")) render.Error(w, errs.BadRequest("missing client certificate"))
return return
} }
var body RekeyRequest var body RekeyRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
@ -55,7 +56,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{ render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,

122
api/render/render.go Normal file
View file

@ -0,0 +1,122 @@
// Package render implements functionality related to response rendering.
package render
import (
"bytes"
"encoding/json"
"net/http"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/log"
)
// JSON is shorthand for JSONStatus(w, v, http.StatusOK).
func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK)
}
// JSONStatus marshals v into w. It additionally sets the status code of
// w to the given one.
//
// JSONStatus sets the Content-Type of w to application/json unless one is
// specified.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(v); err != nil {
panic(err)
}
setContentTypeUnlessPresent(w, "application/json")
w.WriteHeader(status)
_, _ = b.WriteTo(w)
log.EnabledResponse(w, v)
}
// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK).
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
ProtoJSONStatus(w, m, http.StatusOK)
}
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
b, err := protojson.Marshal(m)
if err != nil {
panic(err)
}
setContentTypeUnlessPresent(w, "application/json")
w.WriteHeader(status)
_, _ = w.Write(b)
}
func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) {
const header = "Content-Type"
h := w.Header()
if _, ok := h[header]; !ok {
h.Set(header, contentType)
}
}
// RenderableError is the set of errors that implement the basic Render method.
//
// Errors that implement this interface will use their own Render method when
// being rendered into responses.
type RenderableError interface {
error
Render(http.ResponseWriter)
}
// Error marshals the JSON representation of err to w. In case err implements
// RenderableError its own Render method will be called instead.
func Error(w http.ResponseWriter, err error) {
log.Error(w, err)
if e, ok := err.(RenderableError); ok {
e.Render(w)
return
}
JSONStatus(w, err, statusCodeFromError(err))
}
// StatusCodedError is the set of errors that implement the basic StatusCode
// function.
//
// Errors that implement this interface will use the code reported by StatusCode
// as the HTTP response code when being rendered by this package.
type StatusCodedError interface {
error
StatusCode() int
}
func statusCodeFromError(err error) (code int) {
code = http.StatusInternalServerError
type causer interface {
Cause() error
}
for err != nil {
if sc, ok := err.(StatusCodedError); ok {
code = sc.StatusCode()
break
}
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return
}

115
api/render/render_test.go Normal file
View file

@ -0,0 +1,115 @@
package render
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/smallstep/certificates/logging"
)
func TestJSON(t *testing.T) {
rec := httptest.NewRecorder()
rw := logging.NewResponseLogger(rec)
JSON(rw, map[string]interface{}{"foo": "bar"})
assert.Equal(t, http.StatusOK, rec.Result().StatusCode)
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String())
assert.Empty(t, rw.Fields())
}
func TestJSONPanics(t *testing.T) {
assert.Panics(t, func() {
JSON(httptest.NewRecorder(), make(chan struct{}))
})
}
type renderableError struct {
Code int `json:"-"`
Message string `json:"message"`
}
func (err renderableError) Error() string {
return err.Message
}
func (err renderableError) Render(w http.ResponseWriter) {
w.Header().Set("Content-Type", "something/custom")
JSONStatus(w, err, err.Code)
}
type statusedError struct {
Contents string
}
func (err statusedError) Error() string { return err.Contents }
func (statusedError) StatusCode() int { return 432 }
func TestError(t *testing.T) {
cases := []struct {
err error
code int
body string
header string
}{
0: {
err: renderableError{532, "some string"},
code: 532,
body: "{\"message\":\"some string\"}\n",
header: "something/custom",
},
1: {
err: statusedError{"123"},
code: 432,
body: "{\"Contents\":\"123\"}\n",
header: "application/json",
},
}
for caseIndex := range cases {
kase := cases[caseIndex]
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
rec := httptest.NewRecorder()
Error(rec, kase.err)
assert.Equal(t, kase.code, rec.Result().StatusCode)
assert.Equal(t, kase.body, rec.Body.String())
assert.Equal(t, kase.header, rec.Header().Get("Content-Type"))
})
}
}
type causedError struct {
cause error
}
func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) }
func (err causedError) Cause() error { return err.cause }
func TestStatusCodeFromError(t *testing.T) {
cases := []struct {
err error
exp int
}{
0: {nil, http.StatusInternalServerError},
1: {io.EOF, http.StatusInternalServerError},
2: {statusedError{"123"}, 432},
3: {causedError{statusedError{"432"}}, 432},
}
for caseIndex, kase := range cases {
assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex)
}
}

View file

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -18,13 +19,13 @@ const (
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
cert, err := h.getPeerCertificate(r) cert, err := h.getPeerCertificate(r)
if err != nil { if err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
certChain, err := h.Authority.Renew(cert) certChain, err := h.Authority.Renew(cert)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
@ -34,7 +35,7 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{ render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,

View file

@ -7,6 +7,7 @@ import (
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -51,12 +52,12 @@ func (r *RevokeRequest) Validate() (err error) {
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
@ -73,7 +74,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
if len(body.OTT) > 0 { if len(body.OTT) > 0 {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
@ -82,12 +83,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number // the client certificate Serial Number must match the serial number
// being revoked. // being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing ott or client certificate")) render.Error(w, errs.BadRequest("missing ott or client certificate"))
return return
} }
opts.Crt = r.TLS.PeerCertificates[0] opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial { if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest("serial number in client certificate different than body")) render.Error(w, errs.BadRequest("serial number in client certificate different than body"))
return return
} }
// TODO: should probably be checking if the certificate was revoked here. // TODO: should probably be checking if the certificate was revoked here.
@ -98,12 +99,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
} }
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err, "error revoking certificate")) render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
return return
} }
logRevoke(w, opts) logRevoke(w, opts)
JSON(w, &RevokeResponse{Status: "ok"}) render.JSON(w, &RevokeResponse{Status: "ok"})
} }
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

View file

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -51,13 +52,13 @@ type SignResponse struct {
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
@ -69,13 +70,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
signOpts, err := h.Authority.AuthorizeSign(body.OTT) signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
@ -84,7 +85,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
caPEM = certChainPEM[1] caPEM = certChainPEM[1]
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{ render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,

View file

@ -12,6 +12,7 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -252,19 +253,19 @@ type SSHBastionResponse struct {
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }
@ -272,7 +273,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil { if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
return return
} }
} }
@ -289,13 +290,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
} }
@ -303,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
} }
addUserCertificate = &SSHCertificate{addUserCert} addUserCertificate = &SSHCertificate{addUserCert}
@ -316,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
@ -328,13 +329,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return return
} }
identityCertificate = certChainToPEM(certChain) identityCertificate = certChainToPEM(certChain)
} }
JSONStatus(w, &SSHSignResponse{ render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{cert}, Certificate: SSHCertificate{cert},
AddUserCertificate: addUserCertificate, AddUserCertificate: addUserCertificate,
IdentityCertificate: identityCertificate, IdentityCertificate: identityCertificate,
@ -346,12 +347,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots(r.Context()) keys, err := h.Authority.GetSSHRoots(r.Context())
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound("no keys found")) render.Error(w, errs.NotFound("no keys found"))
return return
} }
@ -363,7 +364,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
} }
JSON(w, resp) render.JSON(w, resp)
} }
// SSHFederation is an HTTP handler that returns the federated SSH public keys // SSHFederation is an HTTP handler that returns the federated SSH public keys
@ -371,12 +372,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation(r.Context()) keys, err := h.Authority.GetSSHFederation(r.Context())
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound("no keys found")) render.Error(w, errs.NotFound("no keys found"))
return return
} }
@ -388,7 +389,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
} }
JSON(w, resp) render.JSON(w, resp)
} }
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients // SSHConfig is an HTTP handler that returns rendered templates for ssh clients
@ -396,17 +397,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
@ -417,31 +418,31 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert: case provisioner.SSHHostCert:
cfg.HostTemplates = ts cfg.HostTemplates = ts
default: default:
WriteError(w, errs.InternalServer("it should hot get here")) render.Error(w, errs.InternalServer("it should hot get here"))
return return
} }
JSON(w, cfg) render.JSON(w, cfg)
} }
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest var body SSHCheckPrincipalRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHCheckPrincipalResponse{ render.JSON(w, &SSHCheckPrincipalResponse{
Exists: exists, Exists: exists,
}) })
} }
@ -455,10 +456,10 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHGetHostsResponse{ render.JSON(w, &SSHGetHostsResponse{
Hosts: hosts, Hosts: hosts,
}) })
} }
@ -467,21 +468,21 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHBastionResponse{ render.JSON(w, &SSHBastionResponse{
Hostname: body.Hostname, Hostname: body.Hostname,
Bastion: bastion, Bastion: bastion,
}) })

View file

@ -7,6 +7,7 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -41,36 +42,37 @@ type SSHRekeyResponse struct {
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return
} }
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
return return
} }
@ -80,11 +82,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
} }
JSONStatus(w, &SSHRekeyResponse{ render.JSONStatus(w, &SSHRekeyResponse{
Certificate: SSHCertificate{newCert}, Certificate: SSHCertificate{newCert},
IdentityCertificate: identity, IdentityCertificate: identity,
}, http.StatusCreated) }, http.StatusCreated)

View file

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -39,30 +40,31 @@ type SSHRenewResponse struct {
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT) _, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return
} }
newCert, err := h.Authority.RenewSSH(ctx, oldCert) newCert, err := h.Authority.RenewSSH(ctx, oldCert)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
return return
} }
@ -72,11 +74,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
} }
JSONStatus(w, &SSHSignResponse{ render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{newCert}, Certificate: SSHCertificate{newCert},
IdentityCertificate: identity, IdentityCertificate: identity,
}, http.StatusCreated) }, http.StatusCreated)

View file

@ -6,6 +6,7 @@ import (
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -50,12 +51,12 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
@ -71,18 +72,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
// otherwise it is assumed that the certificate is revoking itself over mTLS. // otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
return return
} }
logSSHRevoke(w, opts) logSSHRevoke(w, opts)
JSON(w, &SSHRevokeResponse{Status: "ok"}) render.JSON(w, &SSHRevokeResponse{Status: "ok"})
} }
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

View file

@ -1,57 +0,0 @@
package api
import (
"encoding/json"
"net/http"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/log"
)
// JSON writes the passed value into the http.ResponseWriter.
func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK)
}
// JSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil {
log.Error(w, err)
return
}
log.EnabledResponse(w, v)
}
// ProtoJSON writes the passed value into the http.ResponseWriter.
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
ProtoJSONStatus(w, m, http.StatusOK)
}
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
b, err := protojson.Marshal(m)
if err != nil {
log.Error(w, err)
return
}
if _, err := w.Write(b); err != nil {
log.Error(w, err)
return
}
// log.EnabledResponse(w, v)
}

View file

@ -1,53 +0,0 @@
package api
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/smallstep/certificates/logging"
)
func TestJSON(t *testing.T) {
type args struct {
rw http.ResponseWriter
v interface{}
}
tests := []struct {
name string
args args
ok bool
}{
{"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true},
{"fail", args{httptest.NewRecorder(), make(chan int)}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rw := logging.NewResponseLogger(tt.args.rw)
JSON(rw, tt.args.v)
rr, ok := tt.args.rw.(*httptest.ResponseRecorder)
if !ok {
t.Error("ResponseWriter does not implement *httptest.ResponseRecorder")
return
}
fields := rw.Fields()
if tt.ok {
if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" {
t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body)
}
if len(fields) != 0 {
t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields)
}
} else {
if body := rr.Body.String(); body != "" {
t.Errorf("Unexpected body = %s, want empty string", body)
}
if len(fields) != 1 {
t.Errorf("ResponseLogger fields = %v, wants 1 element", fields)
}
}
})
}
}

View file

@ -6,10 +6,12 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/api"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
) )
const ( const (
@ -44,11 +46,11 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
provName := chi.URLParam(r, "provisionerName") provName := chi.URLParam(r, "provisionerName")
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if !eabEnabled { if !eabEnabled {
api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
return return
} }
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, provisionerContextKey, prov)
@ -101,15 +103,15 @@ func NewACMEAdminResponder() *ACMEAdminResponder {
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint // GetExternalAccountKeys writes the response for the EAB keys GET endpoint
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// CreateExternalAccountKey writes the response for the EAB key POST endpoint // CreateExternalAccountKey writes the response for the EAB key POST endpoint
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }

View file

@ -10,6 +10,7 @@ import (
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
@ -85,28 +86,28 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
adm, ok := h.auth.LoadAdminByID(id) adm, ok := h.auth.LoadAdminByID(id)
if !ok { if !ok {
api.WriteError(w, admin.NewError(admin.ErrorNotFoundType, render.Error(w, admin.NewError(admin.ErrorNotFoundType,
"admin %s not found", id)) "admin %s not found", id))
return return
} }
api.ProtoJSON(w, adm) render.ProtoJSON(w, adm)
} }
// GetAdmins returns a segment of admins associated with the authority. // GetAdmins returns a segment of admins associated with the authority.
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params")) "error parsing cursor and limit from query params"))
return return
} }
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
return return
} }
api.JSON(w, &GetAdminsResponse{ render.JSON(w, &GetAdminsResponse{
Admins: admins, Admins: admins,
NextCursor: nextCursor, NextCursor: nextCursor,
}) })
@ -116,18 +117,18 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest var body CreateAdminRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
p, err := h.auth.LoadProvisionerByName(body.Provisioner) p, err := h.auth.LoadProvisionerByName(body.Provisioner)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
return return
} }
adm := &linkedca.Admin{ adm := &linkedca.Admin{
@ -137,11 +138,11 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
} }
// Store to authority collection. // Store to authority collection.
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error storing admin")) render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
return return
} }
api.ProtoJSONStatus(w, adm, http.StatusCreated) render.ProtoJSONStatus(w, adm, http.StatusCreated)
} }
// DeleteAdmin deletes admin. // DeleteAdmin deletes admin.
@ -149,23 +150,23 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
return return
} }
api.JSON(w, &DeleteResponse{Status: "ok"}) render.JSON(w, &DeleteResponse{Status: "ok"})
} }
// UpdateAdmin updates an existing admin. // UpdateAdmin updates an existing admin.
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest var body UpdateAdminRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -173,9 +174,9 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id)) render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
return return
} }
api.ProtoJSON(w, adm) render.ProtoJSON(w, adm)
} }

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
) )
@ -15,7 +15,7 @@ type nextHTTP = func(http.ResponseWriter, *http.Request)
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !h.auth.IsAdminAPIEnabled() { if !h.auth.IsAdminAPIEnabled() {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
"administration API not enabled")) "administration API not enabled"))
return return
} }
@ -28,14 +28,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization") tok := r.Header.Get("Authorization")
if tok == "" { if tok == "" {
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
"missing authorization header token")) "missing authorization header token"))
return return
} }
adm, err := h.auth.AuthorizeAdminToken(r, tok) adm, err := h.auth.AuthorizeAdminToken(r, tok)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }

View file

@ -4,10 +4,12 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -33,39 +35,39 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
) )
if len(id) > 0 { if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil { if p, err = h.auth.LoadProvisionerByID(id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil { if p, err = h.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
api.ProtoJSON(w, prov) render.ProtoJSON(w, prov)
} }
// GetProvisioners returns the given segment of provisioners associated with the authority. // GetProvisioners returns the given segment of provisioners associated with the authority.
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params")) "error parsing cursor and limit from query params"))
return return
} }
p, next, err := h.auth.GetProvisioners(cursor, limit) p, next, err := h.auth.GetProvisioners(cursor, limit)
if err != nil { if err != nil {
api.WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
api.JSON(w, &GetProvisionersResponse{ render.JSON(w, &GetProvisionersResponse{
Provisioners: p, Provisioners: p,
NextCursor: next, NextCursor: next,
}) })
@ -75,21 +77,21 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner) var prov = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, prov); err != nil { if err := read.ProtoJSON(r.Body, prov); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
// TODO: Validate inputs // TODO: Validate inputs
if err := authority.ValidateClaims(prov.Claims); err != nil { if err := authority.ValidateClaims(prov.Claims); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
return return
} }
api.ProtoJSONStatus(w, prov, http.StatusCreated) render.ProtoJSONStatus(w, prov, http.StatusCreated)
} }
// DeleteProvisioner deletes a provisioner. // DeleteProvisioner deletes a provisioner.
@ -103,75 +105,75 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
) )
if len(id) > 0 { if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil { if p, err = h.auth.LoadProvisionerByID(id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil { if p, err = h.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
return return
} }
api.JSON(w, &DeleteResponse{Status: "ok"}) render.JSON(w, &DeleteResponse{Status: "ok"})
} }
// UpdateProvisioner updates an existing prov. // UpdateProvisioner updates an existing prov.
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner) var nu = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, nu); err != nil { if err := read.ProtoJSON(r.Body, nu); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
name := chi.URLParam(r, "name") name := chi.URLParam(r, "name")
_old, err := h.auth.LoadProvisionerByName(name) _old, err := h.auth.LoadProvisionerByName(name)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
return return
} }
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
return return
} }
if nu.Id != old.Id { if nu.Id != old.Id {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID")) render.Error(w, admin.NewErrorISE("cannot change provisioner ID"))
return return
} }
if nu.Type != old.Type { if nu.Type != old.Type {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner type")) render.Error(w, admin.NewErrorISE("cannot change provisioner type"))
return return
} }
if nu.AuthorityId != old.AuthorityId { if nu.AuthorityId != old.AuthorityId {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID")) render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID"))
return return
} }
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt")) render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt"))
return return
} }
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt")) render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
return return
} }
// TODO: Validate inputs // TODO: Validate inputs
if err := authority.ValidateClaims(nu.Claims); err != nil { if err := authority.ValidateClaims(nu.Claims); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
api.ProtoJSON(w, nu) render.ProtoJSON(w, nu)
} }

View file

@ -3,13 +3,10 @@ package admin
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/logging"
) )
// ProblemType is the type of the Admin problem. // ProblemType is the type of the Admin problem.
@ -197,27 +194,9 @@ func (e *Error) ToLog() (interface{}, error) {
return string(b), nil return string(b), nil
} }
// WriteError writes to w a JSON representation of the given error. // Render implements render.RenderableError for Error.
func WriteError(w http.ResponseWriter, err *Error) { func (e *Error) Render(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json") e.Message = e.Err.Error()
w.WriteHeader(err.StatusCode())
err.Message = err.Err.Error() render.JSONStatus(w, e, e.StatusCode())
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err.Err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.Err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
log.Println(err)
}
} }

View file

@ -80,6 +80,14 @@ type Authority struct {
adminMutex sync.RWMutex adminMutex sync.RWMutex
} }
type Info struct {
StartTime time.Time
RootX509Certs []*x509.Certificate
SSHCAUserPublicKey []byte
SSHCAHostPublicKey []byte
DNSNames []string
}
// New creates and initiates a new Authority type. // New creates and initiates a new Authority type.
func New(cfg *config.Config, opts ...Option) (*Authority, error) { func New(cfg *config.Config, opts ...Option) (*Authority, error) {
err := cfg.Validate() err := cfg.Validate()
@ -325,8 +333,6 @@ func (a *Authority) init() error {
return err return err
} }
a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate) a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate)
sum := sha256.Sum256(resp.RootCertificate.Raw)
log.Printf("Using root fingerprint '%s'", hex.EncodeToString(sum[:]))
} }
} }
@ -581,6 +587,21 @@ func (a *Authority) GetAdminDatabase() admin.DB {
return a.adminDB return a.adminDB
} }
func (a *Authority) GetInfo() Info {
ai := Info{
StartTime: a.startTime,
RootX509Certs: a.rootX509Certs,
DNSNames: a.config.DNSNames,
}
if a.sshCAUserCertSignKey != nil {
ai.SSHCAUserPublicKey = ssh.MarshalAuthorizedKey(a.sshCAUserCertSignKey.PublicKey())
}
if a.sshCAHostCertSignKey != nil {
ai.SSHCAHostPublicKey = ssh.MarshalAuthorizedKey(a.sshCAHostCertSignKey.PublicKey())
}
return ai
}
// IsAdminAPIEnabled returns a boolean indicating whether the Admin API has // IsAdminAPIEnabled returns a boolean indicating whether the Admin API has
// been enabled. // been enabled.
func (a *Authority) IsAdminAPIEnabled() bool { func (a *Authority) IsAdminAPIEnabled() bool {

View file

@ -9,6 +9,7 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
@ -16,16 +17,18 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
) )
var testAudiences = provisioner.Audiences{ var testAudiences = provisioner.Audiences{
@ -310,8 +313,8 @@ func TestAuthority_authorizeToken(t *testing.T) {
p, err := tc.auth.authorizeToken(context.Background(), tc.token) p, err := tc.auth.authorizeToken(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -396,8 +399,8 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil { if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -481,8 +484,8 @@ func TestAuthority_authorizeSign(t *testing.T) {
got, err := tc.auth.authorizeSign(context.Background(), tc.token) got, err := tc.auth.authorizeSign(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -740,8 +743,8 @@ func TestAuthority_Authorize(t *testing.T) {
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, got) assert.Nil(t, got)
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -853,7 +856,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
err := tc.auth.authorizeRenew(tc.cert) err := tc.auth.authorizeRenew(tc.cert)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -1001,8 +1004,8 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token) got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -1118,8 +1121,8 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token) got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -1218,8 +1221,8 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) {
if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil { if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -1311,8 +1314,8 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token) cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -163,6 +163,22 @@ func WithX509Signer(crt *x509.Certificate, s crypto.Signer) Option {
} }
} }
// WithX509SignerFunc defines the function used to get the chain of certificates
// and signer used when we sign X.509 certificates.
func WithX509SignerFunc(fn func() ([]*x509.Certificate, crypto.Signer, error)) Option {
return func(a *Authority) error {
srv, err := cas.New(context.Background(), casapi.Options{
Type: casapi.SoftCAS,
CertificateSigner: fn,
})
if err != nil {
return err
}
a.x509CAService = srv
return nil
}
}
// WithSSHUserSigner defines the signer used to sign SSH user certificates. // WithSSHUserSigner defines the signer used to sign SSH user certificates.
func WithSSHUserSigner(s crypto.Signer) Option { func WithSSHUserSigner(s crypto.Signer) Option {
return func(a *Authority) error { return func(a *Authority) error {

View file

@ -3,13 +3,14 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"errors"
"fmt"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/api/render"
) )
func TestACME_Getters(t *testing.T) { func TestACME_Getters(t *testing.T) {
@ -114,7 +115,7 @@ func TestACME_AuthorizeRenew(t *testing.T) {
NotAfter: now.Add(time.Hour), NotAfter: now.Add(time.Hour),
}, },
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -133,8 +134,8 @@ func TestACME_AuthorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -168,8 +169,8 @@ func TestACME_AuthorizeSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -193,7 +194,7 @@ func TestACME_AuthorizeSign(t *testing.T) {
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }

View file

@ -9,6 +9,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -17,10 +18,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestAWS_Getters(t *testing.T) { func TestAWS_Getters(t *testing.T) {
@ -521,8 +522,8 @@ func TestAWS_authorizeToken(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil { if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -668,8 +669,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
case err != nil: case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
default: default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
@ -699,7 +700,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
case dnsNamesValidator: case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -803,8 +804,8 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
@ -861,8 +862,8 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })

View file

@ -8,6 +8,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -15,10 +16,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestAzure_Getters(t *testing.T) { func TestAzure_Getters(t *testing.T) {
@ -335,8 +336,8 @@ func TestAzure_authorizeToken(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, name, group, subscriptionID, objectID, err := tc.p.authorizeToken(tc.token); err != nil { if claims, name, group, subscriptionID, objectID, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -497,8 +498,8 @@ func TestAzure_AuthorizeSign(t *testing.T) {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
case err != nil: case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
default: default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
@ -528,7 +529,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
case dnsNamesValidator: case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"virtualMachine"}) assert.Equals(t, []string(v), []string{"virtualMachine"})
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -573,8 +574,8 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })
@ -669,8 +670,8 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {

View file

@ -8,6 +8,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -16,10 +17,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestGCP_Getters(t *testing.T) { func TestGCP_Getters(t *testing.T) {
@ -390,8 +391,8 @@ func TestGCP_authorizeToken(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil { if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -540,8 +541,8 @@ func TestGCP_AuthorizeSign(t *testing.T) {
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
case err != nil: case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
default: default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
@ -571,7 +572,7 @@ func TestGCP_AuthorizeSign(t *testing.T) {
case dnsNamesValidator: case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -678,8 +679,8 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
@ -735,7 +736,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }

View file

@ -6,15 +6,17 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"errors"
"fmt"
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestJWK_Getters(t *testing.T) { func TestJWK_Getters(t *testing.T) {
@ -183,8 +185,8 @@ func TestJWK_authorizeToken(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil { if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
@ -223,8 +225,8 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil { if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
@ -288,8 +290,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil { if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
@ -316,7 +318,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
case defaultSANsValidator: case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans) assert.Equals(t, []string(v), tt.sans)
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -362,8 +364,8 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })
@ -456,8 +458,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
@ -621,8 +623,8 @@ func TestJWK_AuthorizeSSHRevoke(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -3,14 +3,16 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"errors"
"fmt"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestK8sSA_Getters(t *testing.T) { func TestK8sSA_Getters(t *testing.T) {
@ -116,8 +118,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -165,8 +167,8 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -202,7 +204,7 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
NotAfter: now.Add(time.Hour), NotAfter: now.Add(time.Hour),
}, },
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -221,8 +223,8 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -270,8 +272,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -295,7 +297,7 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
tot++ tot++
} }
@ -327,7 +329,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
p: p, p: p,
token: "foo", token: "foo",
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), err: fmt.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()),
} }
}, },
"fail/invalid-token": func(t *testing.T) test { "fail/invalid-token": func(t *testing.T) test {
@ -358,8 +360,8 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -379,7 +381,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case *sshDefaultDuration: case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
tot++ tot++
} }

View file

@ -6,16 +6,17 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func Test_openIDConfiguration_Validate(t *testing.T) { func Test_openIDConfiguration_Validate(t *testing.T) {
@ -246,8 +247,8 @@ func TestOIDC_authorizeToken(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else { } else {
@ -317,8 +318,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
@ -341,7 +342,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
case emailOnlyIdentity: case emailOnlyIdentity:
assert.Equals(t, string(v), "name@smallstep.com") assert.Equals(t, string(v), "name@smallstep.com")
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -403,8 +404,8 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })
@ -449,8 +450,8 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })
@ -605,8 +606,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
@ -673,8 +674,8 @@ func TestOIDC_AuthorizeSSHRevoke(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })

View file

@ -2,13 +2,14 @@ package provisioner
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"testing" "testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestType_String(t *testing.T) { func TestType_String(t *testing.T) {
@ -240,8 +241,8 @@ func TestUnimplementedMethods(t *testing.T) {
default: default:
t.Errorf("unexpected method %s", tt.method) t.Errorf("unexpected method %s", tt.method)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized) assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
assert.Equals(t, err.Error(), msg) assert.Equals(t, err.Error(), msg)
}) })

View file

@ -5,16 +5,19 @@ import (
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"fmt"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestSSHPOP_Getters(t *testing.T) { func TestSSHPOP_Getters(t *testing.T) {
@ -215,8 +218,8 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -286,8 +289,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -367,8 +370,8 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
tc := tt(t) tc := tt(t)
if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil { if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -449,8 +452,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
tc := tt(t) tc := tt(t)
if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil { if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -464,7 +467,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
assert.Equals(t, tc.cert.Nonce, cert.Nonce) assert.Equals(t, tc.cert.Nonce, cert.Nonce)

View file

@ -3,16 +3,18 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"errors"
"fmt"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestX5C_Getters(t *testing.T) { func TestX5C_Getters(t *testing.T) {
@ -71,7 +73,7 @@ func TestX5C_Init(t *testing.T) {
"fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")}, p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")},
err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), err: errors.New("no x509 certificates found in roots attribute for provisioner 'foo'"),
} }
}, },
"fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest { "fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest {
@ -122,7 +124,7 @@ M46l92gdOozT
// check the number of certificates in the pool. // check the number of certificates in the pool.
numCerts := len(p.rootPool.Subjects()) numCerts := len(p.rootPool.Subjects())
if numCerts != 2 { if numCerts != 2 {
return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts) return fmt.Errorf("unexpected number of certs: want 2, but got %d", numCerts)
} }
return nil return nil
}, },
@ -387,8 +389,8 @@ lgsqsR63is+0YQ==
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -458,7 +460,7 @@ func TestX5C_AuthorizeSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -491,7 +493,7 @@ func TestX5C_AuthorizeSign(t *testing.T) {
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -542,8 +544,8 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -573,7 +575,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
return test{ return test{
p: p, p: p,
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -592,8 +594,8 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
NotAfter: now.Add(time.Hour), NotAfter: now.Add(time.Hour),
}); err != nil { }); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -632,7 +634,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
p: p, p: p,
token: "foo", token: "foo",
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), err: fmt.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()),
} }
}, },
"fail/invalid-token": func(t *testing.T) test { "fail/invalid-token": func(t *testing.T) test {
@ -753,7 +755,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -788,7 +790,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc:
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
tot++ tot++
} }

View file

@ -1,13 +1,13 @@
package authority package authority
import ( import (
"errors"
"net/http" "net/http"
"testing" "testing"
"github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
) )
func TestGetEncryptedKey(t *testing.T) { func TestGetEncryptedKey(t *testing.T) {
@ -49,8 +49,8 @@ func TestGetEncryptedKey(t *testing.T) {
ek, err := tc.a.GetEncryptedKey(tc.kid) ek, err := tc.a.GetEncryptedKey(tc.kid)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -90,8 +90,8 @@ func TestGetProvisioners(t *testing.T) {
ps, next, err := tc.a.GetProvisioners("", 0) ps, next, err := tc.a.GetProvisioners("", 0)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -2,14 +2,15 @@ package authority
import ( import (
"crypto/x509" "crypto/x509"
"errors"
"net/http" "net/http"
"reflect" "reflect"
"testing" "testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestRoot(t *testing.T) { func TestRoot(t *testing.T) {
@ -31,7 +32,7 @@ func TestRoot(t *testing.T) {
crt, err := a.Root(tc.sum) crt, err := a.Root(tc.sum)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())

View file

@ -7,21 +7,22 @@ import (
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/templates"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/templates"
) )
type sshTestModifier ssh.Certificate type sshTestModifier ssh.Certificate
@ -716,8 +717,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { } else if err != nil {
_, ok := err.(errs.StatusCoder) _, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want)
@ -806,8 +807,8 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
hosts, err := auth.GetSSHHosts(context.Background(), tc.cert) hosts, err := auth.GetSSHHosts(context.Background(), tc.cert)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -1033,8 +1034,8 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...) cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -11,24 +11,26 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/smallstep/certificates/cas/softcas" "gopkg.in/square/go-jose.v2/jwt"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas/softcas"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
) )
var ( var (
@ -187,14 +189,14 @@ func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) {
func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) {
b, err := x509.MarshalPKIXPublicKey(pub) b, err := x509.MarshalPKIXPublicKey(pub)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling public key") return nil, fmt.Errorf("error marshaling public key: %w", err)
} }
info := struct { info := struct {
Algorithm pkix.AlgorithmIdentifier Algorithm pkix.AlgorithmIdentifier
SubjectPublicKey asn1.BitString SubjectPublicKey asn1.BitString
}{} }{}
if _, err = asn1.Unmarshal(b, &info); err != nil { if _, err = asn1.Unmarshal(b, &info); err != nil {
return nil, errors.Wrap(err, "error unmarshaling public key") return nil, fmt.Errorf("error unmarshaling public key: %w", err)
} }
hash := sha1.Sum(info.SubjectPublicKey.Bytes) hash := sha1.Sum(info.SubjectPublicKey.Bytes)
return hash[:], nil return hash[:], nil
@ -661,8 +663,8 @@ ZYtQ9Ot36qc=
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -860,8 +862,8 @@ func TestAuthority_Renew(t *testing.T) {
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -1067,8 +1069,8 @@ func TestAuthority_Rekey(t *testing.T) {
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -1456,8 +1458,8 @@ func TestAuthority_Revoke(t *testing.T) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
if err := tc.auth.Revoke(ctx, tc.opts); err != nil { if err := tc.auth.Revoke(ctx, tc.opts); err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())

View file

@ -12,12 +12,13 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api" acmeAPI "github.com/smallstep/certificates/acme/api"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
) )
func TestNewACMEClient(t *testing.T) { func TestNewACMEClient(t *testing.T) {
@ -112,15 +113,15 @@ func TestNewACMEClient(t *testing.T) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
switch { switch {
case i == 0: case i == 0:
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
case i == 1: case i == 1:
w.Header().Set("Replay-Nonce", "abc123") w.Header().Set("Replay-Nonce", "abc123")
api.JSONStatus(w, []byte{}, 200) render.JSONStatus(w, []byte{}, 200)
i++ i++
default: default:
w.Header().Set("Location", accLocation) w.Header().Set("Location", accLocation)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
} }
}) })
@ -206,7 +207,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
}) })
if nonce, err := ac.GetNonce(); err != nil { if nonce, err := ac.GetNonce(); err != nil {
@ -315,7 +316,7 @@ func TestACMEClient_post(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -338,7 +339,7 @@ func TestACMEClient_post(t *testing.T) {
assert.Equals(t, hdr.KeyID, ac.kid) assert.Equals(t, hdr.KeyID, ac.kid)
} }
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil {
@ -455,7 +456,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -477,7 +478,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, payload, norb) assert.Equals(t, payload, norb)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := ac.NewOrder(norb); err != nil { if res, err := ac.NewOrder(norb); err != nil {
@ -577,7 +578,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -599,7 +600,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := ac.GetOrder(url); err != nil { if res, err := ac.GetOrder(url); err != nil {
@ -699,7 +700,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -721,7 +722,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := ac.GetAuthz(url); err != nil { if res, err := ac.GetAuthz(url); err != nil {
@ -821,7 +822,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -844,7 +845,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := ac.GetChallenge(url); err != nil { if res, err := ac.GetChallenge(url); err != nil {
@ -944,7 +945,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -967,7 +968,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
assert.Equals(t, payload, []byte("{}")) assert.Equals(t, payload, []byte("{}"))
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if err := ac.ValidateChallenge(url); err != nil { if err := ac.ValidateChallenge(url); err != nil {
@ -1071,7 +1072,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -1093,7 +1094,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, payload, frb) assert.Equals(t, payload, frb)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if err := ac.FinalizeOrder(url, csr); err != nil { if err := ac.FinalizeOrder(url, csr); err != nil {
@ -1200,7 +1201,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -1222,7 +1223,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := tc.client.GetAccountOrders(); err != nil { if res, err := tc.client.GetAccountOrders(); err != nil {
@ -1331,7 +1332,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -1356,7 +1357,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
if tc.certBytes != nil { if tc.certBytes != nil {
w.Write(tc.certBytes) w.Write(tc.certBytes)
} else { } else {
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
} }
}) })

View file

@ -14,11 +14,14 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
) )
func newLocalListener() net.Listener { func newLocalListener() net.Listener {
@ -79,7 +82,7 @@ func startCAServer(configFile string) (*CA, string, error) {
func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/version" { if r.URL.Path == "/version" {
api.JSON(w, api.VersionResponse{ render.JSON(w, api.VersionResponse{
Version: "test", Version: "test",
RequireClientAuthentication: true, RequireClientAuthentication: true,
}) })
@ -93,7 +96,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han
} }
isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0
if !isMTLS { if !isMTLS {
api.WriteError(w, errs.Unauthorized("missing peer certificate")) render.Error(w, errs.Unauthorized("missing peer certificate"))
} else { } else {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }

View file

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
"strings"
"sync" "sync"
"github.com/go-chi/chi" "github.com/go-chi/chi"
@ -26,11 +27,14 @@ import (
scepAPI "github.com/smallstep/certificates/scep/api" scepAPI "github.com/smallstep/certificates/scep/api"
"github.com/smallstep/certificates/server" "github.com/smallstep/certificates/server"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/cli-utils/step"
"go.step.sm/crypto/x509util"
) )
type options struct { type options struct {
configFile string configFile string
linkedCAToken string linkedCAToken string
quiet bool
password []byte password []byte
issuerPassword []byte issuerPassword []byte
sshHostPassword []byte sshHostPassword []byte
@ -101,6 +105,13 @@ func WithLinkedCAToken(token string) Option {
} }
} }
// WithQuiet sets the quiet flag.
func WithQuiet(quiet bool) Option {
return func(o *options) {
o.quiet = quiet
}
}
// CA is the type used to build the complete certificate authority. It builds // CA is the type used to build the complete certificate authority. It builds
// the HTTP server, set ups the middlewares and the HTTP handlers. // the HTTP server, set ups the middlewares and the HTTP handlers.
type CA struct { type CA struct {
@ -288,6 +299,35 @@ func (ca *CA) Run() error {
var wg sync.WaitGroup var wg sync.WaitGroup
errs := make(chan error, 1) errs := make(chan error, 1)
if !ca.opts.quiet {
authorityInfo := ca.auth.GetInfo()
log.Printf("Starting %s", step.Version())
log.Printf("Documentation: https://u.step.sm/docs/ca")
log.Printf("Community Discord: https://u.step.sm/discord")
if step.Contexts().GetCurrent() != nil {
log.Printf("Current context: %s", step.Contexts().GetCurrent().Name)
}
log.Printf("Config file: %s", ca.opts.configFile)
baseURL := fmt.Sprintf("https://%s%s",
authorityInfo.DNSNames[0],
ca.config.Address[strings.LastIndex(ca.config.Address, ":"):])
log.Printf("The primary server URL is %s", baseURL)
log.Printf("Root certificates are available at %s/roots.pem", baseURL)
if len(authorityInfo.DNSNames) > 1 {
log.Printf("Additional configured hostnames: %s",
strings.Join(authorityInfo.DNSNames[1:], ", "))
}
for _, crt := range authorityInfo.RootX509Certs {
log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt))
}
if authorityInfo.SSHCAHostPublicKey != nil {
log.Printf("SSH Host CA Key is %s\n", authorityInfo.SSHCAHostPublicKey)
}
if authorityInfo.SSHCAUserPublicKey != nil {
log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey)
}
}
if ca.insecureSrv != nil { if ca.insecureSrv != nil {
wg.Add(1) wg.Add(1)
go func() { go func() {
@ -355,6 +395,7 @@ func (ca *CA) Reload() error {
WithSSHUserPassword(ca.opts.sshUserPassword), WithSSHUserPassword(ca.opts.sshUserPassword),
WithIssuerPassword(ca.opts.issuerPassword), WithIssuerPassword(ca.opts.issuerPassword),
WithLinkedCAToken(ca.opts.linkedCAToken), WithLinkedCAToken(ca.opts.linkedCAToken),
WithQuiet(ca.opts.quiet),
WithConfigFile(ca.opts.configFile), WithConfigFile(ca.opts.configFile),
WithDatabase(ca.auth.GetDatabase()), WithDatabase(ca.auth.GetDatabase()),
) )

View file

@ -8,6 +8,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -16,14 +17,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -182,7 +182,7 @@ func TestClient_Version(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Version() got, err := c.Version()
@ -232,7 +232,7 @@ func TestClient_Health(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Health() got, err := c.Health()
@ -290,7 +290,7 @@ func TestClient_Root(t *testing.T) {
if req.RequestURI != expected { if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Root(tt.shasum) got, err := c.Root(tt.shasum)
@ -360,7 +360,7 @@ func TestClient_Sign(t *testing.T) {
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
if tt.request == nil { if tt.request == nil {
@ -371,7 +371,7 @@ func TestClient_Sign(t *testing.T) {
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
} }
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Sign(tt.request) got, err := c.Sign(tt.request)
@ -432,7 +432,7 @@ func TestClient_Revoke(t *testing.T) {
if err := read.JSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
if tt.request == nil { if tt.request == nil {
@ -443,7 +443,7 @@ func TestClient_Revoke(t *testing.T) {
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
} }
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Revoke(tt.request, nil) got, err := c.Revoke(tt.request, nil)
@ -503,7 +503,7 @@ func TestClient_Renew(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Renew(nil) got, err := c.Renew(nil)
@ -519,8 +519,8 @@ func TestClient_Renew(t *testing.T) {
t.Errorf("Client.Renew() = %v, want nil", got) t.Errorf("Client.Renew() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
@ -568,9 +568,9 @@ func TestClient_RenewWithToken(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer token" { if req.Header.Get("Authorization") != "Bearer token" {
api.JSONStatus(w, errs.InternalServer("force"), 500) render.JSONStatus(w, errs.InternalServer("force"), 500)
} else { } else {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
} }
}) })
@ -587,8 +587,8 @@ func TestClient_RenewWithToken(t *testing.T) {
t.Errorf("Client.RenewWithToken() = %v, want nil", got) t.Errorf("Client.RenewWithToken() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
@ -640,7 +640,7 @@ func TestClient_Rekey(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Rekey(tt.request, nil) got, err := c.Rekey(tt.request, nil)
@ -656,8 +656,8 @@ func TestClient_Rekey(t *testing.T) {
t.Errorf("Client.Renew() = %v, want nil", got) t.Errorf("Client.Renew() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
@ -705,7 +705,7 @@ func TestClient_Provisioners(t *testing.T) {
if req.RequestURI != tt.expectedURI { if req.RequestURI != tt.expectedURI {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Provisioners(tt.args...) got, err := c.Provisioners(tt.args...)
@ -762,7 +762,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
if req.RequestURI != expected { if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.ProvisionerKey(tt.kid) got, err := c.ProvisionerKey(tt.kid)
@ -777,8 +777,8 @@ func TestClient_ProvisionerKey(t *testing.T) {
t.Errorf("Client.ProvisionerKey() = %v, want nil", got) t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
@ -821,7 +821,7 @@ func TestClient_Roots(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Roots() got, err := c.Roots()
@ -836,8 +836,8 @@ func TestClient_Roots(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Roots() = %v, want nil", got) t.Errorf("Client.Roots() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
@ -879,7 +879,7 @@ func TestClient_Federation(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Federation() got, err := c.Federation()
@ -894,8 +894,8 @@ func TestClient_Federation(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got) t.Errorf("Client.Federation() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
@ -941,7 +941,7 @@ func TestClient_SSHRoots(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.SSHRoots() got, err := c.SSHRoots()
@ -956,8 +956,8 @@ func TestClient_SSHRoots(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.SSHKeys() = %v, want nil", got) t.Errorf("Client.SSHKeys() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
@ -1041,7 +1041,7 @@ func TestClient_RootFingerprint(t *testing.T) {
} }
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.RootFingerprint() got, err := c.RootFingerprint()
@ -1102,7 +1102,7 @@ func TestClient_SSHBastion(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.SSHBastion(tt.request) got, err := c.SSHBastion(tt.request)
@ -1118,8 +1118,8 @@ func TestClient_SSHBastion(t *testing.T) {
t.Errorf("Client.SSHBastion() = %v, want nil", got) t.Errorf("Client.SSHBastion() = %v, want nil", got)
} }
if tt.responseCode != 200 { if tt.responseCode != 200 {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }

View file

@ -31,12 +31,20 @@ type Options struct {
// https://cloud.google.com/docs/authentication. // https://cloud.google.com/docs/authentication.
CredentialsFile string `json:"credentialsFile,omitempty"` CredentialsFile string `json:"credentialsFile,omitempty"`
// Certificate and signer are the issuer certificate, along with any other // CertificateChain contains the issuer certificate, along with any other
// bundled certificates to be returned in the chain for consumers, and // bundled certificates to be returned in the chain to consumers. It is used
// signer used in SoftCAS. They are configured in ca.json crt and key // used in SoftCAS and it is configured in the crt property of the ca.json.
// properties.
CertificateChain []*x509.Certificate `json:"-"` CertificateChain []*x509.Certificate `json:"-"`
Signer crypto.Signer `json:"-"`
// Signer is the private key or a KMS signer for the issuer certificate. It
// is used in SoftCAS and it is configured in the key property of the
// ca.json.
Signer crypto.Signer `json:"-"`
// CertificateSigner combines CertificateChain and Signer in a callback that
// returns the chain of certificate and signer used to sign X.509
// certificates in SoftCAS.
CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) `json:"-"`
// IsCreator is set to true when we're creating a certificate authority. It // IsCreator is set to true when we're creating a certificate authority. It
// is used to skip some validations when initializing a // is used to skip some validations when initializing a

View file

@ -24,9 +24,10 @@ var now = time.Now
// SoftCAS implements a Certificate Authority Service using Golang or KMS // SoftCAS implements a Certificate Authority Service using Golang or KMS
// crypto. This is the default CAS used in step-ca. // crypto. This is the default CAS used in step-ca.
type SoftCAS struct { type SoftCAS struct {
CertificateChain []*x509.Certificate CertificateChain []*x509.Certificate
Signer crypto.Signer Signer crypto.Signer
KeyManager kms.KeyManager CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error)
KeyManager kms.KeyManager
} }
// New creates a new CertificateAuthorityService implementation using Golang or KMS // New creates a new CertificateAuthorityService implementation using Golang or KMS
@ -34,16 +35,17 @@ type SoftCAS struct {
func New(ctx context.Context, opts apiv1.Options) (*SoftCAS, error) { func New(ctx context.Context, opts apiv1.Options) (*SoftCAS, error) {
if !opts.IsCreator { if !opts.IsCreator {
switch { switch {
case len(opts.CertificateChain) == 0: case len(opts.CertificateChain) == 0 && opts.CertificateSigner == nil:
return nil, errors.New("softCAS 'CertificateChain' cannot be nil") return nil, errors.New("softCAS 'CertificateChain' cannot be nil")
case opts.Signer == nil: case opts.Signer == nil && opts.CertificateSigner == nil:
return nil, errors.New("softCAS 'signer' cannot be nil") return nil, errors.New("softCAS 'signer' cannot be nil")
} }
} }
return &SoftCAS{ return &SoftCAS{
CertificateChain: opts.CertificateChain, CertificateChain: opts.CertificateChain,
Signer: opts.Signer, Signer: opts.Signer,
KeyManager: opts.KeyManager, CertificateSigner: opts.CertificateSigner,
KeyManager: opts.KeyManager,
}, nil }, nil
} }
@ -57,6 +59,7 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
} }
t := now() t := now()
// Provisioners can also set specific values. // Provisioners can also set specific values.
if req.Template.NotBefore.IsZero() { if req.Template.NotBefore.IsZero() {
req.Template.NotBefore = t.Add(-1 * req.Backdate) req.Template.NotBefore = t.Add(-1 * req.Backdate)
@ -64,16 +67,21 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
if req.Template.NotAfter.IsZero() { if req.Template.NotAfter.IsZero() {
req.Template.NotAfter = t.Add(req.Lifetime) req.Template.NotAfter = t.Add(req.Lifetime)
} }
req.Template.Issuer = c.CertificateChain[0].Subject
cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) chain, signer, err := c.getCertSigner()
if err != nil {
return nil, err
}
req.Template.Issuer = chain[0].Subject
cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &apiv1.CreateCertificateResponse{ return &apiv1.CreateCertificateResponse{
Certificate: cert, Certificate: cert,
CertificateChain: c.CertificateChain, CertificateChain: chain,
}, nil }, nil
} }
@ -89,16 +97,21 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R
t := now() t := now()
req.Template.NotBefore = t.Add(-1 * req.Backdate) req.Template.NotBefore = t.Add(-1 * req.Backdate)
req.Template.NotAfter = t.Add(req.Lifetime) req.Template.NotAfter = t.Add(req.Lifetime)
req.Template.Issuer = c.CertificateChain[0].Subject
cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) chain, signer, err := c.getCertSigner()
if err != nil {
return nil, err
}
req.Template.Issuer = chain[0].Subject
cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &apiv1.RenewCertificateResponse{ return &apiv1.RenewCertificateResponse{
Certificate: cert, Certificate: cert,
CertificateChain: c.CertificateChain, CertificateChain: chain,
}, nil }, nil
} }
@ -106,9 +119,13 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R
// operation is a no-op as the actual revoke will happen when we store the entry // operation is a no-op as the actual revoke will happen when we store the entry
// in the db. // in the db.
func (c *SoftCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { func (c *SoftCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
chain, _, err := c.getCertSigner()
if err != nil {
return nil, err
}
return &apiv1.RevokeCertificateResponse{ return &apiv1.RevokeCertificateResponse{
Certificate: req.Certificate, Certificate: req.Certificate,
CertificateChain: c.CertificateChain, CertificateChain: chain,
}, nil }, nil
} }
@ -179,7 +196,7 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori
}, nil }, nil
} }
// initializeKeyManager initiazes the default key manager if was not given. // initializeKeyManager initializes the default key manager if was not given.
func (c *SoftCAS) initializeKeyManager() (err error) { func (c *SoftCAS) initializeKeyManager() (err error) {
if c.KeyManager == nil { if c.KeyManager == nil {
c.KeyManager, err = kms.New(context.Background(), kmsapi.Options{ c.KeyManager, err = kms.New(context.Background(), kmsapi.Options{
@ -189,6 +206,15 @@ func (c *SoftCAS) initializeKeyManager() (err error) {
return return
} }
// getCertSigner returns the certificate chain and signer to use.
func (c *SoftCAS) getCertSigner() ([]*x509.Certificate, crypto.Signer, error) {
if c.CertificateSigner != nil {
return c.CertificateSigner()
}
return c.CertificateChain, c.Signer, nil
}
// createKey uses the configured kms to create a key. // createKey uses the configured kms to create a key.
func (c *SoftCAS) createKey(req *kmsapi.CreateKeyRequest) (*kmsapi.CreateKeyResponse, error) { func (c *SoftCAS) createKey(req *kmsapi.CreateKeyRequest) (*kmsapi.CreateKeyResponse, error) {
if err := c.initializeKeyManager(); err != nil { if err := c.initializeKeyManager(); err != nil {

View file

@ -73,6 +73,12 @@ var (
testSignedTemplate = mustSign(testTemplate, testIssuer, testNow, testNow.Add(24*time.Hour)) testSignedTemplate = mustSign(testTemplate, testIssuer, testNow, testNow.Add(24*time.Hour))
testSignedRootTemplate = mustSign(testRootTemplate, testRootTemplate, testNow, testNow.Add(24*time.Hour)) testSignedRootTemplate = mustSign(testRootTemplate, testRootTemplate, testNow, testNow.Add(24*time.Hour))
testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour))
testCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) {
return []*x509.Certificate{testIssuer}, testSigner, nil
}
testFailCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) {
return nil, nil, errTest
}
) )
type signatureAlgorithmSigner struct { type signatureAlgorithmSigner struct {
@ -186,6 +192,10 @@ func setTeeReader(t *testing.T, w *bytes.Buffer) {
} }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
assertEqual := func(x, y interface{}) bool {
return reflect.DeepEqual(x, y) || fmt.Sprintf("%#v", x) == fmt.Sprintf("%#v", y)
}
type args struct { type args struct {
ctx context.Context ctx context.Context
opts apiv1.Options opts apiv1.Options
@ -197,6 +207,7 @@ func TestNew(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"ok", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}}, &SoftCAS{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}, false}, {"ok", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}}, &SoftCAS{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}, false},
{"ok with callback", args{context.Background(), apiv1.Options{CertificateSigner: testCertificateSigner}}, &SoftCAS{CertificateSigner: testCertificateSigner}, false},
{"fail no issuer", args{context.Background(), apiv1.Options{Signer: testSigner}}, nil, true}, {"fail no issuer", args{context.Background(), apiv1.Options{Signer: testSigner}}, nil, true},
{"fail no signer", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}}}, nil, true}, {"fail no signer", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}}}, nil, true},
} }
@ -207,7 +218,7 @@ func TestNew(t *testing.T) {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !assertEqual(got, tt.want) {
t.Errorf("New() = %v, want %v", got, tt.want) t.Errorf("New() = %v, want %v", got, tt.want)
} }
}) })
@ -265,8 +276,9 @@ func TestSoftCAS_CreateCertificate(t *testing.T) {
} }
type fields struct { type fields struct {
Issuer *x509.Certificate Issuer *x509.Certificate
Signer crypto.Signer Signer crypto.Signer
CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error)
} }
type args struct { type args struct {
req *apiv1.CreateCertificateRequest req *apiv1.CreateCertificateRequest
@ -278,43 +290,53 @@ func TestSoftCAS_CreateCertificate(t *testing.T) {
want *apiv1.CreateCertificateResponse want *apiv1.CreateCertificateResponse
wantErr bool wantErr bool
}{ }{
{"ok", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{
Template: testTemplate, Lifetime: 24 * time.Hour, Template: testTemplate, Lifetime: 24 * time.Hour,
}}, &apiv1.CreateCertificateResponse{ }}, &apiv1.CreateCertificateResponse{
Certificate: testSignedTemplate, Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.CreateCertificateRequest{ {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.CreateCertificateRequest{
Template: &saTemplate, Lifetime: 24 * time.Hour, Template: &saTemplate, Lifetime: 24 * time.Hour,
}}, &apiv1.CreateCertificateResponse{ }}, &apiv1.CreateCertificateResponse{
Certificate: testSignedTemplate, Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ {"ok with notBefore", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{
Template: &tmplNotBefore, Lifetime: 24 * time.Hour, Template: &tmplNotBefore, Lifetime: 24 * time.Hour,
}}, &apiv1.CreateCertificateResponse{ }}, &apiv1.CreateCertificateResponse{
Certificate: testSignedTemplate, Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"ok with notBefore+notAfter", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ {"ok with notBefore+notAfter", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{
Template: &tmplWithLifetime, Lifetime: 24 * time.Hour, Template: &tmplWithLifetime, Lifetime: 24 * time.Hour,
}}, &apiv1.CreateCertificateResponse{ }}, &apiv1.CreateCertificateResponse{
Certificate: testSignedTemplate, Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"fail template", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.CreateCertificateRequest{
{"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true}, Template: testTemplate, Lifetime: 24 * time.Hour,
{"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ }}, &apiv1.CreateCertificateResponse{
Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer},
}, false},
{"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true},
{"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true},
{"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{
Template: &tmplNoSerial, Template: &tmplNoSerial,
Lifetime: 24 * time.Hour, Lifetime: 24 * time.Hour,
}}, nil, true}, }}, nil, true},
{"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.CreateCertificateRequest{
Template: testTemplate, Lifetime: 24 * time.Hour,
}}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &SoftCAS{ c := &SoftCAS{
CertificateChain: []*x509.Certificate{tt.fields.Issuer}, CertificateChain: []*x509.Certificate{tt.fields.Issuer},
Signer: tt.fields.Signer, Signer: tt.fields.Signer,
CertificateSigner: tt.fields.CertificateSigner,
} }
got, err := c.CreateCertificate(tt.args.req) got, err := c.CreateCertificate(tt.args.req)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -345,8 +367,9 @@ func TestSoftCAS_RenewCertificate(t *testing.T) {
} }
type fields struct { type fields struct {
Issuer *x509.Certificate Issuer *x509.Certificate
Signer crypto.Signer Signer crypto.Signer
CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error)
} }
type args struct { type args struct {
req *apiv1.RenewCertificateRequest req *apiv1.RenewCertificateRequest
@ -358,30 +381,40 @@ func TestSoftCAS_RenewCertificate(t *testing.T) {
want *apiv1.RenewCertificateResponse want *apiv1.RenewCertificateResponse
wantErr bool wantErr bool
}{ }{
{"ok", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{
Template: testTemplate, Lifetime: 24 * time.Hour, Template: testTemplate, Lifetime: 24 * time.Hour,
}}, &apiv1.RenewCertificateResponse{ }}, &apiv1.RenewCertificateResponse{
Certificate: testSignedTemplate, Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.RenewCertificateRequest{ {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.RenewCertificateRequest{
Template: testTemplate, Lifetime: 24 * time.Hour, Template: testTemplate, Lifetime: 24 * time.Hour,
}}, &apiv1.RenewCertificateResponse{ }}, &apiv1.RenewCertificateResponse{
Certificate: testSignedTemplate, Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RenewCertificateRequest{
{"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, Template: testTemplate, Lifetime: 24 * time.Hour,
{"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ }}, &apiv1.RenewCertificateResponse{
Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer},
}, false},
{"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true},
{"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true},
{"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{
Template: &tmplNoSerial, Template: &tmplNoSerial,
Lifetime: 24 * time.Hour, Lifetime: 24 * time.Hour,
}}, nil, true}, }}, nil, true},
{"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RenewCertificateRequest{
Template: testTemplate, Lifetime: 24 * time.Hour,
}}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &SoftCAS{ c := &SoftCAS{
CertificateChain: []*x509.Certificate{tt.fields.Issuer}, CertificateChain: []*x509.Certificate{tt.fields.Issuer},
Signer: tt.fields.Signer, Signer: tt.fields.Signer,
CertificateSigner: tt.fields.CertificateSigner,
} }
got, err := c.RenewCertificate(tt.args.req) got, err := c.RenewCertificate(tt.args.req)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -397,8 +430,9 @@ func TestSoftCAS_RenewCertificate(t *testing.T) {
func TestSoftCAS_RevokeCertificate(t *testing.T) { func TestSoftCAS_RevokeCertificate(t *testing.T) {
type fields struct { type fields struct {
Issuer *x509.Certificate Issuer *x509.Certificate
Signer crypto.Signer Signer crypto.Signer
CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error)
} }
type args struct { type args struct {
req *apiv1.RevokeCertificateRequest req *apiv1.RevokeCertificateRequest
@ -410,7 +444,7 @@ func TestSoftCAS_RevokeCertificate(t *testing.T) {
want *apiv1.RevokeCertificateResponse want *apiv1.RevokeCertificateResponse
wantErr bool wantErr bool
}{ }{
{"ok", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{
Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}},
Reason: "test reason", Reason: "test reason",
ReasonCode: 1, ReasonCode: 1,
@ -418,23 +452,37 @@ func TestSoftCAS_RevokeCertificate(t *testing.T) {
Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}},
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"ok no cert", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ {"ok no cert", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{
Reason: "test reason", Reason: "test reason",
ReasonCode: 1, ReasonCode: 1,
}}, &apiv1.RevokeCertificateResponse{ }}, &apiv1.RevokeCertificateResponse{
Certificate: nil, Certificate: nil,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"ok empty", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{ {"ok empty", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{
Certificate: nil, Certificate: nil,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RevokeCertificateRequest{
Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}},
Reason: "test reason",
ReasonCode: 1,
}}, &apiv1.RevokeCertificateResponse{
Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}},
CertificateChain: []*x509.Certificate{testIssuer},
}, false},
{"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RevokeCertificateRequest{
Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}},
Reason: "test reason",
ReasonCode: 1,
}}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &SoftCAS{ c := &SoftCAS{
CertificateChain: []*x509.Certificate{tt.fields.Issuer}, CertificateChain: []*x509.Certificate{tt.fields.Issuer},
Signer: tt.fields.Signer, Signer: tt.fields.Signer,
CertificateSigner: tt.fields.CertificateSigner,
} }
got, err := c.RevokeCertificate(tt.args.req) got, err := c.RevokeCertificate(tt.args.req)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -609,3 +657,56 @@ func TestSoftCAS_CreateCertificateAuthority(t *testing.T) {
}) })
} }
} }
func TestSoftCAS_defaultKeyManager(t *testing.T) {
mockNow(t)
type args struct {
req *apiv1.CreateCertificateAuthorityRequest
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok root", args{&apiv1.CreateCertificateAuthorityRequest{
Type: apiv1.RootCA,
Template: &x509.Certificate{
Subject: pkix.Name{CommonName: "Test Root CA"},
KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 1,
SerialNumber: big.NewInt(1234),
},
Lifetime: 24 * time.Hour,
}}, false},
{"ok intermediate", args{&apiv1.CreateCertificateAuthorityRequest{
Type: apiv1.IntermediateCA,
Template: testIntermediateTemplate,
Lifetime: 24 * time.Hour,
Parent: &apiv1.CreateCertificateAuthorityResponse{
Certificate: testSignedRootTemplate,
Signer: testSigner,
},
}}, false},
{"fail with default key manager", args{&apiv1.CreateCertificateAuthorityRequest{
Type: apiv1.IntermediateCA,
Template: testIntermediateTemplate,
Lifetime: 24 * time.Hour,
Parent: &apiv1.CreateCertificateAuthorityResponse{
Certificate: testSignedRootTemplate,
Signer: &badSigner{},
},
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &SoftCAS{}
_, err := c.CreateCertificateAuthority(tt.args.req)
if (err != nil) != tt.wantErr {
t.Errorf("SoftCAS.CreateCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -58,6 +58,11 @@ certificate issuer private key used in the RA mode.`,
Usage: "token used to enable the linked ca.", Usage: "token used to enable the linked ca.",
EnvVar: "STEP_CA_TOKEN", EnvVar: "STEP_CA_TOKEN",
}, },
cli.BoolFlag{
Name: "quiet",
Usage: "disable startup information",
EnvVar: "STEP_CA_QUIET",
},
cli.StringFlag{ cli.StringFlag{
Name: "context", Name: "context",
Usage: "The name of the authority's context.", Usage: "The name of the authority's context.",
@ -74,6 +79,7 @@ func appAction(ctx *cli.Context) error {
issuerPassFile := ctx.String("issuer-password-file") issuerPassFile := ctx.String("issuer-password-file")
resolver := ctx.String("resolver") resolver := ctx.String("resolver")
token := ctx.String("token") token := ctx.String("token")
quiet := ctx.Bool("quiet")
if ctx.NArg() > 1 { if ctx.NArg() > 1 {
return errs.TooManyArguments(ctx) return errs.TooManyArguments(ctx)
@ -155,7 +161,8 @@ To get a linked authority token:
ca.WithSSHHostPassword(sshHostPassword), ca.WithSSHHostPassword(sshHostPassword),
ca.WithSSHUserPassword(sshUserPassword), ca.WithSSHUserPassword(sshUserPassword),
ca.WithIssuerPassword(issuerPassword), ca.WithIssuerPassword(issuerPassword),
ca.WithLinkedCAToken(token)) ca.WithLinkedCAToken(token),
ca.WithQuiet(quiet))
if err != nil { if err != nil {
fatal(err) fatal(err)
} }

View file

@ -6,18 +6,11 @@ import (
"net/http" "net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/api/render"
) )
// StatusCoder interface is used by errors that returns the HTTP response code.
type StatusCoder interface {
StatusCode() int
}
// StackTracer must be by those errors that return an stack trace.
type StackTracer interface {
StackTrace() errors.StackTrace
}
// Option modifies the Error type. // Option modifies the Error type.
type Option func(e *Error) error type Option func(e *Error) error
@ -257,7 +250,7 @@ func NewError(status int, err error, format string, args ...interface{}) error {
return err return err
} }
msg := fmt.Sprintf(format, args...) msg := fmt.Sprintf(format, args...)
if _, ok := err.(StackTracer); !ok { if _, ok := err.(log.StackTracedError); !ok {
err = errors.Wrap(err, msg) err = errors.Wrap(err, msg)
} }
return &Error{ return &Error{
@ -275,11 +268,11 @@ func NewErr(status int, err error, opts ...Option) error {
ok bool ok bool
) )
if e, ok = err.(*Error); !ok { if e, ok = err.(*Error); !ok {
if sc, ok := err.(StatusCoder); ok { if sc, ok := err.(render.StatusCodedError); ok {
e = &Error{Status: sc.StatusCode(), Err: err} e = &Error{Status: sc.StatusCode(), Err: err}
} else { } else {
cause := errors.Cause(err) cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok { if sc, ok := cause.(render.StatusCodedError); ok {
e = &Error{Status: sc.StatusCode(), Err: err} e = &Error{Status: sc.StatusCode(), Err: err}
} else { } else {
e = &Error{Status: status, Err: err} e = &Error{Status: status, Err: err}

3
go.mod
View file

@ -33,10 +33,11 @@ require (
github.com/slackhq/nebula v1.5.2 github.com/slackhq/nebula v1.5.2
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262
github.com/smallstep/nosql v0.4.0 github.com/smallstep/nosql v0.4.0
github.com/stretchr/testify v1.7.1
github.com/urfave/cli v1.22.4 github.com/urfave/cli v1.22.4
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.7.0 go.step.sm/cli-utils v0.7.0
go.step.sm/crypto v0.15.3 go.step.sm/crypto v0.16.1
go.step.sm/linkedca v0.12.0 go.step.sm/linkedca v0.12.0
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd

7
go.sum
View file

@ -468,9 +468,11 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
@ -707,8 +709,8 @@ go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqe
go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds= go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds=
go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E=
go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0=
go.step.sm/crypto v0.15.3 h1:f3GMl+aCydt294BZRjTYwpaXRqwwndvoTY2NLN4wu10= go.step.sm/crypto v0.16.1 h1:4mnZk21cSxyMGxsEpJwZKKvJvDu1PN09UVrWWFNUBdk=
go.step.sm/crypto v0.15.3/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= go.step.sm/crypto v0.16.1/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
@ -1194,6 +1196,7 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=