diff --git a/CHANGELOG.md b/CHANGELOG.md index 73c338f8..49e4b15e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased - 0.18.3] - DATE ### Added - Added support for renew after expiry using the claim `allowRenewAfterExpiry`. +- Added support for `extraNames` in X.509 templates. ### Changed - Made SCEP CA URL paths dynamic - Support two latest versions of Go (1.17, 1.18) diff --git a/acme/api/account.go b/acme/api/account.go index 0dc8ab40..ade51aef 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -5,8 +5,9 @@ import ( "net/http" "github.com/go-chi/chi" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/logging" ) @@ -70,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := nar.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := acmeProvisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -96,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { acmeErr, ok := err.(*acme.Error) if !ok || acmeErr.Status != http.StatusBadRequest { // Something went wrong ... - api.WriteError(w, err) + render.Error(w, err) return } // Account does not exist // if nar.OnlyReturnExisting { - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist")) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } eak, err := h.validateExternalAccountBinding(ctx, &nar) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -125,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { Status: acme.StatusValid, } if err := h.db.CreateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) + render.Error(w, acme.WrapErrorISE(err, "error creating account")) return } if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response err := eak.BindTo(acc) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key")) + render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) return } acc.ExternalAccountBinding = nar.ExternalAccountBinding @@ -149,7 +150,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) - api.JSONStatus(w, acc, httpStatus) + render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. @@ -157,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -171,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := uar.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if len(uar.Status) > 0 || len(uar.Contact) > 0 { @@ -187,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { } if err := h.db.UpdateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) + render.Error(w, acme.WrapErrorISE(err, "error updating account")) return } } @@ -196,7 +197,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) - api.JSON(w, acc) + render.JSON(w, acc) } func logOrdersByAccount(w http.ResponseWriter, oids []string) { @@ -213,22 +214,22 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } accID := chi.URLParam(r, "accID") if acc.ID != accID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } h.linker.LinkOrdersByAccountID(ctx, orders) - api.JSON(w, orders) + render.JSON(w, orders) logOrdersByAccount(w, orders) } diff --git a/acme/api/handler.go b/acme/api/handler.go index c3a481f9..10eb22cb 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -12,8 +12,10 @@ import ( "time" "github.com/go-chi/chi" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" ) @@ -181,11 +183,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } - api.JSON(w, &Directory{ + render.JSON(w, &Directory{ NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), @@ -200,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. @@ -208,28 +210,28 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) return } if acc.ID != az.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } if err = az.UpdateStatus(ctx, h.db); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) + render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) return } h.linker.LinkAuthorization(ctx, az) w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) - api.JSON(w, az) + render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. @@ -237,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // Just verify that the payload was set, since we're not strictly adhering // to ACME V2 spec for reasons specified below. _, err = payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -257,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { azID := chi.URLParam(r, "authzID") ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) return } ch.AuthorizationID = azID if acc.ID != ch.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) + render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } @@ -280,7 +282,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) - api.JSON(w, ch) + render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. @@ -288,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } certID := chi.URLParam(r, "certID") cert, err := h.db.GetCertificate(ctx, certID) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) return } if cert.AccountID != acc.ID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own certificate '%s'", acc.ID, certID)) return } diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 0cdeaabb..10f7841f 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -10,13 +10,14 @@ import ( "strings" "github.com/go-chi/chi" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/nosql" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" ) type nextHTTP = func(http.ResponseWriter, *http.Request) @@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { nonce, err := h.db.CreateNonce(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } w.Header().Set("Replay-Nonce", string(nonce)) @@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { var expected []string p, err := provisionerFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return } } - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected content-type to be in %s, but got %s", expected, ct)) } } @@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) + render.Error(w, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } ctx := context.WithValue(r.Context(), jwsContextKey, jws) @@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if len(jws.Signatures) == 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } @@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { len(uh.Algorithm) > 0 || len(uh.Nonce) > 0 || len(uh.ExtraHeaders) > 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected @@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least %d bits (%d bytes) in size", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match")) return } @@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: - api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) + render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && hdr.KeyID == "" { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) @@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } // Overwrite KeyID with the JWK thumbprint. jwk.KeyID, err = acme.KeyToID(jwk) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) + render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } @@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { // For NewAccount and Revoke requests ... break case err != nil: - api.WriteError(w, err) + render.Error(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) @@ -290,17 +291,17 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { nameEscaped := chi.URLParam(r, "provisionerID") name, err := url.PathUnescape(nameEscaped) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) + render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } p, err := h.ca.LoadProvisionerByName(name) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } acmeProv, ok := p.(*provisioner.ACME) if !ok { - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) return } ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) @@ -315,11 +316,11 @@ func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { ctx := r.Context() ok, err := h.prerequisitesChecker(ctx) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) return } if !ok { - api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) return } next(w, r.WithContext(ctx)) @@ -334,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got %s", kidPrefix, kid)) return @@ -351,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { acc, err := h.db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: - api.WriteError(w, err) + render.Error(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) @@ -376,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -412,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } payload, err := jws.Verify(jwk) if err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ @@ -443,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if !payload.isPostAsGet { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) diff --git a/acme/api/order.go b/acme/api/order.go index 9cf2c1eb..99eb0e95 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -11,9 +11,11 @@ import ( "time" "github.com/go-chi/chi" - "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "go.step.sm/crypto/randutil" + + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api/render" ) // NewOrderRequest represents the body for a NewOrder request. @@ -70,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-order request payload")) return } if err := nor.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -116,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { Status: acme.StatusPending, } if err := h.newAuthorization(ctx, az); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o.AuthorizationIDs[i] = az.ID @@ -135,14 +137,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { } if err := h.db.CreateOrder(ctx, o); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) + render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSONStatus(w, o, http.StatusCreated) + render.JSONStatus(w, o, http.StatusCreated) } func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { @@ -186,38 +188,38 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.UpdateStatus(ctx, h.db); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) + render.Error(w, acme.WrapErrorISE(err, "error updating order status")) return } h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSON(w, o) + render.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. @@ -225,54 +227,54 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var fr FinalizeRequest if err := json.Unmarshal(payload.value, &fr); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal finalize-order request payload")) return } if err := fr.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) + render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) return } h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSON(w, o) + render.JSON(w, o) } // challengeTypes determines the types of challenges that should be used diff --git a/acme/api/revoke.go b/acme/api/revoke.go index d01e401c..4b71bc22 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -10,13 +10,14 @@ import ( "net/http" "strings" + "go.step.sm/crypto/jose" + "golang.org/x/crypto/ocsp" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" - "go.step.sm/crypto/jose" - "golang.org/x/crypto/ocsp" ) type revokePayload struct { @@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var p revokePayload err = json.Unmarshal(payload.value, &p) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload")) + render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload")) return } certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) if err != nil { // in this case the most likely cause is a client that didn't properly encode the certificate - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) return } certToBeRevoked, err := x509.ParseCertificate(certBytes) if err != nil { // in this case a client may have encoded something different than a certificate - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) return } serial := certToBeRevoked.SerialNumber.String() dbCert, err := h.db.GetCertificateBySerial(ctx, serial) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return } if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { // this should never happen - api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal")) + render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal")) return } if shouldCheckAccountFrom(jws) { account, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { - api.WriteError(w, acmeErr) + render.Error(w, acmeErr) return } } else { @@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { _, err := jws.Verify(certToBeRevoked.PublicKey) if err != nil { // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? - api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) + render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) return } } hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return } if hasBeenRevokedBefore { - api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) + render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) return } reasonCode := p.ReasonCode acmeErr := validateReasonCode(reasonCode) if acmeErr != nil { - api.WriteError(w, acmeErr) + render.Error(w, acmeErr) return } @@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) err = prov.AuthorizeRevoke(ctx, "") if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) + render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) return } options := revokeOptions(serial, certToBeRevoked, reasonCode) err = h.ca.Revoke(ctx, options) if err != nil { - api.WriteError(w, wrapRevokeErr(err)) + render.Error(w, wrapRevokeErr(err)) return } diff --git a/acme/challenge.go b/acme/challenge.go index 0e1994e4..9f08bae5 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -79,7 +79,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, } func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { - u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} + u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} resp, err := vo.HTTPGet(u.String()) if err != nil { @@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb return nil } +// http01ChallengeHost checks if a Challenge value is an IPv6 address +// and adds square brackets if that's the case, so that it can be used +// as a hostname. Returns the original Challenge value as the host to +// use in other cases. +func http01ChallengeHost(value string) string { + if ip := net.ParseIP(value); ip != nil && ip.To4() == nil { + value = "[" + value + "]" + } + return value +} + func tlsAlert(err error) uint8 { var opErr *net.OpError if errors.As(err, &opErr) { diff --git a/acme/challenge_test.go b/acme/challenge_test.go index d8ce4d76..c05b25e7 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -13,6 +13,7 @@ import ( "encoding/asn1" "encoding/base64" "encoding/hex" + "errors" "fmt" "io" "math/big" @@ -23,9 +24,9 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" ) func Test_storeError(t *testing.T) { @@ -2350,3 +2351,34 @@ func Test_serverName(t *testing.T) { }) } } + +func Test_http01ChallengeHost(t *testing.T) { + tests := []struct { + name string + value string + want string + }{ + { + name: "dns", + value: "www.example.com", + want: "www.example.com", + }, + { + name: "ipv4", + value: "127.0.0.1", + want: "127.0.0.1", + }, + { + name: "ipv6", + value: "::1", + want: "[::1]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := http01ChallengeHost(tt.value); got != tt.want { + t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/acme/errors.go b/acme/errors.go index a5c820ba..05888c24 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -3,13 +3,10 @@ package acme import ( "encoding/json" "fmt" - "log" "net/http" - "os" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/api/render" ) // ProblemType is the type of the ACME problem. @@ -353,26 +350,8 @@ func (e *Error) ToLog() (interface{}, error) { return string(b), nil } -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err *Error) { +// Render implements render.RenderableError for Error. +func (e *Error) Render(w http.ResponseWriter) { w.Header().Set("Content-Type", "application/problem+json") - w.WriteHeader(err.StatusCode()) - - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err.Err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.Err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } - - if err := json.NewEncoder(w).Encode(err); err != nil { - log.Println(err) - } + render.JSONStatus(w, e, e.StatusCode()) } diff --git a/api/api.go b/api/api.go index 1c47c03d..da6309fd 100644 --- a/api/api.go +++ b/api/api.go @@ -22,6 +22,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" @@ -284,7 +285,7 @@ func (h *caHandler) Route(r Router) { // Version is an HTTP handler that returns the version of the server. func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { v := h.Authority.Version() - JSON(w, VersionResponse{ + render.JSON(w, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, }) @@ -292,7 +293,7 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { // Health is an HTTP handler that returns the status of the server. func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { - JSON(w, HealthResponse{Status: "ok"}) + render.JSON(w, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root @@ -303,11 +304,11 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { // Load root certificate with the cert, err := h.Authority.Root(sum) if err != nil { - WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) + render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return } - JSON(w, &RootResponse{RootPEM: Certificate{cert}}) + render.JSON(w, &RootResponse{RootPEM: Certificate{cert}}) } func certChainToPEM(certChain []*x509.Certificate) []Certificate { @@ -322,16 +323,16 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { - WriteError(w, err) + render.Error(w, err) return } p, next, err := h.Authority.GetProvisioners(cursor, limit) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &ProvisionersResponse{ + render.JSON(w, &ProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -342,17 +343,17 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") key, err := h.Authority.GetEncryptedKey(kid) if err != nil { - WriteError(w, errs.NotFoundErr(err)) + render.Error(w, errs.NotFoundErr(err)) return } - JSON(w, &ProvisionerKeyResponse{key}) + render.JSON(w, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { roots, err := h.Authority.GetRoots() if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error getting roots")) + render.Error(w, errs.ForbiddenErr(err, "error getting roots")) return } @@ -361,7 +362,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{roots[i]} } - JSONStatus(w, &RootsResponse{ + render.JSONStatus(w, &RootsResponse{ Certificates: certs, }, http.StatusCreated) } @@ -370,7 +371,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { roots, err := h.Authority.GetRoots() if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } @@ -393,7 +394,7 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { federated, err := h.Authority.GetFederation() if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error getting federated roots")) + render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) return } @@ -402,7 +403,7 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{federated[i]} } - JSONStatus(w, &FederationResponse{ + render.JSONStatus(w, &FederationResponse{ Certificates: certs, }, http.StatusCreated) } diff --git a/api/errors.go b/api/errors.go deleted file mode 100644 index 680e6578..00000000 --- a/api/errors.go +++ /dev/null @@ -1,62 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "net/http" - "os" - - "github.com/pkg/errors" - - "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api/log" - "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" -) - -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err error) { - switch k := err.(type) { - case *acme.Error: - acme.WriteError(w, k) - return - case *admin.Error: - admin.WriteError(w, k) - return - } - - cause := errors.Cause(err) - - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } else if e, ok := cause.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } - - code := http.StatusInternalServerError - if sc, ok := err.(errs.StatusCoder); ok { - code = sc.StatusCode() - } else if sc, ok := cause.(errs.StatusCoder); ok { - code = sc.StatusCode() - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - - if err := json.NewEncoder(w).Encode(err); err != nil { - log.Error(w, err) - } -} diff --git a/api/log/log.go b/api/log/log.go index 78dae506..cb31410b 100644 --- a/api/log/log.go +++ b/api/log/log.go @@ -2,22 +2,54 @@ package log import ( + "fmt" "log" "net/http" + "os" + + "github.com/pkg/errors" "github.com/smallstep/certificates/logging" ) +// StackTracedError is the set of errors implementing the StackTrace function. +// +// Errors implementing this interface have their stack traces logged when passed +// to the Error function of this package. +type StackTracedError interface { + error + + StackTrace() errors.StackTrace +} + // Error adds to the response writer the given error if it implements // logging.ResponseLogger. If it does not implement it, then writes the error // using the log package. func Error(rw http.ResponseWriter, err error) { - if rl, ok := rw.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err, - }) - } else { + rl, ok := rw.(logging.ResponseLogger) + if !ok { log.Println(err) + + return + } + + rl.WithFields(map[string]interface{}{ + "error": err, + }) + + if os.Getenv("STEPDEBUG") != "1" { + return + } + + e, ok := err.(StackTracedError) + if !ok { + e, ok = errors.Cause(err).(StackTracedError) + } + + if ok { + rl.WithFields(map[string]interface{}{ + "stack-trace": fmt.Sprintf("%+v", e.StackTrace()), + }) } } diff --git a/api/rekey.go b/api/rekey.go index 269086bb..3116cf74 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) @@ -28,24 +29,24 @@ func (s *RekeyRequest) Validate() error { // Rekey is similar to renew except that the certificate will be renewed with new key from csr. func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing client certificate")) + render.Error(w, errs.BadRequest("missing client certificate")) return } var body RekeyRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) + render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return } certChainPEM := certChainToPEM(certChain) @@ -55,7 +56,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/render/render.go b/api/render/render.go new file mode 100644 index 00000000..9df4c791 --- /dev/null +++ b/api/render/render.go @@ -0,0 +1,122 @@ +// Package render implements functionality related to response rendering. +package render + +import ( + "bytes" + "encoding/json" + "net/http" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/smallstep/certificates/api/log" +) + +// JSON is shorthand for JSONStatus(w, v, http.StatusOK). +func JSON(w http.ResponseWriter, v interface{}) { + JSONStatus(w, v, http.StatusOK) +} + +// JSONStatus marshals v into w. It additionally sets the status code of +// w to the given one. +// +// JSONStatus sets the Content-Type of w to application/json unless one is +// specified. +func JSONStatus(w http.ResponseWriter, v interface{}, status int) { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(v); err != nil { + panic(err) + } + + setContentTypeUnlessPresent(w, "application/json") + w.WriteHeader(status) + _, _ = b.WriteTo(w) + + log.EnabledResponse(w, v) +} + +// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK). +func ProtoJSON(w http.ResponseWriter, m proto.Message) { + ProtoJSONStatus(w, m, http.StatusOK) +} + +// ProtoJSONStatus writes the given value into the http.ResponseWriter and the +// given status is written as the status code of the response. +func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { + b, err := protojson.Marshal(m) + if err != nil { + panic(err) + } + + setContentTypeUnlessPresent(w, "application/json") + w.WriteHeader(status) + _, _ = w.Write(b) +} + +func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) { + const header = "Content-Type" + + h := w.Header() + if _, ok := h[header]; !ok { + h.Set(header, contentType) + } +} + +// RenderableError is the set of errors that implement the basic Render method. +// +// Errors that implement this interface will use their own Render method when +// being rendered into responses. +type RenderableError interface { + error + + Render(http.ResponseWriter) +} + +// Error marshals the JSON representation of err to w. In case err implements +// RenderableError its own Render method will be called instead. +func Error(w http.ResponseWriter, err error) { + log.Error(w, err) + + if e, ok := err.(RenderableError); ok { + e.Render(w) + + return + } + + JSONStatus(w, err, statusCodeFromError(err)) +} + +// StatusCodedError is the set of errors that implement the basic StatusCode +// function. +// +// Errors that implement this interface will use the code reported by StatusCode +// as the HTTP response code when being rendered by this package. +type StatusCodedError interface { + error + + StatusCode() int +} + +func statusCodeFromError(err error) (code int) { + code = http.StatusInternalServerError + + type causer interface { + Cause() error + } + + for err != nil { + if sc, ok := err.(StatusCodedError); ok { + code = sc.StatusCode() + + break + } + + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + + return +} diff --git a/api/render/render_test.go b/api/render/render_test.go new file mode 100644 index 00000000..06d092d3 --- /dev/null +++ b/api/render/render_test.go @@ -0,0 +1,115 @@ +package render + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/smallstep/certificates/logging" +) + +func TestJSON(t *testing.T) { + rec := httptest.NewRecorder() + rw := logging.NewResponseLogger(rec) + + JSON(rw, map[string]interface{}{"foo": "bar"}) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String()) + + assert.Empty(t, rw.Fields()) +} + +func TestJSONPanics(t *testing.T) { + assert.Panics(t, func() { + JSON(httptest.NewRecorder(), make(chan struct{})) + }) +} + +type renderableError struct { + Code int `json:"-"` + Message string `json:"message"` +} + +func (err renderableError) Error() string { + return err.Message +} + +func (err renderableError) Render(w http.ResponseWriter) { + w.Header().Set("Content-Type", "something/custom") + + JSONStatus(w, err, err.Code) +} + +type statusedError struct { + Contents string +} + +func (err statusedError) Error() string { return err.Contents } + +func (statusedError) StatusCode() int { return 432 } + +func TestError(t *testing.T) { + cases := []struct { + err error + code int + body string + header string + }{ + 0: { + err: renderableError{532, "some string"}, + code: 532, + body: "{\"message\":\"some string\"}\n", + header: "something/custom", + }, + 1: { + err: statusedError{"123"}, + code: 432, + body: "{\"Contents\":\"123\"}\n", + header: "application/json", + }, + } + + for caseIndex := range cases { + kase := cases[caseIndex] + + t.Run(strconv.Itoa(caseIndex), func(t *testing.T) { + rec := httptest.NewRecorder() + + Error(rec, kase.err) + + assert.Equal(t, kase.code, rec.Result().StatusCode) + assert.Equal(t, kase.body, rec.Body.String()) + assert.Equal(t, kase.header, rec.Header().Get("Content-Type")) + }) + } +} + +type causedError struct { + cause error +} + +func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) } +func (err causedError) Cause() error { return err.cause } + +func TestStatusCodeFromError(t *testing.T) { + cases := []struct { + err error + exp int + }{ + 0: {nil, http.StatusInternalServerError}, + 1: {io.EOF, http.StatusInternalServerError}, + 2: {statusedError{"123"}, 432}, + 3: {causedError{statusedError{"432"}}, 432}, + } + + for caseIndex, kase := range cases { + assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex) + } +} diff --git a/api/renew.go b/api/renew.go index 408d91a3..9c4bff32 100644 --- a/api/renew.go +++ b/api/renew.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) @@ -18,13 +19,13 @@ const ( func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { cert, err := h.getPeerCertificate(r) if err != nil { - WriteError(w, err) + render.Error(w, err) return } certChain, err := h.Authority.Renew(cert) if err != nil { - WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) + render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return } certChainPEM := certChainToPEM(certChain) @@ -34,7 +35,7 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/revoke.go b/api/revoke.go index 49822e6d..c9da2c18 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ocsp" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -51,12 +52,12 @@ func (r *RevokeRequest) Validate() (err error) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -73,7 +74,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { if len(body.OTT) > 0 { logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT @@ -82,12 +83,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing ott or client certificate")) + render.Error(w, errs.BadRequest("missing ott or client certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - WriteError(w, errs.BadRequest("serial number in client certificate different than body")) + render.Error(w, errs.BadRequest("serial number in client certificate different than body")) return } // TODO: should probably be checking if the certificate was revoked here. @@ -98,12 +99,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { } if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err, "error revoking certificate")) + render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) return } logRevoke(w, opts) - JSON(w, &RevokeResponse{Status: "ok"}) + render.JSON(w, &RevokeResponse{Status: "ok"}) } func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/api/sign.go b/api/sign.go index b2eef45d..b6bfcc8b 100644 --- a/api/sign.go +++ b/api/sign.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -51,13 +52,13 @@ type SignResponse struct { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -69,13 +70,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { signOpts, err := h.Authority.AuthorizeSign(body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return } certChainPEM := certChainToPEM(certChain) @@ -84,7 +85,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { caPEM = certChainPEM[1] } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/ssh.go b/api/ssh.go index fc185d07..3b0de7c1 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -12,6 +12,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" @@ -252,19 +253,19 @@ type SSHBastionResponse struct { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) return } @@ -272,7 +273,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) return } } @@ -289,13 +290,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } @@ -303,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } addUserCertificate = &SSHCertificate{addUserCert} @@ -316,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } @@ -328,13 +329,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return } identityCertificate = certChainToPEM(certChain) } - JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, &SSHSignResponse{ Certificate: SSHCertificate{cert}, AddUserCertificate: addUserCertificate, IdentityCertificate: identityCertificate, @@ -346,12 +347,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHRoots(r.Context()) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound("no keys found")) + render.Error(w, errs.NotFound("no keys found")) return } @@ -363,7 +364,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - JSON(w, resp) + render.JSON(w, resp) } // SSHFederation is an HTTP handler that returns the federated SSH public keys @@ -371,12 +372,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHFederation(r.Context()) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound("no keys found")) + render.Error(w, errs.NotFound("no keys found")) return } @@ -388,7 +389,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - JSON(w, resp) + render.JSON(w, resp) } // SSHConfig is an HTTP handler that returns rendered templates for ssh clients @@ -396,17 +397,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } @@ -417,31 +418,31 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { case provisioner.SSHHostCert: cfg.HostTemplates = ts default: - WriteError(w, errs.InternalServer("it should hot get here")) + render.Error(w, errs.InternalServer("it should hot get here")) return } - JSON(w, cfg) + render.JSON(w, cfg) } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHCheckPrincipalResponse{ + render.JSON(w, &SSHCheckPrincipalResponse{ Exists: exists, }) } @@ -455,10 +456,10 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHGetHostsResponse{ + render.JSON(w, &SSHGetHostsResponse{ Hosts: hosts, }) } @@ -467,21 +468,21 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHBastionResponse{ + render.JSON(w, &SSHBastionResponse{ Hostname: body.Hostname, Bastion: bastion, }) diff --git a/api/sshRekey.go b/api/sshRekey.go index b7581749..92278950 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -41,36 +42,37 @@ type SSHRekeyResponse struct { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) + return } newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return } @@ -80,11 +82,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } - JSONStatus(w, &SSHRekeyResponse{ + render.JSONStatus(w, &SSHRekeyResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRenew.go b/api/sshRenew.go index b98466bf..78d16fa6 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -39,30 +40,31 @@ type SSHRenewResponse struct { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) _, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) + return } newCert, err := h.Authority.RenewSSH(ctx, oldCert) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) return } @@ -72,11 +74,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } - JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, &SSHSignResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRevoke.go b/api/sshRevoke.go index 2d2da1f7..a33082cd 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -6,6 +6,7 @@ import ( "golang.org/x/crypto/ocsp" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -50,12 +51,12 @@ func (r *SSHRevokeRequest) Validate() (err error) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := read.JSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -71,18 +72,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } logSSHRevoke(w, opts) - JSON(w, &SSHRevokeResponse{Status: "ok"}) + render.JSON(w, &SSHRevokeResponse{Status: "ok"}) } func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/api/utils.go b/api/utils.go deleted file mode 100644 index e3fcc9c4..00000000 --- a/api/utils.go +++ /dev/null @@ -1,57 +0,0 @@ -package api - -import ( - "encoding/json" - "net/http" - - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" - - "github.com/smallstep/certificates/api/log" -) - -// JSON writes the passed value into the http.ResponseWriter. -func JSON(w http.ResponseWriter, v interface{}) { - JSONStatus(w, v, http.StatusOK) -} - -// JSONStatus writes the given value into the http.ResponseWriter and the -// given status is written as the status code of the response. -func JSONStatus(w http.ResponseWriter, v interface{}, status int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - if err := json.NewEncoder(w).Encode(v); err != nil { - log.Error(w, err) - - return - } - - log.EnabledResponse(w, v) -} - -// ProtoJSON writes the passed value into the http.ResponseWriter. -func ProtoJSON(w http.ResponseWriter, m proto.Message) { - ProtoJSONStatus(w, m, http.StatusOK) -} - -// ProtoJSONStatus writes the given value into the http.ResponseWriter and the -// given status is written as the status code of the response. -func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - - b, err := protojson.Marshal(m) - if err != nil { - log.Error(w, err) - - return - } - - if _, err := w.Write(b); err != nil { - log.Error(w, err) - - return - } - - // log.EnabledResponse(w, v) -} diff --git a/api/utils_test.go b/api/utils_test.go deleted file mode 100644 index f5e1e1cb..00000000 --- a/api/utils_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package api - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/smallstep/certificates/logging" -) - -func TestJSON(t *testing.T) { - type args struct { - rw http.ResponseWriter - v interface{} - } - tests := []struct { - name string - args args - ok bool - }{ - {"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true}, - {"fail", args{httptest.NewRecorder(), make(chan int)}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rw := logging.NewResponseLogger(tt.args.rw) - JSON(rw, tt.args.v) - - rr, ok := tt.args.rw.(*httptest.ResponseRecorder) - if !ok { - t.Error("ResponseWriter does not implement *httptest.ResponseRecorder") - return - } - - fields := rw.Fields() - if tt.ok { - if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" { - t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body) - } - if len(fields) != 0 { - t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields) - } - } else { - if body := rr.Body.String(); body != "" { - t.Errorf("Unexpected body = %s, want empty string", body) - } - if len(fields) != 1 { - t.Errorf("ResponseLogger fields = %v, wants 1 element", fields) - } - } - }) - } -} diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 27c3ba6f..21a7229d 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -6,10 +6,12 @@ import ( "net/http" "github.com/go-chi/chi" - "github.com/smallstep/certificates/api" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" ) const ( @@ -44,11 +46,11 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { provName := chi.URLParam(r, "provisionerName") eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if !eabEnabled { - api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) + render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) return } ctx = context.WithValue(ctx, provisionerContextKey, prov) @@ -101,15 +103,15 @@ func NewACMEAdminResponder() *ACMEAdminResponder { // GetExternalAccountKeys writes the response for the EAB keys GET endpoint func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // CreateExternalAccountKey writes the response for the EAB key POST endpoint func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 43607c52..5e4b9c30 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -10,6 +10,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" ) @@ -85,28 +86,28 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { adm, ok := h.auth.LoadAdminByID(id) if !ok { - api.WriteError(w, admin.NewError(admin.ErrorNotFoundType, + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) return } - api.ProtoJSON(w, adm) + render.ProtoJSON(w, adm) } // GetAdmins returns a segment of admins associated with the authority. func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) + render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) return } - api.JSON(w, &GetAdminsResponse{ + render.JSON(w, &GetAdminsResponse{ Admins: admins, NextCursor: nextCursor, }) @@ -116,18 +117,18 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest if err := read.JSON(r.Body, &body); err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } p, err := h.auth.LoadProvisionerByName(body.Provisioner) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return } adm := &linkedca.Admin{ @@ -137,11 +138,11 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { } // Store to authority collection. if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error storing admin")) + render.Error(w, admin.WrapErrorISE(err, "error storing admin")) return } - api.ProtoJSONStatus(w, adm, http.StatusCreated) + render.ProtoJSONStatus(w, adm, http.StatusCreated) } // DeleteAdmin deletes admin. @@ -149,23 +150,23 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } - api.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, &DeleteResponse{Status: "ok"}) } // UpdateAdmin updates an existing admin. func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest if err := read.JSON(r.Body, &body); err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -173,9 +174,9 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) return } - api.ProtoJSON(w, adm) + render.ProtoJSON(w, adm) } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index 19025a9d..b57dd6eb 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -4,7 +4,7 @@ import ( "context" "net/http" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" ) @@ -15,7 +15,7 @@ type nextHTTP = func(http.ResponseWriter, *http.Request) func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { if !h.auth.IsAdminAPIEnabled() { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } @@ -28,14 +28,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { tok := r.Header.Get("Authorization") if tok == "" { - api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, + render.Error(w, admin.NewError(admin.ErrorUnauthorizedType, "missing authorization header token")) return } adm, err := h.auth.AuthorizeAdminToken(r, tok) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index 2106733d..1cad62dd 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -4,10 +4,12 @@ import ( "net/http" "github.com/go-chi/chi" + "go.step.sm/linkedca" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" @@ -33,39 +35,39 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { ) if len(id) > 0 { if p, err = h.auth.LoadProvisionerByID(id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = h.auth.LoadProvisionerByName(name); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } - api.ProtoJSON(w, prov) + render.ProtoJSON(w, prov) } // GetProvisioners returns the given segment of provisioners associated with the authority. func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } p, next, err := h.auth.GetProvisioners(cursor, limit) if err != nil { - api.WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - api.JSON(w, &GetProvisionersResponse{ + render.JSON(w, &GetProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -75,21 +77,21 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, prov); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // TODO: Validate inputs if err := authority.ValidateClaims(prov.Claims); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) + render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } - api.ProtoJSONStatus(w, prov, http.StatusCreated) + render.ProtoJSONStatus(w, prov, http.StatusCreated) } // DeleteProvisioner deletes a provisioner. @@ -103,75 +105,75 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { ) if len(id) > 0 { if p, err = h.auth.LoadProvisionerByID(id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = h.auth.LoadProvisionerByName(name); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) + render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } - api.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, &DeleteResponse{Status: "ok"}) } // UpdateProvisioner updates an existing prov. func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) if err := read.ProtoJSON(r.Body, nu); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } name := chi.URLParam(r, "name") _old, err := h.auth.LoadProvisionerByName(name) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) return } if nu.Id != old.Id { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID")) + render.Error(w, admin.NewErrorISE("cannot change provisioner ID")) return } if nu.Type != old.Type { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner type")) + render.Error(w, admin.NewErrorISE("cannot change provisioner type")) return } if nu.AuthorityId != old.AuthorityId { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID")) + render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID")) return } if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt")) + render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt")) return } if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt")) + render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt")) return } // TODO: Validate inputs if err := authority.ValidateClaims(nu.Claims); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } - api.ProtoJSON(w, nu) + render.ProtoJSON(w, nu) } diff --git a/authority/admin/errors.go b/authority/admin/errors.go index 217227ca..baa32dd9 100644 --- a/authority/admin/errors.go +++ b/authority/admin/errors.go @@ -3,13 +3,10 @@ package admin import ( "encoding/json" "fmt" - "log" "net/http" - "os" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/api/render" ) // ProblemType is the type of the Admin problem. @@ -197,27 +194,9 @@ func (e *Error) ToLog() (interface{}, error) { return string(b), nil } -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err *Error) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(err.StatusCode()) +// Render implements render.RenderableError for Error. +func (e *Error) Render(w http.ResponseWriter) { + e.Message = e.Err.Error() - err.Message = err.Err.Error() - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err.Err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.Err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } - - if err := json.NewEncoder(w).Encode(err); err != nil { - log.Println(err) - } + render.JSONStatus(w, e, e.StatusCode()) } diff --git a/authority/authority.go b/authority/authority.go index 1aac594a..9db38e14 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -80,6 +80,14 @@ type Authority struct { adminMutex sync.RWMutex } +type Info struct { + StartTime time.Time + RootX509Certs []*x509.Certificate + SSHCAUserPublicKey []byte + SSHCAHostPublicKey []byte + DNSNames []string +} + // New creates and initiates a new Authority type. func New(cfg *config.Config, opts ...Option) (*Authority, error) { err := cfg.Validate() @@ -325,8 +333,6 @@ func (a *Authority) init() error { return err } a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate) - sum := sha256.Sum256(resp.RootCertificate.Raw) - log.Printf("Using root fingerprint '%s'", hex.EncodeToString(sum[:])) } } @@ -581,6 +587,21 @@ func (a *Authority) GetAdminDatabase() admin.DB { return a.adminDB } +func (a *Authority) GetInfo() Info { + ai := Info{ + StartTime: a.startTime, + RootX509Certs: a.rootX509Certs, + DNSNames: a.config.DNSNames, + } + if a.sshCAUserCertSignKey != nil { + ai.SSHCAUserPublicKey = ssh.MarshalAuthorizedKey(a.sshCAUserCertSignKey.PublicKey()) + } + if a.sshCAHostCertSignKey != nil { + ai.SSHCAHostPublicKey = ssh.MarshalAuthorizedKey(a.sshCAHostCertSignKey.PublicKey()) + } + return ai +} + // IsAdminAPIEnabled returns a boolean indicating whether the Admin API has // been enabled. func (a *Authority) IsAdminAPIEnabled() bool { diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 6a35c4e1..a7bec277 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -9,6 +9,7 @@ import ( "crypto/x509/pkix" "encoding/asn1" "encoding/base64" + "errors" "fmt" "net/http" "reflect" @@ -16,16 +17,18 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" ) var testAudiences = provisioner.Audiences{ @@ -310,8 +313,8 @@ func TestAuthority_authorizeToken(t *testing.T) { p, err := tc.auth.authorizeToken(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -396,8 +399,8 @@ func TestAuthority_authorizeRevoke(t *testing.T) { if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -481,8 +484,8 @@ func TestAuthority_authorizeSign(t *testing.T) { got, err := tc.auth.authorizeSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -740,8 +743,8 @@ func TestAuthority_Authorize(t *testing.T) { if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, got) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -853,7 +856,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { err := tc.auth.authorizeRenew(tc.cert) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -1001,8 +1004,8 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1118,8 +1121,8 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1218,8 +1221,8 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) { if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1311,8 +1314,8 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/options.go b/authority/options.go index a1238b1d..1c154577 100644 --- a/authority/options.go +++ b/authority/options.go @@ -163,6 +163,22 @@ func WithX509Signer(crt *x509.Certificate, s crypto.Signer) Option { } } +// WithX509SignerFunc defines the function used to get the chain of certificates +// and signer used when we sign X.509 certificates. +func WithX509SignerFunc(fn func() ([]*x509.Certificate, crypto.Signer, error)) Option { + return func(a *Authority) error { + srv, err := cas.New(context.Background(), casapi.Options{ + Type: casapi.SoftCAS, + CertificateSigner: fn, + }) + if err != nil { + return err + } + a.x509CAService = srv + return nil + } +} + // WithSSHUserSigner defines the signer used to sign SSH user certificates. func WithSSHUserSigner(s crypto.Signer) Option { return func(a *Authority) error { diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index 8ab29b26..1c9a88cc 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -3,13 +3,14 @@ package provisioner import ( "context" "crypto/x509" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/api/render" ) func TestACME_Getters(t *testing.T) { @@ -114,7 +115,7 @@ func TestACME_AuthorizeRenew(t *testing.T) { NotAfter: now.Add(time.Hour), }, code: http.StatusUnauthorized, - err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), + err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -133,8 +134,8 @@ func TestACME_AuthorizeRenew(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -168,8 +169,8 @@ func TestACME_AuthorizeSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -193,7 +194,7 @@ func TestACME_AuthorizeSign(t *testing.T) { assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index f17ec231..7027a446 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "encoding/hex" "encoding/pem" + "errors" "fmt" "net" "net/http" @@ -17,10 +18,10 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestAWS_Getters(t *testing.T) { @@ -521,8 +522,8 @@ func TestAWS_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -668,8 +669,8 @@ func TestAWS_AuthorizeSign(t *testing.T) { t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Len(t, tt.wantLen, got) @@ -699,7 +700,7 @@ func TestAWS_AuthorizeSign(t *testing.T) { case dnsNamesValidator: assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -803,8 +804,8 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -861,8 +862,8 @@ func TestAWS_AuthorizeRenew(t *testing.T) { if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 28de642c..a8a0a271 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "errors" "fmt" "net/http" "net/http/httptest" @@ -15,10 +16,10 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestAzure_Getters(t *testing.T) { @@ -335,8 +336,8 @@ func TestAzure_authorizeToken(t *testing.T) { tc := tt(t) if claims, name, group, subscriptionID, objectID, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -497,8 +498,8 @@ func TestAzure_AuthorizeSign(t *testing.T) { t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Len(t, tt.wantLen, got) @@ -528,7 +529,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { case dnsNamesValidator: assert.Equals(t, []string(v), []string{"virtualMachine"}) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -573,8 +574,8 @@ func TestAzure_AuthorizeRenew(t *testing.T) { if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -669,8 +670,8 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 7d3bd073..dfb9a329 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "errors" "fmt" "net/http" "net/http/httptest" @@ -16,10 +17,10 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestGCP_Getters(t *testing.T) { @@ -390,8 +391,8 @@ func TestGCP_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -540,8 +541,8 @@ func TestGCP_AuthorizeSign(t *testing.T) { t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Len(t, tt.wantLen, got) @@ -571,7 +572,7 @@ func TestGCP_AuthorizeSign(t *testing.T) { case dnsNamesValidator: assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -678,8 +679,8 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -735,7 +736,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.code) } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index f9000961..926f9d68 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -6,15 +6,17 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "errors" + "fmt" "net/http" "strings" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestJWK_Getters(t *testing.T) { @@ -183,8 +185,8 @@ func TestJWK_authorizeToken(t *testing.T) { t.Run(tt.name, func(t *testing.T) { if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tt.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -223,8 +225,8 @@ func TestJWK_AuthorizeRevoke(t *testing.T) { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil { if assert.NotNil(t, tt.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -288,8 +290,8 @@ func TestJWK_AuthorizeSign(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil { if assert.NotNil(t, tt.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -316,7 +318,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { case defaultSANsValidator: assert.Equals(t, []string(v), tt.sans) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -362,8 +364,8 @@ func TestJWK_AuthorizeRenew(t *testing.T) { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -456,8 +458,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -621,8 +623,8 @@ func TestJWK_AuthorizeSSHRevoke(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index e6305f6e..e98b6f48 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -3,14 +3,16 @@ package provisioner import ( "context" "crypto/x509" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestK8sSA_Getters(t *testing.T) { @@ -116,8 +118,8 @@ func TestK8sSA_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -165,8 +167,8 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -202,7 +204,7 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { NotAfter: now.Add(time.Hour), }, code: http.StatusUnauthorized, - err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), + err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -221,8 +223,8 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -270,8 +272,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -295,7 +297,7 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } @@ -327,7 +329,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { p: p, token: "foo", code: http.StatusUnauthorized, - err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), + err: fmt.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { @@ -358,8 +360,8 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -379,7 +381,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case *sshDefaultDuration: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index bfa46027..548c4dc8 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -6,16 +6,17 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "errors" "fmt" "net/http" "strings" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func Test_openIDConfiguration_Validate(t *testing.T) { @@ -246,8 +247,8 @@ func TestOIDC_authorizeToken(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else { @@ -317,8 +318,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -341,7 +342,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) { case emailOnlyIdentity: assert.Equals(t, string(v), "name@smallstep.com") default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -403,8 +404,8 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) { t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) return } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -449,8 +450,8 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -605,8 +606,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -673,8 +674,8 @@ func TestOIDC_AuthorizeSSHRevoke(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) diff --git a/authority/provisioner/provisioner_test.go b/authority/provisioner/provisioner_test.go index 330d1b57..9678a20b 100644 --- a/authority/provisioner/provisioner_test.go +++ b/authority/provisioner/provisioner_test.go @@ -2,13 +2,14 @@ package provisioner import ( "context" + "errors" "net/http" "testing" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestType_String(t *testing.T) { @@ -240,8 +241,8 @@ func TestUnimplementedMethods(t *testing.T) { default: t.Errorf("unexpected method %s", tt.method) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized) assert.Equals(t, err.Error(), msg) }) diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index b548fe71..13294866 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -5,16 +5,19 @@ import ( "crypto" "crypto/rand" "encoding/base64" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" - "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestSSHPOP_Getters(t *testing.T) { @@ -215,8 +218,8 @@ func TestSSHPOP_authorizeToken(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -286,8 +289,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -367,8 +370,8 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { tc := tt(t) if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -449,8 +452,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { tc := tt(t) if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -464,7 +467,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { case *sshCertValidityValidator: assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } assert.Equals(t, tc.cert.Nonce, cert.Nonce) diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 91e789a1..a3308f00 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -3,16 +3,18 @@ package provisioner import ( "context" "crypto/x509" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestX5C_Getters(t *testing.T) { @@ -71,7 +73,7 @@ func TestX5C_Init(t *testing.T) { "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")}, - err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), + err: errors.New("no x509 certificates found in roots attribute for provisioner 'foo'"), } }, "fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest { @@ -122,7 +124,7 @@ M46l92gdOozT // check the number of certificates in the pool. numCerts := len(p.rootPool.Subjects()) if numCerts != 2 { - return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts) + return fmt.Errorf("unexpected number of certs: want 2, but got %d", numCerts) } return nil }, @@ -387,8 +389,8 @@ lgsqsR63is+0YQ== tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -458,7 +460,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -491,7 +493,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -542,8 +544,8 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -573,7 +575,7 @@ func TestX5C_AuthorizeRenew(t *testing.T) { return test{ p: p, code: http.StatusUnauthorized, - err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()), + err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -592,8 +594,8 @@ func TestX5C_AuthorizeRenew(t *testing.T) { NotAfter: now.Add(time.Hour), }); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -632,7 +634,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { p: p, token: "foo", code: http.StatusUnauthorized, - err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), + err: fmt.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { @@ -753,7 +755,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -788,7 +790,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 3975031b..81dc38bf 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -1,13 +1,13 @@ package authority import ( + "errors" "net/http" "testing" - "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/errs" ) func TestGetEncryptedKey(t *testing.T) { @@ -49,8 +49,8 @@ func TestGetEncryptedKey(t *testing.T) { ek, err := tc.a.GetEncryptedKey(tc.kid) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -90,8 +90,8 @@ func TestGetProvisioners(t *testing.T) { ps, next, err := tc.a.GetProvisioners("", 0) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/root_test.go b/authority/root_test.go index 6e5f1932..a1b08fac 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -2,14 +2,15 @@ package authority import ( "crypto/x509" + "errors" "net/http" "reflect" "testing" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/pemutil" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestRoot(t *testing.T) { @@ -31,7 +32,7 @@ func TestRoot(t *testing.T) { crt, err := a.Root(tc.sum) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) diff --git a/authority/ssh_test.go b/authority/ssh_test.go index c299b347..ce840fe1 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -7,21 +7,22 @@ import ( "crypto/rand" "crypto/x509" "encoding/base64" + "errors" "fmt" "net/http" "reflect" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/templates" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/templates" ) type sshTestModifier ssh.Certificate @@ -716,8 +717,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) { t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) return } else if err != nil { - _, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + _, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) @@ -806,8 +807,8 @@ func TestAuthority_GetSSHHosts(t *testing.T) { hosts, err := auth.GetSSHHosts(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1033,8 +1034,8 @@ func TestAuthority_RekeySSH(t *testing.T) { cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/tls_test.go b/authority/tls_test.go index 6ccf02ca..e199e0c5 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -11,24 +11,26 @@ import ( "crypto/x509/pkix" "encoding/asn1" "encoding/pem" + "errors" "fmt" "net/http" "reflect" "testing" "time" - "github.com/smallstep/certificates/cas/softcas" + "gopkg.in/square/go-jose.v2/jwt" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" - "gopkg.in/square/go-jose.v2/jwt" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/cas/softcas" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" ) var ( @@ -187,14 +189,14 @@ func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) { func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { b, err := x509.MarshalPKIXPublicKey(pub) if err != nil { - return nil, errors.Wrap(err, "error marshaling public key") + return nil, fmt.Errorf("error marshaling public key: %w", err) } info := struct { Algorithm pkix.AlgorithmIdentifier SubjectPublicKey asn1.BitString }{} if _, err = asn1.Unmarshal(b, &info); err != nil { - return nil, errors.Wrap(err, "error unmarshaling public key") + return nil, fmt.Errorf("error unmarshaling public key: %w", err) } hash := sha1.Sum(info.SubjectPublicKey.Bytes) return hash[:], nil @@ -661,8 +663,8 @@ ZYtQ9Ot36qc= if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -860,8 +862,8 @@ func TestAuthority_Renew(t *testing.T) { if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -1067,8 +1069,8 @@ func TestAuthority_Rekey(t *testing.T) { if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -1456,8 +1458,8 @@ func TestAuthority_Revoke(t *testing.T) { ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) if err := tc.auth.Revoke(ctx, tc.opts); err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index ad5f2116..f17a2f7a 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -12,12 +12,13 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/pemutil" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" - "github.com/smallstep/certificates/api" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/pemutil" + "github.com/smallstep/certificates/api/render" ) func TestNewACMEClient(t *testing.T) { @@ -112,15 +113,15 @@ func TestNewACMEClient(t *testing.T) { assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header switch { case i == 0: - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ case i == 1: w.Header().Set("Replay-Nonce", "abc123") - api.JSONStatus(w, []byte{}, 200) + render.JSONStatus(w, []byte{}, 200) i++ default: w.Header().Set("Location", accLocation) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) } }) @@ -206,7 +207,7 @@ func TestACMEClient_GetNonce(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) }) if nonce, err := ac.GetNonce(); err != nil { @@ -315,7 +316,7 @@ func TestACMEClient_post(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -338,7 +339,7 @@ func TestACMEClient_post(t *testing.T) { assert.Equals(t, hdr.KeyID, ac.kid) } - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { @@ -455,7 +456,7 @@ func TestACMEClient_NewOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -477,7 +478,7 @@ func TestACMEClient_NewOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, norb) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.NewOrder(norb); err != nil { @@ -577,7 +578,7 @@ func TestACMEClient_GetOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -599,7 +600,7 @@ func TestACMEClient_GetOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetOrder(url); err != nil { @@ -699,7 +700,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -721,7 +722,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetAuthz(url); err != nil { @@ -821,7 +822,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -844,7 +845,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetChallenge(url); err != nil { @@ -944,7 +945,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -967,7 +968,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { assert.Equals(t, payload, []byte("{}")) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if err := ac.ValidateChallenge(url); err != nil { @@ -1071,7 +1072,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1093,7 +1094,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, frb) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if err := ac.FinalizeOrder(url, csr); err != nil { @@ -1200,7 +1201,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1222,7 +1223,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := tc.client.GetAccountOrders(); err != nil { @@ -1331,7 +1332,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1356,7 +1357,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { if tc.certBytes != nil { w.Write(tc.certBytes) } else { - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) } }) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 0e16bd7d..2332b4d4 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -14,11 +14,14 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/errs" ) func newLocalListener() net.Listener { @@ -79,7 +82,7 @@ func startCAServer(configFile string) (*CA, string, error) { func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/version" { - api.JSON(w, api.VersionResponse{ + render.JSON(w, api.VersionResponse{ Version: "test", RequireClientAuthentication: true, }) @@ -93,7 +96,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han } isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 if !isMTLS { - api.WriteError(w, errs.Unauthorized("missing peer certificate")) + render.Error(w, errs.Unauthorized("missing peer certificate")) } else { next.ServeHTTP(w, r) } diff --git a/ca/ca.go b/ca/ca.go index dfb82731..0d4f1578 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "reflect" + "strings" "sync" "github.com/go-chi/chi" @@ -26,11 +27,14 @@ import ( scepAPI "github.com/smallstep/certificates/scep/api" "github.com/smallstep/certificates/server" "github.com/smallstep/nosql" + "go.step.sm/cli-utils/step" + "go.step.sm/crypto/x509util" ) type options struct { configFile string linkedCAToken string + quiet bool password []byte issuerPassword []byte sshHostPassword []byte @@ -101,6 +105,13 @@ func WithLinkedCAToken(token string) Option { } } +// WithQuiet sets the quiet flag. +func WithQuiet(quiet bool) Option { + return func(o *options) { + o.quiet = quiet + } +} + // CA is the type used to build the complete certificate authority. It builds // the HTTP server, set ups the middlewares and the HTTP handlers. type CA struct { @@ -288,6 +299,35 @@ func (ca *CA) Run() error { var wg sync.WaitGroup errs := make(chan error, 1) + if !ca.opts.quiet { + authorityInfo := ca.auth.GetInfo() + log.Printf("Starting %s", step.Version()) + log.Printf("Documentation: https://u.step.sm/docs/ca") + log.Printf("Community Discord: https://u.step.sm/discord") + if step.Contexts().GetCurrent() != nil { + log.Printf("Current context: %s", step.Contexts().GetCurrent().Name) + } + log.Printf("Config file: %s", ca.opts.configFile) + baseURL := fmt.Sprintf("https://%s%s", + authorityInfo.DNSNames[0], + ca.config.Address[strings.LastIndex(ca.config.Address, ":"):]) + log.Printf("The primary server URL is %s", baseURL) + log.Printf("Root certificates are available at %s/roots.pem", baseURL) + if len(authorityInfo.DNSNames) > 1 { + log.Printf("Additional configured hostnames: %s", + strings.Join(authorityInfo.DNSNames[1:], ", ")) + } + for _, crt := range authorityInfo.RootX509Certs { + log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt)) + } + if authorityInfo.SSHCAHostPublicKey != nil { + log.Printf("SSH Host CA Key is %s\n", authorityInfo.SSHCAHostPublicKey) + } + if authorityInfo.SSHCAUserPublicKey != nil { + log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey) + } + } + if ca.insecureSrv != nil { wg.Add(1) go func() { @@ -355,6 +395,7 @@ func (ca *CA) Reload() error { WithSSHUserPassword(ca.opts.sshUserPassword), WithIssuerPassword(ca.opts.issuerPassword), WithLinkedCAToken(ca.opts.linkedCAToken), + WithQuiet(ca.opts.quiet), WithConfigFile(ca.opts.configFile), WithDatabase(ca.auth.GetDatabase()), ) diff --git a/ca/client_test.go b/ca/client_test.go index 4628d19b..48aa1488 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "fmt" "net/http" "net/http/httptest" @@ -16,14 +17,13 @@ import ( "testing" "time" - "github.com/pkg/errors" - "golang.org/x/crypto/ssh" - "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ssh" "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -182,7 +182,7 @@ func TestClient_Version(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Version() @@ -232,7 +232,7 @@ func TestClient_Health(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Health() @@ -290,7 +290,7 @@ func TestClient_Root(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Root(tt.shasum) @@ -360,7 +360,7 @@ func TestClient_Sign(t *testing.T) { if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") - api.WriteError(w, e) + render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -371,7 +371,7 @@ func TestClient_Sign(t *testing.T) { t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) } } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Sign(tt.request) @@ -432,7 +432,7 @@ func TestClient_Revoke(t *testing.T) { if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") - api.WriteError(w, e) + render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -443,7 +443,7 @@ func TestClient_Revoke(t *testing.T) { t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) } } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Revoke(tt.request, nil) @@ -503,7 +503,7 @@ func TestClient_Renew(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Renew(nil) @@ -519,8 +519,8 @@ func TestClient_Renew(t *testing.T) { t.Errorf("Client.Renew() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -568,9 +568,9 @@ func TestClient_RenewWithToken(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.Header.Get("Authorization") != "Bearer token" { - api.JSONStatus(w, errs.InternalServer("force"), 500) + render.JSONStatus(w, errs.InternalServer("force"), 500) } else { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) } }) @@ -587,8 +587,8 @@ func TestClient_RenewWithToken(t *testing.T) { t.Errorf("Client.RenewWithToken() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -640,7 +640,7 @@ func TestClient_Rekey(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Rekey(tt.request, nil) @@ -656,8 +656,8 @@ func TestClient_Rekey(t *testing.T) { t.Errorf("Client.Renew() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -705,7 +705,7 @@ func TestClient_Provisioners(t *testing.T) { if req.RequestURI != tt.expectedURI { t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Provisioners(tt.args...) @@ -762,7 +762,7 @@ func TestClient_ProvisionerKey(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.ProvisionerKey(tt.kid) @@ -777,8 +777,8 @@ func TestClient_ProvisionerKey(t *testing.T) { t.Errorf("Client.ProvisionerKey() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, tt.err.Error(), err.Error()) default: @@ -821,7 +821,7 @@ func TestClient_Roots(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Roots() @@ -836,8 +836,8 @@ func TestClient_Roots(t *testing.T) { if got != nil { t.Errorf("Client.Roots() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -879,7 +879,7 @@ func TestClient_Federation(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Federation() @@ -894,8 +894,8 @@ func TestClient_Federation(t *testing.T) { if got != nil { t.Errorf("Client.Federation() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, tt.err.Error(), err.Error()) default: @@ -941,7 +941,7 @@ func TestClient_SSHRoots(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHRoots() @@ -956,8 +956,8 @@ func TestClient_SSHRoots(t *testing.T) { if got != nil { t.Errorf("Client.SSHKeys() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, tt.err.Error(), err.Error()) default: @@ -1041,7 +1041,7 @@ func TestClient_RootFingerprint(t *testing.T) { } tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() @@ -1102,7 +1102,7 @@ func TestClient_SSHBastion(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHBastion(tt.request) @@ -1118,8 +1118,8 @@ func TestClient_SSHBastion(t *testing.T) { t.Errorf("Client.SSHBastion() = %v, want nil", got) } if tt.responseCode != 200 { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) } diff --git a/cas/apiv1/options.go b/cas/apiv1/options.go index badad7fc..3fc34208 100644 --- a/cas/apiv1/options.go +++ b/cas/apiv1/options.go @@ -31,12 +31,20 @@ type Options struct { // https://cloud.google.com/docs/authentication. CredentialsFile string `json:"credentialsFile,omitempty"` - // Certificate and signer are the issuer certificate, along with any other - // bundled certificates to be returned in the chain for consumers, and - // signer used in SoftCAS. They are configured in ca.json crt and key - // properties. + // CertificateChain contains the issuer certificate, along with any other + // bundled certificates to be returned in the chain to consumers. It is used + // used in SoftCAS and it is configured in the crt property of the ca.json. CertificateChain []*x509.Certificate `json:"-"` - Signer crypto.Signer `json:"-"` + + // Signer is the private key or a KMS signer for the issuer certificate. It + // is used in SoftCAS and it is configured in the key property of the + // ca.json. + Signer crypto.Signer `json:"-"` + + // CertificateSigner combines CertificateChain and Signer in a callback that + // returns the chain of certificate and signer used to sign X.509 + // certificates in SoftCAS. + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) `json:"-"` // IsCreator is set to true when we're creating a certificate authority. It // is used to skip some validations when initializing a diff --git a/cas/softcas/softcas.go b/cas/softcas/softcas.go index 8e67d016..2a97145b 100644 --- a/cas/softcas/softcas.go +++ b/cas/softcas/softcas.go @@ -24,9 +24,10 @@ var now = time.Now // SoftCAS implements a Certificate Authority Service using Golang or KMS // crypto. This is the default CAS used in step-ca. type SoftCAS struct { - CertificateChain []*x509.Certificate - Signer crypto.Signer - KeyManager kms.KeyManager + CertificateChain []*x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) + KeyManager kms.KeyManager } // New creates a new CertificateAuthorityService implementation using Golang or KMS @@ -34,16 +35,17 @@ type SoftCAS struct { func New(ctx context.Context, opts apiv1.Options) (*SoftCAS, error) { if !opts.IsCreator { switch { - case len(opts.CertificateChain) == 0: + case len(opts.CertificateChain) == 0 && opts.CertificateSigner == nil: return nil, errors.New("softCAS 'CertificateChain' cannot be nil") - case opts.Signer == nil: + case opts.Signer == nil && opts.CertificateSigner == nil: return nil, errors.New("softCAS 'signer' cannot be nil") } } return &SoftCAS{ - CertificateChain: opts.CertificateChain, - Signer: opts.Signer, - KeyManager: opts.KeyManager, + CertificateChain: opts.CertificateChain, + Signer: opts.Signer, + CertificateSigner: opts.CertificateSigner, + KeyManager: opts.KeyManager, }, nil } @@ -57,6 +59,7 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 } t := now() + // Provisioners can also set specific values. if req.Template.NotBefore.IsZero() { req.Template.NotBefore = t.Add(-1 * req.Backdate) @@ -64,16 +67,21 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 if req.Template.NotAfter.IsZero() { req.Template.NotAfter = t.Add(req.Lifetime) } - req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + chain, signer, err := c.getCertSigner() + if err != nil { + return nil, err + } + req.Template.Issuer = chain[0].Subject + + cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer) if err != nil { return nil, err } return &apiv1.CreateCertificateResponse{ Certificate: cert, - CertificateChain: c.CertificateChain, + CertificateChain: chain, }, nil } @@ -89,16 +97,21 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R t := now() req.Template.NotBefore = t.Add(-1 * req.Backdate) req.Template.NotAfter = t.Add(req.Lifetime) - req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + chain, signer, err := c.getCertSigner() + if err != nil { + return nil, err + } + req.Template.Issuer = chain[0].Subject + + cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer) if err != nil { return nil, err } return &apiv1.RenewCertificateResponse{ Certificate: cert, - CertificateChain: c.CertificateChain, + CertificateChain: chain, }, nil } @@ -106,9 +119,13 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R // operation is a no-op as the actual revoke will happen when we store the entry // in the db. func (c *SoftCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { + chain, _, err := c.getCertSigner() + if err != nil { + return nil, err + } return &apiv1.RevokeCertificateResponse{ Certificate: req.Certificate, - CertificateChain: c.CertificateChain, + CertificateChain: chain, }, nil } @@ -179,7 +196,7 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori }, nil } -// initializeKeyManager initiazes the default key manager if was not given. +// initializeKeyManager initializes the default key manager if was not given. func (c *SoftCAS) initializeKeyManager() (err error) { if c.KeyManager == nil { c.KeyManager, err = kms.New(context.Background(), kmsapi.Options{ @@ -189,6 +206,15 @@ func (c *SoftCAS) initializeKeyManager() (err error) { return } +// getCertSigner returns the certificate chain and signer to use. +func (c *SoftCAS) getCertSigner() ([]*x509.Certificate, crypto.Signer, error) { + if c.CertificateSigner != nil { + return c.CertificateSigner() + } + return c.CertificateChain, c.Signer, nil + +} + // createKey uses the configured kms to create a key. func (c *SoftCAS) createKey(req *kmsapi.CreateKeyRequest) (*kmsapi.CreateKeyResponse, error) { if err := c.initializeKeyManager(); err != nil { diff --git a/cas/softcas/softcas_test.go b/cas/softcas/softcas_test.go index 7d3add4f..b4f5b440 100644 --- a/cas/softcas/softcas_test.go +++ b/cas/softcas/softcas_test.go @@ -73,6 +73,12 @@ var ( testSignedTemplate = mustSign(testTemplate, testIssuer, testNow, testNow.Add(24*time.Hour)) testSignedRootTemplate = mustSign(testRootTemplate, testRootTemplate, testNow, testNow.Add(24*time.Hour)) testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) + testCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) { + return []*x509.Certificate{testIssuer}, testSigner, nil + } + testFailCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) { + return nil, nil, errTest + } ) type signatureAlgorithmSigner struct { @@ -186,6 +192,10 @@ func setTeeReader(t *testing.T, w *bytes.Buffer) { } func TestNew(t *testing.T) { + assertEqual := func(x, y interface{}) bool { + return reflect.DeepEqual(x, y) || fmt.Sprintf("%#v", x) == fmt.Sprintf("%#v", y) + } + type args struct { ctx context.Context opts apiv1.Options @@ -197,6 +207,7 @@ func TestNew(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}}, &SoftCAS{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}, false}, + {"ok with callback", args{context.Background(), apiv1.Options{CertificateSigner: testCertificateSigner}}, &SoftCAS{CertificateSigner: testCertificateSigner}, false}, {"fail no issuer", args{context.Background(), apiv1.Options{Signer: testSigner}}, nil, true}, {"fail no signer", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}}}, nil, true}, } @@ -207,7 +218,7 @@ func TestNew(t *testing.T) { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { + if !assertEqual(got, tt.want) { t.Errorf("New() = %v, want %v", got, tt.want) } }) @@ -265,8 +276,9 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { } type fields struct { - Issuer *x509.Certificate - Signer crypto.Signer + Issuer *x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.CreateCertificateRequest @@ -278,43 +290,53 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { want *apiv1.CreateCertificateResponse wantErr bool }{ - {"ok", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &saTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok with notBefore", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNotBefore, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok with notBefore+notAfter", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok with notBefore+notAfter", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplWithLifetime, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"fail template", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, - {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true}, - {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.CreateCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.CreateCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, + {"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, + {"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true}, + {"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, + {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.CreateCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ - CertificateChain: []*x509.Certificate{tt.fields.Issuer}, - Signer: tt.fields.Signer, + CertificateChain: []*x509.Certificate{tt.fields.Issuer}, + Signer: tt.fields.Signer, + CertificateSigner: tt.fields.CertificateSigner, } got, err := c.CreateCertificate(tt.args.req) if (err != nil) != tt.wantErr { @@ -345,8 +367,9 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { } type fields struct { - Issuer *x509.Certificate - Signer crypto.Signer + Issuer *x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.RenewCertificateRequest @@ -358,30 +381,40 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { want *apiv1.RenewCertificateResponse wantErr bool }{ - {"ok", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ + {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.RenewCertificateRequest{ + {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, - {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, - {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ + {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RenewCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.RenewCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, + {"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, + {"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, + {"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, + {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RenewCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ - CertificateChain: []*x509.Certificate{tt.fields.Issuer}, - Signer: tt.fields.Signer, + CertificateChain: []*x509.Certificate{tt.fields.Issuer}, + Signer: tt.fields.Signer, + CertificateSigner: tt.fields.CertificateSigner, } got, err := c.RenewCertificate(tt.args.req) if (err != nil) != tt.wantErr { @@ -397,8 +430,9 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { func TestSoftCAS_RevokeCertificate(t *testing.T) { type fields struct { - Issuer *x509.Certificate - Signer crypto.Signer + Issuer *x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.RevokeCertificateRequest @@ -410,7 +444,7 @@ func TestSoftCAS_RevokeCertificate(t *testing.T) { want *apiv1.RevokeCertificateResponse wantErr bool }{ - {"ok", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ + {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, Reason: "test reason", ReasonCode: 1, @@ -418,23 +452,37 @@ func TestSoftCAS_RevokeCertificate(t *testing.T) { Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok no cert", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ + {"ok no cert", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{ Reason: "test reason", ReasonCode: 1, }}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok empty", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{ + {"ok empty", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, + {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RevokeCertificateRequest{ + Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, + Reason: "test reason", + ReasonCode: 1, + }}, &apiv1.RevokeCertificateResponse{ + Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, + {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RevokeCertificateRequest{ + Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, + Reason: "test reason", + ReasonCode: 1, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ - CertificateChain: []*x509.Certificate{tt.fields.Issuer}, - Signer: tt.fields.Signer, + CertificateChain: []*x509.Certificate{tt.fields.Issuer}, + Signer: tt.fields.Signer, + CertificateSigner: tt.fields.CertificateSigner, } got, err := c.RevokeCertificate(tt.args.req) if (err != nil) != tt.wantErr { @@ -609,3 +657,56 @@ func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { }) } } + +func TestSoftCAS_defaultKeyManager(t *testing.T) { + mockNow(t) + type args struct { + req *apiv1.CreateCertificateAuthorityRequest + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok root", args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: &x509.Certificate{ + Subject: pkix.Name{CommonName: "Test Root CA"}, + KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + SerialNumber: big.NewInt(1234), + }, + Lifetime: 24 * time.Hour, + }}, false}, + {"ok intermediate", args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Certificate: testSignedRootTemplate, + Signer: testSigner, + }, + }}, false}, + {"fail with default key manager", args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Certificate: testSignedRootTemplate, + Signer: &badSigner{}, + }, + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &SoftCAS{} + _, err := c.CreateCertificateAuthority(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("SoftCAS.CreateCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/commands/app.go b/commands/app.go index fc9cd15b..265610f2 100644 --- a/commands/app.go +++ b/commands/app.go @@ -58,6 +58,11 @@ certificate issuer private key used in the RA mode.`, Usage: "token used to enable the linked ca.", EnvVar: "STEP_CA_TOKEN", }, + cli.BoolFlag{ + Name: "quiet", + Usage: "disable startup information", + EnvVar: "STEP_CA_QUIET", + }, cli.StringFlag{ Name: "context", Usage: "The name of the authority's context.", @@ -74,6 +79,7 @@ func appAction(ctx *cli.Context) error { issuerPassFile := ctx.String("issuer-password-file") resolver := ctx.String("resolver") token := ctx.String("token") + quiet := ctx.Bool("quiet") if ctx.NArg() > 1 { return errs.TooManyArguments(ctx) @@ -155,7 +161,8 @@ To get a linked authority token: ca.WithSSHHostPassword(sshHostPassword), ca.WithSSHUserPassword(sshUserPassword), ca.WithIssuerPassword(issuerPassword), - ca.WithLinkedCAToken(token)) + ca.WithLinkedCAToken(token), + ca.WithQuiet(quiet)) if err != nil { fatal(err) } diff --git a/errs/error.go b/errs/error.go index 60da9e1f..c42e342d 100644 --- a/errs/error.go +++ b/errs/error.go @@ -6,18 +6,11 @@ import ( "net/http" "github.com/pkg/errors" + + "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/api/render" ) -// StatusCoder interface is used by errors that returns the HTTP response code. -type StatusCoder interface { - StatusCode() int -} - -// StackTracer must be by those errors that return an stack trace. -type StackTracer interface { - StackTrace() errors.StackTrace -} - // Option modifies the Error type. type Option func(e *Error) error @@ -257,7 +250,7 @@ func NewError(status int, err error, format string, args ...interface{}) error { return err } msg := fmt.Sprintf(format, args...) - if _, ok := err.(StackTracer); !ok { + if _, ok := err.(log.StackTracedError); !ok { err = errors.Wrap(err, msg) } return &Error{ @@ -275,11 +268,11 @@ func NewErr(status int, err error, opts ...Option) error { ok bool ) if e, ok = err.(*Error); !ok { - if sc, ok := err.(StatusCoder); ok { + if sc, ok := err.(render.StatusCodedError); ok { e = &Error{Status: sc.StatusCode(), Err: err} } else { cause := errors.Cause(err) - if sc, ok := cause.(StatusCoder); ok { + if sc, ok := cause.(render.StatusCodedError); ok { e = &Error{Status: sc.StatusCode(), Err: err} } else { e = &Error{Status: status, Err: err} diff --git a/go.mod b/go.mod index 85d21ceb..6a834552 100644 --- a/go.mod +++ b/go.mod @@ -33,10 +33,11 @@ require ( github.com/slackhq/nebula v1.5.2 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/nosql v0.4.0 + github.com/stretchr/testify v1.7.1 github.com/urfave/cli v1.22.4 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 - go.step.sm/crypto v0.15.3 + go.step.sm/crypto v0.16.1 go.step.sm/linkedca v0.12.0 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd diff --git a/go.sum b/go.sum index 47f84801..ba718b16 100644 --- a/go.sum +++ b/go.sum @@ -468,9 +468,11 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -707,8 +709,8 @@ go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqe go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= -go.step.sm/crypto v0.15.3 h1:f3GMl+aCydt294BZRjTYwpaXRqwwndvoTY2NLN4wu10= -go.step.sm/crypto v0.15.3/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= +go.step.sm/crypto v0.16.1 h1:4mnZk21cSxyMGxsEpJwZKKvJvDu1PN09UVrWWFNUBdk= +go.step.sm/crypto v0.16.1/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -1194,6 +1196,7 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=