diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5d0416ef..2ab7084d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: [ '1.15', '1.16', '1.17' ] + go: [ '1.17', '1.18' ] outputs: is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} steps: @@ -33,7 +33,7 @@ jobs: uses: golangci/golangci-lint-action@v2 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: 'v1.44.0' + version: 'v1.45.0' # Optional: working directory, useful for monorepos # working-directory: somedir @@ -106,7 +106,7 @@ jobs: name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.17 + go-version: 1.18 - name: APT Install id: aptInstall @@ -159,7 +159,7 @@ jobs: name: Setup Go uses: actions/setup-go@v2 with: - go-version: '1.17' + go-version: '1.18' - name: Install cosign uses: sigstore/cosign-installer@v1.1.0 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f36e78ef..64cb64cd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: [ '1.16', '1.17' ] + go: [ '1.17', '1.18' ] steps: - name: Checkout @@ -33,7 +33,7 @@ jobs: uses: golangci/golangci-lint-action@v2 with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version - version: 'v1.44.0' + version: 'v1.45.0' # Optional: working directory, useful for monorepos # working-directory: somedir @@ -58,7 +58,7 @@ jobs: run: V=1 make ci - name: Codecov - if: matrix.go == '1.17' + if: matrix.go == '1.18' uses: codecov/codecov-action@v1.2.1 with: file: ./coverage.out # optional diff --git a/.gitignore b/.gitignore index d87786b0..299a2c16 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ coverage.txt output vendor .idea +.envrc diff --git a/.goreleaser.yml b/.goreleaser.yml index 207c75bd..441d5785 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -19,6 +19,7 @@ builds: - linux_386 - linux_amd64 - linux_arm64 + - linux_arm_5 - linux_arm_6 - linux_arm_7 - windows_amd64 @@ -39,6 +40,7 @@ builds: - linux_386 - linux_amd64 - linux_arm64 + - linux_arm_5 - linux_arm_6 - linux_arm_7 - windows_amd64 @@ -59,6 +61,7 @@ builds: - linux_386 - linux_amd64 - linux_arm64 + - linux_arm_5 - linux_arm_6 - linux_arm_7 - windows_amd64 diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a0618ab..49e4b15e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,16 +4,32 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). -## [Unreleased - 0.18.2] - DATE +## [Unreleased - 0.18.3] - DATE ### Added +- Added support for renew after expiry using the claim `allowRenewAfterExpiry`. +- Added support for `extraNames` in X.509 templates. ### Changed -- IPv6 addresses are normalized as IP addresses instead of hostnames. -- More descriptive JWK decryption error message. +- Made SCEP CA URL paths dynamic +- Support two latest versions of Go (1.17, 1.18) ### Deprecated ### Removed ### Fixed ### Security +## [0.18.2] - 2022-03-01 +### Added +- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner. +- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering + out database drivers using Go tags. For example, using the Go flag + `--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx + driver for PostgreSQL. +### Changed +- IPv6 addresses are normalized as IP addresses instead of hostnames. +- More descriptive JWK decryption error message. +- Make the X5C leaf certificate available to the templates using `{{ .AuthorizationCrt }}`. +### Fixed +- During provisioner add - validate provisioner configuration before storing to DB. + ## [0.18.1] - 2022-02-03 ### Added - Support for ACME revocation. diff --git a/acme/api/account.go b/acme/api/account.go index 0dc8ab40..ade51aef 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -5,8 +5,9 @@ import ( "net/http" "github.com/go-chi/chi" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/logging" ) @@ -70,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var nar NewAccountRequest if err := json.Unmarshal(payload.value, &nar); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := nar.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := acmeProvisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -96,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { acmeErr, ok := err.(*acme.Error) if !ok || acmeErr.Status != http.StatusBadRequest { // Something went wrong ... - api.WriteError(w, err) + render.Error(w, err) return } // Account does not exist // if nar.OnlyReturnExisting { - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist")) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } eak, err := h.validateExternalAccountBinding(ctx, &nar) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -125,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { Status: acme.StatusValid, } if err := h.db.CreateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) + render.Error(w, acme.WrapErrorISE(err, "error creating account")) return } if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response err := eak.BindTo(acc) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key")) + render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) return } acc.ExternalAccountBinding = nar.ExternalAccountBinding @@ -149,7 +150,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) - api.JSONStatus(w, acc, httpStatus) + render.JSONStatus(w, acc, httpStatus) } // GetOrUpdateAccount is the api for updating an ACME account. @@ -157,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -171,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-account request payload")) return } if err := uar.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if len(uar.Status) > 0 || len(uar.Contact) > 0 { @@ -187,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { } if err := h.db.UpdateAccount(ctx, acc); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) + render.Error(w, acme.WrapErrorISE(err, "error updating account")) return } } @@ -196,7 +197,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { h.linker.LinkAccount(ctx, acc) w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) - api.JSON(w, acc) + render.JSON(w, acc) } func logOrdersByAccount(w http.ResponseWriter, oids []string) { @@ -213,22 +214,22 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } accID := chi.URLParam(r, "accID") if acc.ID != accID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) return } orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } h.linker.LinkOrdersByAccountID(ctx, orders) - api.JSON(w, orders) + render.JSON(w, orders) logOrdersByAccount(w, orders) } diff --git a/acme/api/handler.go b/acme/api/handler.go index bd226e73..10eb22cb 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -11,8 +12,10 @@ import ( "time" "github.com/go-chi/chi" + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" ) @@ -43,6 +46,7 @@ type Handler struct { ca acme.CertificateAuthority linker Linker validateChallengeOptions *acme.ValidateChallengeOptions + prerequisitesChecker func(ctx context.Context) (bool, error) } // HandlerOptions required to create a new ACME API request handler. @@ -60,6 +64,9 @@ type HandlerOptions struct { // "acme" is the prefix from which the ACME api is accessed. Prefix string CA acme.CertificateAuthority + // PrerequisitesChecker checks if all prerequisites for serving ACME are + // met by the CA configuration. + PrerequisitesChecker func(ctx context.Context) (bool, error) } // NewHandler returns a new ACME API handler. @@ -76,6 +83,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { dialer := &net.Dialer{ Timeout: 30 * time.Second, } + prerequisitesChecker := func(ctx context.Context) (bool, error) { + // by default all prerequisites are met + return true, nil + } + if ops.PrerequisitesChecker != nil { + prerequisitesChecker = ops.PrerequisitesChecker + } return &Handler{ ca: ops.CA, db: ops.DB, @@ -88,6 +102,7 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { return tls.DialWithDialer(dialer, network, addr, config) }, }, + prerequisitesChecker: prerequisitesChecker, } } @@ -95,13 +110,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler { func (h *Handler) Route(r api.Router) { getPath := h.linker.GetUnescapedPathSuffix // Standard ACME API - r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) - r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) - r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) - r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) + r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) + r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) + r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) + r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory)))) validatingMiddleware := func(next nextHTTP) nextHTTP { - return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))) + return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))) } extractPayloadByJWK := func(next nextHTTP) nextHTTP { return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) @@ -168,11 +183,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acmeProv, err := acmeProvisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } - api.JSON(w, &Directory{ + render.JSON(w, &Directory{ NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), @@ -187,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { // NotImplemented returns a 501 and is generally a placeholder for functionality which // MAY be added at some point in the future but is not in any way a guarantee of such. func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } // GetAuthorization ACME api for retrieving an Authz. @@ -195,28 +210,28 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) return } if acc.ID != az.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own authorization '%s'", acc.ID, az.ID)) return } if err = az.UpdateStatus(ctx, h.db); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) + render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) return } h.linker.LinkAuthorization(ctx, az) w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) - api.JSON(w, az) + render.JSON(w, az) } // GetChallenge ACME api for retrieving a Challenge. @@ -224,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // Just verify that the payload was set, since we're not strictly adhering // to ACME V2 spec for reasons specified below. _, err = payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -244,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { azID := chi.URLParam(r, "authzID") ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) return } ch.AuthorizationID = azID if acc.ID != ch.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own challenge '%s'", acc.ID, ch.ID)) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) + render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) return } @@ -267,7 +282,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) - api.JSON(w, ch) + render.JSON(w, ch) } // GetCertificate ACME api for retrieving a Certificate. @@ -275,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } certID := chi.URLParam(r, "certID") cert, err := h.db.GetCertificate(ctx, certID) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) return } if cert.AccountID != acc.ID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own certificate '%s'", acc.ID, certID)) return } diff --git a/acme/api/middleware.go b/acme/api/middleware.go index de8614ee..10f7841f 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -10,13 +10,14 @@ import ( "strings" "github.com/go-chi/chi" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/nosql" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/keyutil" ) type nextHTTP = func(http.ResponseWriter, *http.Request) @@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { nonce, err := h.db.CreateNonce(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } w.Header().Set("Replay-Nonce", string(nonce)) @@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { var expected []string p, err := provisionerFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return } } - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected content-type to be in %s, but got %s", expected, ct)) } } @@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) + render.Error(w, acme.WrapErrorISE(err, "failed to read request body")) return } jws, err := jose.ParseJWS(string(body)) if err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) return } ctx := context.WithValue(r.Context(), jwsContextKey, jws) @@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if len(jws.Signatures) == 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) return } if len(jws.Signatures) > 1 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) return } @@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { len(uh.Algorithm) > 0 || len(uh.Nonce) > 0 || len(uh.ExtraHeaders) > 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) return } hdr := sig.Protected @@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { switch k := hdr.JSONWebKey.Key.(type) { case *rsa.PublicKey: if k.Size() < keyutil.MinRSAKeyBytes { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "rsa keys must be at least %d bits (%d bytes) in size", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) return } default: - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws key type and algorithm do not match")) return } @@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: // we good default: - api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) + render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) return } // Check the validity/freshness of the Nonce. if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } // Check that the JWS url matches the requested url. jwsURL, ok := hdr.ExtraHeaders["url"].(string) if !ok { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) return } reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} if jwsURL != reqURL.String() { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) return } if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) return } if hdr.JSONWebKey == nil && hdr.KeyID == "" { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) return } next(w, r) @@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } jwk := jws.Signatures[0].Protected.JSONWebKey if jwk == nil { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) return } if !jwk.Valid() { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } // Overwrite KeyID with the JWK thumbprint. jwk.KeyID, err = acme.KeyToID(jwk) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) + render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } @@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { // For NewAccount and Revoke requests ... break case err != nil: - api.WriteError(w, err) + render.Error(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) @@ -283,25 +284,24 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { } // lookupProvisioner loads the provisioner associated with the request. -// Responsds 404 if the provisioner does not exist. +// Responds 404 if the provisioner does not exist. func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - nameEscaped := chi.URLParam(r, "provisionerID") name, err := url.PathUnescape(nameEscaped) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) + render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) return } p, err := h.ca.LoadProvisionerByName(name) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } acmeProv, ok := p.(*provisioner.ACME) if !ok { - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) return } ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) @@ -309,6 +309,24 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { } } +// checkPrerequisites checks if all prerequisites for serving ACME +// are met by the CA configuration. +func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ok, err := h.prerequisitesChecker(ctx) + if err != nil { + render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) + return + } + if !ok { + render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) + return + } + next(w, r.WithContext(ctx)) + } +} + // lookupJWK loads the JWK associated with the acme account referenced by the // kid parameter of the signed payload. // Make sure to parse and validate the JWS before running this middleware. @@ -317,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, + render.Error(w, acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got %s", kidPrefix, kid)) return @@ -334,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { acc, err := h.db.GetAccount(ctx, accID) switch { case nosql.IsErrNotFound(err): - api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) + render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) return case err != nil: - api.WriteError(w, err) + render.Error(w, err) return default: if !acc.IsValid() { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) return } ctx = context.WithValue(ctx, accContextKey, acc) @@ -359,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -395,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } jwk, err := jwkFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) return } payload, err := jws.Verify(jwk) if err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) return } ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ @@ -426,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { payload, err := payloadFromContext(r.Context()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if !payload.isPostAsGet { - api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) + render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) return } next(w, r) diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 050b46a5..8003fa16 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -1656,3 +1656,91 @@ func TestHandler_extractOrLookupJWK(t *testing.T) { }) } } + +func TestHandler_checkPrerequisites(t *testing.T) { + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + u := fmt.Sprintf("%s/acme/%s/account/1234", + baseURL, provName) + type test struct { + linker Linker + ctx context.Context + prerequisitesChecker func(context.Context) (bool, error) + next func(http.ResponseWriter, *http.Request) + err *acme.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/error": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + linker: NewLinker("dns", "acme"), + ctx: ctx, + prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + err: acme.WrapErrorISE(errors.New("force"), "error checking acme provisioner prerequisites"), + statusCode: 500, + } + }, + "fail/prerequisites-nok": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + linker: NewLinker("dns", "acme"), + ctx: ctx, + prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + err: acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"), + statusCode: 501, + } + }, + "ok": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + linker: NewLinker("dns", "acme"), + ctx: ctx, + prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + statusCode: 200, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} + req := httptest.NewRequest("GET", u, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.checkPrerequisites(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} diff --git a/acme/api/order.go b/acme/api/order.go index 9cf2c1eb..99eb0e95 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -11,9 +11,11 @@ import ( "time" "github.com/go-chi/chi" - "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "go.step.sm/crypto/randutil" + + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api/render" ) // NewOrderRequest represents the body for a NewOrder request. @@ -70,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var nor NewOrderRequest if err := json.Unmarshal(payload.value, &nor); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal new-order request payload")) return } if err := nor.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -116,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { Status: acme.StatusPending, } if err := h.newAuthorization(ctx, az); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o.AuthorizationIDs[i] = az.ID @@ -135,14 +137,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { } if err := h.db.CreateOrder(ctx, o); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) + render.Error(w, acme.WrapErrorISE(err, "error creating order")) return } h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSONStatus(w, o, http.StatusCreated) + render.JSONStatus(w, o, http.StatusCreated) } func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { @@ -186,38 +188,38 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.UpdateStatus(ctx, h.db); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) + render.Error(w, acme.WrapErrorISE(err, "error updating order status")) return } h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSON(w, o) + render.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. @@ -225,54 +227,54 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var fr FinalizeRequest if err := json.Unmarshal(payload.value, &fr); err != nil { - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to unmarshal finalize-order request payload")) return } if err := fr.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) return } if acc.ID != o.AccountID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account '%s' does not own order '%s'", acc.ID, o.ID)) return } if prov.GetID() != o.ProvisionerID { - api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, + render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) return } if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) + render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) return } h.linker.LinkOrder(ctx, o) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) - api.JSON(w, o) + render.JSON(w, o) } // challengeTypes determines the types of challenges that should be used diff --git a/acme/api/revoke.go b/acme/api/revoke.go index d01e401c..4b71bc22 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -10,13 +10,14 @@ import ( "net/http" "strings" + "go.step.sm/crypto/jose" + "golang.org/x/crypto/ocsp" + "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" - "go.step.sm/crypto/jose" - "golang.org/x/crypto/ocsp" ) type revokePayload struct { @@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { ctx := r.Context() jws, err := jwsFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } prov, err := provisionerFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } payload, err := payloadFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } var p revokePayload err = json.Unmarshal(payload.value, &p) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload")) + render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload")) return } certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) if err != nil { // in this case the most likely cause is a client that didn't properly encode the certificate - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) return } certToBeRevoked, err := x509.ParseCertificate(certBytes) if err != nil { // in this case a client may have encoded something different than a certificate - api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) + render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) return } serial := certToBeRevoked.SerialNumber.String() dbCert, err := h.db.GetCertificateBySerial(ctx, serial) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return } if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { // this should never happen - api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal")) + render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal")) return } if shouldCheckAccountFrom(jws) { account, err := accountFromContext(ctx) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) if acmeErr != nil { - api.WriteError(w, acmeErr) + render.Error(w, acmeErr) return } } else { @@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { _, err := jws.Verify(certToBeRevoked.PublicKey) if err != nil { // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? - api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) + render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) return } } hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) + render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) return } if hasBeenRevokedBefore { - api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) + render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) return } reasonCode := p.ReasonCode acmeErr := validateReasonCode(reasonCode) if acmeErr != nil { - api.WriteError(w, acmeErr) + render.Error(w, acmeErr) return } @@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) err = prov.AuthorizeRevoke(ctx, "") if err != nil { - api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) + render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) return } options := revokeOptions(serial, certToBeRevoked, reasonCode) err = h.ca.Revoke(ctx, options) if err != nil { - api.WriteError(w, wrapRevokeErr(err)) + render.Error(w, wrapRevokeErr(err)) return } diff --git a/acme/challenge.go b/acme/challenge.go index 0e1994e4..9f08bae5 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -79,7 +79,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, } func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { - u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} + u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} resp, err := vo.HTTPGet(u.String()) if err != nil { @@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb return nil } +// http01ChallengeHost checks if a Challenge value is an IPv6 address +// and adds square brackets if that's the case, so that it can be used +// as a hostname. Returns the original Challenge value as the host to +// use in other cases. +func http01ChallengeHost(value string) string { + if ip := net.ParseIP(value); ip != nil && ip.To4() == nil { + value = "[" + value + "]" + } + return value +} + func tlsAlert(err error) uint8 { var opErr *net.OpError if errors.As(err, &opErr) { diff --git a/acme/challenge_test.go b/acme/challenge_test.go index d8ce4d76..c05b25e7 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -13,6 +13,7 @@ import ( "encoding/asn1" "encoding/base64" "encoding/hex" + "errors" "fmt" "io" "math/big" @@ -23,9 +24,9 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" ) func Test_storeError(t *testing.T) { @@ -2350,3 +2351,34 @@ func Test_serverName(t *testing.T) { }) } } + +func Test_http01ChallengeHost(t *testing.T) { + tests := []struct { + name string + value string + want string + }{ + { + name: "dns", + value: "www.example.com", + want: "www.example.com", + }, + { + name: "ipv4", + value: "127.0.0.1", + want: "127.0.0.1", + }, + { + name: "ipv6", + value: "::1", + want: "[::1]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := http01ChallengeHost(tt.value); got != tt.want { + t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/acme/errors.go b/acme/errors.go index a5c820ba..05888c24 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -3,13 +3,10 @@ package acme import ( "encoding/json" "fmt" - "log" "net/http" - "os" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/api/render" ) // ProblemType is the type of the ACME problem. @@ -353,26 +350,8 @@ func (e *Error) ToLog() (interface{}, error) { return string(b), nil } -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err *Error) { +// Render implements render.RenderableError for Error. +func (e *Error) Render(w http.ResponseWriter) { w.Header().Set("Content-Type", "application/problem+json") - w.WriteHeader(err.StatusCode()) - - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err.Err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.Err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } - - if err := json.NewEncoder(w).Encode(err); err != nil { - log.Println(err) - } + render.JSONStatus(w, e, e.StatusCode()) } diff --git a/api/api.go b/api/api.go index 16e24bb2..da6309fd 100644 --- a/api/api.go +++ b/api/api.go @@ -20,6 +20,9 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + + "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" @@ -33,6 +36,7 @@ type Authority interface { // context specifies the Authorize[Sign|Revoke|etc.] method. Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error) + AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) @@ -43,7 +47,7 @@ type Authority interface { GetProvisioners(cursor string, limit int) (provisioner.List, string, error) Revoke(context.Context, *authority.RevokeOptions) error GetEncryptedKey(kid string) (string, error) - GetRoots() (federation []*x509.Certificate, err error) + GetRoots() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error) Version() authority.Version } @@ -257,6 +261,7 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("GET", "/provisioners", h.Provisioners) r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) r.MethodFunc("GET", "/roots", h.Roots) + r.MethodFunc("GET", "/roots.pem", h.RootsPEM) r.MethodFunc("GET", "/federation", h.Federation) // SSH CA r.MethodFunc("POST", "/ssh/sign", h.SSHSign) @@ -280,7 +285,7 @@ func (h *caHandler) Route(r Router) { // Version is an HTTP handler that returns the version of the server. func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { v := h.Authority.Version() - JSON(w, VersionResponse{ + render.JSON(w, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, }) @@ -288,7 +293,7 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { // Health is an HTTP handler that returns the status of the server. func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { - JSON(w, HealthResponse{Status: "ok"}) + render.JSON(w, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root @@ -299,11 +304,11 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { // Load root certificate with the cert, err := h.Authority.Root(sum) if err != nil { - WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) + render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return } - JSON(w, &RootResponse{RootPEM: Certificate{cert}}) + render.JSON(w, &RootResponse{RootPEM: Certificate{cert}}) } func certChainToPEM(certChain []*x509.Certificate) []Certificate { @@ -318,16 +323,16 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { - WriteError(w, err) + render.Error(w, err) return } p, next, err := h.Authority.GetProvisioners(cursor, limit) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &ProvisionersResponse{ + render.JSON(w, &ProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -338,17 +343,17 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") key, err := h.Authority.GetEncryptedKey(kid) if err != nil { - WriteError(w, errs.NotFoundErr(err)) + render.Error(w, errs.NotFoundErr(err)) return } - JSON(w, &ProvisionerKeyResponse{key}) + render.JSON(w, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { roots, err := h.Authority.GetRoots() if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error getting roots")) + render.Error(w, errs.ForbiddenErr(err, "error getting roots")) return } @@ -357,16 +362,39 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{roots[i]} } - JSONStatus(w, &RootsResponse{ + render.JSONStatus(w, &RootsResponse{ Certificates: certs, }, http.StatusCreated) } +// RootsPEM returns all the root certificates for the CA in PEM format. +func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { + roots, err := h.Authority.GetRoots() + if err != nil { + render.Error(w, errs.InternalServerErr(err)) + return + } + + w.Header().Set("Content-Type", "application/x-pem-file") + + for _, root := range roots { + block := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: root.Raw, + }) + + if _, err := w.Write(block); err != nil { + log.Error(w, err) + return + } + } +} + // Federation returns all the public certificates in the federation. func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { federated, err := h.Authority.GetFederation() if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error getting federated roots")) + render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) return } @@ -375,7 +403,7 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { certs[i] = Certificate{federated[i]} } - JSONStatus(w, &FederationResponse{ + render.JSONStatus(w, &FederationResponse{ Certificates: certs, }, http.StatusCreated) } diff --git a/api/api_test.go b/api/api_test.go index c7528f9b..39c77de7 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "encoding/json" "encoding/pem" "fmt" @@ -27,14 +28,17 @@ import ( "github.com/go-chi/chi" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + + "go.step.sm/crypto/jose" + "go.step.sm/crypto/x509util" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" - "go.step.sm/crypto/jose" - "golang.org/x/crypto/ssh" ) const ( @@ -171,6 +175,7 @@ type mockAuthority struct { ret1, ret2 interface{} err error authorizeSign func(ott string) ([]provisioner.SignOption, error) + authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) @@ -208,6 +213,13 @@ func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, err return m.ret1.([]provisioner.SignOption), m.err } +func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) { + if m.authorizeRenewToken != nil { + return m.authorizeRenewToken(ctx, ott) + } + return m.ret1.(*x509.Certificate), m.err +} + func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { if m.getTLSOptions != nil { return m.getTLSOptions() @@ -920,48 +932,141 @@ func Test_caHandler_Renew(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, } + + // Prepare root and leaf for renew after expiry test. + now := time.Now() + rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + root := &x509.Certificate{ + Subject: pkix.Name{CommonName: "Test Root CA"}, + PublicKey: rootPub, + KeyUsage: x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + NotBefore: now.Add(-2 * time.Hour), + NotAfter: now.Add(time.Hour), + } + root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv) + if err != nil { + t.Fatal(err) + } + expiredLeaf := &x509.Certificate{ + Subject: pkix.Name{CommonName: "Leaf certificate"}, + PublicKey: leafPub, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + EmailAddresses: []string{"test@example.org"}, + } + expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv) + if err != nil { + t.Fatal(err) + } + + // Generate renew after expiry token + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)}) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so) + if err != nil { + t.Fatal(err) + } + generateX5cToken := func(claims jose.Claims) string { + s, err := jose.Signed(sig).Claims(claims).CompactSerialize() + if err != nil { + t.Fatal(err) + } + return s + } + tests := []struct { name string tls *tls.ConnectionState + header http.Header cert *x509.Certificate root *x509.Certificate err error statusCode int }{ - {"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, - {"no tls", nil, nil, nil, nil, http.StatusBadRequest}, - {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, - {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, + {"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"ok renew after expiry", &tls.ConnectionState{}, http.Header{ + "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ + NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + })}, + }, expiredLeaf, root, nil, http.StatusCreated}, + {"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest}, + {"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest}, + {"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, + {"fail expired token", &tls.ConnectionState{}, http.Header{ + "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ + NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), + })}, + }, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized}, + {"fail invalid root", &tls.ConnectionState{}, http.Header{ + "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{ + NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)), + })}, + }, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized}, } - expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, + authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { + jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) + if err != nil { + return nil, errs.Unauthorized(err.Error()) + } + var claims jose.Claims + if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil { + return nil, errs.Unauthorized(err.Error()) + } + if err := claims.ValidateWithLeeway(jose.Expected{ + Time: now, + }, time.Minute); err != nil { + return nil, errs.Unauthorized(err.Error()) + } + return chain[0][0], nil + }, getTLSOptions: func() *authority.TLSOptions { return nil }, }).(*caHandler) req := httptest.NewRequest("POST", "http://example.com/renew", nil) req.TLS = tt.tls + req.Header = tt.header w := httptest.NewRecorder() h.Renew(logging.NewResponseLogger(w), req) - res := w.Result() - if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) - } + res := w.Result() + defer res.Body.Close() body, err := io.ReadAll(res.Body) - res.Body.Close() if err != nil { t.Errorf("caHandler.Renew unexpected error = %v", err) } + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + t.Errorf("%s", body) + } + if tt.statusCode < http.StatusBadRequest { + expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` + + `"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` + + `"certChain":["` + + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` + + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`) + if !bytes.Equal(bytes.TrimSpace(body), expected) { - t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) + t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected) } } }) @@ -1239,6 +1344,46 @@ func Test_caHandler_Roots(t *testing.T) { } } +func Test_caHandler_RootsPEM(t *testing.T) { + parsedRoot := parseCertificate(rootPEM) + tests := []struct { + name string + roots []*x509.Certificate + err error + statusCode int + expect string + }{ + {"one root", []*x509.Certificate{parsedRoot}, nil, http.StatusOK, rootPEM}, + {"two roots", []*x509.Certificate{parsedRoot, parsedRoot}, nil, http.StatusOK, rootPEM + "\n" + rootPEM}, + {"fail", nil, errors.New("an error"), http.StatusInternalServerError, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler) + req := httptest.NewRequest("GET", "https://example.com/roots", nil) + w := httptest.NewRecorder() + h.RootsPEM(w, req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.RootsPEM StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.RootsPEM unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), []byte(tt.expect)) { + t.Errorf("caHandler.RootsPEM Body = %s, wants %s", body, tt.expect) + } + } + }) + } +} + func Test_caHandler_Federation(t *testing.T) { cs := &tls.ConnectionState{ PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, diff --git a/api/errors.go b/api/errors.go deleted file mode 100644 index bff46b55..00000000 --- a/api/errors.go +++ /dev/null @@ -1,64 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "net/http" - "os" - - "github.com/pkg/errors" - "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/admin" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" - "github.com/smallstep/certificates/scep" -) - -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err error) { - switch k := err.(type) { - case *acme.Error: - acme.WriteError(w, k) - return - case *admin.Error: - admin.WriteError(w, k) - return - case *scep.Error: - w.Header().Set("Content-Type", "text/plain") - default: - w.Header().Set("Content-Type", "application/json") - } - - cause := errors.Cause(err) - if sc, ok := err.(errs.StatusCoder); ok { - w.WriteHeader(sc.StatusCode()) - } else { - if sc, ok := cause.(errs.StatusCoder); ok { - w.WriteHeader(sc.StatusCode()) - } else { - w.WriteHeader(http.StatusInternalServerError) - } - } - - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } else if e, ok := cause.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } - - if err := json.NewEncoder(w).Encode(err); err != nil { - LogError(w, err) - } -} diff --git a/api/log/log.go b/api/log/log.go new file mode 100644 index 00000000..cb31410b --- /dev/null +++ b/api/log/log.go @@ -0,0 +1,79 @@ +// Package log implements API-related logging helpers. +package log + +import ( + "fmt" + "log" + "net/http" + "os" + + "github.com/pkg/errors" + + "github.com/smallstep/certificates/logging" +) + +// StackTracedError is the set of errors implementing the StackTrace function. +// +// Errors implementing this interface have their stack traces logged when passed +// to the Error function of this package. +type StackTracedError interface { + error + + StackTrace() errors.StackTrace +} + +// Error adds to the response writer the given error if it implements +// logging.ResponseLogger. If it does not implement it, then writes the error +// using the log package. +func Error(rw http.ResponseWriter, err error) { + rl, ok := rw.(logging.ResponseLogger) + if !ok { + log.Println(err) + + return + } + + rl.WithFields(map[string]interface{}{ + "error": err, + }) + + if os.Getenv("STEPDEBUG") != "1" { + return + } + + e, ok := err.(StackTracedError) + if !ok { + e, ok = errors.Cause(err).(StackTracedError) + } + + if ok { + rl.WithFields(map[string]interface{}{ + "stack-trace": fmt.Sprintf("%+v", e.StackTrace()), + }) + } +} + +// EnabledResponse log the response object if it implements the EnableLogger +// interface. +func EnabledResponse(rw http.ResponseWriter, v interface{}) { + type enableLogger interface { + ToLog() (interface{}, error) + } + + if el, ok := v.(enableLogger); ok { + out, err := el.ToLog() + if err != nil { + Error(rw, err) + + return + } + + if rl, ok := rw.(logging.ResponseLogger); ok { + rl.WithFields(map[string]interface{}{ + "response": out, + }) + } else { + log.Println(out) + } + } +} diff --git a/api/log/log_test.go b/api/log/log_test.go new file mode 100644 index 00000000..fcd3ea2b --- /dev/null +++ b/api/log/log_test.go @@ -0,0 +1,44 @@ +package log + +import ( + "errors" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/smallstep/certificates/logging" +) + +func TestError(t *testing.T) { + theError := errors.New("the error") + + type args struct { + rw http.ResponseWriter + err error + } + tests := []struct { + name string + args args + withFields bool + }{ + {"normalLogger", args{httptest.NewRecorder(), theError}, false}, + {"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Error(tt.args.rw, tt.args.err) + if tt.withFields { + if rl, ok := tt.args.rw.(logging.ResponseLogger); ok { + fields := rl.Fields() + if !reflect.DeepEqual(fields["error"], theError) { + t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError) + } + } else { + t.Error("ResponseWriter does not implement logging.ResponseLogger") + } + } + }) + } +} diff --git a/api/read/read.go b/api/read/read.go new file mode 100644 index 00000000..de92c5d7 --- /dev/null +++ b/api/read/read.go @@ -0,0 +1,31 @@ +// Package read implements request object readers. +package read + +import ( + "encoding/json" + "io" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/smallstep/certificates/errs" +) + +// JSON reads JSON from the request body and stores it in the value +// pointed by v. +func JSON(r io.Reader, v interface{}) error { + if err := json.NewDecoder(r).Decode(v); err != nil { + return errs.BadRequestErr(err, "error decoding json") + } + return nil +} + +// ProtoJSON reads JSON from the request body and stores it in the value +// pointed by v. +func ProtoJSON(r io.Reader, m proto.Message) error { + data, err := io.ReadAll(r) + if err != nil { + return errs.BadRequestErr(err, "error reading request body") + } + return protojson.Unmarshal(data, m) +} diff --git a/api/read/read_test.go b/api/read/read_test.go new file mode 100644 index 00000000..f2eff1bc --- /dev/null +++ b/api/read/read_test.go @@ -0,0 +1,46 @@ +package read + +import ( + "io" + "reflect" + "strings" + "testing" + + "github.com/smallstep/certificates/errs" +) + +func TestJSON(t *testing.T) { + type args struct { + r io.Reader + v interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false}, + {"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := JSON(tt.args.r, &tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + e, ok := err.(*errs.Error) + if ok { + if code := e.StatusCode(); code != 400 { + t.Errorf("error.StatusCode() = %v, wants 400", code) + } + } else { + t.Errorf("error type = %T, wants *Error", err) + } + } else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) { + t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"}) + } + }) + } +} diff --git a/api/rekey.go b/api/rekey.go index b7958844..3116cf74 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -3,6 +3,8 @@ package api import ( "net/http" + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) @@ -27,24 +29,24 @@ func (s *RekeyRequest) Validate() error { // Rekey is similar to renew except that the certificate will be renewed with new key from csr. func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing client certificate")) + render.Error(w, errs.BadRequest("missing client certificate")) return } var body RekeyRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { - WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) + render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return } certChainPEM := certChainToPEM(certChain) @@ -54,7 +56,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/render/render.go b/api/render/render.go new file mode 100644 index 00000000..9df4c791 --- /dev/null +++ b/api/render/render.go @@ -0,0 +1,122 @@ +// Package render implements functionality related to response rendering. +package render + +import ( + "bytes" + "encoding/json" + "net/http" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/smallstep/certificates/api/log" +) + +// JSON is shorthand for JSONStatus(w, v, http.StatusOK). +func JSON(w http.ResponseWriter, v interface{}) { + JSONStatus(w, v, http.StatusOK) +} + +// JSONStatus marshals v into w. It additionally sets the status code of +// w to the given one. +// +// JSONStatus sets the Content-Type of w to application/json unless one is +// specified. +func JSONStatus(w http.ResponseWriter, v interface{}, status int) { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(v); err != nil { + panic(err) + } + + setContentTypeUnlessPresent(w, "application/json") + w.WriteHeader(status) + _, _ = b.WriteTo(w) + + log.EnabledResponse(w, v) +} + +// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK). +func ProtoJSON(w http.ResponseWriter, m proto.Message) { + ProtoJSONStatus(w, m, http.StatusOK) +} + +// ProtoJSONStatus writes the given value into the http.ResponseWriter and the +// given status is written as the status code of the response. +func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { + b, err := protojson.Marshal(m) + if err != nil { + panic(err) + } + + setContentTypeUnlessPresent(w, "application/json") + w.WriteHeader(status) + _, _ = w.Write(b) +} + +func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) { + const header = "Content-Type" + + h := w.Header() + if _, ok := h[header]; !ok { + h.Set(header, contentType) + } +} + +// RenderableError is the set of errors that implement the basic Render method. +// +// Errors that implement this interface will use their own Render method when +// being rendered into responses. +type RenderableError interface { + error + + Render(http.ResponseWriter) +} + +// Error marshals the JSON representation of err to w. In case err implements +// RenderableError its own Render method will be called instead. +func Error(w http.ResponseWriter, err error) { + log.Error(w, err) + + if e, ok := err.(RenderableError); ok { + e.Render(w) + + return + } + + JSONStatus(w, err, statusCodeFromError(err)) +} + +// StatusCodedError is the set of errors that implement the basic StatusCode +// function. +// +// Errors that implement this interface will use the code reported by StatusCode +// as the HTTP response code when being rendered by this package. +type StatusCodedError interface { + error + + StatusCode() int +} + +func statusCodeFromError(err error) (code int) { + code = http.StatusInternalServerError + + type causer interface { + Cause() error + } + + for err != nil { + if sc, ok := err.(StatusCodedError); ok { + code = sc.StatusCode() + + break + } + + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + + return +} diff --git a/api/render/render_test.go b/api/render/render_test.go new file mode 100644 index 00000000..06d092d3 --- /dev/null +++ b/api/render/render_test.go @@ -0,0 +1,115 @@ +package render + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/smallstep/certificates/logging" +) + +func TestJSON(t *testing.T) { + rec := httptest.NewRecorder() + rw := logging.NewResponseLogger(rec) + + JSON(rw, map[string]interface{}{"foo": "bar"}) + + assert.Equal(t, http.StatusOK, rec.Result().StatusCode) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String()) + + assert.Empty(t, rw.Fields()) +} + +func TestJSONPanics(t *testing.T) { + assert.Panics(t, func() { + JSON(httptest.NewRecorder(), make(chan struct{})) + }) +} + +type renderableError struct { + Code int `json:"-"` + Message string `json:"message"` +} + +func (err renderableError) Error() string { + return err.Message +} + +func (err renderableError) Render(w http.ResponseWriter) { + w.Header().Set("Content-Type", "something/custom") + + JSONStatus(w, err, err.Code) +} + +type statusedError struct { + Contents string +} + +func (err statusedError) Error() string { return err.Contents } + +func (statusedError) StatusCode() int { return 432 } + +func TestError(t *testing.T) { + cases := []struct { + err error + code int + body string + header string + }{ + 0: { + err: renderableError{532, "some string"}, + code: 532, + body: "{\"message\":\"some string\"}\n", + header: "something/custom", + }, + 1: { + err: statusedError{"123"}, + code: 432, + body: "{\"Contents\":\"123\"}\n", + header: "application/json", + }, + } + + for caseIndex := range cases { + kase := cases[caseIndex] + + t.Run(strconv.Itoa(caseIndex), func(t *testing.T) { + rec := httptest.NewRecorder() + + Error(rec, kase.err) + + assert.Equal(t, kase.code, rec.Result().StatusCode) + assert.Equal(t, kase.body, rec.Body.String()) + assert.Equal(t, kase.header, rec.Header().Get("Content-Type")) + }) + } +} + +type causedError struct { + cause error +} + +func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) } +func (err causedError) Cause() error { return err.cause } + +func TestStatusCodeFromError(t *testing.T) { + cases := []struct { + err error + exp int + }{ + 0: {nil, http.StatusInternalServerError}, + 1: {io.EOF, http.StatusInternalServerError}, + 2: {statusedError{"123"}, 432}, + 3: {causedError{statusedError{"432"}}, 432}, + } + + for caseIndex, kase := range cases { + assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex) + } +} diff --git a/api/renew.go b/api/renew.go index 725322ee..9c4bff32 100644 --- a/api/renew.go +++ b/api/renew.go @@ -1,22 +1,31 @@ package api import ( + "crypto/x509" "net/http" + "strings" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/errs" ) +const ( + authorizationHeader = "Authorization" + bearerScheme = "Bearer" +) + // Renew uses the information of certificate in the TLS connection to create a // new one. func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing client certificate")) + cert, err := h.getPeerCertificate(r) + if err != nil { + render.Error(w, err) return } - certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) + certChain, err := h.Authority.Renew(cert) if err != nil { - WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) + render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return } certChainPEM := certChainToPEM(certChain) @@ -26,10 +35,22 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, TLSOptions: h.Authority.GetTLSOptions(), }, http.StatusCreated) } + +func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) { + if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { + return r.TLS.PeerCertificates[0], nil + } + if s := r.Header.Get(authorizationHeader); s != "" { + if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { + return h.Authority.AuthorizeRenewToken(r.Context(), parts[1]) + } + } + return nil, errs.BadRequest("missing client certificate") +} diff --git a/api/revoke.go b/api/revoke.go index 25520e3e..c9da2c18 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -4,11 +4,14 @@ import ( "context" "net/http" + "golang.org/x/crypto/ocsp" + + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" - "golang.org/x/crypto/ocsp" ) // RevokeResponse is the response object that returns the health of the server. @@ -48,13 +51,13 @@ func (r *RevokeRequest) Validate() (err error) { // TODO: Add CRL and OCSP support. func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -71,7 +74,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { if len(body.OTT) > 0 { logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT @@ -80,12 +83,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { // the client certificate Serial Number must match the serial number // being revoked. if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, errs.BadRequest("missing ott or client certificate")) + render.Error(w, errs.BadRequest("missing ott or client certificate")) return } opts.Crt = r.TLS.PeerCertificates[0] if opts.Crt.SerialNumber.String() != opts.Serial { - WriteError(w, errs.BadRequest("serial number in client certificate different than body")) + render.Error(w, errs.BadRequest("serial number in client certificate different than body")) return } // TODO: should probably be checking if the certificate was revoked here. @@ -96,12 +99,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { } if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err, "error revoking certificate")) + render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) return } logRevoke(w, opts) - JSON(w, &RevokeResponse{Status: "ok"}) + render.JSON(w, &RevokeResponse{Status: "ok"}) } func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/api/revoke_test.go b/api/revoke_test.go index 4ed4e3fe..7635ce68 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -13,6 +13,7 @@ import ( "testing" "github.com/pkg/errors" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" diff --git a/api/sign.go b/api/sign.go index 93c5f599..b6bfcc8b 100644 --- a/api/sign.go +++ b/api/sign.go @@ -5,6 +5,8 @@ import ( "encoding/json" "net/http" + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" @@ -49,14 +51,14 @@ type SignResponse struct { // information in the certificate request. func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -68,13 +70,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { signOpts, err := h.Authority.AuthorizeSign(body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return } certChainPEM := certChainToPEM(certChain) @@ -83,7 +85,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { caPEM = certChainPEM[1] } LogCertificate(w, certChain[0]) - JSONStatus(w, &SignResponse{ + render.JSONStatus(w, &SignResponse{ ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, diff --git a/api/ssh.go b/api/ssh.go index c9be1527..3b0de7c1 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -9,12 +9,15 @@ import ( "time" "github.com/pkg/errors" + "golang.org/x/crypto/ssh" + + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/templates" - "golang.org/x/crypto/ssh" ) // SSHAuthority is the interface implemented by a SSH CA authority. @@ -249,20 +252,20 @@ type SSHBastionResponse struct { // the request. func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) return } @@ -270,7 +273,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if body.AddUserPublicKey != nil { addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) return } } @@ -287,13 +290,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } @@ -301,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } addUserCertificate = &SSHCertificate{addUserCert} @@ -314,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } @@ -326,13 +329,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return } identityCertificate = certChainToPEM(certChain) } - JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, &SSHSignResponse{ Certificate: SSHCertificate{cert}, AddUserCertificate: addUserCertificate, IdentityCertificate: identityCertificate, @@ -344,12 +347,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHRoots(r.Context()) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound("no keys found")) + render.Error(w, errs.NotFound("no keys found")) return } @@ -361,7 +364,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - JSON(w, resp) + render.JSON(w, resp) } // SSHFederation is an HTTP handler that returns the federated SSH public keys @@ -369,12 +372,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { keys, err := h.Authority.GetSSHFederation(r.Context()) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { - WriteError(w, errs.NotFound("no keys found")) + render.Error(w, errs.NotFound("no keys found")) return } @@ -386,25 +389,25 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) } - JSON(w, resp) + render.JSON(w, resp) } // SSHConfig is an HTTP handler that returns rendered templates for ssh clients // and servers. func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } @@ -415,31 +418,31 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { case provisioner.SSHHostCert: cfg.HostTemplates = ts default: - WriteError(w, errs.InternalServer("it should hot get here")) + render.Error(w, errs.InternalServer("it should hot get here")) return } - JSON(w, cfg) + render.JSON(w, cfg) } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHCheckPrincipalResponse{ + render.JSON(w, &SSHCheckPrincipalResponse{ Exists: exists, }) } @@ -453,10 +456,10 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHGetHostsResponse{ + render.JSON(w, &SSHGetHostsResponse{ Hosts: hosts, }) } @@ -464,22 +467,22 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { // SSHBastion provides returns the bastion configured if any. func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - JSON(w, &SSHBastionResponse{ + render.JSON(w, &SSHBastionResponse{ Hostname: body.Hostname, Bastion: bastion, }) diff --git a/api/sshRekey.go b/api/sshRekey.go index 8670f0bd..92278950 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -4,9 +4,12 @@ import ( "net/http" "time" + "golang.org/x/crypto/ssh" + + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" - "golang.org/x/crypto/ssh" ) // SSHRekeyRequest is the request body of an SSH certificate request. @@ -38,37 +41,38 @@ type SSHRekeyResponse struct { // the request. func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } publicKey, err := ssh.ParsePublicKey(body.PublicKey) if err != nil { - WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) + render.Error(w, errs.BadRequestErr(err, "error parsing publicKey")) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) + return } newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return } @@ -78,11 +82,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } - JSONStatus(w, &SSHRekeyResponse{ + render.JSONStatus(w, &SSHRekeyResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRenew.go b/api/sshRenew.go index 57b6f432..78d16fa6 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -6,6 +6,9 @@ import ( "time" "github.com/pkg/errors" + + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -36,31 +39,32 @@ type SSHRenewResponse struct { // the request. func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) _, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) if err != nil { - WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) + return } newCert, err := h.Authority.RenewSSH(ctx, oldCert) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) return } @@ -70,11 +74,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) + render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } - JSONStatus(w, &SSHSignResponse{ + render.JSONStatus(w, &SSHSignResponse{ Certificate: SSHCertificate{newCert}, IdentityCertificate: identity, }, http.StatusCreated) diff --git a/api/sshRevoke.go b/api/sshRevoke.go index 60f44f2a..a33082cd 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -3,11 +3,14 @@ package api import ( "net/http" + "golang.org/x/crypto/ocsp" + + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" - "golang.org/x/crypto/ocsp" ) // SSHRevokeResponse is the response object that returns the health of the server. @@ -47,13 +50,13 @@ func (r *SSHRevokeRequest) Validate() (err error) { // NOTE: currently only Passive revocation is supported. func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest - if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.BadRequestErr(err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, errs.BadRequestErr(err, "error reading request body")) return } if err := body.Validate(); err != nil { - WriteError(w, err) + render.Error(w, err) return } @@ -69,18 +72,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { - WriteError(w, errs.UnauthorizedErr(err)) + render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) + render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } logSSHRevoke(w, opts) - JSON(w, &SSHRevokeResponse{Status: "ok"}) + render.JSON(w, &SSHRevokeResponse{Status: "ok"}) } func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { diff --git a/api/ssh_test.go b/api/ssh_test.go index a3d7da0d..88a301f5 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -18,12 +18,13 @@ import ( "testing" "time" + "golang.org/x/crypto/ssh" + "github.com/smallstep/assert" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" - "golang.org/x/crypto/ssh" ) var ( diff --git a/api/utils.go b/api/utils.go deleted file mode 100644 index a7f4bf58..00000000 --- a/api/utils.go +++ /dev/null @@ -1,109 +0,0 @@ -package api - -import ( - "encoding/json" - "io" - "log" - "net/http" - - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" -) - -// EnableLogger is an interface that enables response logging for an object. -type EnableLogger interface { - ToLog() (interface{}, error) -} - -// LogError adds to the response writer the given error if it implements -// logging.ResponseLogger. If it does not implement it, then writes the error -// using the log package. -func LogError(rw http.ResponseWriter, err error) { - if rl, ok := rw.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err, - }) - } else { - log.Println(err) - } -} - -// LogEnabledResponse log the response object if it implements the EnableLogger -// interface. -func LogEnabledResponse(rw http.ResponseWriter, v interface{}) { - if el, ok := v.(EnableLogger); ok { - out, err := el.ToLog() - if err != nil { - LogError(rw, err) - return - } - if rl, ok := rw.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "response": out, - }) - } else { - log.Println(out) - } - } -} - -// JSON writes the passed value into the http.ResponseWriter. -func JSON(w http.ResponseWriter, v interface{}) { - JSONStatus(w, v, http.StatusOK) -} - -// JSONStatus writes the given value into the http.ResponseWriter and the -// given status is written as the status code of the response. -func JSONStatus(w http.ResponseWriter, v interface{}, status int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - if err := json.NewEncoder(w).Encode(v); err != nil { - LogError(w, err) - return - } - LogEnabledResponse(w, v) -} - -// ProtoJSON writes the passed value into the http.ResponseWriter. -func ProtoJSON(w http.ResponseWriter, m proto.Message) { - ProtoJSONStatus(w, m, http.StatusOK) -} - -// ProtoJSONStatus writes the given value into the http.ResponseWriter and the -// given status is written as the status code of the response. -func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - - b, err := protojson.Marshal(m) - if err != nil { - LogError(w, err) - return - } - if _, err := w.Write(b); err != nil { - LogError(w, err) - return - } - //LogEnabledResponse(w, v) -} - -// ReadJSON reads JSON from the request body and stores it in the value -// pointed by v. -func ReadJSON(r io.Reader, v interface{}) error { - if err := json.NewDecoder(r).Decode(v); err != nil { - return errs.BadRequestErr(err, "error decoding json") - } - return nil -} - -// ReadProtoJSON reads JSON from the request body and stores it in the value -// pointed by v. -func ReadProtoJSON(r io.Reader, m proto.Message) error { - data, err := io.ReadAll(r) - if err != nil { - return errs.BadRequestErr(err, "error reading request body") - } - return protojson.Unmarshal(data, m) -} diff --git a/api/utils_test.go b/api/utils_test.go deleted file mode 100644 index 81146653..00000000 --- a/api/utils_test.go +++ /dev/null @@ -1,125 +0,0 @@ -package api - -import ( - "io" - "net/http" - "net/http/httptest" - "reflect" - "strings" - "testing" - - "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" -) - -func TestLogError(t *testing.T) { - theError := errors.New("the error") - type args struct { - rw http.ResponseWriter - err error - } - tests := []struct { - name string - args args - withFields bool - }{ - {"normalLogger", args{httptest.NewRecorder(), theError}, false}, - {"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - LogError(tt.args.rw, tt.args.err) - if tt.withFields { - if rl, ok := tt.args.rw.(logging.ResponseLogger); ok { - fields := rl.Fields() - if !reflect.DeepEqual(fields["error"], theError) { - t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError) - } - } else { - t.Error("ResponseWriter does not implement logging.ResponseLogger") - } - } - }) - } -} - -func TestJSON(t *testing.T) { - type args struct { - rw http.ResponseWriter - v interface{} - } - tests := []struct { - name string - args args - ok bool - }{ - {"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true}, - {"fail", args{httptest.NewRecorder(), make(chan int)}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rw := logging.NewResponseLogger(tt.args.rw) - JSON(rw, tt.args.v) - - rr, ok := tt.args.rw.(*httptest.ResponseRecorder) - if !ok { - t.Error("ResponseWriter does not implement *httptest.ResponseRecorder") - return - } - - fields := rw.Fields() - if tt.ok { - if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" { - t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body) - } - if len(fields) != 0 { - t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields) - } - } else { - if body := rr.Body.String(); body != "" { - t.Errorf("Unexpected body = %s, want empty string", body) - } - if len(fields) != 1 { - t.Errorf("ResponseLogger fields = %v, wants 1 element", fields) - } - } - }) - } -} - -func TestReadJSON(t *testing.T) { - type args struct { - r io.Reader - v interface{} - } - tests := []struct { - name string - args args - wantErr bool - }{ - {"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false}, - {"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ReadJSON(tt.args.r, &tt.args.v) - if (err != nil) != tt.wantErr { - t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.wantErr { - e, ok := err.(*errs.Error) - if ok { - if code := e.StatusCode(); code != 400 { - t.Errorf("error.StatusCode() = %v, wants 400", code) - } - } else { - t.Errorf("error type = %T, wants *Error", err) - } - } else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) { - t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"}) - } - }) - } -} diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 27c3ba6f..21a7229d 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -6,10 +6,12 @@ import ( "net/http" "github.com/go-chi/chi" - "github.com/smallstep/certificates/api" + + "go.step.sm/linkedca" + + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" ) const ( @@ -44,11 +46,11 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { provName := chi.URLParam(r, "provisionerName") eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if !eabEnabled { - api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) + render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) return } ctx = context.WithValue(ctx, provisionerContextKey, prov) @@ -101,15 +103,15 @@ func NewACMEAdminResponder() *ACMEAdminResponder { // GetExternalAccountKeys writes the response for the EAB keys GET endpoint func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // CreateExternalAccountKey writes the response for the EAB key POST endpoint func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) } diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index 7aa66d0f..5e4b9c30 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -5,10 +5,14 @@ import ( "net/http" "github.com/go-chi/chi" + + "go.step.sm/linkedca" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" - "go.step.sm/linkedca" ) type adminAuthority interface { @@ -82,28 +86,28 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { adm, ok := h.auth.LoadAdminByID(id) if !ok { - api.WriteError(w, admin.NewError(admin.ErrorNotFoundType, + render.Error(w, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)) return } - api.ProtoJSON(w, adm) + render.ProtoJSON(w, adm) } // GetAdmins returns a segment of admins associated with the authority. func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) + render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) return } - api.JSON(w, &GetAdminsResponse{ + render.JSON(w, &GetAdminsResponse{ Admins: admins, NextCursor: nextCursor, }) @@ -112,19 +116,19 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { // CreateAdmin creates a new admin. func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { var body CreateAdminRequest - if err := api.ReadJSON(r.Body, &body); err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } p, err := h.auth.LoadProvisionerByName(body.Provisioner) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) return } adm := &linkedca.Admin{ @@ -134,11 +138,11 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { } // Store to authority collection. if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error storing admin")) + render.Error(w, admin.WrapErrorISE(err, "error storing admin")) return } - api.ProtoJSONStatus(w, adm, http.StatusCreated) + render.ProtoJSONStatus(w, adm, http.StatusCreated) } // DeleteAdmin deletes admin. @@ -146,23 +150,23 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) return } - api.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, &DeleteResponse{Status: "ok"}) } // UpdateAdmin updates an existing admin. func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { var body UpdateAdminRequest - if err := api.ReadJSON(r.Body, &body); err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + if err := read.JSON(r.Body, &body); err != nil { + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) return } if err := body.Validate(); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } @@ -170,9 +174,9 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) return } - api.ProtoJSON(w, adm) + render.ProtoJSON(w, adm) } diff --git a/authority/admin/api/middleware.go b/authority/admin/api/middleware.go index 19025a9d..b57dd6eb 100644 --- a/authority/admin/api/middleware.go +++ b/authority/admin/api/middleware.go @@ -4,7 +4,7 @@ import ( "context" "net/http" - "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/admin" ) @@ -15,7 +15,7 @@ type nextHTTP = func(http.ResponseWriter, *http.Request) func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { if !h.auth.IsAdminAPIEnabled() { - api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, + render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled")) return } @@ -28,14 +28,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { tok := r.Header.Get("Authorization") if tok == "" { - api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, + render.Error(w, admin.NewError(admin.ErrorUnauthorizedType, "missing authorization header token")) return } adm, err := h.auth.AuthorizeAdminToken(r, tok) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index b8cc0f4c..1cad62dd 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -4,12 +4,16 @@ import ( "net/http" "github.com/go-chi/chi" + + "go.step.sm/linkedca" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" - "go.step.sm/linkedca" ) // GetProvisionersResponse is the type for GET /admin/provisioners responses. @@ -31,39 +35,39 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { ) if len(id) > 0 { if p, err = h.auth.LoadProvisionerByID(id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = h.auth.LoadProvisionerByName(name); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) if err != nil { - api.WriteError(w, err) + render.Error(w, err) return } - api.ProtoJSON(w, prov) + render.ProtoJSON(w, prov) } // GetProvisioners returns the given segment of provisioners associated with the authority. func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { - api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error parsing cursor and limit from query params")) return } p, next, err := h.auth.GetProvisioners(cursor, limit) if err != nil { - api.WriteError(w, errs.InternalServerErr(err)) + render.Error(w, errs.InternalServerErr(err)) return } - api.JSON(w, &GetProvisionersResponse{ + render.JSON(w, &GetProvisionersResponse{ Provisioners: p, NextCursor: next, }) @@ -72,22 +76,22 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { // CreateProvisioner creates a new prov. func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { var prov = new(linkedca.Provisioner) - if err := api.ReadProtoJSON(r.Body, prov); err != nil { - api.WriteError(w, err) + if err := read.ProtoJSON(r.Body, prov); err != nil { + render.Error(w, err) return } // TODO: Validate inputs if err := authority.ValidateClaims(prov.Claims); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) + render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) return } - api.ProtoJSONStatus(w, prov, http.StatusCreated) + render.ProtoJSONStatus(w, prov, http.StatusCreated) } // DeleteProvisioner deletes a provisioner. @@ -101,75 +105,75 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { ) if len(id) > 0 { if p, err = h.auth.LoadProvisionerByID(id); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) return } } else { if p, err = h.auth.LoadProvisionerByName(name); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) return } } if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) + render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) return } - api.JSON(w, &DeleteResponse{Status: "ok"}) + render.JSON(w, &DeleteResponse{Status: "ok"}) } // UpdateProvisioner updates an existing prov. func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { var nu = new(linkedca.Provisioner) - if err := api.ReadProtoJSON(r.Body, nu); err != nil { - api.WriteError(w, err) + if err := read.ProtoJSON(r.Body, nu); err != nil { + render.Error(w, err) return } name := chi.URLParam(r, "name") _old, err := h.auth.LoadProvisionerByName(name) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) return } old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) if err != nil { - api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) + render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) return } if nu.Id != old.Id { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID")) + render.Error(w, admin.NewErrorISE("cannot change provisioner ID")) return } if nu.Type != old.Type { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner type")) + render.Error(w, admin.NewErrorISE("cannot change provisioner type")) return } if nu.AuthorityId != old.AuthorityId { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID")) + render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID")) return } if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt")) + render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt")) return } if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { - api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt")) + render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt")) return } // TODO: Validate inputs if err := authority.ValidateClaims(nu.Claims); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { - api.WriteError(w, err) + render.Error(w, err) return } - api.ProtoJSON(w, nu) + render.ProtoJSON(w, nu) } diff --git a/authority/admin/errors.go b/authority/admin/errors.go index 217227ca..baa32dd9 100644 --- a/authority/admin/errors.go +++ b/authority/admin/errors.go @@ -3,13 +3,10 @@ package admin import ( "encoding/json" "fmt" - "log" "net/http" - "os" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/api/render" ) // ProblemType is the type of the Admin problem. @@ -197,27 +194,9 @@ func (e *Error) ToLog() (interface{}, error) { return string(b), nil } -// WriteError writes to w a JSON representation of the given error. -func WriteError(w http.ResponseWriter, err *Error) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(err.StatusCode()) +// Render implements render.RenderableError for Error. +func (e *Error) Render(w http.ResponseWriter) { + e.Message = e.Err.Error() - err.Message = err.Err.Error() - // Write errors in the response writer - if rl, ok := w.(logging.ResponseLogger); ok { - rl.WithFields(map[string]interface{}{ - "error": err.Err, - }) - if os.Getenv("STEPDEBUG") == "1" { - if e, ok := err.Err.(errs.StackTracer); ok { - rl.WithFields(map[string]interface{}{ - "stack-trace": fmt.Sprintf("%+v", e), - }) - } - } - } - - if err := json.NewEncoder(w).Encode(err); err != nil { - log.Println(err) - } + render.JSONStatus(w, e, e.StatusCode()) } diff --git a/authority/authority.go b/authority/authority.go index b10c3c33..9db38e14 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -70,14 +70,24 @@ type Authority struct { startTime time.Time // Custom functions - sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) - sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) - sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) - getIdentityFunc provisioner.GetIdentityFunc + sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) + sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) + sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) + getIdentityFunc provisioner.GetIdentityFunc + authorizeRenewFunc provisioner.AuthorizeRenewFunc + authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc adminMutex sync.RWMutex } +type Info struct { + StartTime time.Time + RootX509Certs []*x509.Certificate + SSHCAUserPublicKey []byte + SSHCAHostPublicKey []byte + DNSNames []string +} + // New creates and initiates a new Authority type. func New(cfg *config.Config, opts ...Option) (*Authority, error) { err := cfg.Validate() @@ -175,7 +185,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error { // Create provisioner collection. provClxn := provisioner.NewCollection(provisionerConfig.Audiences) for _, p := range provList { - if err := p.Init(*provisionerConfig); err != nil { + if err := p.Init(provisionerConfig); err != nil { return err } if err := provClxn.Store(p); err != nil { @@ -251,6 +261,21 @@ func (a *Authority) init() error { } } + // Initialize linkedca client if necessary. On a linked RA, the issuer + // configuration might come from majordomo. + var linkedcaClient *linkedCaClient + if a.config.AuthorityConfig.EnableAdmin && a.linkedCAToken != "" && a.adminDB == nil { + linkedcaClient, err = newLinkedCAClient(a.linkedCAToken) + if err != nil { + return err + } + // If authorityId is configured make sure it matches the one in the token + if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) { + return errors.New("error initializing linkedca: token authority and configured authority do not match") + } + linkedcaClient.Run() + } + // Initialize the X.509 CA Service if it has not been set in the options. if a.x509CAService == nil { var options casapi.Options @@ -258,6 +283,22 @@ func (a *Authority) init() error { options = *a.config.AuthorityConfig.Options } + // Configure linked RA + if linkedcaClient != nil && options.CertificateAuthority == "" { + conf, err := linkedcaClient.GetConfiguration(context.Background()) + if err != nil { + return err + } + if conf.RaConfig != nil { + options.CertificateAuthority = conf.RaConfig.CaUrl + options.CertificateAuthorityFingerprint = conf.RaConfig.Fingerprint + options.CertificateIssuer = &casapi.CertificateIssuer{ + Type: conf.RaConfig.Provisioner.Type.String(), + Provisioner: conf.RaConfig.Provisioner.Name, + } + } + } + // Set the issuer password if passed in the flags. if options.CertificateIssuer != nil && a.issuerPassword != nil { options.CertificateIssuer.Password = string(a.issuerPassword) @@ -292,8 +333,6 @@ func (a *Authority) init() error { return err } a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate) - sum := sha256.Sum256(resp.RootCertificate.Raw) - log.Printf("Using root fingerprint '%s'", hex.EncodeToString(sum[:])) } } @@ -479,24 +518,13 @@ func (a *Authority) init() error { // Initialize step-ca Admin Database if it's not already initialized using // WithAdminDB. if a.adminDB == nil { - if a.linkedCAToken == "" { - // Check if AuthConfig already exists + if linkedcaClient != nil { + a.adminDB = linkedcaClient + } else { a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID) if err != nil { return err } - } else { - // Use the linkedca client as the admindb. - client, err := newLinkedCAClient(a.linkedCAToken) - if err != nil { - return err - } - // If authorityId is configured make sure it matches the one in the token - if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) { - return errors.New("error initializing linkedca: token authority and configured authority do not match") - } - client.Run() - a.adminDB = client } } @@ -559,6 +587,21 @@ func (a *Authority) GetAdminDatabase() admin.DB { return a.adminDB } +func (a *Authority) GetInfo() Info { + ai := Info{ + StartTime: a.startTime, + RootX509Certs: a.rootX509Certs, + DNSNames: a.config.DNSNames, + } + if a.sshCAUserCertSignKey != nil { + ai.SSHCAUserPublicKey = ssh.MarshalAuthorizedKey(a.sshCAUserCertSignKey.PublicKey()) + } + if a.sshCAHostCertSignKey != nil { + ai.SSHCAHostPublicKey = ssh.MarshalAuthorizedKey(a.sshCAHostCertSignKey.PublicKey()) + } + return ai +} + // IsAdminAPIEnabled returns a boolean indicating whether the Admin API has // been enabled. func (a *Authority) IsAdminAPIEnabled() bool { diff --git a/authority/authorize.go b/authority/authorize.go index 5108f567..7c1c2ff6 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "encoding/hex" "net/http" + "net/url" "strconv" "strings" "time" @@ -276,6 +277,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error { serial := cert.SerialNumber.String() var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} + isRevoked, err := a.IsRevoked(serial) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) @@ -283,7 +285,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error { if isRevoked { return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) } - p, ok := a.provisioners.LoadByCertificate(cert) if !ok { return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) @@ -371,3 +372,80 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error } return nil } + +// AuthorizeRenewToken validates the renew token and returns the leaf +// certificate in the x5cInsecure header. +func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) { + var claims jose.Claims + jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs) + if err != nil { + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token")) + } + leaf := chain[0][0] + if err := jwt.Claims(leaf.PublicKey, &claims); err != nil { + return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token")) + } + + p, ok := a.provisioners.LoadByCertificate(leaf) + if !ok { + return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate") + } + if err := a.UseToken(ott, p); err != nil { + return nil, err + } + + if err := claims.ValidateWithLeeway(jose.Expected{ + Issuer: p.GetName(), + Subject: leaf.Subject.CommonName, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + switch err { + case jose.ErrInvalidIssuer: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)")) + case jose.ErrInvalidSubject: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)")) + case jose.ErrNotValidYet: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)")) + case jose.ErrExpired: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)")) + case jose.ErrIssuedInTheFuture: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)")) + default: + return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token")) + } + } + + audiences := a.config.GetAudiences().Renew + if !matchesAudience(claims.Audience, audiences) { + return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)")) + } + + return leaf, nil +} + +// matchesAudience returns true if A and B share at least one element. +func matchesAudience(as, bs []string) bool { + if len(bs) == 0 || len(as) == 0 { + return false + } + + for _, b := range bs { + for _, a := range as { + if b == a || stripPort(a) == stripPort(b) { + return true + } + } + } + return false +} + +// stripPort attempts to strip the port from the given url. If parsing the url +// produces errors it will just return the passed argument. +func stripPort(rawurl string) string { + u, err := url.Parse(rawurl) + if err != nil { + return rawurl + } + u.Host = u.Hostname() + return u.String() +} diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 6d524a25..a7bec277 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -3,24 +3,32 @@ package authority import ( "context" "crypto" + "crypto/ed25519" "crypto/rand" "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" "encoding/base64" + "errors" "fmt" "net/http" + "reflect" "strconv" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" - "golang.org/x/crypto/ssh" + "go.step.sm/crypto/x509util" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" ) var testAudiences = provisioner.Audiences{ @@ -305,8 +313,8 @@ func TestAuthority_authorizeToken(t *testing.T) { p, err := tc.auth.authorizeToken(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -391,8 +399,8 @@ func TestAuthority_authorizeRevoke(t *testing.T) { if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -476,14 +484,14 @@ func TestAuthority_authorizeSign(t *testing.T) { got, err := tc.auth.authorizeSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 7, got) + assert.Len(t, 8, got) } } }) @@ -735,8 +743,8 @@ func TestAuthority_Authorize(t *testing.T) { if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, got) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -753,6 +761,7 @@ func TestAuthority_Authorize(t *testing.T) { func TestAuthority_authorizeRenew(t *testing.T) { fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") + fooCrt.NotAfter = time.Now().Add(time.Hour) assert.FatalError(t, err) renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") @@ -822,7 +831,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { return &authorizeTest{ auth: a, cert: renewDisabledCrt, - err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"), + err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"), code: http.StatusUnauthorized, } }, @@ -847,7 +856,7 @@ func TestAuthority_authorizeRenew(t *testing.T) { err := tc.auth.authorizeRenew(tc.cert) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -909,6 +918,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner. } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err @@ -917,6 +927,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now.Unix()) + } + if cert.ValidBefore == 0 { + cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) + } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } @@ -988,8 +1004,8 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1003,6 +1019,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { } func TestAuthority_authorizeSSHRenew(t *testing.T) { + now := time.Now().UTC() + sshpop := func(a *Authority) (*ssh.Certificate, string) { + p, ok := a.provisioners.Load("sshpop/sshpop") + assert.Fatal(t, ok, "sshpop provisioner not found in test authority") + key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") + assert.FatalError(t, err) + signer, ok := key.(crypto.Signer) + assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") + sshSigner, err := ssh.NewSignerFromSigner(signer) + assert.FatalError(t, err) + cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) + assert.FatalError(t, err) + token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert)) + assert.FatalError(t, err) + return cert, token + } + a := testAuthority(t) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) @@ -1012,8 +1045,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) assert.FatalError(t, err) - now := time.Now().UTC() - validIssuer := "step-cli" type authorizeTest struct { @@ -1050,27 +1081,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { code: http.StatusUnauthorized, } }, + "fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { + aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { + return errs.Forbidden("forbidden") + })) + _, token := sshpop(aa) + return &authorizeTest{ + auth: aa, + token: token, + err: errors.New("authority.authorizeSSHRenew: forbidden"), + code: http.StatusForbidden, + } + }, "ok": func(t *testing.T) *authorizeTest { - key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") - assert.FatalError(t, err) - signer, ok := key.(crypto.Signer) - assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer") - sshSigner, err := ssh.NewSignerFromSigner(signer) - assert.FatalError(t, err) - - cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) - assert.FatalError(t, err) - - p, ok := a.provisioners.Load("sshpop/sshpop") - assert.Fatal(t, ok, "sshpop provisioner not found in test authority") - - tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", - []string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert)) - assert.FatalError(t, err) - + cert, token := sshpop(a) return &authorizeTest{ auth: a, - token: tok, + token: token, + cert: cert, + } + }, + "ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest { + aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error { + return nil + })) + cert, token := sshpop(aa) + return &authorizeTest{ + auth: aa, + token: token, cert: cert, } }, @@ -1083,8 +1121,8 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1183,8 +1221,8 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) { if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1276,8 +1314,8 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1290,3 +1328,283 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { }) } } + +func TestAuthority_AuthorizeRenewToken(t *testing.T) { + ctx := context.Background() + type stepProvisionerASN1 struct { + Type int + Name []byte + CredentialID []byte + KeyValuePairs []string `asn1:"optional,omitempty"` + } + + _, signer, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer) + if err != nil { + t.Fatal(err) + } + _, otherSigner, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) { + chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + if err != nil { + t.Fatal(err) + } + + var x5c []string + for _, c := range chain { + x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw)) + } + + so := new(jose.SignerOptions) + so.WithType("JWT") + so.WithHeader("x5cInsecure", x5c) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so) + if err != nil { + t.Fatal(err) + } + s, err := jose.Signed(sig).Claims(claims).CompactSerialize() + if err != nil { + t.Fatal(err) + } + return s, chain[0] + } + + now := time.Now() + a1 := testAuthority(t) + t1, c1 := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + t2, c2 := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + IssuedAt: jose.NewNumericDate(now), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now.Add(-time.Hour) + cert.NotAfter = now.Add(-time.Minute) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "test.example.com", + Issuer: "bad-issuer", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badSubject, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/renew"}, + Subject: "bad-subject", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)), + Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)), + Expiry: jose.NewNumericDate(now.Add(-time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + badAudience, _ := generateX5cToken(a1, signer, jose.Claims{ + Audience: []string{"https://example.com/1.0/sign"}, + Subject: "test.example.com", + Issuer: "step-cli", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), + }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + cert.NotBefore = now + cert.NotAfter = now.Add(time.Hour) + b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil}) + if err != nil { + return err + } + cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1}, + Value: b, + }) + return nil + })) + + type args struct { + ctx context.Context + ott string + } + tests := []struct { + name string + authority *Authority + args args + want *x509.Certificate + wantErr bool + }{ + {"ok", a1, args{ctx, t1}, c1, false}, + {"ok expired cert", a1, args{ctx, t2}, c2, false}, + {"fail token", a1, args{ctx, "not.a.token"}, nil, true}, + {"fail token reuse", a1, args{ctx, t1}, nil, true}, + {"fail token signature", a1, args{ctx, badSigner}, nil, true}, + {"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true}, + {"fail token iss", a1, args{ctx, badIssuer}, nil, true}, + {"fail token sub", a1, args{ctx, badSubject}, nil, true}, + {"fail token iat", a1, args{ctx, badNotBefore}, nil, true}, + {"fail token iat", a1, args{ctx, badExpiry}, nil, true}, + {"fail token iat", a1, args{ctx, badIssuedAt}, nil, true}, + {"fail token aud", a1, args{ctx, badAudience}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/config/config.go b/authority/config/config.go index 9fada6f1..1729a693 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -26,23 +26,27 @@ var ( DefaultBackdate = time.Minute // DefaultDisableRenewal disables renewals per provisioner. DefaultDisableRenewal = false + // DefaultAllowRenewAfterExpiry allows renewals even if the certificate is + // expired. + DefaultAllowRenewAfterExpiry = false // DefaultEnableSSHCA enable SSH CA features per provisioner or globally // for all provisioners. DefaultEnableSSHCA = false // GlobalProvisionerClaims default claims for the Authority. Can be overridden // by provisioner specific claims. GlobalProvisionerClaims = provisioner.Claims{ - MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs - MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DisableRenewal: &DefaultDisableRenewal, - MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &DefaultEnableSSHCA, + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &DefaultEnableSSHCA, + DisableRenewal: &DefaultDisableRenewal, + AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry, } ) @@ -273,28 +277,32 @@ func (c *Config) GetAudiences() provisioner.Audiences { } for _, name := range c.DNSNames { + hostname := toHostname(name) audiences.Sign = append(audiences.Sign, - fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), - fmt.Sprintf("https://%s/sign", toHostname(name)), - fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), - fmt.Sprintf("https://%s/ssh/sign", toHostname(name))) + fmt.Sprintf("https://%s/1.0/sign", hostname), + fmt.Sprintf("https://%s/sign", hostname), + fmt.Sprintf("https://%s/1.0/ssh/sign", hostname), + fmt.Sprintf("https://%s/ssh/sign", hostname)) + audiences.Renew = append(audiences.Renew, + fmt.Sprintf("https://%s/1.0/renew", hostname), + fmt.Sprintf("https://%s/renew", hostname)) audiences.Revoke = append(audiences.Revoke, - fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)), - fmt.Sprintf("https://%s/revoke", toHostname(name))) + fmt.Sprintf("https://%s/1.0/revoke", hostname), + fmt.Sprintf("https://%s/revoke", hostname)) audiences.SSHSign = append(audiences.SSHSign, - fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), - fmt.Sprintf("https://%s/ssh/sign", toHostname(name)), - fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), - fmt.Sprintf("https://%s/sign", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/sign", hostname), + fmt.Sprintf("https://%s/ssh/sign", hostname), + fmt.Sprintf("https://%s/1.0/sign", hostname), + fmt.Sprintf("https://%s/sign", hostname)) audiences.SSHRevoke = append(audiences.SSHRevoke, - fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)), - fmt.Sprintf("https://%s/ssh/revoke", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname), + fmt.Sprintf("https://%s/ssh/revoke", hostname)) audiences.SSHRenew = append(audiences.SSHRenew, - fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)), - fmt.Sprintf("https://%s/ssh/renew", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/renew", hostname), + fmt.Sprintf("https://%s/ssh/renew", hostname)) audiences.SSHRekey = append(audiences.SSHRekey, - fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)), - fmt.Sprintf("https://%s/ssh/rekey", toHostname(name))) + fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname), + fmt.Sprintf("https://%s/ssh/rekey", hostname)) } return audiences diff --git a/authority/linkedca.go b/authority/linkedca.go index b568dcbb..6a0800c2 100644 --- a/authority/linkedca.go +++ b/authority/linkedca.go @@ -15,6 +15,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" @@ -151,13 +152,21 @@ func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linked } func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) { + resp, err := c.GetConfiguration(ctx) + if err != nil { + return nil, err + } + return resp.Provisioners, nil +} + +func (c *linkedCaClient) GetConfiguration(ctx context.Context) (*linkedca.ConfigurationResponse, error) { resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ AuthorityId: c.authorityID, }) if err != nil { - return nil, errors.Wrap(err, "error getting provisioners") + return nil, errors.Wrap(err, "error getting configuration") } - return resp.Provisioners, nil + return resp, nil } func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { @@ -204,11 +213,9 @@ func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Adm } func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) { - resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ - AuthorityId: c.authorityID, - }) + resp, err := c.GetConfiguration(ctx) if err != nil { - return nil, errors.Wrap(err, "error getting admins") + return nil, err } return resp.Admins, nil } @@ -228,12 +235,13 @@ func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error { return errors.Wrap(err, "error deleting admin") } -func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error { +func (c *linkedCaClient) StoreCertificateChain(prov provisioner.Interface, fullchain ...*x509.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ PemCertificate: serializeCertificateChain(fullchain[0]), PemCertificateChain: serializeCertificateChain(fullchain[1:]...), + Provisioner: createProvisionerIdentity(prov), }) return errors.Wrap(err, "error posting certificate") } @@ -310,6 +318,17 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) { return resp.Status != linkedca.RevocationStatus_ACTIVE, nil } +func createProvisionerIdentity(prov provisioner.Interface) *linkedca.ProvisionerIdentity { + if prov == nil { + return nil + } + return &linkedca.ProvisionerIdentity{ + Id: prov.GetID(), + Type: linkedca.Provisioner_Type(prov.GetType()), + Name: prov.GetName(), + } +} + func serializeCertificate(crt *x509.Certificate) string { if crt == nil { return "" diff --git a/authority/options.go b/authority/options.go index f92db99b..1c154577 100644 --- a/authority/options.go +++ b/authority/options.go @@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e } } +// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of +// an X.509 certificate. +func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option { + return func(a *Authority) error { + a.authorizeRenewFunc = fn + return nil + } +} + +// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal +// of a SSH certificate. +func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option { + return func(a *Authority) error { + a.authorizeSSHRenewFunc = fn + return nil + } +} + // WithSSHBastionFunc sets a custom function to get the bastion for a // given user-host pair. func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option { @@ -145,6 +163,22 @@ func WithX509Signer(crt *x509.Certificate, s crypto.Signer) Option { } } +// WithX509SignerFunc defines the function used to get the chain of certificates +// and signer used when we sign X.509 certificates. +func WithX509SignerFunc(fn func() ([]*x509.Certificate, crypto.Signer, error)) Option { + return func(a *Authority) error { + srv, err := cas.New(context.Background(), casapi.Options{ + Type: casapi.SoftCAS, + CertificateSigner: fn, + }) + if err != nil { + return err + } + a.x509CAService = srv + return nil + } +} + // WithSSHUserSigner defines the signer used to sign SSH user certificates. func WithSSHUserSigner(s crypto.Signer) Option { return func(a *Authority) error { diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 21958d36..b5d806ab 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -6,7 +6,6 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/errs" ) // ACME is the acme provisioner type, an entity that can authorize the ACME @@ -24,7 +23,7 @@ type ACME struct { RequireEAB bool `json:"requireEAB,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -69,7 +68,7 @@ func (p *ACME) GetOptions() *Options { // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (p *ACME) DefaultTLSCertDuration() time.Duration { - return p.claimer.DefaultTLSCertDuration() + return p.ctl.Claimer.DefaultTLSCertDuration() } // Init initializes and validates the fields of a JWK type. @@ -81,12 +80,8 @@ func (p *ACME) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign does not do any validation, because all validation is handled @@ -94,13 +89,14 @@ func (p *ACME) Init(config Config) (err error) { // on the resulting certificate. func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { return []SignOption{ + p, // modifiers / withOptions newProvisionerExtensionOption(TypeACME, p.Name, ""), newForceCNOption(p.ForceCN), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -118,8 +114,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error { // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index bd173f87..1c9a88cc 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -3,13 +3,14 @@ package provisioner import ( "context" "crypto/x509" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" + "github.com/smallstep/certificates/api/render" ) func TestACME_Getters(t *testing.T) { @@ -91,6 +92,7 @@ func TestACME_Init(t *testing.T) { } func TestACME_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) type test struct { p *ACME cert *x509.Certificate @@ -104,21 +106,27 @@ func TestACME_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, code: http.StatusUnauthorized, - err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()), + err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateACME() assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, } }, } @@ -126,8 +134,8 @@ func TestACME_AuthorizeRenew(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -161,31 +169,32 @@ func TestACME_AuthorizeSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) && assert.NotNil(t, opts) { - assert.Len(t, 5, opts) + assert.Len(t, 6, opts) for _, o := range opts { switch v := o.(type) { + case *ACME: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeACME)) + assert.Equals(t, v.Type, TypeACME) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case *forceCNOption: assert.Equals(t, v.ForceCN, tc.p.ForceCN) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index fdad7b4a..9d27e016 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -264,9 +264,8 @@ type AWS struct { IIDRoots string `json:"iidRoots,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *awsConfig - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -400,15 +399,11 @@ func (p *AWS) Init(config Config) (err error) { case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } + // Add default config if p.config, err = newAWSConfig(p.IIDRoots); err != nil { return err } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) // validate IMDS versions if len(p.IMDSVersions) == 0 { @@ -425,7 +420,9 @@ func (p *AWS) Init(config Config) (err error) { } } - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign validates the given token and returns the sign options that @@ -470,14 +467,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } return append(so, + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, commonNameValidator(payload.Claims.Subject), - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } @@ -486,10 +484,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized @@ -664,7 +659,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { } // validate audiences with the defaults - if !matchesAudience(payload.Audience, p.audiences.Sign) { + if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)") } @@ -704,7 +699,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) @@ -752,11 +747,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 0d2786db..7027a446 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "encoding/hex" "encoding/pem" + "errors" "fmt" "net" "net/http" @@ -17,10 +18,10 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestAWS_Getters(t *testing.T) { @@ -521,8 +522,8 @@ func TestAWS_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -641,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false}, - {"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false}, - {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false}, - {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false}, - {"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false}, + {"ok", p1, args{t1, "foo.local"}, 7, http.StatusOK, false}, + {"ok", p2, args{t2, "instance-id"}, 11, http.StatusOK, false}, + {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 11, http.StatusOK, false}, + {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 11, http.StatusOK, false}, + {"ok", p1, args{t4, "instance-id"}, 7, http.StatusOK, false}, {"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true}, {"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true}, @@ -668,27 +669,28 @@ func TestAWS_AuthorizeSign(t *testing.T) { t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Len(t, tt.wantLen, got) for _, o := range got { switch v := o.(type) { + case *AWS: case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeAWS)) + assert.Equals(t, v.Type, TypeAWS) assert.Equals(t, v.Name, tt.aws.GetName()) assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) assert.Len(t, 2, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), tt.args.cn) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) case emailAddressesValidator: @@ -698,7 +700,7 @@ func TestAWS_AuthorizeSign(t *testing.T) { case dnsNamesValidator: assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -726,7 +728,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com") @@ -747,7 +749,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), @@ -802,8 +804,8 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -824,6 +826,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) { } func TestAWS_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateAWS() assert.FatalError(t, err) p2, err := generateAWS() @@ -832,7 +835,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -845,16 +848,22 @@ func TestAWS_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 55d77f49..e6323e9f 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -30,7 +30,7 @@ const azureDefaultAudience = "https://management.azure.com/" // azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim. // Using case insensitive as resourceGroups appears as resourcegroups. -var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`) +var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`) type azureConfig struct { oidcDiscoveryURL string @@ -89,15 +89,17 @@ type Azure struct { Name string `json:"name"` TenantID string `json:"tenantID"` ResourceGroups []string `json:"resourceGroups"` + SubscriptionIDs []string `json:"subscriptionIDs"` + ObjectIDs []string `json:"objectIDs"` Audience string `json:"audience,omitempty"` DisableCustomSANs bool `json:"disableCustomSANs"` DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *azureConfig oidcConfig openIDConfiguration keyStore *keyStore + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -201,37 +203,34 @@ func (p *Azure) Init(config Config) (err error) { case p.Audience == "": // use default audience p.Audience = azureDefaultAudience } + // Initialize config p.assertConfig() - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - // Decode and validate openid-configuration endpoint - if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { - return err + if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { + return } if err := p.oidcConfig.Validate(); err != nil { return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL) } // Get JWK key set if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil { - return err + return } - return nil + p.ctl, err = NewController(p, p.Claims, config) + return } -// authorizeToken returns the claims, name, group, error. -func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) { +// authorizeToken returns the claims, name, group, subscription, identityObjectID, error. +func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, string, string, error) { jwt, err := jose.ParseSigned(token) if err != nil { - return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token") + return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token") } if len(jwt.Headers) == 0 { - return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header") + return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header") } var found bool @@ -244,7 +243,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err } } if !found { - return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token") + return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token") } if err := claims.ValidateWithLeeway(jose.Expected{ @@ -252,26 +251,30 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err Issuer: p.oidcConfig.Issuer, Time: time.Now(), }, 1*time.Minute); err != nil { - return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload") + return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload") } // Validate TenantID if claims.TenantID != p.TenantID { - return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)") + return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)") } re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) - if len(re) != 4 { - return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID) + if len(re) != 5 { + return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID) } - group, name := re[2], re[3] - return &claims, name, group, nil + + var subscription, group, name string + identityObjectID := claims.ObjectID + subscription, group, name = re[1], re[2], re[4] + + return &claims, name, group, subscription, identityObjectID, nil } // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - _, name, group, err := p.authorizeToken(token) + _, name, group, subscription, identityObjectID, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign") } @@ -290,6 +293,34 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, } } + // Filter by subscription id + if len(p.SubscriptionIDs) > 0 { + var found bool + for _, s := range p.SubscriptionIDs { + if s == subscription { + found = true + break + } + } + if !found { + return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid subscription id") + } + } + + // Filter by Azure AD identity object id + if len(p.ObjectIDs) > 0 { + var found bool + for _, i := range p.ObjectIDs { + if i == identityObjectID { + found = true + break + } + } + if !found { + return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid identity object id") + } + } + // Template options data := x509util.NewTemplateData() data.SetCommonName(name) @@ -321,13 +352,14 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, } return append(so, + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } @@ -336,19 +368,16 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) } - _, name, _, err := p.authorizeToken(token) + _, name, _, _, _, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign") } @@ -389,11 +418,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 7f8d6017..a8a0a271 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "errors" "fmt" "net/http" "net/http/httptest" @@ -15,10 +16,10 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestAzure_Getters(t *testing.T) { @@ -95,7 +96,7 @@ func TestAzure_GetIdentityToken(t *testing.T) { assert.FatalError(t, err) t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) @@ -237,7 +238,7 @@ func TestAzure_authorizeToken(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, - p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), jwk) assert.FatalError(t, err) return test{ @@ -252,7 +253,7 @@ func TestAzure_authorizeToken(t *testing.T) { assert.FatalError(t, err) defer srv.Close() tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, - p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ @@ -267,7 +268,7 @@ func TestAzure_authorizeToken(t *testing.T) { assert.FatalError(t, err) defer srv.Close() tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, - "foo", "subscriptionID", "resourceGroup", "virtualMachine", + "foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ @@ -321,7 +322,7 @@ func TestAzure_authorizeToken(t *testing.T) { assert.FatalError(t, err) defer srv.Close() tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, - p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p.keyStore.keySet.Keys[0]) assert.FatalError(t, err) return test{ @@ -333,10 +334,10 @@ func TestAzure_authorizeToken(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if claims, name, group, err := tc.p.authorizeToken(tc.token); err != nil { + if claims, name, group, subscriptionID, objectID, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -348,6 +349,8 @@ func TestAzure_authorizeToken(t *testing.T) { assert.Equals(t, name, "virtualMachine") assert.Equals(t, group, "resourceGroup") + assert.Equals(t, subscriptionID, "subscriptionID") + assert.Equals(t, objectID, "the-oid") } } }) @@ -382,6 +385,38 @@ func TestAzure_AuthorizeSign(t *testing.T) { p4.oidcConfig = p1.oidcConfig p4.keyStore = p1.keyStore + p5, err := generateAzure() + assert.FatalError(t, err) + p5.TenantID = p1.TenantID + p5.SubscriptionIDs = []string{"subscriptionID"} + p5.config = p1.config + p5.oidcConfig = p1.oidcConfig + p5.keyStore = p1.keyStore + + p6, err := generateAzure() + assert.FatalError(t, err) + p6.TenantID = p1.TenantID + p6.SubscriptionIDs = []string{"foobarzar"} + p6.config = p1.config + p6.oidcConfig = p1.oidcConfig + p6.keyStore = p1.keyStore + + p7, err := generateAzure() + assert.FatalError(t, err) + p7.TenantID = p1.TenantID + p7.ObjectIDs = []string{"the-oid"} + p7.config = p1.config + p7.oidcConfig = p1.oidcConfig + p7.keyStore = p1.keyStore + + p8, err := generateAzure() + assert.FatalError(t, err) + p8.TenantID = p1.TenantID + p8.ObjectIDs = []string{"foobarzar"} + p8.config = p1.config + p8.oidcConfig = p1.oidcConfig + p8.keyStore = p1.keyStore + badKey, err := generateJSONWebKey() assert.FatalError(t, err) @@ -393,30 +428,38 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.FatalError(t, err) t4, err := p4.GetIdentityToken("subject", "caURL") assert.FatalError(t, err) + t5, err := p5.GetIdentityToken("subject", "caURL") + assert.FatalError(t, err) + t6, err := p6.GetIdentityToken("subject", "caURL") + assert.FatalError(t, err) + t7, err := p6.GetIdentityToken("subject", "caURL") + assert.FatalError(t, err) + t8, err := p6.GetIdentityToken("subject", "caURL") + assert.FatalError(t, err) t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience", - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0]) assert.FatalError(t, err) failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, - p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", + p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), badKey) assert.FatalError(t, err) @@ -431,11 +474,15 @@ func TestAzure_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1}, 5, http.StatusOK, false}, - {"ok", p2, args{t2}, 10, http.StatusOK, false}, - {"ok", p1, args{t11}, 5, http.StatusOK, false}, + {"ok", p1, args{t1}, 6, http.StatusOK, false}, + {"ok", p2, args{t2}, 11, http.StatusOK, false}, + {"ok", p1, args{t11}, 6, http.StatusOK, false}, + {"ok", p5, args{t5}, 6, http.StatusOK, false}, + {"ok", p7, args{t7}, 6, http.StatusOK, false}, {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true}, {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true}, + {"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true}, + {"fail object id", p8, args{t8}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true}, {"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true}, @@ -451,27 +498,28 @@ func TestAzure_AuthorizeSign(t *testing.T) { t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Len(t, tt.wantLen, got) for _, o := range got { switch v := o.(type) { + case *Azure: case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeAzure)) + assert.Equals(t, v.Type, TypeAzure) assert.Equals(t, v.Name, tt.azure.GetName()) assert.Equals(t, v.CredentialID, tt.azure.TenantID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), "virtualMachine") case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: @@ -481,7 +529,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { case dnsNamesValidator: assert.Equals(t, []string(v), []string{"virtualMachine"}) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -490,6 +538,7 @@ func TestAzure_AuthorizeSign(t *testing.T) { } func TestAzure_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateAzure() assert.FatalError(t, err) p2, err := generateAzure() @@ -498,7 +547,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -511,16 +560,22 @@ func TestAzure_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -549,7 +604,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := p1.GetIdentityToken("subject", "caURL") @@ -570,7 +625,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"virtualMachine"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), @@ -615,8 +670,8 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index 629a313c..2a3e2c61 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -10,10 +10,10 @@ import ( // Claims so that individual provisioners can override global claims. type Claims struct { // TLS CA properties - MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` - MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` - DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` - DisableRenewal *bool `json:"disableRenewal,omitempty"` + MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` + MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` + DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` + // SSH CA properties MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"` MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"` @@ -22,6 +22,10 @@ type Claims struct { MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"` DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"` EnableSSHCA *bool `json:"enableSSHCA,omitempty"` + + // Renewal properties + DisableRenewal *bool `json:"disableRenewal,omitempty"` + AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"` } // Claimer is the type that controls claims. It provides an interface around the @@ -40,19 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) { // Claims returns the merge of the inner and global claims. func (c *Claimer) Claims() Claims { disableRenewal := c.IsDisableRenewal() + allowRenewAfterExpiry := c.AllowRenewAfterExpiry() enableSSHCA := c.IsSSHCAEnabled() + return Claims{ - MinTLSDur: &Duration{c.MinTLSCertDuration()}, - MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, - DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, - DisableRenewal: &disableRenewal, - MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, - MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, - DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, - MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, - MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, - DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, - EnableSSHCA: &enableSSHCA, + MinTLSDur: &Duration{c.MinTLSCertDuration()}, + MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, + DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, + MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, + MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, + DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, + MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, + MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, + DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, + EnableSSHCA: &enableSSHCA, + DisableRenewal: &disableRenewal, + AllowRenewAfterExpiry: &allowRenewAfterExpiry, } } @@ -102,6 +109,16 @@ func (c *Claimer) IsDisableRenewal() bool { return *c.claims.DisableRenewal } +// AllowRenewAfterExpiry returns if the renewal flow is authorized if the +// certificate is expired. If the property is not set within the provisioner +// then the global value from the authority configuration will be used. +func (c *Claimer) AllowRenewAfterExpiry() bool { + if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil { + return *c.global.AllowRenewAfterExpiry + } + return *c.claims.AllowRenewAfterExpiry +} + // DefaultSSHCertDuration returns the default SSH certificate duration for the // given certificate type. func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) { diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index 1bec8689..8bbace5f 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -152,8 +152,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) // proper id to load the provisioner. func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) { for _, e := range cert.Extensions { - if e.Id.Equal(stepOIDProvisioner) { - var provisioner stepProvisionerASN1 + if e.Id.Equal(StepOIDProvisioner) { + var provisioner extensionASN1 if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { return nil, false } diff --git a/authority/provisioner/collection_test.go b/authority/provisioner/collection_test.go index 348b797c..24db4593 100644 --- a/authority/provisioner/collection_test.go +++ b/authority/provisioner/collection_test.go @@ -147,6 +147,17 @@ func TestCollection_LoadByToken(t *testing.T) { } func TestCollection_LoadByCertificate(t *testing.T) { + mustExtension := func(typ Type, name, credentialID string) pkix.Extension { + e := Extension{ + Type: typ, Name: name, CredentialID: credentialID, + } + ext, err := e.ToExtension() + if err != nil { + t.Fatal(err) + } + return ext + } + p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateOIDC() @@ -159,30 +170,21 @@ func TestCollection_LoadByCertificate(t *testing.T) { byName.Store(p2.GetName(), p2) byName.Store(p3.GetName(), p3) - ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID) - assert.FatalError(t, err) - ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID) - assert.FatalError(t, err) - ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "") - assert.FatalError(t, err) - notFoundExt, err := createProvisionerExtension(1, "foo", "bar") - assert.FatalError(t, err) - ok1Cert := &x509.Certificate{ - Extensions: []pkix.Extension{ok1Ext}, + Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)}, } ok2Cert := &x509.Certificate{ - Extensions: []pkix.Extension{ok2Ext}, + Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)}, } ok3Cert := &x509.Certificate{ - Extensions: []pkix.Extension{ok3Ext}, + Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")}, } notFoundCert := &x509.Certificate{ - Extensions: []pkix.Extension{notFoundExt}, + Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")}, } badCert := &x509.Certificate{ Extensions: []pkix.Extension{ - {Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")}, + {Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")}, }, } diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go new file mode 100644 index 00000000..a91ebaac --- /dev/null +++ b/authority/provisioner/controller.go @@ -0,0 +1,194 @@ +package provisioner + +import ( + "context" + "crypto/x509" + "regexp" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" +) + +// Controller wraps a provisioner with other attributes useful in callback +// functions. +type Controller struct { + Interface + Audiences *Audiences + Claimer *Claimer + IdentityFunc GetIdentityFunc + AuthorizeRenewFunc AuthorizeRenewFunc + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc +} + +// NewController initializes a new provisioner controller. +func NewController(p Interface, claims *Claims, config Config) (*Controller, error) { + claimer, err := NewClaimer(claims, config.Claims) + if err != nil { + return nil, err + } + return &Controller{ + Interface: p, + Audiences: &config.Audiences, + Claimer: claimer, + IdentityFunc: config.GetIdentityFunc, + AuthorizeRenewFunc: config.AuthorizeRenewFunc, + AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc, + }, nil +} + +// GetIdentity returns the identity for a given email. +func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) { + if c.IdentityFunc != nil { + return c.IdentityFunc(ctx, c.Interface, email) + } + return DefaultIdentityFunc(ctx, c.Interface, email) +} + +// AuthorizeRenew returns nil if the given cert can be renewed, returns an error +// otherwise. +func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { + if c.AuthorizeRenewFunc != nil { + return c.AuthorizeRenewFunc(ctx, c, cert) + } + return DefaultAuthorizeRenew(ctx, c, cert) +} + +// AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an +// error otherwise. +func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error { + if c.AuthorizeSSHRenewFunc != nil { + return c.AuthorizeSSHRenewFunc(ctx, c, cert) + } + return DefaultAuthorizeSSHRenew(ctx, c, cert) +} + +// Identity is the type representing an externally supplied identity that is used +// by provisioners to populate certificate fields. +type Identity struct { + Usernames []string `json:"usernames"` + Permissions `json:"permissions"` +} + +// GetIdentityFunc is a function that returns an identity. +type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error) + +// AuthorizeRenewFunc is a function that returns nil if the renewal of a +// certificate is enabled. +type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error + +// AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the +// given SSH certificate is enabled. +type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error + +// DefaultIdentityFunc return a default identity depending on the provisioner +// type. For OIDC email is always present and the usernames might +// contain empty strings. +func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) { + switch k := p.(type) { + case *OIDC: + // OIDC principals would be: + // ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled + // 2. Sanitized local. + // 3. Raw local (if different). + // 4. Email address. + name := SanitizeSSHUserPrincipal(email) + if !sshUserRegex.MatchString(name) { + return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email) + } + usernames := []string{name} + if i := strings.LastIndex(email, "@"); i >= 0 { + usernames = append(usernames, email[:i]) + } + usernames = append(usernames, email) + return &Identity{ + Usernames: SanitizeStringSlices(usernames), + }, nil + default: + return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k) + } +} + +// DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It +// will return an error if the provisioner has the renewal disabled, if the +// certificate is not yet valid or if the certificate is expired and renew after +// expiry is disabled. +func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error { + if p.Claimer.IsDisableRenewal() { + return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) + } + + now := time.Now().Truncate(time.Second) + if now.Before(cert.NotBefore) { + return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano)) + } + if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() { + return errs.Unauthorized("certificate has expired") + } + + return nil +} + +// DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It +// will return an error if the provisioner has the renewal disabled, if the +// certificate is not yet valid or if the certificate is expired and renew after +// expiry is disabled. +func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + if p.Claimer.IsDisableRenewal() { + return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) + } + + unixNow := time.Now().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return errs.Unauthorized("certificate is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() { + return errs.Unauthorized("certificate has expired") + } + + return nil +} + +var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$") + +// SanitizeStringSlices removes duplicated an empty strings. +func SanitizeStringSlices(original []string) []string { + output := []string{} + seen := make(map[string]struct{}) + for _, entry := range original { + if entry == "" { + continue + } + if _, value := seen[entry]; !value { + seen[entry] = struct{}{} + output = append(output, entry) + } + } + return output +} + +// SanitizeSSHUserPrincipal grabs an email or a string with the format +// local@domain and returns a sanitized version of the local, valid to be used +// as a user name. If the email starts with a letter between a and z, the +// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`. +func SanitizeSSHUserPrincipal(email string) string { + if i := strings.LastIndex(email, "@"); i >= 0 { + email = email[:i] + } + return strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r + case r >= '0' && r <= '9': + return r + case r == '-': + return '-' + case r == '.': // drop dots + return -1 + default: + return '_' + } + }, strings.ToLower(email)) +} diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go new file mode 100644 index 00000000..9fb90e9d --- /dev/null +++ b/authority/provisioner/controller_test.go @@ -0,0 +1,391 @@ +package provisioner + +import ( + "context" + "crypto/x509" + "fmt" + "reflect" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +var trueValue = true + +func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer { + t.Helper() + c, err := NewClaimer(claims, global) + if err != nil { + t.Fatal(err) + } + return c +} +func mustDuration(t *testing.T, s string) *Duration { + t.Helper() + d, err := NewDuration(s) + if err != nil { + t.Fatal(err) + } + return d +} + +func TestNewController(t *testing.T) { + type args struct { + p Interface + claims *Claims + config Config + } + tests := []struct { + name string + args args + want *Controller + wantErr bool + }{ + {"ok", args{&JWK{}, nil, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, false}, + {"ok with claims", args{&JWK{}, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, globalProvisionerClaims), + }, false}, + {"fail claimer", args{&JWK{}, &Claims{ + MinTLSDur: mustDuration(t, "24h"), + MaxTLSDur: mustDuration(t, "2h"), + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewController(tt.args.p, tt.args.claims, tt.args.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewController() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestController_GetIdentity(t *testing.T) { + ctx := context.Background() + type fields struct { + Interface Interface + IdentityFunc GetIdentityFunc + } + type args struct { + ctx context.Context + email string + } + tests := []struct { + name string + fields fields + args args + want *Identity + wantErr bool + }{ + {"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{ + Usernames: []string{"jane", "jane@doe.org"}, + }, false}, + {"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { + return &Identity{Usernames: []string{"jane"}}, nil + }}, args{ctx, "jane@doe.org"}, &Identity{ + Usernames: []string{"jane"}, + }, false}, + {"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true}, + {"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { + return nil, fmt.Errorf("an error") + }}, args{ctx, "jane@doe.org"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + IdentityFunc: tt.fields.IdentityFunc, + } + got, err := c.GetIdentity(tt.args.ctx, tt.args.email) + if (err != nil) != tt.wantErr { + t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestController_AuthorizeRenew(t *testing.T) { + ctx := context.Background() + now := time.Now().Truncate(time.Second) + type fields struct { + Interface Interface + Claimer *Claimer + AuthorizeRenewFunc AuthorizeRenewFunc + } + type args struct { + ctx context.Context + cert *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return nil + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return nil + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, false}, + {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(time.Hour), + NotAfter: now.Add(2 * time.Hour), + }}, true}, + {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, true}, + {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return fmt.Errorf("an error") + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + Claimer: tt.fields.Claimer, + AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc, + } + if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestController_AuthorizeSSHRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type fields struct { + Interface Interface + Claimer *Claimer + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc + } + type args struct { + ctx context.Context + cert *ssh.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return nil + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return nil + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, false}, + {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(time.Hour).Unix()), + ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), + }}, true}, + {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, true}, + {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return fmt.Errorf("an error") + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + Claimer: tt.fields.Claimer, + AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc, + } + if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefaultAuthorizeRenew(t *testing.T) { + ctx := context.Background() + now := time.Now().Truncate(time.Second) + type args struct { + ctx context.Context + p *Controller + cert *x509.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok renew after expiry", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, false}, + {"fail disabled", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + {"fail not yet valid", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(time.Hour), + NotAfter: now.Add(2 * time.Hour), + }}, true}, + {"fail expired", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefaultAuthorizeSSHRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type args struct { + ctx context.Context + p *Controller + cert *ssh.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok renew after expiry", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, false}, + {"fail disabled", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + {"fail not yet valid", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(time.Hour).Unix()), + ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), + }}, true}, + {"fail expired", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/provisioner/extension.go b/authority/provisioner/extension.go new file mode 100644 index 00000000..c316329d --- /dev/null +++ b/authority/provisioner/extension.go @@ -0,0 +1,73 @@ +package provisioner + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" +) + +var ( + // StepOIDRoot is the root OID for smallstep. + StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} + + // StepOIDProvisioner is the OID for the provisioner extension. + StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...) +) + +// Extension is the Go representation of the provisioner extension. +type Extension struct { + Type Type + Name string + CredentialID string + KeyValuePairs []string +} + +type extensionASN1 struct { + Type int + Name []byte + CredentialID []byte + KeyValuePairs []string `asn1:"optional,omitempty"` +} + +// Marshal marshals the extension using encoding/asn1. +func (e *Extension) Marshal() ([]byte, error) { + return asn1.Marshal(extensionASN1{ + Type: int(e.Type), + Name: []byte(e.Name), + CredentialID: []byte(e.CredentialID), + KeyValuePairs: e.KeyValuePairs, + }) +} + +// ToExtension returns the pkix.Extension representation of the provisioner +// extension. +func (e *Extension) ToExtension() (pkix.Extension, error) { + b, err := e.Marshal() + if err != nil { + return pkix.Extension{}, err + } + return pkix.Extension{ + Id: StepOIDProvisioner, + Value: b, + }, nil +} + +// GetProvisionerExtension goes through all the certificate extensions and +// returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1). +func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) { + for _, e := range cert.Extensions { + if e.Id.Equal(StepOIDProvisioner) { + var provisioner extensionASN1 + if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { + return nil, false + } + return &Extension{ + Type: Type(provisioner.Type), + Name: string(provisioner.Name), + CredentialID: string(provisioner.CredentialID), + KeyValuePairs: provisioner.KeyValuePairs, + }, true + } + } + return nil, false +} diff --git a/authority/provisioner/extension_test.go b/authority/provisioner/extension_test.go new file mode 100644 index 00000000..69be9e18 --- /dev/null +++ b/authority/provisioner/extension_test.go @@ -0,0 +1,158 @@ +package provisioner + +import ( + "crypto/x509" + "crypto/x509/pkix" + "reflect" + "testing" + + "go.step.sm/crypto/pemutil" +) + +func TestExtension_Marshal(t *testing.T) { + type fields struct { + Type Type + Name string + CredentialID string + KeyValuePairs []string + } + tests := []struct { + name string + fields fields + want []byte + wantErr bool + }{ + {"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{ + 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, + }, false}, + {"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{ + 0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f, + 0x13, 0x03, 0x62, 0x61, 0x72, + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Extension{ + Type: tt.fields.Type, + Name: tt.fields.Name, + CredentialID: tt.fields.CredentialID, + KeyValuePairs: tt.fields.KeyValuePairs, + } + got, err := e.Marshal() + if (err != nil) != tt.wantErr { + t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want) + } + }) + } +} + +func TestExtension_ToExtension(t *testing.T) { + type fields struct { + Type Type + Name string + CredentialID string + KeyValuePairs []string + } + tests := []struct { + name string + fields fields + want pkix.Extension + wantErr bool + }{ + {"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{ + Id: StepOIDProvisioner, + Value: []byte{ + 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, + }, + }, false}, + {"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{ + Id: StepOIDProvisioner, + Value: []byte{ + 0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, + }, + }, false}, + {"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{ + Id: StepOIDProvisioner, + Value: []byte{ + 0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49, + 0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f, + 0x13, 0x03, 0x62, 0x61, 0x72, + }, + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Extension{ + Type: tt.fields.Type, + Name: tt.fields.Name, + CredentialID: tt.fields.CredentialID, + KeyValuePairs: tt.fields.KeyValuePairs, + } + got, err := e.ToExtension() + if (err != nil) != tt.wantErr { + t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetProvisionerExtension(t *testing.T) { + mustCertificate := func(fn string) *x509.Certificate { + cert, err := pemutil.ReadCertificate(fn) + if err != nil { + t.Fatal(err) + } + return cert + } + + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + args args + want *Extension + want1 bool + }{ + {"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{ + Type: TypeJWK, + Name: "mariano@smallstep.com", + CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ", + }, true}, + {"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false}, + {"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := GetProvisionerExtension(tt.args.cert) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index e46f4ce4..69d909a2 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -88,10 +88,9 @@ type GCP struct { InstanceAge Duration `json:"instanceAge,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer config *gcpConfig keyStore *keyStore - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. The name should uniquely @@ -194,8 +193,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) { } // Init validates and initializes the GCP provisioner. -func (p *GCP) Init(config Config) error { - var err error +func (p *GCP) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -204,20 +202,18 @@ func (p *GCP) Init(config Config) error { case p.InstanceAge.Value() < 0: return errors.New("provisioner instanceAge cannot be negative") } + // Initialize config p.assertConfig() - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } + // Initialize key store - p.keyStore, err = newKeyStore(p.config.CertsURL) - if err != nil { - return err + if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil { + return } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // AuthorizeSign validates the given token and returns the sign options that @@ -266,22 +262,20 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } return append(so, + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), ), nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // assertConfig initializes the config if it has not been initialized. @@ -328,7 +322,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { } // validate audiences with the defaults - if !matchesAudience(claims.Audience, p.audiences.Sign) { + if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) { return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)") } @@ -383,7 +377,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) { // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) } claims, err := p.authorizeToken(token) @@ -431,11 +425,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Validate user SignSSHOptions. sshCertOptionsValidator(defaults), // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 5f6f9bc7..dfb9a329 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "errors" "fmt" "net/http" "net/http/httptest" @@ -16,10 +17,10 @@ import ( "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestGCP_Getters(t *testing.T) { @@ -390,8 +391,8 @@ func TestGCP_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -515,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{t1}, 5, http.StatusOK, false}, - {"ok", p2, args{t2}, 10, http.StatusOK, false}, - {"ok", p3, args{t3}, 5, http.StatusOK, false}, + {"ok", p1, args{t1}, 6, http.StatusOK, false}, + {"ok", p2, args{t2}, 11, http.StatusOK, false}, + {"ok", p3, args{t3}, 6, http.StatusOK, false}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, {"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true}, @@ -540,27 +541,28 @@ func TestGCP_AuthorizeSign(t *testing.T) { t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) return case err != nil: - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) default: assert.Len(t, tt.wantLen, got) for _, o := range got { switch v := o.(type) { + case *GCP: case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeGCP)) + assert.Equals(t, v.Type, TypeGCP) assert.Equals(t, v.Name, tt.gcp.GetName()) assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) assert.Len(t, 4, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration()) case commonNameSliceValidator: assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration()) case ipAddressesValidator: assert.Equals(t, v, nil) case emailAddressesValidator: @@ -570,7 +572,7 @@ func TestGCP_AuthorizeSign(t *testing.T) { case dnsNamesValidator: assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -595,7 +597,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p3.Claims = &Claims{EnableSSHCA: &disable} - p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) + p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) assert.FatalError(t, err) t1, err := generateGCPToken(p1.ServiceAccounts[0], @@ -622,7 +624,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedHostOptions := &SignSSHOptions{ CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), @@ -677,8 +679,8 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -698,6 +700,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) { } func TestGCP_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateGCP() assert.FatalError(t, err) p2, err := generateGCP() @@ -706,7 +709,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -719,15 +722,21 @@ func TestGCP_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renewal-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tt.code) } diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 137915c8..3c5032fb 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -35,8 +35,7 @@ type JWK struct { EncryptedKey string `json:"encryptedKey,omitempty"` Claims *Claims `json:"claims,omitempty"` Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences + ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id @@ -98,13 +97,8 @@ func (p *JWK) Init(config Config) (err error) { return errors.New("provisioner key cannot be empty") } - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -146,13 +140,13 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } @@ -176,15 +170,16 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er } return []SignOption{ + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators commonNameValidator(claims.Subject), defaultPublicKeyValidator{}, defaultSANsValidator(claims.SANs), - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -193,18 +188,15 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") } @@ -261,11 +253,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, return append(signOptions, // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, ), nil @@ -273,6 +265,6 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.SSHRevoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") } diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index deae8f7a..926f9d68 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -6,15 +6,17 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "errors" + "fmt" "net/http" "strings" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestJWK_Getters(t *testing.T) { @@ -76,13 +78,13 @@ func TestJWK_Init(t *testing.T) { }, "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, err: errors.New("claims: MinTLSCertDuration must be greater than 0"), } }, "ok": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences}, + p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}}, } }, } @@ -183,8 +185,8 @@ func TestJWK_authorizeToken(t *testing.T) { t.Run(tt.name, func(t *testing.T) { if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tt.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -223,8 +225,8 @@ func TestJWK_AuthorizeRevoke(t *testing.T) { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil { if assert.NotNil(t, tt.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -288,34 +290,35 @@ func TestJWK_AuthorizeSign(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil { if assert.NotNil(t, tt.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.HasPrefix(t, err.Error(), tt.err.Error()) } } else { if assert.NotNil(t, got) { - assert.Len(t, 7, got) + assert.Len(t, 8, got) for _, o := range got { switch v := o.(type) { + case *JWK: case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeJWK)) + assert.Equals(t, v.Type, TypeJWK) assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case commonNameValidator: assert.Equals(t, string(v), "subject") case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case defaultSANsValidator: assert.Equals(t, []string(v), tt.sans) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -325,6 +328,7 @@ func TestJWK_AuthorizeSign(t *testing.T) { } func TestJWK_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateJWK() assert.FatalError(t, err) p2, err := generateJWK() @@ -333,7 +337,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -346,16 +350,22 @@ func TestJWK_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -373,7 +383,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p2.Claims = &Claims{EnableSSHCA: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) jwk, err := decryptJSONWebKey(p1.EncryptedKey) @@ -402,8 +412,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), @@ -448,8 +458,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -485,8 +495,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) { signer, err := generateJSONWebKey() assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), @@ -613,8 +623,8 @@ func TestJWK_AuthorizeSSHRevoke(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index d260f5ec..083773e0 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -42,16 +42,15 @@ type k8sSAPayload struct { // entity trusted to make signature requests. type K8sSA struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - PubKeys []byte `json:"publicKeys,omitempty"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + PubKeys []byte `json:"publicKeys,omitempty"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` //kauthn kauthn.AuthenticationV1Interface pubKeys []interface{} + ctl *Controller } // GetID returns the provisioner unique identifier. The name and credential id @@ -138,13 +137,8 @@ func (p *K8sSA) Init(config Config) (err error) { p.kauthn = k8s.AuthenticationV1() */ - // Update claims with global ones - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences - return err + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -211,13 +205,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") } @@ -237,30 +231,28 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, } return []SignOption{ + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeK8sSA, p.Name, ""), - profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign validates an request for an SSH certificate. func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") } @@ -282,11 +274,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio // Require type, key-id and principals in the SignSSHOptions. &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, // Set the validity bounds if not set. - &sshDefaultDuration{p.claimer}, + &sshDefaultDuration{p.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/k8sSA_test.go b/authority/provisioner/k8sSA_test.go index 176cdfd3..e98b6f48 100644 --- a/authority/provisioner/k8sSA_test.go +++ b/authority/provisioner/k8sSA_test.go @@ -3,14 +3,16 @@ package provisioner import ( "context" "crypto/x509" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestK8sSA_Getters(t *testing.T) { @@ -116,8 +118,8 @@ func TestK8sSA_authorizeToken(t *testing.T) { tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -165,8 +167,8 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -179,6 +181,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) { } func TestK8sSA_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) type test struct { p *K8sSA cert *x509.Certificate @@ -192,21 +195,27 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, code: http.StatusUnauthorized, - err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()), + err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { p, err := generateK8sSA(nil) assert.FatalError(t, err) return test{ - p: p, - cert: &x509.Certificate{}, + p: p, + cert: &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }, } }, } @@ -214,8 +223,8 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -263,8 +272,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -274,24 +283,25 @@ func TestK8sSA_AuthorizeSign(t *testing.T) { tot := 0 for _, o := range opts { switch v := o.(type) { + case *K8sSA: case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeK8sSA)) + assert.Equals(t, v.Type, TypeK8sSA) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } - assert.Equals(t, tot, 5) + assert.Equals(t, tot, 6) } } } @@ -313,13 +323,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p.Claims = &Claims{EnableSSHCA: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, - err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), + err: fmt.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { @@ -350,8 +360,8 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -365,13 +375,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) { case *sshCertOptionsRequireValidator: assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshDefaultDuration: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index a77f4281..4216e997 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -34,19 +34,18 @@ const ( // https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by // go.step.sm/crypto/x25519. type Nebula struct { - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - Roots []byte `json:"roots"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - caPool *nebula.NebulaCAPool - audiences Audiences + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + Roots []byte `json:"roots"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` + caPool *nebula.NebulaCAPool + ctl *Controller } // Init verifies and initializes the Nebula provisioner. -func (p *Nebula) Init(config Config) error { +func (p *Nebula) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -56,19 +55,14 @@ func (p *Nebula) Init(config Config) error { return errors.New("provisioner root(s) cannot be empty") } - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots) if err != nil { return errs.InternalServer("failed to create ca pool: %v", err) } - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // GetID returns the provisioner id. @@ -120,7 +114,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) { // AuthorizeSign returns the list of SignOption for a Sign request. func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - crt, claims, err := p.authorizeToken(token, p.audiences.Sign) + crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, err } @@ -139,8 +133,9 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, data.SetToken(v) } - // The Nebula certificate will be available using the template variable Crt. - // For example {{ .Crt.Details.Groups }} can be used to get all the groups. + // The Nebula certificate will be available using the template variable + // AuthorizationCrt. For example {{ .AuthorizationCrt.Details.Groups }} can + // be used to get all the groups. data.SetAuthorizationCertificate(crt) templateOptions, err := TemplateOptions(p.Options, data) @@ -149,11 +144,12 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, } return []SignOption{ + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeNebula, p.Name, ""), profileLimitDuration{ - def: p.claimer.DefaultTLSCertDuration(), + def: p.ctl.Claimer.DefaultTLSCertDuration(), notBefore: crt.Details.NotBefore, notAfter: crt.Details.NotAfter, }, @@ -164,18 +160,18 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, IPs: crt.Details.Ips, }, defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // Currently the Nebula provisioner only grants host SSH certificates. func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) } - crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign) + crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, err } @@ -253,11 +249,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti return append(signOptions, templateOptions, // Checks the validity bounds, and set the validity if has not been set. - &sshLimitDuration{p.claimer, crt.Details.NotAfter}, + &sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil @@ -265,23 +261,20 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti // AuthorizeRenew returns an error if the renewal is disabled. func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, crt) } // AuthorizeRevoke returns an error if the token is not valid. func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error { - return p.validateToken(token, p.audiences.Revoke) + return p.validateToken(token, p.ctl.Audiences.Revoke) } // AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid. func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) } - if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil { + if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil { return err } return nil diff --git a/authority/provisioner/nebula_test.go b/authority/provisioner/nebula_test.go index bc539af1..b190d607 100644 --- a/authority/provisioner/nebula_test.go +++ b/authority/provisioner/nebula_test.go @@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) { func TestNebula_GetTokenID(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer) - t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) + t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) _, claims, err := parseToken(t1) if err != nil { t.Fatal(err) @@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) { ctx := context.TODO() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) - okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv) + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv) pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions.caPool = p.caPool @@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1"}, }, crt, priv) - okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv) - okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv) + okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)), ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)), }, crt, priv) - failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "user", }, crt, priv) - failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ + failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan", "10.1.0.1", "foo.bar"}, @@ -549,6 +549,8 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) { func TestNebula_AuthorizeRenew(t *testing.T) { ctx := context.TODO() + now := time.Now().Truncate(time.Second) + // Ok provisioner p, _, _ := mustNebulaProvisioner(t) @@ -567,8 +569,14 @@ func TestNebula_AuthorizeRenew(t *testing.T) { args args wantErr bool }{ - {"ok", p, args{ctx, &x509.Certificate{}}, false}, - {"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true}, + {"ok", p, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"fail disabled", pDisabled, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -584,12 +592,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) - failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) + failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -618,12 +626,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) { // Ok provisioner p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv) // Fail different CA nc, signer := mustNebulaCA(t) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) - failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) + failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv) // Provisioner with SSH disabled var bFalse bool @@ -657,7 +665,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) { func TestNebula_AuthorizeSSHRenew(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -689,7 +697,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) { func TestNebula_AuthorizeSSHRekey(t *testing.T) { p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv) + t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv) type args struct { ctx context.Context @@ -726,20 +734,20 @@ func TestNebula_authorizeToken(t *testing.T) { t1 := now() p, ca, signer := mustNebulaProvisioner(t) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) - ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) - okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv) - okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{ + ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv) + okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{ CertType: "host", KeyID: "test.lan", Principals: []string{"test.lan"}, }, crt, priv) - okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv) + okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv) // Token with errors - failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) - failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) + failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv) - failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) // Not a nebula token jwk, err := generateJSONWebKey() @@ -761,7 +769,7 @@ func TestNebula_authorizeToken(t *testing.T) { IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), - Audience: []string{p.audiences.Sign[0]}, + Audience: []string{p.ctl.Audiences.Sign[0]}, } sshClaims := jose.Claims{ ID: "[REPLACEME]", @@ -770,7 +778,7 @@ func TestNebula_authorizeToken(t *testing.T) { IssuedAt: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), - Audience: []string{p.audiences.SSHSign[0]}, + Audience: []string{p.ctl.Audiences.SSHSign[0]}, } type args struct { @@ -785,14 +793,14 @@ func TestNebula_authorizeToken(t *testing.T) { want1 *jwtPayload wantErr bool }{ - {"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{ + {"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{ Claims: x509Claims, SANs: []string{"10.1.0.1"}, }, false}, - {"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{ + {"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{ Claims: x509Claims, }, false}, - {"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{ + {"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{ Claims: sshClaims, Step: &stepPayload{ SSH: &SignSSHOptions{ @@ -802,16 +810,16 @@ func TestNebula_authorizeToken(t *testing.T) { }, }, }, false}, - {"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{ + {"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{ Claims: sshClaims, }, false}, - {"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true}, - {"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true}, - {"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true}, - {"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true}, - {"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true}, - {"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true}, - {"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true}, + {"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true}, + {"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/noop.go b/authority/provisioner/noop.go index 1709fbca..39661e54 100644 --- a/authority/provisioner/noop.go +++ b/authority/provisioner/noop.go @@ -38,7 +38,7 @@ func (p *noop) Init(config Config) error { } func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - return []SignOption{}, nil + return []SignOption{p}, nil } func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { diff --git a/authority/provisioner/noop_test.go b/authority/provisioner/noop_test.go index 19e4d235..b10d1d29 100644 --- a/authority/provisioner/noop_test.go +++ b/authority/provisioner/noop_test.go @@ -24,6 +24,6 @@ func Test_noop(t *testing.T) { ctx := NewContextWithMethod(context.Background(), SignMethod) sigOptions, err := p.AuthorizeSign(ctx, "foo") - assert.Equals(t, []SignOption{}, sigOptions) + assert.Equals(t, []SignOption{&p}, sigOptions) assert.Equals(t, nil, err) } diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index ac1f2a25..3a9398a2 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -92,8 +92,7 @@ type OIDC struct { Options *Options `json:"options,omitempty"` configuration openIDConfiguration keyStore *keyStore - claimer *Claimer - getIdentityFunc GetIdentityFunc + ctl *Controller } func sanitizeEmail(email string) string { @@ -172,11 +171,6 @@ func (o *OIDC) Init(config Config) (err error) { } } - // Update claims with global ones - if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil { - return err - } - // Decode and validate openid-configuration endpoint u, err := url.Parse(o.ConfigurationEndpoint) if err != nil { @@ -201,13 +195,8 @@ func (o *OIDC) Init(config Config) (err error) { return err } - // Set the identity getter if it exists, otherwise use the default. - if config.GetIdentityFunc == nil { - o.getIdentityFunc = DefaultIdentityFunc - } else { - o.getIdentityFunc = config.GetIdentityFunc - } - return nil + o.ctl, err = NewController(o, o.Claims, config) + return } // ValidatePayload validates the given token payload. @@ -356,13 +345,14 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e } return []SignOption{ + o, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), - profileDefaultDuration(o.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()), // validators defaultPublicKeyValidator{}, - newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()), + newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()), }, nil } @@ -371,15 +361,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // revocation status. Just confirms that the provisioner that created the // certificate was configured to allow renewals. func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if o.claimer.IsDisableRenewal() { - return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName()) - } - return nil + return o.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !o.claimer.IsSSHCAEnabled() { + if !o.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName()) } claims, err := o.authorizeToken(token) @@ -394,7 +381,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Get the identity using either the default identityFunc or one injected // externally. Note that the PreferredUsername might be empty. // TBD: Would preferred_username present a safety issue here? - iden, err := o.getIdentityFunc(ctx, o, claims.Email) + iden, err := o.ctl.GetIdentity(ctx, claims.Email) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } @@ -445,11 +432,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption return append(signOptions, // Set the validity bounds if not set. - &sshDefaultDuration{o.claimer}, + &sshDefaultDuration{o.ctl.Claimer}, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{o.claimer}, + &sshCertValidityValidator{o.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 7bf6ad7a..548c4dc8 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -6,16 +6,17 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "errors" "fmt" "net/http" "strings" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func Test_openIDConfiguration_Validate(t *testing.T) { @@ -246,8 +247,8 @@ func TestOIDC_authorizeToken(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else { @@ -317,30 +318,31 @@ func TestOIDC_AuthorizeSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { - assert.Len(t, 5, got) + assert.Len(t, 6, got) for _, o := range got { switch v := o.(type) { + case *OIDC: case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeOIDC)) + assert.Equals(t, v.Type, TypeOIDC) assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.CredentialID, tt.prov.ClientID) assert.Len(t, 0, v.KeyValuePairs) case profileDefaultDuration: - assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) + assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration()) case defaultPublicKeyValidator: case *validityValidator: - assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) case emailOnlyIdentity: assert.Equals(t, string(v), "name@smallstep.com") default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -402,8 +404,8 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) { t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) return } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -411,6 +413,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) { } func TestOIDC_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) p1, err := generateOIDC() assert.FatalError(t, err) p2, err := generateOIDC() @@ -419,7 +422,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p2.Claims = &Claims{DisableRenewal: &disable} - p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) + p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) assert.FatalError(t, err) type args struct { @@ -432,8 +435,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { code int wantErr bool }{ - {"ok", p1, args{nil}, http.StatusOK, false}, - {"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, + {"ok", p1, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusOK, false}, + {"fail/renew-disabled", p2, args{&x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, http.StatusUnauthorized, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -441,8 +450,8 @@ func TestOIDC_AuthorizeRenew(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) @@ -478,7 +487,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { // disable sshCA disable := false p6.Claims = &Claims{EnableSSHCA: &disable} - p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) + p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) assert.FatalError(t, err) // Update configuration endpoints and initialize @@ -494,10 +503,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p5.Init(config)) - p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return &Identity{Usernames: []string{"max", "mariano"}}, nil } - p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return nil, errors.New("force") } // Additional test needed for empty usernames and duplicate email and usernames @@ -527,8 +536,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) assert.FatalError(t, err) - userDuration := p1.claimer.DefaultUserSSHCertDuration() - hostDuration := p1.claimer.DefaultHostSSHCertDuration() + userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration() + hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SignSSHOptions{ CertType: "user", Principals: []string{"name", "name@smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), @@ -597,8 +606,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { return } if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) assert.Nil(t, got) } else if assert.NotNil(t, got) { @@ -665,8 +674,8 @@ func TestOIDC_AuthorizeSSHRevoke(t *testing.T) { if (err != nil) != tt.wantErr { t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr) } else if err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.code) } }) diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 55ebe092..7438ea17 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -6,7 +6,6 @@ import ( "encoding/json" stderrors "errors" "net/url" - "regexp" "strings" "github.com/pkg/errors" @@ -47,6 +46,7 @@ var ErrAllowTokenReuse = stderrors.New("allow token reuse") // Audiences stores all supported audiences by request type. type Audiences struct { Sign []string + Renew []string Revoke []string SSHSign []string SSHRevoke []string @@ -57,6 +57,7 @@ type Audiences struct { // All returns all supported audiences across all request types in one list. func (a Audiences) All() (auds []string) { auds = a.Sign + auds = append(auds, a.Renew...) auds = append(auds, a.Revoke...) auds = append(auds, a.SSHSign...) auds = append(auds, a.SSHRevoke...) @@ -70,6 +71,7 @@ func (a Audiences) All() (auds []string) { func (a Audiences) WithFragment(fragment string) Audiences { ret := Audiences{ Sign: make([]string, len(a.Sign)), + Renew: make([]string, len(a.Renew)), Revoke: make([]string, len(a.Revoke)), SSHSign: make([]string, len(a.SSHSign)), SSHRevoke: make([]string, len(a.SSHRevoke)), @@ -83,6 +85,13 @@ func (a Audiences) WithFragment(fragment string) Audiences { ret.Sign[i] = s } } + for i, s := range a.Renew { + if u, err := url.Parse(s); err == nil { + ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() + } else { + ret.Renew[i] = s + } + } for i, s := range a.Revoke { if u, err := url.Parse(s); err == nil { ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() @@ -210,6 +219,12 @@ type Config struct { // GetIdentityFunc is a function that returns an identity that will be // used by the provisioner to populate certificate attributes. GetIdentityFunc GetIdentityFunc + // AuthorizeRenewFunc is a function that returns nil if a given X.509 + // certificate can be renewed. + AuthorizeRenewFunc AuthorizeRenewFunc + // AuthorizeSSHRenewFunc is a function that returns nil if a given SSH + // certificate can be renewed. + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc } type provisioner struct { @@ -278,32 +293,6 @@ func (l *List) UnmarshalJSON(data []byte) error { return nil } -var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$") - -// SanitizeSSHUserPrincipal grabs an email or a string with the format -// local@domain and returns a sanitized version of the local, valid to be used -// as a user name. If the email starts with a letter between a and z, the -// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`. -func SanitizeSSHUserPrincipal(email string) string { - if i := strings.LastIndex(email, "@"); i >= 0 { - email = email[:i] - } - return strings.Map(func(r rune) rune { - switch { - case r >= 'a' && r <= 'z': - return r - case r >= '0' && r <= '9': - return r - case r == '-': - return '-' - case r == '.': // drop dots - return -1 - default: - return '_' - } - }, strings.ToLower(email)) -} - type base struct{} // AuthorizeSign returns an unimplemented error. Provisioners should overwrite @@ -348,66 +337,12 @@ func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certif return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented") } -// Identity is the type representing an externally supplied identity that is used -// by provisioners to populate certificate fields. -type Identity struct { - Usernames []string `json:"usernames"` - Permissions `json:"permissions"` -} - // Permissions defines extra extensions and critical options to grant to an SSH certificate. type Permissions struct { Extensions map[string]string `json:"extensions"` CriticalOptions map[string]string `json:"criticalOptions"` } -// GetIdentityFunc is a function that returns an identity. -type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error) - -// DefaultIdentityFunc return a default identity depending on the provisioner -// type. For OIDC email is always present and the usernames might -// contain empty strings. -func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) { - switch k := p.(type) { - case *OIDC: - // OIDC principals would be: - // ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled - // 2. Sanitized local. - // 3. Raw local (if different). - // 4. Email address. - name := SanitizeSSHUserPrincipal(email) - if !sshUserRegex.MatchString(name) { - return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email) - } - usernames := []string{name} - if i := strings.LastIndex(email, "@"); i >= 0 { - usernames = append(usernames, email[:i]) - } - usernames = append(usernames, email) - return &Identity{ - Usernames: SanitizeStringSlices(usernames), - }, nil - default: - return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k) - } -} - -// SanitizeStringSlices removes duplicated an empty strings. -func SanitizeStringSlices(original []string) []string { - output := []string{} - seen := make(map[string]struct{}) - for _, entry := range original { - if entry == "" { - continue - } - if _, value := seen[entry]; !value { - seen[entry] = struct{}{} - output = append(output, entry) - } - } - return output -} - // MockProvisioner for testing type MockProvisioner struct { Mret1, Mret2, Mret3 interface{} diff --git a/authority/provisioner/provisioner_test.go b/authority/provisioner/provisioner_test.go index 330d1b57..9678a20b 100644 --- a/authority/provisioner/provisioner_test.go +++ b/authority/provisioner/provisioner_test.go @@ -2,13 +2,14 @@ package provisioner import ( "context" + "errors" "net/http" "testing" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestType_String(t *testing.T) { @@ -240,8 +241,8 @@ func TestUnimplementedMethods(t *testing.T) { default: t.Errorf("unexpected method %s", tt.method) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized) assert.Equals(t, err.Error(), msg) }) diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index 5d67762c..9dc1edd8 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -11,28 +11,30 @@ import ( // SCEP provisioning flow type SCEP struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` ForceCN bool `json:"forceCN,omitempty"` ChallengePassword string `json:"challenge,omitempty"` Capabilities []string `json:"capabilities,omitempty"` + // IncludeRoot makes the provisioner return the CA root in addition to the // intermediate in the GetCACerts response IncludeRoot bool `json:"includeRoot,omitempty"` + // MinimumPublicKeyLength is the minimum length for public keys in CSRs MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` + // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // Defaults to 0, being DES-CBC - EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` - Options *Options `json:"options,omitempty"` - Claims *Claims `json:"claims,omitempty"` - claimer *Claimer + EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` + Options *Options `json:"options,omitempty"` + Claims *Claims `json:"claims,omitempty"` secretChallengePassword string encryptionAlgorithm int + ctl *Controller } // GetID returns the provisioner unique identifier. @@ -77,7 +79,7 @@ func (s *SCEP) GetOptions() *Options { // DefaultTLSCertDuration returns the default TLS cert duration enforced by // the provisioner. func (s *SCEP) DefaultTLSCertDuration() time.Duration { - return s.claimer.DefaultTLSCertDuration() + return s.ctl.Claimer.DefaultTLSCertDuration() } // Init initializes and validates the fields of a SCEP type. @@ -90,11 +92,6 @@ func (s *SCEP) Init(config Config) (err error) { return errors.New("provisioner name cannot be empty") } - // Update claims with global ones - if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil { - return err - } - // Mask the actual challenge value, so it won't be marshaled s.secretChallengePassword = s.ChallengePassword s.ChallengePassword = "*** redacted ***" @@ -115,7 +112,8 @@ func (s *SCEP) Init(config Config) (err error) { // TODO: add other, SCEP specific, options? - return err + s.ctl, err = NewController(s, s.Claims, config) + return } // AuthorizeSign does not do any verification, because all verification is handled @@ -123,13 +121,14 @@ func (s *SCEP) Init(config Config) (err error) { // on the resulting certificate. func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { return []SignOption{ + s, // modifiers / withOptions newProvisionerExtensionOption(TypeSCEP, s.Name, ""), newForceCNOption(s.ForceCN), - profileDefaultDuration(s.claimer.DefaultTLSCertDuration()), + profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()), // validators newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), - newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()), + newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()), }, nil } diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 34b2e99b..80dfc66e 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -6,7 +6,6 @@ import ( "crypto/rsa" "crypto/x509" "crypto/x509/pkix" - "encoding/asn1" "encoding/json" "net" "net/http" @@ -14,7 +13,6 @@ import ( "reflect" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/errs" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" @@ -404,17 +402,12 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { return nil } -var ( - stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} - stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) -) - -type stepProvisionerASN1 struct { - Type int - Name []byte - CredentialID []byte - KeyValuePairs []string `asn1:"optional,omitempty"` -} +// type stepProvisionerASN1 struct { +// Type int +// Name []byte +// CredentialID []byte +// KeyValuePairs []string `asn1:"optional,omitempty"` +// } type forceCNOption struct { ForceCN bool @@ -441,23 +434,22 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error { } type provisionerExtensionOption struct { - Type int - Name string - CredentialID string - KeyValuePairs []string + Extension } func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption { return &provisionerExtensionOption{ - Type: int(typ), - Name: name, - CredentialID: credentialID, - KeyValuePairs: keyValuePairs, + Extension: Extension{ + Type: typ, + Name: name, + CredentialID: credentialID, + KeyValuePairs: keyValuePairs, + }, } } func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { - ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...) + ext, err := o.ToExtension() if err != nil { return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") } @@ -471,20 +463,3 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...) return nil } - -func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) { - b, err := asn1.Marshal(stepProvisionerASN1{ - Type: typ, - Name: []byte(name), - CredentialID: []byte(credentialID), - KeyValuePairs: keyValuePairs, - }) - if err != nil { - return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension") - } - return pkix.Extension{ - Id: stepOIDProvisioner, - Critical: false, - Value: b, - }, nil -} diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index 32b8e3c6..fc4d675a 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -636,18 +636,18 @@ func Test_newProvisionerExtension_Option(t *testing.T) { valid: func(cert *x509.Certificate) { if assert.Len(t, 1, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] - assert.Equals(t, ext.Id, stepOIDProvisioner) + assert.Equals(t, ext.Id, StepOIDProvisioner) } }, } }, "ok/prepend": func() test { return test{ - cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, + cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, valid: func(cert *x509.Certificate) { if assert.Len(t, 3, cert.ExtraExtensions) { ext := cert.ExtraExtensions[0] - assert.Equals(t, ext.Id, stepOIDProvisioner) + assert.Equals(t, ext.Id, StepOIDProvisioner) assert.False(t, ext.Critical) } }, diff --git a/authority/provisioner/sign_ssh_options_test.go b/authority/provisioner/sign_ssh_options_test.go index b59d6945..28a35639 100644 --- a/authority/provisioner/sign_ssh_options_test.go +++ b/authority/provisioner/sign_ssh_options_test.go @@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) { func Test_sshCertValidityValidator(t *testing.T) { p, err := generateX5C(nil) assert.FatalError(t, err) - v := sshCertValidityValidator{p.claimer} + v := sshCertValidityValidator{p.ctl.Claimer} n := now() tests := []struct { name string @@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) { tests := map[string]func() test{ "fail/type-not-set": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ ValidAfter: uint64(n.Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), @@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/type-not-recognized": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)}, cert: &ssh.Certificate{ CertType: 4, ValidAfter: uint64(n.Unix()), @@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/requested-validAfter-after-limit": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Add(2 * time.Hour).Unix()), @@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) { }, "fail/requested-validBefore-after-limit": func() test { return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: uint64(n.Unix()), @@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/no-limit": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, @@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/defaults": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer}, cert: &ssh.Certificate{ CertType: 1, }, @@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/valid-requested-validBefore": func() test { va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, @@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/empty-requested-validBefore-limit-after-default": func() test { va := uint64(n.Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, @@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) { "ok/empty-requested-validBefore-limit-before-default": func() test { va := uint64(n.Unix()) return test{ - svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, + svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)}, cert: &ssh.Certificate{ CertType: 1, ValidAfter: va, diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 3039d2a3..9de0fca2 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -29,8 +29,7 @@ type SSHPOP struct { Type string `json:"type"` Name string `json:"name"` Claims *Claims `json:"claims,omitempty"` - claimer *Claimer - audiences Audiences + ctl *Controller sshPubKeys *SSHKeys } @@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) { } // Init initializes and validates the fields of a SSHPOP type. -func (p *SSHPOP) Init(config Config) error { +func (p *SSHPOP) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -93,15 +92,11 @@ func (p *SSHPOP) Init(config Config) error { return errors.New("provisioner public SSH validation keys cannot be empty") } - // Update claims with global ones - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.sshPubKeys = config.SSHKeys - return nil + + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -109,7 +104,7 @@ func (p *SSHPOP) Init(config Config) error { // e.g. a Sign request will auth/validate different fields than a Revoke request. // // Checking for certificate revocation has been moved to the authority package. -func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { +func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) { sshCert, jwt, err := ExtractSSHPOPCert(token) if err != nil { return nil, errs.Wrap(http.StatusUnauthorized, err, @@ -117,13 +112,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa } // Check validity period of the certificate. - n := time.Now() - if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { - return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") - } - if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { - return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") + // + // Controller.AuthorizeSSHRenew will validate this on the renewal flow. + if checkValidity { + unixNow := time.Now().Unix() + if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) { + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") + } + if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) { + return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") + } } + sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) if !ok { return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey") @@ -186,7 +186,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa // AuthorizeSSHRevoke validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { - claims, err := p.authorizeToken(token, p.audiences.SSHRevoke) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } @@ -199,22 +199,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { // AuthorizeSSHRenew validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { - claims, err := p.authorizeToken(token, p.audiences.SSHRenew) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") } if claims.sshCert.CertType != ssh.HostCert { return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } - - return claims.sshCert, nil - + return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert) } // AuthorizeSSHRekey validates the authorization token and extracts/validates // the SSH certificate from the ssh-pop header. func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.SSHRekey) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true) if err != nil { return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") } @@ -225,11 +223,10 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require and validate all the default fields in the SSH certificate. &sshCertDefaultValidator{}, }, nil - } // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index da036864..13294866 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -5,16 +5,19 @@ import ( "crypto" "crypto/rand" "encoding/base64" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" + "golang.org/x/crypto/ssh" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" - "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestSSHPOP_Getters(t *testing.T) { @@ -38,6 +41,7 @@ func TestSSHPOP_Getters(t *testing.T) { } func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { + now := time.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) if err != nil { return nil, nil, err @@ -46,6 +50,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, if err != nil { return nil, nil, err } + if cert.ValidAfter == 0 { + cert.ValidAfter = uint64(now.Unix()) + } + if cert.ValidBefore == 0 { + cert.ValidBefore = uint64(now.Add(time.Hour).Unix()) + } if err := cert.SignCert(rand.Reader, signer); err != nil { return nil, nil, err } @@ -207,9 +217,9 @@ func TestSSHPOP_authorizeToken(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil { + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -279,8 +289,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { t.Run(name, func(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) if assert.NotNil(t, tc.err) { assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -360,8 +370,8 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { tc := tt(t) if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -442,8 +452,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { tc := tt(t) if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -455,9 +465,9 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } assert.Equals(t, tc.cert.Nonce, cert.Nonce) diff --git a/authority/provisioner/testdata/certs/bad-extension.crt b/authority/provisioner/testdata/certs/bad-extension.crt new file mode 100644 index 00000000..ecce0f28 --- /dev/null +++ b/authority/provisioner/testdata/certs/bad-extension.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi +MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy +MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs +aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg +U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0 +ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc +pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE +AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw +ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg +d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB +hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu +Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh +NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA +ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G +A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2 +LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49 +BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY +wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/authority/provisioner/testdata/certs/good-extension.crt b/authority/provisioner/testdata/certs/good-extension.crt new file mode 100644 index 00000000..103353a7 --- /dev/null +++ b/authority/provisioner/testdata/certs/good-extension.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi +MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx +MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs +aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg +U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0 +ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0 +4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE +AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw +ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg +d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB +hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu +Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh +NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA +ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G +A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2 +LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz +bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf +SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB +bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index fe2678fc..c55c58d2 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -24,20 +24,22 @@ import ( ) var ( - defaultDisableRenewal = false - defaultEnableSSHCA = true - globalProvisionerClaims = Claims{ - MinTLSDur: &Duration{5 * time.Minute}, - MaxTLSDur: &Duration{24 * time.Hour}, - DefaultTLSDur: &Duration{24 * time.Hour}, - DisableRenewal: &defaultDisableRenewal, - MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs - MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, - DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, - MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs - MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, - EnableSSHCA: &defaultEnableSSHCA, + defaultDisableRenewal = false + defaultAllowRenewAfterExpiry = false + defaultEnableSSHCA = true + globalProvisionerClaims = Claims{ + MinTLSDur: &Duration{5 * time.Minute}, + MaxTLSDur: &Duration{24 * time.Hour}, + DefaultTLSDur: &Duration{24 * time.Hour}, + MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs + MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, + DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, + MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs + MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, + EnableSSHCA: &defaultEnableSSHCA, + DisableRenewal: &defaultDisableRenewal, + AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry, } testAudiences = Audiences{ Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, @@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &JWK{ + + p := &JWK{ Name: name, Type: "JWK", Key: &public, EncryptedKey: encrypted, Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { @@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } pubKeys := []interface{}{fooPub, barPub} if inputPubKey != nil { pubKeys = append(pubKeys, inputPubKey) } - return &K8sSA{ - Name: K8sSAName, - Type: "K8sSA", - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - pubKeys: pubKeys, - }, nil + p := &K8sSA{ + Name: K8sSAName, + Type: "K8sSA", + Claims: &globalProvisionerClaims, + pubKeys: pubKeys, + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateSSHPOP() (*SSHPOP, error) { @@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub") if err != nil { return nil, err @@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) { return nil, err } - return &SSHPOP{ - Name: name, - Type: "SSHPOP", - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, + p := &SSHPOP{ + Name: name, + Type: "SSHPOP", + Claims: &globalProvisionerClaims, sshPubKeys: &SSHKeys{ UserKeys: []ssh.PublicKey{userKey}, HostKeys: []ssh.PublicKey{hostKey}, }, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateX5C(root []byte) (*X5C, error) { @@ -283,11 +279,6 @@ M46l92gdOozT if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - rootPool := x509.NewCertPool() var ( @@ -305,15 +296,17 @@ M46l92gdOozT } rootPool.AddCert(cert) } - return &X5C{ - Name: name, - Type: "X5C", - Roots: root, - Claims: &globalProvisionerClaims, - audiences: testAudiences, - claimer: claimer, - rootPool: rootPool, - }, nil + p := &X5C{ + Name: name, + Type: "X5C", + Roots: root, + Claims: &globalProvisionerClaims, + rootPool: rootPool, + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateOIDC() (*OIDC, error) { @@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &OIDC{ + p := &OIDC{ Name: name, Type: "OIDC", ClientID: clientID, @@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) { keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - claimer: claimer, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateGCP() (*GCP, error) { @@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } - return &GCP{ + p := &GCP{ Type: "GCP", Name: name, ServiceAccounts: []string{serviceAccount}, Claims: &globalProvisionerClaims, - claimer: claimer, config: newGCPConfig(), keyStore: &keyStore{ keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - audiences: testAudiences.WithFragment("gcp/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("gcp/" + name), + }) + return p, err } func generateAWS() (*AWS, error) { @@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") @@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) { if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } - return &AWS{ + p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v2", "v1"}, - claimer: claimer, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, @@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) { certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, - audiences: testAudiences.WithFragment("aws/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("aws/" + name), + }) + return p, err } func generateAWSWithServer() (*AWS, *httptest.Server, error) { @@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } block, _ := pem.Decode([]byte(awsTestCertificate)) if block == nil || block.Type != "CERTIFICATE" { return nil, errors.New("error decoding AWS certificate") @@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) { if err != nil { return nil, errors.Wrap(err, "error parsing AWS certificate") } - return &AWS{ + p := &AWS{ Type: "AWS", Name: name, Accounts: []string{accountID}, Claims: &globalProvisionerClaims, IMDSVersions: []string{"v1"}, - claimer: claimer, config: &awsConfig{ identityURL: awsIdentityURL, signatureURL: awsSignatureURL, @@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) { certificates: []*x509.Certificate{cert}, signatureAlgorithm: awsSignatureAlgorithm, }, - audiences: testAudiences.WithFragment("aws/" + name), - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences.WithFragment("aws/" + name), + }) + return p, err } func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { @@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) { if err != nil { return nil, err } - claimer, err := NewClaimer(nil, globalProvisionerClaims) - if err != nil { - return nil, err - } jwk, err := generateJSONWebKey() if err != nil { return nil, err } - return &Azure{ + p := &Azure{ Type: "Azure", Name: name, TenantID: tenantID, Audience: azureDefaultAudience, Claims: &globalProvisionerClaims, - claimer: claimer, config: newAzureConfig(tenantID), oidcConfig: openIDConfiguration{ Issuer: "https://sts.windows.net/" + tenantID + "/", @@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) { keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, - }, nil + } + p.ctl, err = NewController(p, p.Claims, Config{ + Audiences: testAudiences, + }) + return p, err } func generateAzureWithServer() (*Azure, *httptest.Server, error) { @@ -671,7 +656,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) { w.Header().Add("Cache-Control", "max-age=5") writeJSON(w, getPublic(az.keyStore.keySet)) case "/metadata/identity/oauth2/token": - tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0]) + tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0]) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } else { @@ -1009,7 +994,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r return jose.Signed(sig).Claims(claims).CompactSerialize() } -func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { +func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { sig, err := jose.NewSigner( jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), @@ -1017,6 +1002,12 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, if err != nil { return "", err } + var xmsMirID string + if resourceType == "vm" { + xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName) + } else if resourceType == "uai" { + xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName) + } claims := azurePayload{ Claims: jose.Claims{ @@ -1034,7 +1025,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, ObjectID: "the-oid", TenantID: tenantID, Version: "the-version", - XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine), + XMSMirID: xmsMirID, } return jose.Signed(sig).Claims(claims).CompactSerialize() } diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 8710acb5..295d81fb 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -26,15 +26,14 @@ type x5cPayload struct { // signature requests. type X5C struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - Roots []byte `json:"roots"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer - audiences Audiences - rootPool *x509.CertPool + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + Roots []byte `json:"roots"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` + ctl *Controller + rootPool *x509.CertPool } // GetID returns the provisioner unique identifier. The name and credential id @@ -86,7 +85,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) { } // Init initializes and validates the fields of a X5C type. -func (p *X5C) Init(config Config) error { +func (p *X5C) Init(config Config) (err error) { switch { case p.Type == "": return errors.New("provisioner type cannot be empty") @@ -101,6 +100,7 @@ func (p *X5C) Init(config Config) error { var ( block *pem.Block rest = p.Roots + count int ) for rest != nil { block, rest = pem.Decode(rest) @@ -111,22 +111,18 @@ func (p *X5C) Init(config Config) error { if err != nil { return errors.Wrap(err, "error parsing x509 certificate from PEM block") } + count++ p.rootPool.AddCert(cert) } // Verify that at least one root was found. - if len(p.rootPool.Subjects()) == 0 { + if count == 0 { return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) } - // Update claims with global ones - var err error - if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { - return err - } - - p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) - return nil + config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) + p.ctl, err = NewController(p, p.Claims, config) + return } // authorizeToken performs common jwt authorization actions and returns the @@ -189,13 +185,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err // AuthorizeRevoke returns an error if the provisioner does not have rights to // revoke the certificate with serial number in the `sub` property. func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error { - _, err := p.authorizeToken(token, p.audiences.Revoke) + _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") } // AuthorizeSign validates the given token. func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { - claims, err := p.authorizeToken(token, p.audiences.Sign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") } @@ -213,40 +209,45 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er data.SetToken(v) } + // The X509 certificate will be available using the template variable + // AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be + // used to get all the domains. + data.SetAuthorizationCertificate(claims.chains[0][0]) + templateOptions, err := TemplateOptions(p.Options, data) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") } return []SignOption{ + p, templateOptions, // modifiers / withOptions newProvisionerExtensionOption(TypeX5C, p.Name, ""), - profileLimitDuration{p.claimer.DefaultTLSCertDuration(), - claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter}, + profileLimitDuration{ + p.ctl.Claimer.DefaultTLSCertDuration(), + claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter, + }, // validators commonNameValidator(claims.Subject), defaultSANsValidator(claims.SANs), defaultPublicKeyValidator{}, - newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), + newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), }, nil } // AuthorizeRenew returns an error if the renewal is disabled. func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { - if p.claimer.IsDisableRenewal() { - return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()) - } - return nil + return p.ctl.AuthorizeRenew(ctx, cert) } // AuthorizeSSHSign returns the list of SignOption for a SignSSH request. func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { - if !p.claimer.IsSSHCAEnabled() { + if !p.ctl.Claimer.IsSSHCAEnabled() { return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) } - claims, err := p.authorizeToken(token, p.audiences.SSHSign) + claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") } @@ -287,6 +288,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, data.SetToken(v) } + // The X509 certificate will be available using the template variable + // AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be + // used to get all the domains. + data.SetAuthorizationCertificate(claims.chains[0][0]) + templateOptions, err := TemplateSSHOptions(p.Options, data) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") @@ -304,11 +310,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, return append(signOptions, // Checks the validity bounds, and set the validity if has not been set. - &sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter}, + &sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter}, // Validate public key. &sshDefaultPublicKeyValidator{}, // Validate the validity period. - &sshCertValidityValidator{p.claimer}, + &sshCertValidityValidator{p.ctl.Claimer}, // Require all the fields in the SSH certificate &sshCertDefaultValidator{}, ), nil diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 2959f8c6..a3308f00 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -2,16 +2,19 @@ package provisioner import ( "context" + "crypto/x509" + "errors" + "fmt" "net/http" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/randutil" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestX5C_Getters(t *testing.T) { @@ -69,8 +72,8 @@ func TestX5C_Init(t *testing.T) { }, "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { return ProvisionerValidateTest{ - p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences}, - err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), + p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")}, + err: errors.New("no x509 certificates found in roots attribute for provisioner 'foo'"), } }, "fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest { @@ -117,9 +120,11 @@ M46l92gdOozT return ProvisionerValidateTest{ p: p, extraValid: func(p *X5C) error { + // nolint:staticcheck // We don't have a different way to + // check the number of certificates in the pool. numCerts := len(p.rootPool.Subjects()) if numCerts != 2 { - return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts) + return fmt.Errorf("unexpected number of certs: want 2, but got %d", numCerts) } return nil }, @@ -141,7 +146,7 @@ M46l92gdOozT } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID())) + assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID())) if tc.extraValid != nil { assert.Nil(t, tc.extraValid(tc.p)) } @@ -384,8 +389,8 @@ lgsqsR63is+0YQ== tc := tt(t) if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -455,7 +460,7 @@ func TestX5C_AuthorizeSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -463,19 +468,20 @@ func TestX5C_AuthorizeSign(t *testing.T) { } else { if assert.Nil(t, tc.err) { if assert.NotNil(t, opts) { - assert.Equals(t, len(opts), 7) + assert.Equals(t, len(opts), 8) for _, o := range opts { switch v := o.(type) { + case *X5C: case certificateOptionsFunc: case *provisionerExtensionOption: - assert.Equals(t, v.Type, int(TypeX5C)) + assert.Equals(t, v.Type, TypeX5C) assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.CredentialID, "") assert.Len(t, 0, v.KeyValuePairs) case profileLimitDuration: - assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration()) + assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration()) - claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign) + claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign) assert.FatalError(t, err) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) case commonNameValidator: @@ -484,10 +490,10 @@ func TestX5C_AuthorizeSign(t *testing.T) { case defaultSANsValidator: assert.Equals(t, []string(v), tc.sans) case *validityValidator: - assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) - assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) + assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) + assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } } } @@ -538,8 +544,8 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { tc := tt(t) if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -551,6 +557,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { } func TestX5C_AuthorizeRenew(t *testing.T) { + now := time.Now().Truncate(time.Second) type test struct { p *X5C code int @@ -563,12 +570,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) { // disable renewal disable := true p.Claims = &Claims{DisableRenewal: &disable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, code: http.StatusUnauthorized, - err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()), + err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()), } }, "ok": func(t *testing.T) test { @@ -582,10 +589,13 @@ func TestX5C_AuthorizeRenew(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil { + if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -618,13 +628,13 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { // disable sshCA enable := false p.Claims = &Claims{EnableSSHCA: &enable} - p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) + p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) assert.FatalError(t, err) return test{ p: p, token: "foo", code: http.StatusUnauthorized, - err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), + err: fmt.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), } }, "fail/invalid-token": func(t *testing.T) test { @@ -745,7 +755,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { tc := tt(t) if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -774,13 +784,13 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { case sshCertDefaultsModifier: assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert}) case *sshLimitDuration: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) case *sshCertValidityValidator: - assert.Equals(t, v.Claimer, tc.p.claimer) + assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: default: - assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) + assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) } tot++ } diff --git a/authority/provisioners.go b/authority/provisioners.go index 3b14657c..a6ac5aa8 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -87,20 +87,20 @@ func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, e return p, nil } -func (a *Authority) generateProvisionerConfig(ctx context.Context) (*provisioner.Config, error) { +func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.Config, error) { // Merge global and configuration claims claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims) if err != nil { - return nil, err + return provisioner.Config{}, err } // TODO: should we also be combining the ssh federated roots here? // If we rotate ssh roots keys, sshpop provisioner will lose ability to // validate old SSH certificates, unless they are added as federated certs. sshKeys, err := a.GetSSHRoots(ctx) if err != nil { - return nil, err + return provisioner.Config{}, err } - return &provisioner.Config{ + return provisioner.Config{ Claims: claimer.Claims(), Audiences: a.config.GetAudiences(), DB: a.db, @@ -108,7 +108,9 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (*provisioner UserKeys: sshKeys.UserKeys, HostKeys: sshKeys.HostKeys, }, - GetIdentityFunc: a.getIdentityFunc, + GetIdentityFunc: a.getIdentityFunc, + AuthorizeRenewFunc: a.authorizeRenewFunc, + AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc, }, nil } @@ -133,9 +135,18 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi "provisioner with token ID %s already exists", certProv.GetIDForToken()) } + provisionerConfig, err := a.generateProvisionerConfig(ctx) + if err != nil { + return admin.WrapErrorISE(err, "error generating provisioner config") + } + + if err := certProv.Init(provisionerConfig); err != nil { + return admin.WrapError(admin.ErrorBadRequestType, err, "error validating configuration for provisioner %s", prov.Name) + } + // Store to database -- this will set the ID. if err := a.adminDB.CreateProvisioner(ctx, prov); err != nil { - return admin.WrapErrorISE(err, "error creating admin") + return admin.WrapErrorISE(err, "error creating provisioner") } // We need a new conversion that has the newly set ID. @@ -145,12 +156,7 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi "error converting to certificates provisioner from linkedca provisioner") } - provisionerConfig, err := a.generateProvisionerConfig(ctx) - if err != nil { - return admin.WrapErrorISE(err, "error generating provisioner config") - } - - if err := certProv.Init(*provisionerConfig); err != nil { + if err := certProv.Init(provisionerConfig); err != nil { return admin.WrapErrorISE(err, "error initializing provisioner %s", prov.Name) } @@ -179,7 +185,7 @@ func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisio return admin.WrapErrorISE(err, "error generating provisioner config") } - if err := certProv.Init(*provisionerConfig); err != nil { + if err := certProv.Init(provisionerConfig); err != nil { return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name) } @@ -431,7 +437,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) { } pc := &provisioner.Claims{ - DisableRenewal: &c.DisableRenewal, + DisableRenewal: &c.DisableRenewal, + AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry, } var err error @@ -469,12 +476,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims { } disableRenewal := config.DefaultDisableRenewal + allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry + if c.DisableRenewal != nil { disableRenewal = *c.DisableRenewal } + if c.AllowRenewAfterExpiry != nil { + allowRenewAfterExpiry = *c.AllowRenewAfterExpiry + } lc := &linkedca.Claims{ - DisableRenewal: disableRenewal, + DisableRenewal: disableRenewal, + AllowRenewAfterExpiry: allowRenewAfterExpiry, } if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil { @@ -706,6 +719,8 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, Name: p.Name, TenantID: cfg.TenantId, ResourceGroups: cfg.ResourceGroups, + SubscriptionIDs: cfg.SubscriptionIds, + ObjectIDs: cfg.ObjectIds, Audience: cfg.Audience, DisableCustomSANs: cfg.DisableCustomSans, DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse, @@ -865,6 +880,8 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro Azure: &linkedca.AzureProvisioner{ TenantId: p.TenantID, ResourceGroups: p.ResourceGroups, + SubscriptionIds: p.SubscriptionIDs, + ObjectIds: p.ObjectIDs, Audience: p.Audience, DisableCustomSans: p.DisableCustomSANs, DisableTrustOnFirstUse: p.DisableTrustOnFirstUse, diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 3975031b..81dc38bf 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -1,13 +1,13 @@ package authority import ( + "errors" "net/http" "testing" - "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/errs" ) func TestGetEncryptedKey(t *testing.T) { @@ -49,8 +49,8 @@ func TestGetEncryptedKey(t *testing.T) { ek, err := tc.a.GetEncryptedKey(tc.kid) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -90,8 +90,8 @@ func TestGetProvisioners(t *testing.T) { ps, next, err := tc.a.GetProvisioners("", 0) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/root_test.go b/authority/root_test.go index 6e5f1932..a1b08fac 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -2,14 +2,15 @@ package authority import ( "crypto/x509" + "errors" "net/http" "reflect" "testing" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/pemutil" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" ) func TestRoot(t *testing.T) { @@ -31,7 +32,7 @@ func TestRoot(t *testing.T) { crt, err := a.Root(tc.sum) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) + sc, ok := err.(render.StatusCodedError) assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) diff --git a/authority/ssh_test.go b/authority/ssh_test.go index c299b347..ce840fe1 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -7,21 +7,22 @@ import ( "crypto/rand" "crypto/x509" "encoding/base64" + "errors" "fmt" "net/http" "reflect" "testing" "time" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" - "github.com/smallstep/certificates/templates" "go.step.sm/crypto/jose" "go.step.sm/crypto/sshutil" "golang.org/x/crypto/ssh" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/templates" ) type sshTestModifier ssh.Certificate @@ -716,8 +717,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) { t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) return } else if err != nil { - _, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + _, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) @@ -806,8 +807,8 @@ func TestAuthority_GetSSHHosts(t *testing.T) { hosts, err := auth.GetSSHHosts(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } @@ -1033,8 +1034,8 @@ func TestAuthority_RekeySSH(t *testing.T) { cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...) if err != nil { if assert.NotNil(t, tc.err) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) } diff --git a/authority/status/status.go b/authority/status/status.go deleted file mode 100644 index 49e4c0bb..00000000 --- a/authority/status/status.go +++ /dev/null @@ -1,11 +0,0 @@ -package status - -// Type is the type for status. -type Type string - -var ( - // Active active - Active = Type("active") - // Deleted deleted - Deleted = Type("deleted") -) diff --git a/authority/tls.go b/authority/tls.go index eb2b0001..dab8775e 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -89,8 +89,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign // Set backdate with the configured value signOpts.Backdate = a.config.AuthorityConfig.Backdate.Duration + var prov provisioner.Interface for _, op := range extraOpts { switch k := op.(type) { + // Capture current provisioner + case provisioner.Interface: + prov = k + // Adds new options to NewCertificate case provisioner.CertificateOptions: certOptions = append(certOptions, k.Options(signOpts)...) @@ -204,7 +209,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) - if err = a.storeCertificate(fullchain); err != nil { + if err = a.storeCertificate(prov, fullchain); err != nil { if err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error storing certificate in db", opts...) @@ -325,19 +330,28 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 // TODO: at some point we should replace the db.AuthDB interface to implement // `StoreCertificate(...*x509.Certificate) error` instead of just // `StoreCertificate(*x509.Certificate) error`. -func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { +func (a *Authority) storeCertificate(prov provisioner.Interface, fullchain []*x509.Certificate) error { + type linkedChainStorer interface { + StoreCertificateChain(provisioner.Interface, ...*x509.Certificate) error + } type certificateChainStorer interface { StoreCertificateChain(...*x509.Certificate) error } // Store certificate in linkedca - if s, ok := a.adminDB.(certificateChainStorer); ok { + switch s := a.adminDB.(type) { + case linkedChainStorer: + return s.StoreCertificateChain(prov, fullchain...) + case certificateChainStorer: return s.StoreCertificateChain(fullchain...) } + // Store certificate in local db - if s, ok := a.db.(certificateChainStorer); ok { + switch s := a.db.(type) { + case certificateChainStorer: return s.StoreCertificateChain(fullchain...) + default: + return a.db.StoreCertificate(fullchain[0]) } - return a.db.StoreCertificate(fullchain[0]) } // storeRenewedCertificate allows to use an extension of the db.AuthDB interface diff --git a/authority/tls_test.go b/authority/tls_test.go index aeadaf0f..e199e0c5 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -11,24 +11,26 @@ import ( "crypto/x509/pkix" "encoding/asn1" "encoding/pem" + "errors" "fmt" "net/http" "reflect" "testing" "time" - "github.com/smallstep/certificates/cas/softcas" + "gopkg.in/square/go-jose.v2/jwt" - "github.com/pkg/errors" - "github.com/smallstep/assert" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" - "github.com/smallstep/certificates/errs" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" - "gopkg.in/square/go-jose.v2/jwt" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/cas/softcas" + "github.com/smallstep/certificates/db" + "github.com/smallstep/certificates/errs" ) var ( @@ -187,14 +189,14 @@ func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) { func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { b, err := x509.MarshalPKIXPublicKey(pub) if err != nil { - return nil, errors.Wrap(err, "error marshaling public key") + return nil, fmt.Errorf("error marshaling public key: %w", err) } info := struct { Algorithm pkix.AlgorithmIdentifier SubjectPublicKey asn1.BitString }{} if _, err = asn1.Unmarshal(b, &info); err != nil { - return nil, errors.Wrap(err, "error unmarshaling public key") + return nil, fmt.Errorf("error unmarshaling public key: %w", err) } hash := sha1.Sum(info.SubjectPublicKey.Bytes) return hash[:], nil @@ -661,8 +663,8 @@ ZYtQ9Ot36qc= if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -757,7 +759,7 @@ func TestAuthority_Renew(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) - na1 := now + na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), @@ -798,7 +800,20 @@ func TestAuthority_Renew(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), + err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), + code: http.StatusUnauthorized, + }, nil + }, + "fail/WithAuthorizeRenewFunc": func() (*renewTest, error) { + aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { + return errs.Unauthorized("not authorized") + })) + aa.x509CAService = a.x509CAService + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + return &renewTest{ + auth: aa, + cert: cert, + err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"), code: http.StatusUnauthorized, }, nil }, @@ -820,6 +835,17 @@ func TestAuthority_Renew(t *testing.T) { cert: cert, }, nil }, + "ok/WithAuthorizeRenewFunc": func() (*renewTest, error) { + aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error { + return nil + })) + aa.x509CAService = a.x509CAService + aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template + return &renewTest{ + auth: aa, + cert: cert, + }, nil + }, } for name, genTestCase := range tests { @@ -836,8 +862,8 @@ func TestAuthority_Renew(t *testing.T) { if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -856,7 +882,7 @@ func TestAuthority_Renew(t *testing.T) { expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) - assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) + assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equals(t, leaf.Subject.String(), @@ -956,7 +982,7 @@ func TestAuthority_Rekey(t *testing.T) { now := time.Now().UTC() nb1 := now.Add(-time.Minute * 7) - na1 := now + na1 := now.Add(time.Hour) so := &provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(nb1), NotAfter: provisioner.NewTimeDuration(na1), @@ -998,7 +1024,7 @@ func TestAuthority_Rekey(t *testing.T) { "fail/unauthorized": func() (*renewTest, error) { return &renewTest{ cert: certNoRenew, - err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), + err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"), code: http.StatusUnauthorized, }, nil }, @@ -1043,8 +1069,8 @@ func TestAuthority_Rekey(t *testing.T) { if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) @@ -1063,7 +1089,7 @@ func TestAuthority_Rekey(t *testing.T) { expiry := now.Add(time.Minute * 7) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) - assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) + assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour))) tmplt := a.config.AuthorityConfig.Template assert.Equals(t, leaf.Subject.String(), @@ -1432,8 +1458,8 @@ func TestAuthority_Revoke(t *testing.T) { ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) if err := tc.auth.Revoke(ctx, tc.opts); err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tc.code) assert.HasPrefix(t, err.Error(), tc.err.Error()) diff --git a/ca/acmeClient_test.go b/ca/acmeClient_test.go index ad5f2116..f17a2f7a 100644 --- a/ca/acmeClient_test.go +++ b/ca/acmeClient_test.go @@ -12,12 +12,13 @@ import ( "time" "github.com/pkg/errors" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/pemutil" + "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" acmeAPI "github.com/smallstep/certificates/acme/api" - "github.com/smallstep/certificates/api" - "go.step.sm/crypto/jose" - "go.step.sm/crypto/pemutil" + "github.com/smallstep/certificates/api/render" ) func TestNewACMEClient(t *testing.T) { @@ -112,15 +113,15 @@ func TestNewACMEClient(t *testing.T) { assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header switch { case i == 0: - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ case i == 1: w.Header().Set("Replay-Nonce", "abc123") - api.JSONStatus(w, []byte{}, 200) + render.JSONStatus(w, []byte{}, 200) i++ default: w.Header().Set("Location", accLocation) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) } }) @@ -206,7 +207,7 @@ func TestACMEClient_GetNonce(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header w.Header().Set("Replay-Nonce", expectedNonce) - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) }) if nonce, err := ac.GetNonce(); err != nil { @@ -315,7 +316,7 @@ func TestACMEClient_post(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -338,7 +339,7 @@ func TestACMEClient_post(t *testing.T) { assert.Equals(t, hdr.KeyID, ac.kid) } - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { @@ -455,7 +456,7 @@ func TestACMEClient_NewOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -477,7 +478,7 @@ func TestACMEClient_NewOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, norb) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.NewOrder(norb); err != nil { @@ -577,7 +578,7 @@ func TestACMEClient_GetOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -599,7 +600,7 @@ func TestACMEClient_GetOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetOrder(url); err != nil { @@ -699,7 +700,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -721,7 +722,7 @@ func TestACMEClient_GetAuthz(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetAuthz(url); err != nil { @@ -821,7 +822,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -844,7 +845,7 @@ func TestACMEClient_GetChallenge(t *testing.T) { assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := ac.GetChallenge(url); err != nil { @@ -944,7 +945,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -967,7 +968,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) { assert.Equals(t, payload, []byte("{}")) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if err := ac.ValidateChallenge(url); err != nil { @@ -1071,7 +1072,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1093,7 +1094,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, payload, frb) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if err := ac.FinalizeOrder(url, csr); err != nil { @@ -1200,7 +1201,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1222,7 +1223,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, len(payload), 0) - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) }) if res, err := tc.client.GetAccountOrders(); err != nil { @@ -1331,7 +1332,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { w.Header().Set("Replay-Nonce", expectedNonce) if i == 0 { - api.JSONStatus(w, tc.r1, tc.rc1) + render.JSONStatus(w, tc.r1, tc.rc1) i++ return } @@ -1356,7 +1357,7 @@ func TestACMEClient_GetCertificate(t *testing.T) { if tc.certBytes != nil { w.Write(tc.certBytes) } else { - api.JSONStatus(w, tc.r2, tc.rc2) + render.JSONStatus(w, tc.r2, tc.rc2) } }) diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 9482d657..2332b4d4 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -14,11 +14,14 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority" - "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/errs" ) func newLocalListener() net.Listener { @@ -79,7 +82,7 @@ func startCAServer(configFile string) (*CA, string, error) { func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/version" { - api.JSON(w, api.VersionResponse{ + render.JSON(w, api.VersionResponse{ Version: "test", RequireClientAuthentication: true, }) @@ -93,7 +96,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han } isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 if !isMTLS { - api.WriteError(w, errs.Unauthorized("missing peer certificate")) + render.Error(w, errs.Unauthorized("missing peer certificate")) } else { next.ServeHTTP(w, r) } @@ -408,6 +411,7 @@ func TestBootstrapClientServerRotation(t *testing.T) { server.ServeTLS(listener, "", "") }() defer server.Close() + time.Sleep(1 * time.Second) // Create bootstrap client token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") @@ -419,7 +423,6 @@ func TestBootstrapClientServerRotation(t *testing.T) { // doTest does a request that requires mTLS doTest := func(client *http.Client) error { - time.Sleep(1 * time.Second) // test with ca resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) if err != nil { diff --git a/ca/ca.go b/ca/ca.go index c95ba22f..0d4f1578 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "reflect" + "strings" "sync" "github.com/go-chi/chi" @@ -26,11 +27,14 @@ import ( scepAPI "github.com/smallstep/certificates/scep/api" "github.com/smallstep/certificates/server" "github.com/smallstep/nosql" + "go.step.sm/cli-utils/step" + "go.step.sm/crypto/x509util" ) type options struct { configFile string linkedCAToken string + quiet bool password []byte issuerPassword []byte sshHostPassword []byte @@ -101,6 +105,13 @@ func WithLinkedCAToken(token string) Option { } } +// WithQuiet sets the quiet flag. +func WithQuiet(quiet bool) Option { + return func(o *options) { + o.quiet = quiet + } +} + // CA is the type used to build the complete certificate authority. It builds // the HTTP server, set ups the middlewares and the HTTP handlers. type CA struct { @@ -288,6 +299,35 @@ func (ca *CA) Run() error { var wg sync.WaitGroup errs := make(chan error, 1) + if !ca.opts.quiet { + authorityInfo := ca.auth.GetInfo() + log.Printf("Starting %s", step.Version()) + log.Printf("Documentation: https://u.step.sm/docs/ca") + log.Printf("Community Discord: https://u.step.sm/discord") + if step.Contexts().GetCurrent() != nil { + log.Printf("Current context: %s", step.Contexts().GetCurrent().Name) + } + log.Printf("Config file: %s", ca.opts.configFile) + baseURL := fmt.Sprintf("https://%s%s", + authorityInfo.DNSNames[0], + ca.config.Address[strings.LastIndex(ca.config.Address, ":"):]) + log.Printf("The primary server URL is %s", baseURL) + log.Printf("Root certificates are available at %s/roots.pem", baseURL) + if len(authorityInfo.DNSNames) > 1 { + log.Printf("Additional configured hostnames: %s", + strings.Join(authorityInfo.DNSNames[1:], ", ")) + } + for _, crt := range authorityInfo.RootX509Certs { + log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt)) + } + if authorityInfo.SSHCAHostPublicKey != nil { + log.Printf("SSH Host CA Key is %s\n", authorityInfo.SSHCAHostPublicKey) + } + if authorityInfo.SSHCAUserPublicKey != nil { + log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey) + } + } + if ca.insecureSrv != nil { wg.Add(1) go func() { @@ -355,6 +395,7 @@ func (ca *CA) Reload() error { WithSSHUserPassword(ca.opts.sshUserPassword), WithIssuerPassword(ca.opts.issuerPassword), WithLinkedCAToken(ca.opts.linkedCAToken), + WithQuiet(ca.opts.quiet), WithConfigFile(ca.opts.configFile), WithDatabase(ca.auth.GetDatabase()), ) @@ -450,9 +491,6 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) { tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven tlsConfig.ClientCAs = certPool - // Use server's most preferred ciphersuite - tlsConfig.PreferServerCipherSuites = true - return tlsConfig, nil } diff --git a/ca/client.go b/ca/client.go index 6bc48a42..3a36fcd6 100644 --- a/ca/client.go +++ b/ca/client.go @@ -563,6 +563,11 @@ func (c *Client) retryOnError(r *http.Response) bool { return false } +// GetCaURL returns the configured CA url. +func (c *Client) GetCaURL() string { + return c.endpoint.String() +} + // GetRootCAs returns the RootCAs certificate pool from the configured // transport. func (c *Client) GetRootCAs() *x509.CertPool { @@ -723,6 +728,36 @@ retry: return &sign, nil } +// RenewWithToken performs the renew request to the CA with the given +// authorization token and returns the api.SignResponse struct. This method is +// generally used to renew an expired certificate. +func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) + req, err := http.NewRequest("POST", u.String(), http.NoBody) + if err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request") + } + req.Header.Add("Authorization", "Bearer "+token) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readError(resp.Body) + } + var sign api.SignResponse + if err := readJSON(resp.Body, &sign); err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u) + } + return &sign, nil +} + // Rekey performs the rekey request to the CA and returns the api.SignResponse // struct. func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { diff --git a/ca/client_test.go b/ca/client_test.go index 29a4848d..48aa1488 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "fmt" "net/http" "net/http/httptest" @@ -16,14 +17,16 @@ import ( "testing" "time" - "github.com/pkg/errors" + "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ssh" + "github.com/smallstep/assert" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/read" + "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" ) const ( @@ -179,7 +182,7 @@ func TestClient_Version(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Version() @@ -229,7 +232,7 @@ func TestClient_Health(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Health() @@ -287,7 +290,7 @@ func TestClient_Root(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Root(tt.shasum) @@ -354,10 +357,10 @@ func TestClient_Sign(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.SignRequest) - if err := api.ReadJSON(req.Body, body); err != nil { + if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") - api.WriteError(w, e) + render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -368,7 +371,7 @@ func TestClient_Sign(t *testing.T) { t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) } } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Sign(tt.request) @@ -426,10 +429,10 @@ func TestClient_Revoke(t *testing.T) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.RevokeRequest) - if err := api.ReadJSON(req.Body, body); err != nil { + if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) assert.Fatal(t, ok, "response expected to be error type") - api.WriteError(w, e) + render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { if tt.request == nil { @@ -440,7 +443,7 @@ func TestClient_Revoke(t *testing.T) { t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) } } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Revoke(tt.request, nil) @@ -500,7 +503,7 @@ func TestClient_Renew(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Renew(nil) @@ -516,8 +519,8 @@ func TestClient_Renew(t *testing.T) { t.Errorf("Client.Renew() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -529,6 +532,74 @@ func TestClient_Renew(t *testing.T) { } } +func TestClient_RenewWithToken(t *testing.T) { + ok := &api.SignResponse{ + ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + CertChainPEM: []api.Certificate{ + {Certificate: parseCertificate(certPEM)}, + {Certificate: parseCertificate(rootPEM)}, + }, + } + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + err error + }{ + {"ok", ok, 200, false, nil}, + {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.Header.Get("Authorization") != "Bearer token" { + render.JSONStatus(w, errs.InternalServer("force"), 500) + } else { + render.JSONStatus(w, tt.response, tt.responseCode) + } + }) + + got, err := c.RenewWithToken("token") + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.RenewWithToken() = %v, want nil", got) + } + + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, err.Error(), tt.err.Error()) + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response) + } + } + }) + } +} + func TestClient_Rekey(t *testing.T) { ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, @@ -569,7 +640,7 @@ func TestClient_Rekey(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Rekey(tt.request, nil) @@ -585,8 +656,8 @@ func TestClient_Rekey(t *testing.T) { t.Errorf("Client.Renew() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -634,7 +705,7 @@ func TestClient_Provisioners(t *testing.T) { if req.RequestURI != tt.expectedURI { t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Provisioners(tt.args...) @@ -691,7 +762,7 @@ func TestClient_ProvisionerKey(t *testing.T) { if req.RequestURI != expected { t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) } - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.ProvisionerKey(tt.kid) @@ -706,8 +777,8 @@ func TestClient_ProvisionerKey(t *testing.T) { t.Errorf("Client.ProvisionerKey() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, tt.err.Error(), err.Error()) default: @@ -750,7 +821,7 @@ func TestClient_Roots(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Roots() @@ -765,8 +836,8 @@ func TestClient_Roots(t *testing.T) { if got != nil { t.Errorf("Client.Roots() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) default: @@ -808,7 +879,7 @@ func TestClient_Federation(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Federation() @@ -823,8 +894,8 @@ func TestClient_Federation(t *testing.T) { if got != nil { t.Errorf("Client.Federation() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, tt.err.Error(), err.Error()) default: @@ -870,7 +941,7 @@ func TestClient_SSHRoots(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHRoots() @@ -885,8 +956,8 @@ func TestClient_SSHRoots(t *testing.T) { if got != nil { t.Errorf("Client.SSHKeys() = %v, want nil", got) } - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, tt.err.Error(), err.Error()) default: @@ -970,7 +1041,7 @@ func TestClient_RootFingerprint(t *testing.T) { } tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() @@ -1031,7 +1102,7 @@ func TestClient_SSHBastion(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - api.JSONStatus(w, tt.response, tt.responseCode) + render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHBastion(tt.request) @@ -1047,8 +1118,8 @@ func TestClient_SSHBastion(t *testing.T) { t.Errorf("Client.SSHBastion() = %v, want nil", got) } if tt.responseCode != 200 { - sc, ok := err.(errs.StatusCoder) - assert.Fatal(t, ok, "error does not implement StatusCoder interface") + sc, ok := err.(render.StatusCodedError) + assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.HasPrefix(t, err.Error(), tt.err.Error()) } @@ -1060,3 +1131,28 @@ func TestClient_SSHBastion(t *testing.T) { }) } } + +func TestClient_GetCaURL(t *testing.T) { + tests := []struct { + name string + caURL string + want string + }{ + {"ok", "https://ca.com", "https://ca.com"}, + {"ok no schema", "ca.com", "https://ca.com"}, + {"ok with port", "https://ca.com:9000", "https://ca.com:9000"}, + {"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + if got := c.GetCaURL(); got != tt.want { + t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ca/identity/client_test.go b/ca/identity/client_test.go index 0f1234e9..9660a3bd 100644 --- a/ca/identity/client_test.go +++ b/ca/identity/client_test.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "reflect" + "sort" "testing" ) @@ -196,7 +197,7 @@ func TestLoadClient(t *testing.T) { switch { case gotTransport.TLSClientConfig.GetClientCertificate == nil: t.Error("LoadClient() transport does not define GetClientCertificate") - case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()): + case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs): t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) default: crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) @@ -238,3 +239,23 @@ func Test_defaultsConfig_Validate(t *testing.T) { }) } } + +// nolint:staticcheck,gocritic +func equalPools(a, b *x509.CertPool) bool { + if reflect.DeepEqual(a, b) { + return true + } + subjects := a.Subjects() + sA := make([]string, len(subjects)) + for i := range subjects { + sA[i] = string(subjects[i]) + } + subjects = b.Subjects() + sB := make([]string, len(subjects)) + for i := range subjects { + sB[i] = string(subjects[i]) + } + sort.Strings(sA) + sort.Strings(sB) + return reflect.DeepEqual(sA, sB) +} diff --git a/ca/identity/identity_test.go b/ca/identity/identity_test.go index d3b1d541..55fc60fd 100644 --- a/ca/identity/identity_test.go +++ b/ca/identity/identity_test.go @@ -346,6 +346,8 @@ func TestIdentity_GetCertPool(t *testing.T) { return } if got != nil { + // nolint:staticcheck // we don't have a different way to check + // the certificates in the pool. subjects := got.Subjects() if !reflect.DeepEqual(subjects, tt.wantSubjects) { t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects) diff --git a/ca/renew.go b/ca/renew.go index 915be787..27898993 100644 --- a/ca/renew.go +++ b/ca/renew.go @@ -60,7 +60,10 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption } } - period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore) + // Use the current time to calculate the initial period. Using a notBefore + // in the past might set a renewBefore too large, causing continuous + // renewals due to the negative values in nextRenewDuration. + period := cert.Leaf.NotAfter.Sub(time.Now().Truncate(time.Second)) if period < minCertDuration { return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period) } @@ -181,7 +184,7 @@ func (r *TLSRenewer) renewCertificate() { } func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration { - d := time.Until(notAfter) - r.renewBefore + d := time.Until(notAfter).Truncate(time.Second) - r.renewBefore n := rand.Int63n(int64(r.renewJitter)) d -= time.Duration(n) if d < 0 { diff --git a/ca/testdata/ca.json b/ca/testdata/ca.json index d40325e8..2a336f24 100644 --- a/ca/testdata/ca.json +++ b/ca/testdata/ca.json @@ -6,7 +6,7 @@ "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.3, diff --git a/ca/testdata/federated-ca.json b/ca/testdata/federated-ca.json index 342adfcf..0b1c6c8d 100644 --- a/ca/testdata/federated-ca.json +++ b/ca/testdata/federated-ca.json @@ -6,7 +6,7 @@ "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-0.json b/ca/testdata/rotate-ca-0.json index 20dd603a..aa9353ed 100644 --- a/ca/testdata/rotate-ca-0.json +++ b/ca/testdata/rotate-ca-0.json @@ -5,7 +5,7 @@ "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-1.json b/ca/testdata/rotate-ca-1.json index b038f694..c78ba035 100644 --- a/ca/testdata/rotate-ca-1.json +++ b/ca/testdata/rotate-ca-1.json @@ -5,7 +5,7 @@ "password": "password", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-2.json b/ca/testdata/rotate-ca-2.json index 7ec965d0..2db1c992 100644 --- a/ca/testdata/rotate-ca-2.json +++ b/ca/testdata/rotate-ca-2.json @@ -5,7 +5,7 @@ "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/testdata/rotate-ca-3.json b/ca/testdata/rotate-ca-3.json index 968da6ba..50f4a118 100644 --- a/ca/testdata/rotate-ca-3.json +++ b/ca/testdata/rotate-ca-3.json @@ -5,7 +5,7 @@ "password": "asdf", "address": "127.0.0.1:0", "dnsNames": ["127.0.0.1"], - "logger": {"format": "text"}, + "_logger": {"format": "text"}, "tls": { "minVersion": 1.2, "maxVersion": 1.2, diff --git a/ca/tls.go b/ca/tls.go index 0738d0e0..7954cbdf 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -95,7 +95,6 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, // Note that with GetClientCertificate tlsConfig.Certificates is not used. // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetClientCertificate = renewer.GetClientCertificate - tlsConfig.PreferServerCipherSuites = true // Apply options and initialize mutable tls.Config tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) @@ -137,7 +136,6 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() tlsConfig.GetCertificate = renewer.GetCertificate tlsConfig.GetClientCertificate = renewer.GetClientCertificate - tlsConfig.PreferServerCipherSuites = true tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert // Apply options and initialize mutable tls.Config diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 7d94926b..ca5f80b8 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -542,6 +542,7 @@ func TestAddFederationToCAs(t *testing.T) { } } +// nolint:staticcheck,gocritic func equalPools(a, b *x509.CertPool) bool { if reflect.DeepEqual(a, b) { return true diff --git a/cas/apiv1/options.go b/cas/apiv1/options.go index a39b4115..cc5998ae 100644 --- a/cas/apiv1/options.go +++ b/cas/apiv1/options.go @@ -33,12 +33,20 @@ type Options struct { // https://cloud.google.com/docs/authentication. CredentialsFile string `json:"credentialsFile,omitempty"` - // Certificate and signer are the issuer certificate, along with any other - // bundled certificates to be returned in the chain for consumers, and - // signer used in SoftCAS. They are configured in ca.json crt and key - // properties. + // CertificateChain contains the issuer certificate, along with any other + // bundled certificates to be returned in the chain to consumers. It is used + // used in SoftCAS and it is configured in the crt property of the ca.json. CertificateChain []*x509.Certificate `json:"-"` - Signer crypto.Signer `json:"-"` + + // Signer is the private key or a KMS signer for the issuer certificate. It + // is used in SoftCAS and it is configured in the key property of the + // ca.json. + Signer crypto.Signer `json:"-"` + + // CertificateSigner combines CertificateChain and Signer in a callback that + // returns the chain of certificate and signer used to sign X.509 + // certificates in SoftCAS. + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) `json:"-"` // IsCreator is set to true when we're creating a certificate authority. It // is used to skip some validations when initializing a diff --git a/cas/softcas/softcas.go b/cas/softcas/softcas.go index 8e67d016..2a97145b 100644 --- a/cas/softcas/softcas.go +++ b/cas/softcas/softcas.go @@ -24,9 +24,10 @@ var now = time.Now // SoftCAS implements a Certificate Authority Service using Golang or KMS // crypto. This is the default CAS used in step-ca. type SoftCAS struct { - CertificateChain []*x509.Certificate - Signer crypto.Signer - KeyManager kms.KeyManager + CertificateChain []*x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) + KeyManager kms.KeyManager } // New creates a new CertificateAuthorityService implementation using Golang or KMS @@ -34,16 +35,17 @@ type SoftCAS struct { func New(ctx context.Context, opts apiv1.Options) (*SoftCAS, error) { if !opts.IsCreator { switch { - case len(opts.CertificateChain) == 0: + case len(opts.CertificateChain) == 0 && opts.CertificateSigner == nil: return nil, errors.New("softCAS 'CertificateChain' cannot be nil") - case opts.Signer == nil: + case opts.Signer == nil && opts.CertificateSigner == nil: return nil, errors.New("softCAS 'signer' cannot be nil") } } return &SoftCAS{ - CertificateChain: opts.CertificateChain, - Signer: opts.Signer, - KeyManager: opts.KeyManager, + CertificateChain: opts.CertificateChain, + Signer: opts.Signer, + CertificateSigner: opts.CertificateSigner, + KeyManager: opts.KeyManager, }, nil } @@ -57,6 +59,7 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 } t := now() + // Provisioners can also set specific values. if req.Template.NotBefore.IsZero() { req.Template.NotBefore = t.Add(-1 * req.Backdate) @@ -64,16 +67,21 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1 if req.Template.NotAfter.IsZero() { req.Template.NotAfter = t.Add(req.Lifetime) } - req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + chain, signer, err := c.getCertSigner() + if err != nil { + return nil, err + } + req.Template.Issuer = chain[0].Subject + + cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer) if err != nil { return nil, err } return &apiv1.CreateCertificateResponse{ Certificate: cert, - CertificateChain: c.CertificateChain, + CertificateChain: chain, }, nil } @@ -89,16 +97,21 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R t := now() req.Template.NotBefore = t.Add(-1 * req.Backdate) req.Template.NotAfter = t.Add(req.Lifetime) - req.Template.Issuer = c.CertificateChain[0].Subject - cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) + chain, signer, err := c.getCertSigner() + if err != nil { + return nil, err + } + req.Template.Issuer = chain[0].Subject + + cert, err := createCertificate(req.Template, chain[0], req.Template.PublicKey, signer) if err != nil { return nil, err } return &apiv1.RenewCertificateResponse{ Certificate: cert, - CertificateChain: c.CertificateChain, + CertificateChain: chain, }, nil } @@ -106,9 +119,13 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R // operation is a no-op as the actual revoke will happen when we store the entry // in the db. func (c *SoftCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { + chain, _, err := c.getCertSigner() + if err != nil { + return nil, err + } return &apiv1.RevokeCertificateResponse{ Certificate: req.Certificate, - CertificateChain: c.CertificateChain, + CertificateChain: chain, }, nil } @@ -179,7 +196,7 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori }, nil } -// initializeKeyManager initiazes the default key manager if was not given. +// initializeKeyManager initializes the default key manager if was not given. func (c *SoftCAS) initializeKeyManager() (err error) { if c.KeyManager == nil { c.KeyManager, err = kms.New(context.Background(), kmsapi.Options{ @@ -189,6 +206,15 @@ func (c *SoftCAS) initializeKeyManager() (err error) { return } +// getCertSigner returns the certificate chain and signer to use. +func (c *SoftCAS) getCertSigner() ([]*x509.Certificate, crypto.Signer, error) { + if c.CertificateSigner != nil { + return c.CertificateSigner() + } + return c.CertificateChain, c.Signer, nil + +} + // createKey uses the configured kms to create a key. func (c *SoftCAS) createKey(req *kmsapi.CreateKeyRequest) (*kmsapi.CreateKeyResponse, error) { if err := c.initializeKeyManager(); err != nil { diff --git a/cas/softcas/softcas_test.go b/cas/softcas/softcas_test.go index 7d3add4f..b4f5b440 100644 --- a/cas/softcas/softcas_test.go +++ b/cas/softcas/softcas_test.go @@ -73,6 +73,12 @@ var ( testSignedTemplate = mustSign(testTemplate, testIssuer, testNow, testNow.Add(24*time.Hour)) testSignedRootTemplate = mustSign(testRootTemplate, testRootTemplate, testNow, testNow.Add(24*time.Hour)) testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) + testCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) { + return []*x509.Certificate{testIssuer}, testSigner, nil + } + testFailCertificateSigner = func() ([]*x509.Certificate, crypto.Signer, error) { + return nil, nil, errTest + } ) type signatureAlgorithmSigner struct { @@ -186,6 +192,10 @@ func setTeeReader(t *testing.T, w *bytes.Buffer) { } func TestNew(t *testing.T) { + assertEqual := func(x, y interface{}) bool { + return reflect.DeepEqual(x, y) || fmt.Sprintf("%#v", x) == fmt.Sprintf("%#v", y) + } + type args struct { ctx context.Context opts apiv1.Options @@ -197,6 +207,7 @@ func TestNew(t *testing.T) { wantErr bool }{ {"ok", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}}, &SoftCAS{CertificateChain: []*x509.Certificate{testIssuer}, Signer: testSigner}, false}, + {"ok with callback", args{context.Background(), apiv1.Options{CertificateSigner: testCertificateSigner}}, &SoftCAS{CertificateSigner: testCertificateSigner}, false}, {"fail no issuer", args{context.Background(), apiv1.Options{Signer: testSigner}}, nil, true}, {"fail no signer", args{context.Background(), apiv1.Options{CertificateChain: []*x509.Certificate{testIssuer}}}, nil, true}, } @@ -207,7 +218,7 @@ func TestNew(t *testing.T) { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { + if !assertEqual(got, tt.want) { t.Errorf("New() = %v, want %v", got, tt.want) } }) @@ -265,8 +276,9 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { } type fields struct { - Issuer *x509.Certificate - Signer crypto.Signer + Issuer *x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.CreateCertificateRequest @@ -278,43 +290,53 @@ func TestSoftCAS_CreateCertificate(t *testing.T) { want *apiv1.CreateCertificateResponse wantErr bool }{ - {"ok", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &saTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok with notBefore", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNotBefore, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok with notBefore+notAfter", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok with notBefore+notAfter", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplWithLifetime, Lifetime: 24 * time.Hour, }}, &apiv1.CreateCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"fail template", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, - {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true}, - {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ + {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.CreateCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.CreateCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, + {"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, + {"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{Template: testTemplate}}, nil, true}, + {"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.CreateCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, + {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.CreateCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ - CertificateChain: []*x509.Certificate{tt.fields.Issuer}, - Signer: tt.fields.Signer, + CertificateChain: []*x509.Certificate{tt.fields.Issuer}, + Signer: tt.fields.Signer, + CertificateSigner: tt.fields.CertificateSigner, } got, err := c.CreateCertificate(tt.args.req) if (err != nil) != tt.wantErr { @@ -345,8 +367,9 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { } type fields struct { - Issuer *x509.Certificate - Signer crypto.Signer + Issuer *x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.RenewCertificateRequest @@ -358,30 +381,40 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { want *apiv1.RenewCertificateResponse wantErr bool }{ - {"ok", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ + {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.RenewCertificateRequest{ + {"ok signature algorithm", fields{testIssuer, saSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: testTemplate, Lifetime: 24 * time.Hour, }}, &apiv1.RenewCertificateResponse{ Certificate: testSignedTemplate, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, - {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, - {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ + {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RenewCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, &apiv1.RenewCertificateResponse{ + Certificate: testSignedTemplate, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, + {"fail template", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, + {"fail lifetime", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, + {"fail CreateCertificate", fields{testIssuer, testSigner, nil}, args{&apiv1.RenewCertificateRequest{ Template: &tmplNoSerial, Lifetime: 24 * time.Hour, }}, nil, true}, + {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RenewCertificateRequest{ + Template: testTemplate, Lifetime: 24 * time.Hour, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ - CertificateChain: []*x509.Certificate{tt.fields.Issuer}, - Signer: tt.fields.Signer, + CertificateChain: []*x509.Certificate{tt.fields.Issuer}, + Signer: tt.fields.Signer, + CertificateSigner: tt.fields.CertificateSigner, } got, err := c.RenewCertificate(tt.args.req) if (err != nil) != tt.wantErr { @@ -397,8 +430,9 @@ func TestSoftCAS_RenewCertificate(t *testing.T) { func TestSoftCAS_RevokeCertificate(t *testing.T) { type fields struct { - Issuer *x509.Certificate - Signer crypto.Signer + Issuer *x509.Certificate + Signer crypto.Signer + CertificateSigner func() ([]*x509.Certificate, crypto.Signer, error) } type args struct { req *apiv1.RevokeCertificateRequest @@ -410,7 +444,7 @@ func TestSoftCAS_RevokeCertificate(t *testing.T) { want *apiv1.RevokeCertificateResponse wantErr bool }{ - {"ok", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ + {"ok", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{ Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, Reason: "test reason", ReasonCode: 1, @@ -418,23 +452,37 @@ func TestSoftCAS_RevokeCertificate(t *testing.T) { Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok no cert", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{ + {"ok no cert", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{ Reason: "test reason", ReasonCode: 1, }}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, - {"ok empty", fields{testIssuer, testSigner}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{ + {"ok empty", fields{testIssuer, testSigner, nil}, args{&apiv1.RevokeCertificateRequest{}}, &apiv1.RevokeCertificateResponse{ Certificate: nil, CertificateChain: []*x509.Certificate{testIssuer}, }, false}, + {"ok with callback", fields{nil, nil, testCertificateSigner}, args{&apiv1.RevokeCertificateRequest{ + Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, + Reason: "test reason", + ReasonCode: 1, + }}, &apiv1.RevokeCertificateResponse{ + Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, + CertificateChain: []*x509.Certificate{testIssuer}, + }, false}, + {"fail with callback", fields{nil, nil, testFailCertificateSigner}, args{&apiv1.RevokeCertificateRequest{ + Certificate: &x509.Certificate{Subject: pkix.Name{CommonName: "fake"}}, + Reason: "test reason", + ReasonCode: 1, + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &SoftCAS{ - CertificateChain: []*x509.Certificate{tt.fields.Issuer}, - Signer: tt.fields.Signer, + CertificateChain: []*x509.Certificate{tt.fields.Issuer}, + Signer: tt.fields.Signer, + CertificateSigner: tt.fields.CertificateSigner, } got, err := c.RevokeCertificate(tt.args.req) if (err != nil) != tt.wantErr { @@ -609,3 +657,56 @@ func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { }) } } + +func TestSoftCAS_defaultKeyManager(t *testing.T) { + mockNow(t) + type args struct { + req *apiv1.CreateCertificateAuthorityRequest + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok root", args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.RootCA, + Template: &x509.Certificate{ + Subject: pkix.Name{CommonName: "Test Root CA"}, + KeyUsage: x509.KeyUsageCRLSign | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + SerialNumber: big.NewInt(1234), + }, + Lifetime: 24 * time.Hour, + }}, false}, + {"ok intermediate", args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Certificate: testSignedRootTemplate, + Signer: testSigner, + }, + }}, false}, + {"fail with default key manager", args{&apiv1.CreateCertificateAuthorityRequest{ + Type: apiv1.IntermediateCA, + Template: testIntermediateTemplate, + Lifetime: 24 * time.Hour, + Parent: &apiv1.CreateCertificateAuthorityResponse{ + Certificate: testSignedRootTemplate, + Signer: &badSigner{}, + }, + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &SoftCAS{} + _, err := c.CreateCertificateAuthority(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("SoftCAS.CreateCertificateAuthority() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/cmd/step-ca/main.go b/cmd/step-ca/main.go index f976b4b4..bc7bf2e3 100644 --- a/cmd/step-ca/main.go +++ b/cmd/step-ca/main.go @@ -118,7 +118,7 @@ func main() { app.HelpName = "step-ca" app.Version = step.Version() app.Usage = "an online certificate authority for secure automated certificate management" - app.UsageText = `**step-ca** [**--password-file**=] + app.UsageText = `**step-ca** [config] [**--context**=] [**--password-file**=] [**--ssh-host-password-file**=] [**--ssh-user-password-file**=] [**--issuer-password-file**=] [**--resolver**=] [**--help**] [**--version**]` app.Description = `**step-ca** runs the Step Online Certificate Authority @@ -134,6 +134,7 @@ This command will run indefinitely on success and return \>0 if any error occurs These examples assume that you have already initialized your PKI by running 'step ca init'. If you have not completed this step please see the 'Getting Started' section of the README. + Run the Step CA and prompt for password: ''' $ step-ca $STEPPATH/config/ca.json @@ -142,7 +143,26 @@ Run the Step CA and read the password from a file - this is useful for automating deployment: ''' $ step-ca $STEPPATH/config/ca.json --password-file ./password.txt -'''` +''' +Run the Step CA for the context selected with step and a custom password file: +''' +$ step context select ssh +$ step-ca --password-file ./password.txt +''' +Run the Step CA for the context named _mybiz_ and prompt for password: +''' +$ step-ca --context=mybiz +''' +Run the Step CA for the context named _mybiz_ and an alternate ca.json file: +''' +$ step-ca --context=mybiz other-ca.json +''' +Run the Step CA for the context named _mybiz_ and read the password from a file - this is useful for +automating deployment: +''' +$ step-ca --context=mybiz --password-file ./password.txt +''' +` app.Flags = append(app.Flags, commands.AppCommand.Flags...) app.Flags = append(app.Flags, cli.HelpFlag) app.Copyright = fmt.Sprintf("(c) 2018-%d Smallstep Labs, Inc.", time.Now().Year()) diff --git a/commands/app.go b/commands/app.go index 8c40de0e..265610f2 100644 --- a/commands/app.go +++ b/commands/app.go @@ -16,6 +16,7 @@ import ( "github.com/smallstep/certificates/pki" "github.com/urfave/cli" "go.step.sm/cli-utils/errs" + "go.step.sm/cli-utils/step" ) // AppCommand is the action used as the top action. @@ -57,6 +58,16 @@ certificate issuer private key used in the RA mode.`, Usage: "token used to enable the linked ca.", EnvVar: "STEP_CA_TOKEN", }, + cli.BoolFlag{ + Name: "quiet", + Usage: "disable startup information", + EnvVar: "STEP_CA_QUIET", + }, + cli.StringFlag{ + Name: "context", + Usage: "The name of the authority's context.", + EnvVar: "STEP_CA_CONTEXT", + }, }, } @@ -68,16 +79,25 @@ func appAction(ctx *cli.Context) error { issuerPassFile := ctx.String("issuer-password-file") resolver := ctx.String("resolver") token := ctx.String("token") + quiet := ctx.Bool("quiet") - // If zero cmd line args show help, if >1 cmd line args show error. - if ctx.NArg() == 0 { - return cli.ShowAppHelp(ctx) - } - if err := errs.NumberOfArguments(ctx, 1); err != nil { - return err + if ctx.NArg() > 1 { + return errs.TooManyArguments(ctx) + } + + if caCtx := ctx.String("context"); caCtx != "" { + if err := step.Contexts().SetCurrent(caCtx); err != nil { + return err + } + } + + var configFile string + if ctx.NArg() > 0 { + configFile = ctx.Args().Get(0) + } else { + configFile = step.CaConfigFile() } - configFile := ctx.Args().Get(0) cfg, err := config.LoadConfiguration(configFile) if err != nil { fatal(err) @@ -141,7 +161,8 @@ To get a linked authority token: ca.WithSSHHostPassword(sshHostPassword), ca.WithSSHUserPassword(sshUserPassword), ca.WithIssuerPassword(issuerPassword), - ca.WithLinkedCAToken(token)) + ca.WithLinkedCAToken(token), + ca.WithQuiet(quiet)) if err != nil { fatal(err) } diff --git a/errs/error.go b/errs/error.go index 60da9e1f..c42e342d 100644 --- a/errs/error.go +++ b/errs/error.go @@ -6,18 +6,11 @@ import ( "net/http" "github.com/pkg/errors" + + "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/api/render" ) -// StatusCoder interface is used by errors that returns the HTTP response code. -type StatusCoder interface { - StatusCode() int -} - -// StackTracer must be by those errors that return an stack trace. -type StackTracer interface { - StackTrace() errors.StackTrace -} - // Option modifies the Error type. type Option func(e *Error) error @@ -257,7 +250,7 @@ func NewError(status int, err error, format string, args ...interface{}) error { return err } msg := fmt.Sprintf(format, args...) - if _, ok := err.(StackTracer); !ok { + if _, ok := err.(log.StackTracedError); !ok { err = errors.Wrap(err, msg) } return &Error{ @@ -275,11 +268,11 @@ func NewErr(status int, err error, opts ...Option) error { ok bool ) if e, ok = err.(*Error); !ok { - if sc, ok := err.(StatusCoder); ok { + if sc, ok := err.(render.StatusCodedError); ok { e = &Error{Status: sc.StatusCode(), Err: err} } else { cause := errors.Cause(err) - if sc, ok := cause.(StatusCoder); ok { + if sc, ok := cause.(render.StatusCodedError); ok { e = &Error{Status: sc.StatusCode(), Err: err} } else { e = &Error{Status: status, Err: err} diff --git a/go.mod b/go.mod index 8f24e688..139c82e1 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module github.com/smallstep/certificates go 1.16 require ( - cloud.google.com/go v0.83.0 + cloud.google.com/go v0.100.2 + cloud.google.com/go/kms v1.4.0 + cloud.google.com/go/security v1.3.0 github.com/Azure/azure-sdk-for-go v58.0.0+incompatible github.com/Azure/go-autorest/autorest v0.11.17 github.com/Azure/go-autorest/autorest/azure/auth v0.5.8 @@ -18,9 +20,9 @@ require ( github.com/go-kit/kit v0.10.0 // indirect github.com/go-piv/piv-go v1.7.0 github.com/golang/mock v1.6.0 - github.com/google/go-cmp v0.5.6 + github.com/google/go-cmp v0.5.7 github.com/google/uuid v1.3.0 - github.com/googleapis/gax-go/v2 v2.0.5 + github.com/googleapis/gax-go/v2 v2.1.1 github.com/hashicorp/vault/api v1.3.1 github.com/hashicorp/vault/api/auth/approle v0.1.1 github.com/hashicorp/vault/sdk v0.3.0 @@ -33,18 +35,18 @@ require ( github.com/sirupsen/logrus v1.8.1 github.com/slackhq/nebula v1.5.2 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 - github.com/smallstep/nosql v0.3.9 + github.com/smallstep/nosql v0.4.0 + github.com/stretchr/testify v1.7.1 github.com/urfave/cli v1.22.4 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.step.sm/cli-utils v0.7.0 - go.step.sm/crypto v0.15.0 - go.step.sm/linkedca v0.9.2 + go.step.sm/crypto v0.16.1 + go.step.sm/linkedca v0.12.0 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 - golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d - golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect - google.golang.org/api v0.47.0 - google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 - google.golang.org/grpc v1.43.0 + golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd + google.golang.org/api v0.70.0 + google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf + google.golang.org/grpc v1.44.0 google.golang.org/protobuf v1.27.1 gopkg.in/square/go-jose.v2 v2.6.0 ) diff --git a/go.sum b/go.sum index b4b0c75d..0b08d836 100644 --- a/go.sum +++ b/go.sum @@ -18,20 +18,38 @@ cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmW cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg= cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8= cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0= -cloud.google.com/go v0.83.0 h1:bAMqZidYkmIsUqe6PtkEPT7Q+vfizScn+jfNA6jwK9c= cloud.google.com/go v0.83.0/go.mod h1:Z7MJUsANfY0pYPdw0lbnivPx4/vhy/e2FEkSkF7vAVY= +cloud.google.com/go v0.84.0/go.mod h1:RazrYuxIK6Kb7YrzzhPoLmCVzl7Sup4NrbKPg8KHSUM= +cloud.google.com/go v0.87.0/go.mod h1:TpDYlFy7vuLzZMMZ+B6iRiELaY7z/gJPaqbMx6mlWcY= +cloud.google.com/go v0.90.0/go.mod h1:kRX0mNRHe0e2rC6oNakvwQqzyDmg57xJ+SZU1eT2aDQ= +cloud.google.com/go v0.93.3/go.mod h1:8utlLll2EF5XMAV15woO4lSbWQlk8rer9aLOfLh7+YI= +cloud.google.com/go v0.94.1/go.mod h1:qAlAugsXlC+JWO+Bke5vCtc9ONxjQT3drlTTnAplMW4= +cloud.google.com/go v0.97.0/go.mod h1:GF7l59pYBVlXQIBLx3a761cZ41F9bBH3JUlihCt2Udc= +cloud.google.com/go v0.99.0/go.mod h1:w0Xx2nLzqWJPuozYQX+hFfCSI8WioryfRDzkoI/Y2ZA= +cloud.google.com/go v0.100.1/go.mod h1:fs4QogzfH5n2pBXBP9vRiU+eCny7lD2vmFZy79Iuw1U= +cloud.google.com/go v0.100.2 h1:t9Iw5QH5v4XtlEQaCtUY7x6sCABps8sW0acw7e2WQ6Y= +cloud.google.com/go v0.100.2/go.mod h1:4Xra9TjzAeYHrl5+oeLlzbM2k3mjVhZh4UqTZ//w99A= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= +cloud.google.com/go/compute v0.1.0/go.mod h1:GAesmwr110a34z04OlxYkATPBEfVhkymfTBXtfbBFow= +cloud.google.com/go/compute v1.3.0 h1:mPL/MzDDYHsh5tHRS9mhmhWlcgClCrCa6ApQCU6wnHI= +cloud.google.com/go/compute v1.3.0/go.mod h1:cCZiE1NHEtai4wiufUhW8I8S1JKkAnhnQJWM7YD99wM= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= +cloud.google.com/go/iam v0.1.0 h1:W2vbGCrE3Z7J/x3WXLxxGl9LMSB2uhsAA7Ss/6u/qRY= +cloud.google.com/go/iam v0.1.0/go.mod h1:vcUNEa0pEm0qRVpmWepWaFMIAI8/hjB9mO8rNCJtF6c= +cloud.google.com/go/kms v1.4.0 h1:iElbfoE61VeLhnZcGOltqL8HIly8Nhbe5t6JlH9GXjo= +cloud.google.com/go/kms v1.4.0/go.mod h1:fajBHndQ+6ubNw6Ss2sSd+SWvjL26RNo/dr7uxsnnOA= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= +cloud.google.com/go/security v1.3.0 h1:BhCl33x+KQI4qiZnFrfr2gAGhb2aZ0ZvKB3Y4QlEfgo= +cloud.google.com/go/security v1.3.0/go.mod h1:pQsnLAXfMzuWVJdctBs8BV3tGd3Jr0SMYu6KK3QXYAs= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= @@ -162,12 +180,16 @@ github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cpuguy83/go-md2man v1.0.10 h1:BSKMNlYxDvnunlTymqtgONjNnaRV1sTpcovwwjF22jk= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= @@ -254,6 +276,8 @@ github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-test/deep v1.0.2 h1:onZX1rnHT3Wv6cqNgYyFOOlgVKJrksuCMCRvJStbMYw= github.com/go-test/deep v1.0.2/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= @@ -310,8 +334,9 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -330,6 +355,8 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -337,8 +364,10 @@ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0= +github.com/googleapis/gax-go/v2 v2.1.1 h1:dp3bWCh+PPO1zjRRiCSczJav13sBvG4UhNyVTa1KqdU= +github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= @@ -435,6 +464,55 @@ github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= github.com/jhump/protoreflect v1.6.0/go.mod h1:eaTn3RZAmMBcV0fifFvlm6VHNz3wSkYyXYWUh7ymB74= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.10.1 h1:DzdIHIjG1AxGwoEEqS+mGsURyjt4enSmqzACXvVzOT8= +github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.2.0 h1:r7JypeP2D3onoQTCxWdTpCtJ4D+qpKr0TxvoyMhZ5ns= +github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.9.0 h1:/SH1RxEtltvJgsDqp3TbiTFApD3mey3iygpuEGeuBXk= +github.com/jackc/pgtype v1.9.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.14.0 h1:TgdrmgnM7VY72EuSQzBbBd4JA1RLqJolrw9nQVZABVc= +github.com/jackc/pgx/v4 v4.14.0/go.mod h1:jT3ibf/A0ZVCp89rtCIN0zCJxcE74ypROmHEZYsG/j8= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= @@ -458,6 +536,7 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/klauspost/compress v1.12.3 h1:G5AfA94pHPysR56qqrkO2pxEexdDzrpFJ6yt/VqWxVU= github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -465,9 +544,16 @@ github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lxn/walk v0.0.0-20210112085537-c389da54e794/go.mod h1:E23UucZGqpuUANJooIbHWCufXvOcT6E7Stq81gU+CSQ= @@ -477,12 +563,15 @@ github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czP github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= @@ -605,6 +694,8 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/russross/blackfriday v1.5.2 h1:HyvC0ARfnZBqnXwABFeSZHpKvJHJJfPz81GNueLj0oo= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= @@ -614,12 +705,15 @@ github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFo github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= @@ -630,8 +724,8 @@ github.com/slackhq/nebula v1.5.2/go.mod h1:xaCM6wqbFk/NRmmUe1bv88fWBm3a1UioXJVIp github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= -github.com/smallstep/nosql v0.3.9 h1:YPy5PR3PXClqmpFaVv0wfXDXDc7NXGBE1auyU2c87dc= -github.com/smallstep/nosql v0.3.9/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs= +github.com/smallstep/nosql v0.4.0 h1:Go3WYwttUuvwqMtFiiU4g7kBIlY+hR0bIZAqVdakQ3M= +github.com/smallstep/nosql v0.4.0/go.mod h1:yKZT5h7cdIVm6wEKM9+jN5dgK80Hljpuy8HNsnI7Gzo= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= @@ -657,13 +751,15 @@ github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5J github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/thales-e-security/pool v0.0.2 h1:RAPs4q2EbWsTit6tpzuvTFlgFRJ3S8Evf5gtvVDbmPg= github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpupY8mv0Phz0gjhU= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= @@ -685,6 +781,7 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= @@ -706,33 +803,40 @@ go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqe go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= -go.step.sm/crypto v0.15.0 h1:VioBln+x3+RoejgeBhvxkLGVYdWRy6PFiAaUUN29/E0= -go.step.sm/crypto v0.15.0/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= -go.step.sm/linkedca v0.9.2 h1:CpAkd174sLXFfrOZrbPEiTzik91QRj3+L0omsiwsiok= -go.step.sm/linkedca v0.9.2/go.mod h1:5uTRjozEGSPAZal9xJqlaD38cvJcLe3o1VAFVjqcORo= +go.step.sm/crypto v0.16.1 h1:4mnZk21cSxyMGxsEpJwZKKvJvDu1PN09UVrWWFNUBdk= +go.step.sm/crypto v0.16.1/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= +go.step.sm/linkedca v0.12.0 h1:FA18uJO5P6W2pklcezMs+w+N3dVbpKEE1LP9HLsJgg4= +go.step.sm/linkedca v0.12.0/go.mod h1:W59ucS4vFpuR0g4PtkGbbtXAwxbDEnNCg+ovkej1ANM= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200414173820-0848c9571904/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= @@ -823,8 +927,8 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211020060615-d418f374d309/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d h1:1n1fc535VhN8SYtD4cDUyNlfpAF2ROMM9+11equK3hs= -golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -836,8 +940,12 @@ golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c h1:pkQiBZBvdos9qq4wBAHqlzuZHEXo07pqV06ef90u1WI= golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210805134026-6f1e6394065a/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 h1:RerP+noqYHUQ8CMRcPlC2nvTa4dcBIjegkuWdcUDuqg= +golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -862,6 +970,7 @@ golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -871,6 +980,7 @@ golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -915,15 +1025,25 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211031064116-611d5d643895/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220209214540-3681064d5158 h1:rm+CHSpPEEW2IsXUib1ThaHIjuBVZjxNgSKmBLFfD4c= +golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -954,12 +1074,14 @@ golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -1000,7 +1122,12 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1030,8 +1157,19 @@ google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34q google.golang.org/api v0.40.0/go.mod h1:fYKFpnQN0DsDSKRVRcQSDQNtqWPfM9i+zNPxepjRCQ8= google.golang.org/api v0.41.0/go.mod h1:RkxM5lITDfTzmyKFPt+wGrCJbVfniCr2ool8kTBzRTU= google.golang.org/api v0.43.0/go.mod h1:nQsDGjRXMo4lvh5hP0TKqF244gqhGcr/YSIykhUk/94= -google.golang.org/api v0.47.0 h1:sQLWZQvP6jPGIP4JGPkJu4zHswrv81iobiyszr3b/0I= google.golang.org/api v0.47.0/go.mod h1:Wbvgpq1HddcWVtzsVLyfLp8lDg6AA241LmgIL59tHXo= +google.golang.org/api v0.48.0/go.mod h1:71Pr1vy+TAZRPkPs/xlCf5SsU8WjuAWv1Pfjbtukyy4= +google.golang.org/api v0.50.0/go.mod h1:4bNT5pAuq5ji4SRZm+5QIkjny9JAyVD/3gaSihNefaw= +google.golang.org/api v0.51.0/go.mod h1:t4HdrdoNgyN5cbEfm7Lum0lcLDLiise1F8qDKX00sOU= +google.golang.org/api v0.54.0/go.mod h1:7C4bFFOvVDGXjfDTAsgGwDgAxRDeQ4X8NvUedIt6z3k= +google.golang.org/api v0.55.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE= +google.golang.org/api v0.56.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE= +google.golang.org/api v0.57.0/go.mod h1:dVPlbZyBo2/OjBpmvNdpn2GRm6rPy75jyU7bmhdrMgI= +google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I= +google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo= +google.golang.org/api v0.67.0/go.mod h1:ShHKP8E60yPsKNw/w8w+VYaj9H6buA5UqDp8dhbQZ6g= +google.golang.org/api v0.70.0 h1:67zQnAE0T2rB0A3CwLSas0K+SbVzSxP+zTLkQLexeiw= +google.golang.org/api v0.70.0/go.mod h1:Bs4ZM2HGifEvXwd50TtW70ovgJffJYw2oRCOFU/SkfA= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1085,9 +1223,29 @@ google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= google.golang.org/genproto v0.0.0-20210513213006-bf773b8c8384/go.mod h1:P3QM42oQyzQSnHPnZ/vqoCdDmzH28fzWByN9asMeM8A= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= -google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5 h1:zzNejm+EgrbLfDZ6lu9Uud2IVvHySPl8vQzf04laR5Q= -google.golang.org/genproto v0.0.0-20220118154757-00ab72f36ad5/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= -google.golang.org/grpc v1.8.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/genproto v0.0.0-20210604141403-392c879c8b08/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= +google.golang.org/genproto v0.0.0-20210608205507-b6d2f5bf0d7d/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= +google.golang.org/genproto v0.0.0-20210624195500-8bfb893ecb84/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24= +google.golang.org/genproto v0.0.0-20210713002101-d411969a0d9a/go.mod h1:AxrInvYm1dci+enl5hChSFPOmmUF1+uAa/UsgNRWd7k= +google.golang.org/genproto v0.0.0-20210716133855-ce7ef5c701ea/go.mod h1:AxrInvYm1dci+enl5hChSFPOmmUF1+uAa/UsgNRWd7k= +google.golang.org/genproto v0.0.0-20210728212813-7823e685a01f/go.mod h1:ob2IJxKrgPT52GcgX759i1sleT07tiKowYBGbczaW48= +google.golang.org/genproto v0.0.0-20210805201207-89edb61ffb67/go.mod h1:ob2IJxKrgPT52GcgX759i1sleT07tiKowYBGbczaW48= +google.golang.org/genproto v0.0.0-20210813162853-db860fec028c/go.mod h1:cFeNkxwySK631ADgubI+/XFU/xp8FD5KIVV4rj8UC5w= +google.golang.org/genproto v0.0.0-20210821163610-241b8fcbd6c8/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= +google.golang.org/genproto v0.0.0-20210828152312-66f60bf46e71/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= +google.golang.org/genproto v0.0.0-20210831024726-fe130286e0e2/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= +google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= +google.golang.org/genproto v0.0.0-20210909211513-a8c4777a87af/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= +google.golang.org/genproto v0.0.0-20210924002016-3dee208752a0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20211221195035-429b39de9b1c/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20220126215142-9970aeb2e350/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20220207164111-0872dc986b00/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= +google.golang.org/genproto v0.0.0-20220218161850-94dd64e39d7c/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI= +google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf h1:SVYXkUz2yZS9FWb2Gm8ivSlbNQzL2Z/NpPKE3RG2jWk= +google.golang.org/genproto v0.0.0-20220222213610-43724f9ea8cf/go.mod h1:kGP+zUP2Ddo0ayMi4YuN7C3WZyJvGLZRh8Z5wnAqvEI= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= @@ -1115,10 +1273,12 @@ google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.37.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.37.1/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= +google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= +google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= -google.golang.org/grpc v1.41.0/go.mod h1:U3l9uK9J0sini8mHphKoXyaqDA/8VyGnDee1zzIUK6k= -google.golang.org/grpc v1.43.0 h1:Eeu7bZtDZ2DpRCsLhUlcrLnvYaMK1Gz86a+hMVvELmM= -google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= +google.golang.org/grpc v1.40.1/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= +google.golang.org/grpc v1.44.0 h1:weqSxi/TMs1SqFRMHCtBgXRs8k3X39QIDEZ0pRcttUg= +google.golang.org/grpc v1.44.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= @@ -1138,11 +1298,13 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= diff --git a/scep/api/api.go b/scep/api/api.go index 4f8d897b..31f0f10d 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -1,23 +1,25 @@ +// Package api implements a SCEP HTTP server. package api import ( "context" "crypto/x509" "encoding/base64" + "errors" + "fmt" "io" "net/http" "net/url" "strings" "github.com/go-chi/chi" - "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/scep" + microscep "github.com/micromdm/scep/v2/scep" "go.mozilla.org/pkcs7" - "github.com/pkg/errors" - - microscep "github.com/micromdm/scep/v2/scep" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/api/log" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/scep" ) const ( @@ -30,22 +32,14 @@ const ( const maxPayloadSize = 2 << 20 -type nextHTTP = func(http.ResponseWriter, *http.Request) - -const ( - certChainHeader = "application/x-x509-ca-ra-cert" - leafHeader = "application/x-x509-ca-cert" - pkiOperationHeader = "application/x-pki-message" -) - -// SCEPRequest is a SCEP server request. -type SCEPRequest struct { +// request is a SCEP server request. +type request struct { Operation string Message []byte } -// SCEPResponse is a SCEP server response. -type SCEPResponse struct { +// response is a SCEP server response. +type response struct { Operation string CACertNum int Data []byte @@ -53,82 +47,86 @@ type SCEPResponse struct { Error error } -// Handler is the SCEP request handler. -type Handler struct { - Auth scep.Interface +// handler is the SCEP request handler. +type handler struct { + auth *scep.Authority } // New returns a new SCEP API router. -func New(scepAuth scep.Interface) api.RouterHandler { - return &Handler{scepAuth} +func New(auth *scep.Authority) api.RouterHandler { + return &handler{ + auth: auth, + } } // Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { - getLink := h.Auth.GetLinkExplicit +func (h *handler) Route(r api.Router) { + getLink := h.auth.GetLinkExplicit + r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get)) r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get)) + r.MethodFunc(http.MethodPost, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Post)) r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post)) } // Get handles all SCEP GET requests -func (h *Handler) Get(w http.ResponseWriter, r *http.Request) { +func (h *handler) Get(w http.ResponseWriter, r *http.Request) { - request, err := decodeSCEPRequest(r) + req, err := decodeRequest(r) if err != nil { - writeError(w, errors.Wrap(err, "invalid scep get request")) + fail(w, fmt.Errorf("invalid scep get request: %w", err)) return } ctx := r.Context() - var response SCEPResponse + var res response - switch request.Operation { + switch req.Operation { case opnGetCACert: - response, err = h.GetCACert(ctx) + res, err = h.GetCACert(ctx) case opnGetCACaps: - response, err = h.GetCACaps(ctx) + res, err = h.GetCACaps(ctx) case opnPKIOperation: // TODO: implement the GET for PKI operation? Default CACAPS doesn't specify this is in use, though default: - err = errors.Errorf("unknown operation: %s", request.Operation) + err = fmt.Errorf("unknown operation: %s", req.Operation) } if err != nil { - writeError(w, errors.Wrap(err, "scep get request failed")) + fail(w, fmt.Errorf("scep get request failed: %w", err)) return } - writeSCEPResponse(w, response) + writeResponse(w, res) } // Post handles all SCEP POST requests -func (h *Handler) Post(w http.ResponseWriter, r *http.Request) { +func (h *handler) Post(w http.ResponseWriter, r *http.Request) { - request, err := decodeSCEPRequest(r) + req, err := decodeRequest(r) if err != nil { - writeError(w, errors.Wrap(err, "invalid scep post request")) + fail(w, fmt.Errorf("invalid scep post request: %w", err)) return } ctx := r.Context() - var response SCEPResponse + var res response - switch request.Operation { + switch req.Operation { case opnPKIOperation: - response, err = h.PKIOperation(ctx, request) + res, err = h.PKIOperation(ctx, req) default: - err = errors.Errorf("unknown operation: %s", request.Operation) + err = fmt.Errorf("unknown operation: %s", req.Operation) } if err != nil { - writeError(w, errors.Wrap(err, "scep post request failed")) + fail(w, fmt.Errorf("scep post request failed: %w", err)) return } - writeSCEPResponse(w, response) + writeResponse(w, res) } -func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { +func decodeRequest(r *http.Request) (request, error) { defer r.Body.Close() @@ -144,7 +142,7 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { case http.MethodGet: switch operation { case opnGetCACert, opnGetCACaps: - return SCEPRequest{ + return request{ Operation: operation, Message: []byte{}, }, nil @@ -156,50 +154,50 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { // TODO: verify this; it seems like it should be StdEncoding instead of URLEncoding decodedMessage, err := base64.URLEncoding.DecodeString(message) if err != nil { - return SCEPRequest{}, err + return request{}, err } - return SCEPRequest{ + return request{ Operation: operation, Message: decodedMessage, }, nil default: - return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation) + return request{}, fmt.Errorf("unsupported operation: %s", operation) } case http.MethodPost: body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) if err != nil { - return SCEPRequest{}, err + return request{}, err } - return SCEPRequest{ + return request{ Operation: operation, Message: body, }, nil default: - return SCEPRequest{}, errors.Errorf("unsupported method: %s", method) + return request{}, fmt.Errorf("unsupported method: %s", method) } } // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. -func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { +func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "provisionerName") provisionerName, err := url.PathUnescape(name) if err != nil { - api.WriteError(w, errors.Errorf("error url unescaping provisioner name '%s'", name)) + fail(w, fmt.Errorf("error url unescaping provisioner name '%s'", name)) return } - p, err := h.Auth.LoadProvisionerByName(provisionerName) + p, err := h.auth.LoadProvisionerByName(provisionerName) if err != nil { - api.WriteError(w, err) + fail(w, err) return } prov, ok := p.(*provisioner.SCEP) if !ok { - api.WriteError(w, errors.New("provisioner must be of type SCEP")) + fail(w, errors.New("provisioner must be of type SCEP")) return } @@ -210,59 +208,59 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { } // GetCACert returns the CA certificates in a SCEP response -func (h *Handler) GetCACert(ctx context.Context) (SCEPResponse, error) { +func (h *handler) GetCACert(ctx context.Context) (response, error) { - certs, err := h.Auth.GetCACertificates(ctx) + certs, err := h.auth.GetCACertificates(ctx) if err != nil { - return SCEPResponse{}, err + return response{}, err } if len(certs) == 0 { - return SCEPResponse{}, errors.New("missing CA cert") + return response{}, errors.New("missing CA cert") } - response := SCEPResponse{ + res := response{ Operation: opnGetCACert, CACertNum: len(certs), } if len(certs) == 1 { - response.Data = certs[0].Raw + res.Data = certs[0].Raw } else { // create degenerate pkcs7 certificate structure, according to // https://tools.ietf.org/html/rfc8894#section-4.2.1.2, because // not signed or encrypted data has to be returned. data, err := microscep.DegenerateCertificates(certs) if err != nil { - return SCEPResponse{}, err + return response{}, err } - response.Data = data + res.Data = data } - return response, nil + return res, nil } // GetCACaps returns the CA capabilities in a SCEP response -func (h *Handler) GetCACaps(ctx context.Context) (SCEPResponse, error) { +func (h *handler) GetCACaps(ctx context.Context) (response, error) { - caps := h.Auth.GetCACaps(ctx) + caps := h.auth.GetCACaps(ctx) - response := SCEPResponse{ + res := response{ Operation: opnGetCACaps, Data: formatCapabilities(caps), } - return response, nil + return res, nil } // PKIOperation performs PKI operations and returns a SCEP response -func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPResponse, error) { +func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) { // parse the message using microscep implementation - microMsg, err := microscep.ParsePKIMessage(request.Message) + microMsg, err := microscep.ParsePKIMessage(req.Message) if err != nil { // return the error, because we can't use the msg for creating a CertRep - return SCEPResponse{}, err + return response{}, err } // this is essentially doing the same as microscep.ParsePKIMessage, but @@ -270,7 +268,7 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe // wrapper for the microscep implementation. p7, err := pkcs7.Parse(microMsg.Raw) if err != nil { - return SCEPResponse{}, err + return response{}, err } // copy over properties to our internal PKIMessage @@ -282,8 +280,8 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe P7: p7, } - if err := h.Auth.DecryptPKIEnvelope(ctx, msg); err != nil { - return SCEPResponse{}, err + if err := h.auth.DecryptPKIEnvelope(ctx, msg); err != nil { + return response{}, err } // NOTE: at this point we have sufficient information for returning nicely signed CertReps @@ -295,7 +293,7 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe // a certificate exists; then it will use RenewalReq. Adding the challenge check here may be a small breaking change for clients. // We'll have to see how it works out. if msg.MessageType == microscep.PKCSReq || msg.MessageType == microscep.RenewalReq { - challengeMatches, err := h.Auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) + challengeMatches, err := h.auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) if err != nil { return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password")) } @@ -313,72 +311,67 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe // Authentication by the (self-signed) certificate with an optional challenge is required; supporting renewals incl. verification // of the client cert is not. - certRep, err := h.Auth.SignCSR(ctx, csr, msg) + certRep, err := h.auth.SignCSR(ctx, csr, msg) if err != nil { - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.Wrap(err, "error when signing new certificate")) + return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) } - response := SCEPResponse{ + res := response{ Operation: opnPKIOperation, Data: certRep.Raw, Certificate: certRep.Certificate, } - return response, nil + return res, nil } func formatCapabilities(caps []string) []byte { return []byte(strings.Join(caps, "\r\n")) } -// writeSCEPResponse writes a SCEP response back to the SCEP client. -func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) { +// writeResponse writes a SCEP response back to the SCEP client. +func writeResponse(w http.ResponseWriter, res response) { - if response.Error != nil { - api.LogError(w, response.Error) + if res.Error != nil { + log.Error(w, res.Error) } - if response.Certificate != nil { - api.LogCertificate(w, response.Certificate) + if res.Certificate != nil { + api.LogCertificate(w, res.Certificate) } - w.Header().Set("Content-Type", contentHeader(response)) - _, err := w.Write(response.Data) - if err != nil { - writeError(w, errors.Wrap(err, "error when writing scep response")) // This could end up as an error again - } + w.Header().Set("Content-Type", contentHeader(res)) + _, _ = w.Write(res.Data) } -func writeError(w http.ResponseWriter, err error) { - scepError := &scep.Error{ - Message: err.Error(), - Status: http.StatusInternalServerError, // TODO: make this a param? - } - api.WriteError(w, scepError) +func fail(w http.ResponseWriter, err error) { + log.Error(w, err) + + http.Error(w, err.Error(), http.StatusInternalServerError) } -func (h *Handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (SCEPResponse, error) { - certRepMsg, err := h.Auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) +func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) { + certRepMsg, err := h.auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) if err != nil { - return SCEPResponse{}, err + return response{}, err } - return SCEPResponse{ + return response{ Operation: opnPKIOperation, Data: certRepMsg.Raw, Error: failError, }, nil } -func contentHeader(r SCEPResponse) string { +func contentHeader(r response) string { switch r.Operation { - case opnGetCACert: - if r.CACertNum > 1 { - return certChainHeader - } - return leafHeader - case opnPKIOperation: - return pkiOperationHeader default: return "text/plain" + case opnGetCACert: + if r.CACertNum > 1 { + return "application/x-x509-ca-ra-cert" + } + return "application/x-x509-ca-cert" + case opnPKIOperation: + return "application/x-pki-message" } } diff --git a/scep/authority.go b/scep/authority.go index 269e3ae1..71f92152 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -4,33 +4,19 @@ import ( "context" "crypto/subtle" "crypto/x509" + "errors" + "fmt" "net/url" - "github.com/smallstep/certificates/authority/provisioner" - microx509util "github.com/micromdm/scep/v2/cryptoutil/x509util" microscep "github.com/micromdm/scep/v2/scep" - - "github.com/pkg/errors" - "go.mozilla.org/pkcs7" "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/authority/provisioner" ) -// Interface is the SCEP authority interface. -type Interface interface { - LoadProvisionerByName(string) (provisioner.Interface, error) - GetLinkExplicit(provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string - - GetCACertificates(ctx context.Context) ([]*x509.Certificate, error) - DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) error - SignCSR(ctx context.Context, csr *x509.CertificateRequest, msg *PKIMessage) (*PKIMessage, error) - CreateFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *PKIMessage, info FailInfoName, infoText string) (*PKIMessage, error) - MatchChallengePassword(ctx context.Context, password string) (bool, error) - GetCACaps(ctx context.Context) []string -} - // Authority is the layer that handles all SCEP interactions. type Authority struct { prefix string @@ -180,12 +166,12 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err p7c, err := pkcs7.Parse(msg.P7.Content) if err != nil { - return errors.Wrap(err, "error parsing pkcs7 content") + return fmt.Errorf("error parsing pkcs7 content: %w", err) } envelope, err := p7c.Decrypt(a.intermediateCertificate, a.service.decrypter) if err != nil { - return errors.Wrap(err, "error decrypting encrypted pkcs7 content") + return fmt.Errorf("error decrypting encrypted pkcs7 content: %w", err) } msg.pkiEnvelope = envelope @@ -194,19 +180,19 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err case microscep.CertRep: certs, err := microscep.CACerts(msg.pkiEnvelope) if err != nil { - return errors.Wrap(err, "error extracting CA certs from pkcs7 degenerate data") + return fmt.Errorf("error extracting CA certs from pkcs7 degenerate data: %w", err) } msg.CertRepMessage.Certificate = certs[0] return nil case microscep.PKCSReq, microscep.UpdateReq, microscep.RenewalReq: csr, err := x509.ParseCertificateRequest(msg.pkiEnvelope) if err != nil { - return errors.Wrap(err, "parse CSR from pkiEnvelope") + return fmt.Errorf("parse CSR from pkiEnvelope: %w", err) } // check for challengePassword cp, err := microx509util.ParseChallengePassword(msg.pkiEnvelope) if err != nil { - return errors.Wrap(err, "parse challenge password in pkiEnvelope") + return fmt.Errorf("parse challenge password in pkiEnvelope: %w", err) } msg.CSRReqMessage = µscep.CSRReqMessage{ RawDecrypted: msg.pkiEnvelope, @@ -215,7 +201,7 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err } return nil case microscep.GetCRL, microscep.GetCert, microscep.CertPoll: - return errors.Errorf("not implemented") + return errors.New("not implemented") } return nil @@ -274,19 +260,19 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOps, err := p.AuthorizeSign(ctx, "") if err != nil { - return nil, errors.Wrap(err, "error retrieving authorization options from SCEP provisioner") + return nil, fmt.Errorf("error retrieving authorization options from SCEP provisioner: %w", err) } opts := provisioner.SignOptions{} templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) if err != nil { - return nil, errors.Wrap(err, "error creating template options from SCEP provisioner") + return nil, fmt.Errorf("error creating template options from SCEP provisioner: %w", err) } signOps = append(signOps, templateOptions) certChain, err := a.signAuth.Sign(csr, opts, signOps...) if err != nil { - return nil, errors.Wrap(err, "error generating certificate for order") + return nil, fmt.Errorf("error generating certificate for order: %w", err) } // take the issued certificate (only); https://tools.ietf.org/html/rfc8894#section-3.3.2 diff --git a/scep/errors.go b/scep/errors.go deleted file mode 100644 index 4287403b..00000000 --- a/scep/errors.go +++ /dev/null @@ -1,12 +0,0 @@ -package scep - -// Error is an SCEP error type -type Error struct { - Message string `json:"message"` - Status int `json:"-"` -} - -// Error implements the error interface. -func (e *Error) Error() string { - return e.Message -} diff --git a/scep/scep.go b/scep/scep.go index afabf368..372a5436 100644 --- a/scep/scep.go +++ b/scep/scep.go @@ -1,3 +1,4 @@ +// Package scep implements Simple Certificate Enrollment Protocol related functionality. package scep import ( @@ -5,9 +6,6 @@ import ( "encoding/asn1" microscep "github.com/micromdm/scep/v2/scep" - - //"github.com/smallstep/certificates/scep/pkcs7" - "go.mozilla.org/pkcs7" )