Merge branch 'master' into feat/vault

This commit is contained in:
Mariano Cano 2022-04-11 14:57:45 -07:00 committed by GitHub
commit 37b521ec6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
119 changed files with 4455 additions and 2003 deletions

View file

@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.15', '1.16', '1.17' ] go: [ '1.17', '1.18' ]
outputs: outputs:
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
steps: steps:
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.44.0' version: 'v1.45.0'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -106,7 +106,7 @@ jobs:
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.17 go-version: 1.18
- -
name: APT Install name: APT Install
id: aptInstall id: aptInstall
@ -159,7 +159,7 @@ jobs:
name: Setup Go name: Setup Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: '1.17' go-version: '1.18'
- -
name: Install cosign name: Install cosign
uses: sigstore/cosign-installer@v1.1.0 uses: sigstore/cosign-installer@v1.1.0

View file

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.16', '1.17' ] go: [ '1.17', '1.18' ]
steps: steps:
- -
name: Checkout name: Checkout
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.44.0' version: 'v1.45.0'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -58,7 +58,7 @@ jobs:
run: V=1 make ci run: V=1 make ci
- -
name: Codecov name: Codecov
if: matrix.go == '1.17' if: matrix.go == '1.18'
uses: codecov/codecov-action@v1.2.1 uses: codecov/codecov-action@v1.2.1
with: with:
file: ./coverage.out # optional file: ./coverage.out # optional

1
.gitignore vendored
View file

@ -19,3 +19,4 @@ coverage.txt
output output
vendor vendor
.idea .idea
.envrc

View file

@ -19,6 +19,7 @@ builds:
- linux_386 - linux_386
- linux_amd64 - linux_amd64
- linux_arm64 - linux_arm64
- linux_arm_5
- linux_arm_6 - linux_arm_6
- linux_arm_7 - linux_arm_7
- windows_amd64 - windows_amd64
@ -39,6 +40,7 @@ builds:
- linux_386 - linux_386
- linux_amd64 - linux_amd64
- linux_arm64 - linux_arm64
- linux_arm_5
- linux_arm_6 - linux_arm_6
- linux_arm_7 - linux_arm_7
- windows_amd64 - windows_amd64
@ -59,6 +61,7 @@ builds:
- linux_386 - linux_386
- linux_amd64 - linux_amd64
- linux_arm64 - linux_arm64
- linux_arm_5
- linux_arm_6 - linux_arm_6
- linux_arm_7 - linux_arm_7
- windows_amd64 - windows_amd64

View file

@ -4,16 +4,32 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased - 0.18.2] - DATE ## [Unreleased - 0.18.3] - DATE
### Added ### Added
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
- Added support for `extraNames` in X.509 templates.
### Changed ### Changed
- IPv6 addresses are normalized as IP addresses instead of hostnames. - Made SCEP CA URL paths dynamic
- More descriptive JWK decryption error message. - Support two latest versions of Go (1.17, 1.18)
### Deprecated ### Deprecated
### Removed ### Removed
### Fixed ### Fixed
### Security ### Security
## [0.18.2] - 2022-03-01
### Added
- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner.
- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering
out database drivers using Go tags. For example, using the Go flag
`--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx
driver for PostgreSQL.
### Changed
- IPv6 addresses are normalized as IP addresses instead of hostnames.
- More descriptive JWK decryption error message.
- Make the X5C leaf certificate available to the templates using `{{ .AuthorizationCrt }}`.
### Fixed
- During provisioner add - validate provisioner configuration before storing to DB.
## [0.18.1] - 2022-02-03 ## [0.18.1] - 2022-02-03
### Added ### Added
- Support for ACME revocation. - Support for ACME revocation.

View file

@ -5,8 +5,9 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
@ -70,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var nar NewAccountRequest var nar NewAccountRequest
if err := json.Unmarshal(payload.value, &nar); err != nil { if err := json.Unmarshal(payload.value, &nar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
if err := nar.Validate(); err != nil { if err := nar.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := acmeProvisionerFromContext(ctx) prov, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -96,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
acmeErr, ok := err.(*acme.Error) acmeErr, ok := err.(*acme.Error)
if !ok || acmeErr.Status != http.StatusBadRequest { if !ok || acmeErr.Status != http.StatusBadRequest {
// Something went wrong ... // Something went wrong ...
api.WriteError(w, err) render.Error(w, err)
return return
} }
// Account does not exist // // Account does not exist //
if nar.OnlyReturnExisting { if nar.OnlyReturnExisting {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
"account does not exist")) "account does not exist"))
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
eak, err := h.validateExternalAccountBinding(ctx, &nar) eak, err := h.validateExternalAccountBinding(ctx, &nar)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -125,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusValid, Status: acme.StatusValid,
} }
if err := h.db.CreateAccount(ctx, acc); err != nil { if err := h.db.CreateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) render.Error(w, acme.WrapErrorISE(err, "error creating account"))
return return
} }
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
err := eak.BindTo(acc) err := eak.BindTo(acc)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key")) render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
return return
} }
acc.ExternalAccountBinding = nar.ExternalAccountBinding acc.ExternalAccountBinding = nar.ExternalAccountBinding
@ -149,7 +150,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc) h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
api.JSONStatus(w, acc, httpStatus) render.JSONStatus(w, acc, httpStatus)
} }
// GetOrUpdateAccount is the api for updating an ACME account. // GetOrUpdateAccount is the api for updating an ACME account.
@ -157,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -171,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
if !payload.isPostAsGet { if !payload.isPostAsGet {
var uar UpdateAccountRequest var uar UpdateAccountRequest
if err := json.Unmarshal(payload.value, &uar); err != nil { if err := json.Unmarshal(payload.value, &uar); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
if err := uar.Validate(); err != nil { if err := uar.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if len(uar.Status) > 0 || len(uar.Contact) > 0 { if len(uar.Status) > 0 || len(uar.Contact) > 0 {
@ -187,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
} }
if err := h.db.UpdateAccount(ctx, acc); err != nil { if err := h.db.UpdateAccount(ctx, acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating account")) render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return return
} }
} }
@ -196,7 +197,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc) h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
api.JSON(w, acc) render.JSON(w, acc)
} }
func logOrdersByAccount(w http.ResponseWriter, oids []string) { func logOrdersByAccount(w http.ResponseWriter, oids []string) {
@ -213,22 +214,22 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
accID := chi.URLParam(r, "accID") accID := chi.URLParam(r, "accID")
if acc.ID != accID { if acc.ID != accID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return return
} }
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID) orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
h.linker.LinkOrdersByAccountID(ctx, orders) h.linker.LinkOrdersByAccountID(ctx, orders)
api.JSON(w, orders) render.JSON(w, orders)
logOrdersByAccount(w, orders) logOrdersByAccount(w, orders)
} }

View file

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
@ -11,8 +12,10 @@ import (
"time" "time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
@ -43,6 +46,7 @@ type Handler struct {
ca acme.CertificateAuthority ca acme.CertificateAuthority
linker Linker linker Linker
validateChallengeOptions *acme.ValidateChallengeOptions validateChallengeOptions *acme.ValidateChallengeOptions
prerequisitesChecker func(ctx context.Context) (bool, error)
} }
// HandlerOptions required to create a new ACME API request handler. // HandlerOptions required to create a new ACME API request handler.
@ -60,6 +64,9 @@ type HandlerOptions struct {
// "acme" is the prefix from which the ACME api is accessed. // "acme" is the prefix from which the ACME api is accessed.
Prefix string Prefix string
CA acme.CertificateAuthority CA acme.CertificateAuthority
// PrerequisitesChecker checks if all prerequisites for serving ACME are
// met by the CA configuration.
PrerequisitesChecker func(ctx context.Context) (bool, error)
} }
// NewHandler returns a new ACME API handler. // NewHandler returns a new ACME API handler.
@ -76,6 +83,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
} }
prerequisitesChecker := func(ctx context.Context) (bool, error) {
// by default all prerequisites are met
return true, nil
}
if ops.PrerequisitesChecker != nil {
prerequisitesChecker = ops.PrerequisitesChecker
}
return &Handler{ return &Handler{
ca: ops.CA, ca: ops.CA,
db: ops.DB, db: ops.DB,
@ -88,6 +102,7 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
return tls.DialWithDialer(dialer, network, addr, config) return tls.DialWithDialer(dialer, network, addr, config)
}, },
}, },
prerequisitesChecker: prerequisitesChecker,
} }
} }
@ -95,13 +110,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
func (h *Handler) Route(r api.Router) { func (h *Handler) Route(r api.Router) {
getPath := h.linker.GetUnescapedPathSuffix getPath := h.linker.GetUnescapedPathSuffix
// Standard ACME API // Standard ACME API
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
validatingMiddleware := func(next nextHTTP) nextHTTP { validatingMiddleware := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))) return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))))
} }
extractPayloadByJWK := func(next nextHTTP) nextHTTP { extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
@ -168,11 +183,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
api.JSON(w, &Directory{ render.JSON(w, &Directory{
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
@ -187,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
// NotImplemented returns a 501 and is generally a placeholder for functionality which // NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such. // MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
} }
// GetAuthorization ACME api for retrieving an Authz. // GetAuthorization ACME api for retrieving an Authz.
@ -195,28 +210,28 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization")) render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
return return
} }
if acc.ID != az.AccountID { if acc.ID != az.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own authorization '%s'", acc.ID, az.ID)) "account '%s' does not own authorization '%s'", acc.ID, az.ID))
return return
} }
if err = az.UpdateStatus(ctx, h.db); err != nil { if err = az.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
return return
} }
h.linker.LinkAuthorization(ctx, az) h.linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
api.JSON(w, az) render.JSON(w, az)
} }
// GetChallenge ACME api for retrieving a Challenge. // GetChallenge ACME api for retrieving a Challenge.
@ -224,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
// Just verify that the payload was set, since we're not strictly adhering // Just verify that the payload was set, since we're not strictly adhering
// to ACME V2 spec for reasons specified below. // to ACME V2 spec for reasons specified below.
_, err = payloadFromContext(ctx) _, err = payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -244,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
azID := chi.URLParam(r, "authzID") azID := chi.URLParam(r, "authzID")
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge")) render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
return return
} }
ch.AuthorizationID = azID ch.AuthorizationID = azID
if acc.ID != ch.AccountID { if acc.ID != ch.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own challenge '%s'", acc.ID, ch.ID)) "account '%s' does not own challenge '%s'", acc.ID, ch.ID))
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge")) render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return return
} }
@ -267,7 +282,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
api.JSON(w, ch) render.JSON(w, ch)
} }
// GetCertificate ACME api for retrieving a Certificate. // GetCertificate ACME api for retrieving a Certificate.
@ -275,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
certID := chi.URLParam(r, "certID") certID := chi.URLParam(r, "certID")
cert, err := h.db.GetCertificate(ctx, certID) cert, err := h.db.GetCertificate(ctx, certID)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
return return
} }
if cert.AccountID != acc.ID { if cert.AccountID != acc.ID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own certificate '%s'", acc.ID, certID)) "account '%s' does not own certificate '%s'", acc.ID, certID))
return return
} }

View file

@ -10,13 +10,14 @@ import (
"strings" "strings"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
) )
type nextHTTP = func(http.ResponseWriter, *http.Request) type nextHTTP = func(http.ResponseWriter, *http.Request)
@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.db.CreateNonce(r.Context()) nonce, err := h.db.CreateNonce(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
w.Header().Set("Replay-Nonce", string(nonce)) w.Header().Set("Replay-Nonce", string(nonce))
@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
var expected []string var expected []string
p, err := provisionerFromContext(r.Context()) p, err := provisionerFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return return
} }
} }
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"expected content-type to be in %s, but got %s", expected, ct)) "expected content-type to be in %s, but got %s", expected, ct))
} }
} }
@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) render.Error(w, acme.WrapErrorISE(err, "failed to read request body"))
return return
} }
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
return return
} }
ctx := context.WithValue(r.Context(), jwsContextKey, jws) ctx := context.WithValue(r.Context(), jwsContextKey, jws)
@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) jws, err := jwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if len(jws.Signatures) == 0 { if len(jws.Signatures) == 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
return return
} }
if len(jws.Signatures) > 1 { if len(jws.Signatures) > 1 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
return return
} }
@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
len(uh.Algorithm) > 0 || len(uh.Algorithm) > 0 ||
len(uh.Nonce) > 0 || len(uh.Nonce) > 0 ||
len(uh.ExtraHeaders) > 0 { len(uh.ExtraHeaders) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return return
} }
hdr := sig.Protected hdr := sig.Protected
@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
switch k := hdr.JSONWebKey.Key.(type) { switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
if k.Size() < keyutil.MinRSAKeyBytes { if k.Size() < keyutil.MinRSAKeyBytes {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"rsa keys must be at least %d bits (%d bytes) in size", "rsa keys must be at least %d bits (%d bytes) in size",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)) 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
return return
} }
default: default:
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"jws key type and algorithm do not match")) "jws key type and algorithm do not match"))
return return
} }
@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good // we good
default: default:
api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm)) render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
return return
} }
// Check the validity/freshness of the Nonce. // Check the validity/freshness of the Nonce.
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
// Check that the JWS url matches the requested url. // Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string) jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok { if !ok {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
return return
} }
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() { if jwsURL != reqURL.String() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL)) "url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
return return
} }
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return return
} }
if hdr.JSONWebKey == nil && hdr.KeyID == "" { if hdr.JSONWebKey == nil && hdr.KeyID == "" {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return return
} }
next(w, r) next(w, r)
@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) jws, err := jwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
jwk := jws.Signatures[0].Protected.JSONWebKey jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil { if jwk == nil {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
return return
} }
if !jwk.Valid() { if !jwk.Valid() {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
return return
} }
// Overwrite KeyID with the JWK thumbprint. // Overwrite KeyID with the JWK thumbprint.
jwk.KeyID, err = acme.KeyToID(jwk) jwk.KeyID, err = acme.KeyToID(jwk)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
return return
} }
@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
// For NewAccount and Revoke requests ... // For NewAccount and Revoke requests ...
break break
case err != nil: case err != nil:
api.WriteError(w, err) render.Error(w, err)
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
@ -283,25 +284,24 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
} }
// lookupProvisioner loads the provisioner associated with the request. // lookupProvisioner loads the provisioner associated with the request.
// Responsds 404 if the provisioner does not exist. // Responds 404 if the provisioner does not exist.
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
nameEscaped := chi.URLParam(r, "provisionerID") nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped) name, err := url.PathUnescape(nameEscaped)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped)) render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return return
} }
p, err := h.ca.LoadProvisionerByName(name) p, err := h.ca.LoadProvisionerByName(name)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
acmeProv, ok := p.(*provisioner.ACME) acmeProv, ok := p.(*provisioner.ACME)
if !ok { if !ok {
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME")) render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return return
} }
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv)) ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
@ -309,6 +309,24 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
} }
} }
// checkPrerequisites checks if all prerequisites for serving ACME
// are met by the CA configuration.
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ok, err := h.prerequisitesChecker(ctx)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
next(w, r.WithContext(ctx))
}
}
// lookupJWK loads the JWK associated with the acme account referenced by the // lookupJWK loads the JWK associated with the acme account referenced by the
// kid parameter of the signed payload. // kid parameter of the signed payload.
// Make sure to parse and validate the JWS before running this middleware. // Make sure to parse and validate the JWS before running this middleware.
@ -317,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
"kid does not have required prefix; expected %s, but got %s", "kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid)) kidPrefix, kid))
return return
@ -334,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
acc, err := h.db.GetAccount(ctx, accID) acc, err := h.db.GetAccount(ctx, accID)
switch { switch {
case nosql.IsErrNotFound(err): case nosql.IsErrNotFound(err):
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return return
case err != nil: case err != nil:
api.WriteError(w, err) render.Error(w, err)
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
@ -359,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -395,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
jwk, err := jwkFromContext(ctx) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return return
} }
payload, err := jws.Verify(jwk) payload, err := jws.Verify(jwk)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
return return
} }
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
@ -426,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context()) payload, err := payloadFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if !payload.isPostAsGet { if !payload.isPostAsGet {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET")) render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
return return
} }
next(w, r) next(w, r)

View file

@ -1656,3 +1656,91 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
}) })
} }
} }
func TestHandler_checkPrerequisites(t *testing.T) {
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
u := fmt.Sprintf("%s/acme/%s/account/1234",
baseURL, provName)
type test struct {
linker Linker
ctx context.Context
prerequisitesChecker func(context.Context) (bool, error)
next func(http.ResponseWriter, *http.Request)
err *acme.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
err: acme.WrapErrorISE(errors.New("force"), "error checking acme provisioner prerequisites"),
statusCode: 500,
}
},
"fail/prerequisites-nok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
err: acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"),
statusCode: 501,
}
},
"ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{
linker: NewLinker("dns", "acme"),
ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
statusCode: 200,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.checkPrerequisites(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
}
})
}
}

View file

@ -11,9 +11,11 @@ import (
"time" "time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render"
) )
// NewOrderRequest represents the body for a NewOrder request. // NewOrderRequest represents the body for a NewOrder request.
@ -70,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var nor NewOrderRequest var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil { if err := json.Unmarshal(payload.value, &nor); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-order request payload")) "failed to unmarshal new-order request payload"))
return return
} }
if err := nor.Validate(); err != nil { if err := nor.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -116,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
Status: acme.StatusPending, Status: acme.StatusPending,
} }
if err := h.newAuthorization(ctx, az); err != nil { if err := h.newAuthorization(ctx, az); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
o.AuthorizationIDs[i] = az.ID o.AuthorizationIDs[i] = az.ID
@ -135,14 +137,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
} }
if err := h.db.CreateOrder(ctx, o); err != nil { if err := h.db.CreateOrder(ctx, o); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating order")) render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return return
} }
h.linker.LinkOrder(ctx, o) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSONStatus(w, o, http.StatusCreated) render.JSONStatus(w, o, http.StatusCreated)
} }
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
@ -186,38 +188,38 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
} }
if acc.ID != o.AccountID { if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID)) "account '%s' does not own order '%s'", acc.ID, o.ID))
return return
} }
if prov.GetID() != o.ProvisionerID { if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.UpdateStatus(ctx, h.db); err != nil { if err = o.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating order status")) render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
return return
} }
h.linker.LinkOrder(ctx, o) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o) render.JSON(w, o)
} }
// FinalizeOrder attemptst to finalize an order and create a certificate. // FinalizeOrder attemptst to finalize an order and create a certificate.
@ -225,54 +227,54 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var fr FinalizeRequest var fr FinalizeRequest
if err := json.Unmarshal(payload.value, &fr); err != nil { if err := json.Unmarshal(payload.value, &fr); err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal finalize-order request payload")) "failed to unmarshal finalize-order request payload"))
return return
} }
if err := fr.Validate(); err != nil { if err := fr.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
} }
if acc.ID != o.AccountID { if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID)) "account '%s' does not own order '%s'", acc.ID, o.ID))
return return
} }
if prov.GetID() != o.ProvisionerID { if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil { if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order")) render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
return return
} }
h.linker.LinkOrder(ctx, o) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o) render.JSON(w, o)
} }
// challengeTypes determines the types of challenges that should be used // challengeTypes determines the types of challenges that should be used

View file

@ -10,13 +10,14 @@ import (
"net/http" "net/http"
"strings" "strings"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
) )
type revokePayload struct { type revokePayload struct {
@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx) prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
var p revokePayload var p revokePayload
err = json.Unmarshal(payload.value, &p) err = json.Unmarshal(payload.value, &p)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload")) render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
return return
} }
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
if err != nil { if err != nil {
// in this case the most likely cause is a client that didn't properly encode the certificate // in this case the most likely cause is a client that didn't properly encode the certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
return return
} }
certToBeRevoked, err := x509.ParseCertificate(certBytes) certToBeRevoked, err := x509.ParseCertificate(certBytes)
if err != nil { if err != nil {
// in this case a client may have encoded something different than a certificate // in this case a client may have encoded something different than a certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
return return
} }
serial := certToBeRevoked.SerialNumber.String() serial := certToBeRevoked.SerialNumber.String()
dbCert, err := h.db.GetCertificateBySerial(ctx, serial) dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return return
} }
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
// this should never happen // this should never happen
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal")) render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal"))
return return
} }
if shouldCheckAccountFrom(jws) { if shouldCheckAccountFrom(jws) {
account, err := accountFromContext(ctx) account, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil { if acmeErr != nil {
api.WriteError(w, acmeErr) render.Error(w, acmeErr)
return return
} }
} else { } else {
@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
_, err := jws.Verify(certToBeRevoked.PublicKey) _, err := jws.Verify(certToBeRevoked.PublicKey)
if err != nil { if err != nil {
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
return return
} }
} }
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return return
} }
if hasBeenRevokedBefore { if hasBeenRevokedBefore {
api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
return return
} }
reasonCode := p.ReasonCode reasonCode := p.ReasonCode
acmeErr := validateReasonCode(reasonCode) acmeErr := validateReasonCode(reasonCode)
if acmeErr != nil { if acmeErr != nil {
api.WriteError(w, acmeErr) render.Error(w, acmeErr)
return return
} }
@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
err = prov.AuthorizeRevoke(ctx, "") err = prov.AuthorizeRevoke(ctx, "")
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
return return
} }
options := revokeOptions(serial, certToBeRevoked, reasonCode) options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = h.ca.Revoke(ctx, options) err = h.ca.Revoke(ctx, options)
if err != nil { if err != nil {
api.WriteError(w, wrapRevokeErr(err)) render.Error(w, wrapRevokeErr(err))
return return
} }

View file

@ -79,7 +79,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
} }
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(u.String()) resp, err := vo.HTTPGet(u.String())
if err != nil { if err != nil {
@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
return nil return nil
} }
// http01ChallengeHost checks if a Challenge value is an IPv6 address
// and adds square brackets if that's the case, so that it can be used
// as a hostname. Returns the original Challenge value as the host to
// use in other cases.
func http01ChallengeHost(value string) string {
if ip := net.ParseIP(value); ip != nil && ip.To4() == nil {
value = "[" + value + "]"
}
return value
}
func tlsAlert(err error) uint8 { func tlsAlert(err error) uint8 {
var opErr *net.OpError var opErr *net.OpError
if errors.As(err, &opErr) { if errors.As(err, &opErr) {

View file

@ -13,6 +13,7 @@ import (
"encoding/asn1" "encoding/asn1"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
@ -23,9 +24,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
) )
func Test_storeError(t *testing.T) { func Test_storeError(t *testing.T) {
@ -2350,3 +2351,34 @@ func Test_serverName(t *testing.T) {
}) })
} }
} }
func Test_http01ChallengeHost(t *testing.T) {
tests := []struct {
name string
value string
want string
}{
{
name: "dns",
value: "www.example.com",
want: "www.example.com",
},
{
name: "ipv4",
value: "127.0.0.1",
want: "127.0.0.1",
},
{
name: "ipv6",
value: "::1",
want: "[::1]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := http01ChallengeHost(tt.value); got != tt.want {
t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -3,13 +3,10 @@ package acme
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/logging"
) )
// ProblemType is the type of the ACME problem. // ProblemType is the type of the ACME problem.
@ -353,26 +350,8 @@ func (e *Error) ToLog() (interface{}, error) {
return string(b), nil return string(b), nil
} }
// WriteError writes to w a JSON representation of the given error. // Render implements render.RenderableError for Error.
func WriteError(w http.ResponseWriter, err *Error) { func (e *Error) Render(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/problem+json") w.Header().Set("Content-Type", "application/problem+json")
w.WriteHeader(err.StatusCode()) render.JSONStatus(w, e, e.StatusCode())
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err.Err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.Err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
log.Println(err)
}
} }

View file

@ -20,6 +20,9 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -33,6 +36,7 @@ type Authority interface {
// context specifies the Authorize[Sign|Revoke|etc.] method. // context specifies the Authorize[Sign|Revoke|etc.] method.
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error)
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -43,7 +47,7 @@ type Authority interface {
GetProvisioners(cursor string, limit int) (provisioner.List, string, error) GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
Revoke(context.Context, *authority.RevokeOptions) error Revoke(context.Context, *authority.RevokeOptions) error
GetEncryptedKey(kid string) (string, error) GetEncryptedKey(kid string) (string, error)
GetRoots() (federation []*x509.Certificate, err error) GetRoots() ([]*x509.Certificate, error)
GetFederation() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error)
Version() authority.Version Version() authority.Version
} }
@ -257,6 +261,7 @@ func (h *caHandler) Route(r Router) {
r.MethodFunc("GET", "/provisioners", h.Provisioners) r.MethodFunc("GET", "/provisioners", h.Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
r.MethodFunc("GET", "/roots", h.Roots) r.MethodFunc("GET", "/roots", h.Roots)
r.MethodFunc("GET", "/roots.pem", h.RootsPEM)
r.MethodFunc("GET", "/federation", h.Federation) r.MethodFunc("GET", "/federation", h.Federation)
// SSH CA // SSH CA
r.MethodFunc("POST", "/ssh/sign", h.SSHSign) r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
@ -280,7 +285,7 @@ func (h *caHandler) Route(r Router) {
// Version is an HTTP handler that returns the version of the server. // Version is an HTTP handler that returns the version of the server.
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
v := h.Authority.Version() v := h.Authority.Version()
JSON(w, VersionResponse{ render.JSON(w, VersionResponse{
Version: v.Version, Version: v.Version,
RequireClientAuthentication: v.RequireClientAuthentication, RequireClientAuthentication: v.RequireClientAuthentication,
}) })
@ -288,7 +293,7 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
// Health is an HTTP handler that returns the status of the server. // Health is an HTTP handler that returns the status of the server.
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
JSON(w, HealthResponse{Status: "ok"}) render.JSON(w, HealthResponse{Status: "ok"})
} }
// Root is an HTTP handler that using the SHA256 from the URL, returns the root // Root is an HTTP handler that using the SHA256 from the URL, returns the root
@ -299,11 +304,11 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the // Load root certificate with the
cert, err := h.Authority.Root(sum) cert, err := h.Authority.Root(sum)
if err != nil { if err != nil {
WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return return
} }
JSON(w, &RootResponse{RootPEM: Certificate{cert}}) render.JSON(w, &RootResponse{RootPEM: Certificate{cert}})
} }
func certChainToPEM(certChain []*x509.Certificate) []Certificate { func certChainToPEM(certChain []*x509.Certificate) []Certificate {
@ -318,16 +323,16 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r) cursor, limit, err := ParseCursor(r)
if err != nil { if err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
p, next, err := h.Authority.GetProvisioners(cursor, limit) p, next, err := h.Authority.GetProvisioners(cursor, limit)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &ProvisionersResponse{ render.JSON(w, &ProvisionersResponse{
Provisioners: p, Provisioners: p,
NextCursor: next, NextCursor: next,
}) })
@ -338,17 +343,17 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid") kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid) key, err := h.Authority.GetEncryptedKey(kid)
if err != nil { if err != nil {
WriteError(w, errs.NotFoundErr(err)) render.Error(w, errs.NotFoundErr(err))
return return
} }
JSON(w, &ProvisionerKeyResponse{key}) render.JSON(w, &ProvisionerKeyResponse{key})
} }
// Roots returns all the root certificates for the CA. // Roots returns all the root certificates for the CA.
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := h.Authority.GetRoots()
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error getting roots")) render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
return return
} }
@ -357,16 +362,39 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{roots[i]} certs[i] = Certificate{roots[i]}
} }
JSONStatus(w, &RootsResponse{ render.JSONStatus(w, &RootsResponse{
Certificates: certs, Certificates: certs,
}, http.StatusCreated) }, http.StatusCreated)
} }
// RootsPEM returns all the root certificates for the CA in PEM format.
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
if err != nil {
render.Error(w, errs.InternalServerErr(err))
return
}
w.Header().Set("Content-Type", "application/x-pem-file")
for _, root := range roots {
block := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: root.Raw,
})
if _, err := w.Write(block); err != nil {
log.Error(w, err)
return
}
}
}
// Federation returns all the public certificates in the federation. // Federation returns all the public certificates in the federation.
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation() federated, err := h.Authority.GetFederation()
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error getting federated roots")) render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
return return
} }
@ -375,7 +403,7 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
certs[i] = Certificate{federated[i]} certs[i] = Certificate{federated[i]}
} }
JSONStatus(w, &FederationResponse{ render.JSONStatus(w, &FederationResponse{
Certificates: certs, Certificates: certs,
}, http.StatusCreated) }, http.StatusCreated)
} }

View file

@ -13,6 +13,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
@ -27,14 +28,17 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ssh"
) )
const ( const (
@ -171,6 +175,7 @@ type mockAuthority struct {
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
authorizeSign func(ott string) ([]provisioner.SignOption, error) authorizeSign func(ott string) ([]provisioner.SignOption, error)
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -208,6 +213,13 @@ func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, err
return m.ret1.([]provisioner.SignOption), m.err return m.ret1.([]provisioner.SignOption), m.err
} }
func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
if m.authorizeRenewToken != nil {
return m.authorizeRenewToken(ctx, ott)
}
return m.ret1.(*x509.Certificate), m.err
}
func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions {
if m.getTLSOptions != nil { if m.getTLSOptions != nil {
return m.getTLSOptions() return m.getTLSOptions()
@ -920,48 +932,141 @@ func Test_caHandler_Renew(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
// Prepare root and leaf for renew after expiry test.
now := time.Now()
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
root := &x509.Certificate{
Subject: pkix.Name{CommonName: "Test Root CA"},
PublicKey: rootPub,
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
NotBefore: now.Add(-2 * time.Hour),
NotAfter: now.Add(time.Hour),
}
root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv)
if err != nil {
t.Fatal(err)
}
expiredLeaf := &x509.Certificate{
Subject: pkix.Name{CommonName: "Leaf certificate"},
PublicKey: leafPub,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
EmailAddresses: []string{"test@example.org"},
}
expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv)
if err != nil {
t.Fatal(err)
}
// Generate renew after expiry token
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)})
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so)
if err != nil {
t.Fatal(err)
}
generateX5cToken := func(claims jose.Claims) string {
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return s
}
tests := []struct { tests := []struct {
name string name string
tls *tls.ConnectionState tls *tls.ConnectionState
header http.Header
cert *x509.Certificate cert *x509.Certificate
root *x509.Certificate root *x509.Certificate
err error err error
statusCode int statusCode int
}{ }{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest}, {"ok renew after expiry", &tls.ConnectionState{}, http.Header{
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
})},
}, expiredLeaf, root, nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
{"fail expired token", &tls.ConnectionState{}, http.Header{
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
})},
}, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized},
{"fail invalid root", &tls.ConnectionState{}, http.Header{
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
})},
}, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized},
} }
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ h := New(&mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err, ret1: tt.cert, ret2: tt.root, err: tt.err,
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
if err != nil {
return nil, errs.Unauthorized(err.Error())
}
var claims jose.Claims
if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil {
return nil, errs.Unauthorized(err.Error())
}
if err := claims.ValidateWithLeeway(jose.Expected{
Time: now,
}, time.Minute); err != nil {
return nil, errs.Unauthorized(err.Error())
}
return chain[0][0], nil
},
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) }).(*caHandler)
req := httptest.NewRequest("POST", "http://example.com/renew", nil) req := httptest.NewRequest("POST", "http://example.com/renew", nil)
req.TLS = tt.tls req.TLS = tt.tls
req.Header = tt.header
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Renew(logging.NewResponseLogger(w), req) h.Renew(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode { res := w.Result()
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) defer res.Body.Close()
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Renew unexpected error = %v", err) t.Errorf("caHandler.Renew unexpected error = %v", err)
} }
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
t.Errorf("%s", body)
}
if tt.statusCode < http.StatusBadRequest { if tt.statusCode < http.StatusBadRequest {
expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` +
`"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` +
`"certChain":["` +
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` +
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`)
if !bytes.Equal(bytes.TrimSpace(body), expected) { if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected)
} }
} }
}) })
@ -1239,6 +1344,46 @@ func Test_caHandler_Roots(t *testing.T) {
} }
} }
func Test_caHandler_RootsPEM(t *testing.T) {
parsedRoot := parseCertificate(rootPEM)
tests := []struct {
name string
roots []*x509.Certificate
err error
statusCode int
expect string
}{
{"one root", []*x509.Certificate{parsedRoot}, nil, http.StatusOK, rootPEM},
{"two roots", []*x509.Certificate{parsedRoot, parsedRoot}, nil, http.StatusOK, rootPEM + "\n" + rootPEM},
{"fail", nil, errors.New("an error"), http.StatusInternalServerError, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler)
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
w := httptest.NewRecorder()
h.RootsPEM(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.RootsPEM StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.RootsPEM unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), []byte(tt.expect)) {
t.Errorf("caHandler.RootsPEM Body = %s, wants %s", body, tt.expect)
}
}
})
}
}
func Test_caHandler_Federation(t *testing.T) { func Test_caHandler_Federation(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},

View file

@ -1,64 +0,0 @@
package api
import (
"encoding/json"
"fmt"
"net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/scep"
)
// WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err error) {
switch k := err.(type) {
case *acme.Error:
acme.WriteError(w, k)
return
case *admin.Error:
admin.WriteError(w, k)
return
case *scep.Error:
w.Header().Set("Content-Type", "text/plain")
default:
w.Header().Set("Content-Type", "application/json")
}
cause := errors.Cause(err)
if sc, ok := err.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
if sc, ok := cause.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
w.WriteHeader(http.StatusInternalServerError)
}
}
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
} else if e, ok := cause.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
LogError(w, err)
}
}

79
api/log/log.go Normal file
View file

@ -0,0 +1,79 @@
// Package log implements API-related logging helpers.
package log
import (
"fmt"
"log"
"net/http"
"os"
"github.com/pkg/errors"
"github.com/smallstep/certificates/logging"
)
// StackTracedError is the set of errors implementing the StackTrace function.
//
// Errors implementing this interface have their stack traces logged when passed
// to the Error function of this package.
type StackTracedError interface {
error
StackTrace() errors.StackTrace
}
// Error adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func Error(rw http.ResponseWriter, err error) {
rl, ok := rw.(logging.ResponseLogger)
if !ok {
log.Println(err)
return
}
rl.WithFields(map[string]interface{}{
"error": err,
})
if os.Getenv("STEPDEBUG") != "1" {
return
}
e, ok := err.(StackTracedError)
if !ok {
e, ok = errors.Cause(err).(StackTracedError)
}
if ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e.StackTrace()),
})
}
}
// EnabledResponse log the response object if it implements the EnableLogger
// interface.
func EnabledResponse(rw http.ResponseWriter, v interface{}) {
type enableLogger interface {
ToLog() (interface{}, error)
}
if el, ok := v.(enableLogger); ok {
out, err := el.ToLog()
if err != nil {
Error(rw, err)
return
}
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"response": out,
})
} else {
log.Println(out)
}
}
}

44
api/log/log_test.go Normal file
View file

@ -0,0 +1,44 @@
package log
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/smallstep/certificates/logging"
)
func TestError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct {
name string
args args
withFields bool
}{
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Error(tt.args.rw, tt.args.err)
if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else {
t.Error("ResponseWriter does not implement logging.ResponseLogger")
}
}
})
}
}

31
api/read/read.go Normal file
View file

@ -0,0 +1,31 @@
// Package read implements request object readers.
package read
import (
"encoding/json"
"io"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/errs"
)
// JSON reads JSON from the request body and stores it in the value
// pointed by v.
func JSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
}
return nil
}
// ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
}

46
api/read/read_test.go Normal file
View file

@ -0,0 +1,46 @@
package read
import (
"io"
"reflect"
"strings"
"testing"
"github.com/smallstep/certificates/errs"
)
func TestJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := JSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
})
}
}

View file

@ -3,6 +3,8 @@ package api
import ( import (
"net/http" "net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -27,24 +29,24 @@ func (s *RekeyRequest) Validate() error {
// Rekey is similar to renew except that the certificate will be renewed with new key from csr. // Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing client certificate")) render.Error(w, errs.BadRequest("missing client certificate"))
return return
} }
var body RekeyRequest var body RekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
@ -54,7 +56,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{ render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,

122
api/render/render.go Normal file
View file

@ -0,0 +1,122 @@
// Package render implements functionality related to response rendering.
package render
import (
"bytes"
"encoding/json"
"net/http"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/log"
)
// JSON is shorthand for JSONStatus(w, v, http.StatusOK).
func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK)
}
// JSONStatus marshals v into w. It additionally sets the status code of
// w to the given one.
//
// JSONStatus sets the Content-Type of w to application/json unless one is
// specified.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(v); err != nil {
panic(err)
}
setContentTypeUnlessPresent(w, "application/json")
w.WriteHeader(status)
_, _ = b.WriteTo(w)
log.EnabledResponse(w, v)
}
// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK).
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
ProtoJSONStatus(w, m, http.StatusOK)
}
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
b, err := protojson.Marshal(m)
if err != nil {
panic(err)
}
setContentTypeUnlessPresent(w, "application/json")
w.WriteHeader(status)
_, _ = w.Write(b)
}
func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) {
const header = "Content-Type"
h := w.Header()
if _, ok := h[header]; !ok {
h.Set(header, contentType)
}
}
// RenderableError is the set of errors that implement the basic Render method.
//
// Errors that implement this interface will use their own Render method when
// being rendered into responses.
type RenderableError interface {
error
Render(http.ResponseWriter)
}
// Error marshals the JSON representation of err to w. In case err implements
// RenderableError its own Render method will be called instead.
func Error(w http.ResponseWriter, err error) {
log.Error(w, err)
if e, ok := err.(RenderableError); ok {
e.Render(w)
return
}
JSONStatus(w, err, statusCodeFromError(err))
}
// StatusCodedError is the set of errors that implement the basic StatusCode
// function.
//
// Errors that implement this interface will use the code reported by StatusCode
// as the HTTP response code when being rendered by this package.
type StatusCodedError interface {
error
StatusCode() int
}
func statusCodeFromError(err error) (code int) {
code = http.StatusInternalServerError
type causer interface {
Cause() error
}
for err != nil {
if sc, ok := err.(StatusCodedError); ok {
code = sc.StatusCode()
break
}
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return
}

115
api/render/render_test.go Normal file
View file

@ -0,0 +1,115 @@
package render
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/smallstep/certificates/logging"
)
func TestJSON(t *testing.T) {
rec := httptest.NewRecorder()
rw := logging.NewResponseLogger(rec)
JSON(rw, map[string]interface{}{"foo": "bar"})
assert.Equal(t, http.StatusOK, rec.Result().StatusCode)
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String())
assert.Empty(t, rw.Fields())
}
func TestJSONPanics(t *testing.T) {
assert.Panics(t, func() {
JSON(httptest.NewRecorder(), make(chan struct{}))
})
}
type renderableError struct {
Code int `json:"-"`
Message string `json:"message"`
}
func (err renderableError) Error() string {
return err.Message
}
func (err renderableError) Render(w http.ResponseWriter) {
w.Header().Set("Content-Type", "something/custom")
JSONStatus(w, err, err.Code)
}
type statusedError struct {
Contents string
}
func (err statusedError) Error() string { return err.Contents }
func (statusedError) StatusCode() int { return 432 }
func TestError(t *testing.T) {
cases := []struct {
err error
code int
body string
header string
}{
0: {
err: renderableError{532, "some string"},
code: 532,
body: "{\"message\":\"some string\"}\n",
header: "something/custom",
},
1: {
err: statusedError{"123"},
code: 432,
body: "{\"Contents\":\"123\"}\n",
header: "application/json",
},
}
for caseIndex := range cases {
kase := cases[caseIndex]
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
rec := httptest.NewRecorder()
Error(rec, kase.err)
assert.Equal(t, kase.code, rec.Result().StatusCode)
assert.Equal(t, kase.body, rec.Body.String())
assert.Equal(t, kase.header, rec.Header().Get("Content-Type"))
})
}
}
type causedError struct {
cause error
}
func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) }
func (err causedError) Cause() error { return err.cause }
func TestStatusCodeFromError(t *testing.T) {
cases := []struct {
err error
exp int
}{
0: {nil, http.StatusInternalServerError},
1: {io.EOF, http.StatusInternalServerError},
2: {statusedError{"123"}, 432},
3: {causedError{statusedError{"432"}}, 432},
}
for caseIndex, kase := range cases {
assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex)
}
}

View file

@ -1,22 +1,31 @@
package api package api
import ( import (
"crypto/x509"
"net/http" "net/http"
"strings"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
const (
authorizationHeader = "Authorization"
bearerScheme = "Bearer"
)
// Renew uses the information of certificate in the TLS connection to create a // Renew uses the information of certificate in the TLS connection to create a
// new one. // new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { cert, err := h.getPeerCertificate(r)
WriteError(w, errs.BadRequest("missing client certificate")) if err != nil {
render.Error(w, err)
return return
} }
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) certChain, err := h.Authority.Renew(cert)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
@ -26,10 +35,22 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{ render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated) }, http.StatusCreated)
} }
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], nil
}
if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
}
}
return nil, errs.BadRequest("missing client certificate")
}

View file

@ -4,11 +4,14 @@ import (
"context" "context"
"net/http" "net/http"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
) )
// RevokeResponse is the response object that returns the health of the server. // RevokeResponse is the response object that returns the health of the server.
@ -48,13 +51,13 @@ func (r *RevokeRequest) Validate() (err error) {
// TODO: Add CRL and OCSP support. // TODO: Add CRL and OCSP support.
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
@ -71,7 +74,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
if len(body.OTT) > 0 { if len(body.OTT) > 0 {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
@ -80,12 +83,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number // the client certificate Serial Number must match the serial number
// being revoked. // being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing ott or client certificate")) render.Error(w, errs.BadRequest("missing ott or client certificate"))
return return
} }
opts.Crt = r.TLS.PeerCertificates[0] opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial { if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest("serial number in client certificate different than body")) render.Error(w, errs.BadRequest("serial number in client certificate different than body"))
return return
} }
// TODO: should probably be checking if the certificate was revoked here. // TODO: should probably be checking if the certificate was revoked here.
@ -96,12 +99,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
} }
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err, "error revoking certificate")) render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
return return
} }
logRevoke(w, opts) logRevoke(w, opts)
JSON(w, &RevokeResponse{Status: "ok"}) render.JSON(w, &RevokeResponse{Status: "ok"})
} }
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

View file

@ -13,6 +13,7 @@ import (
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"

View file

@ -5,6 +5,8 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -49,14 +51,14 @@ type SignResponse struct {
// information in the certificate request. // information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
@ -68,13 +70,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
signOpts, err := h.Authority.AuthorizeSign(body.OTT) signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)
@ -83,7 +85,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
caPEM = certChainPEM[1] caPEM = certChainPEM[1]
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
JSONStatus(w, &SignResponse{ render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,

View file

@ -9,12 +9,15 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh"
) )
// SSHAuthority is the interface implemented by a SSH CA authority. // SSHAuthority is the interface implemented by a SSH CA authority.
@ -249,20 +252,20 @@ type SSHBastionResponse struct {
// the request. // the request.
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }
@ -270,7 +273,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil { if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey")) render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
return return
} }
} }
@ -287,13 +290,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
} }
@ -301,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
} }
addUserCertificate = &SSHCertificate{addUserCert} addUserCertificate = &SSHCertificate{addUserCert}
@ -314,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
@ -326,13 +329,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return return
} }
identityCertificate = certChainToPEM(certChain) identityCertificate = certChainToPEM(certChain)
} }
JSONStatus(w, &SSHSignResponse{ render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{cert}, Certificate: SSHCertificate{cert},
AddUserCertificate: addUserCertificate, AddUserCertificate: addUserCertificate,
IdentityCertificate: identityCertificate, IdentityCertificate: identityCertificate,
@ -344,12 +347,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots(r.Context()) keys, err := h.Authority.GetSSHRoots(r.Context())
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound("no keys found")) render.Error(w, errs.NotFound("no keys found"))
return return
} }
@ -361,7 +364,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
} }
JSON(w, resp) render.JSON(w, resp)
} }
// SSHFederation is an HTTP handler that returns the federated SSH public keys // SSHFederation is an HTTP handler that returns the federated SSH public keys
@ -369,12 +372,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation(r.Context()) keys, err := h.Authority.GetSSHFederation(r.Context())
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 { if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, errs.NotFound("no keys found")) render.Error(w, errs.NotFound("no keys found"))
return return
} }
@ -386,25 +389,25 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k}) resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
} }
JSON(w, resp) render.JSON(w, resp)
} }
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients // SSHConfig is an HTTP handler that returns rendered templates for ssh clients
// and servers. // and servers.
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
@ -415,31 +418,31 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert: case provisioner.SSHHostCert:
cfg.HostTemplates = ts cfg.HostTemplates = ts
default: default:
WriteError(w, errs.InternalServer("it should hot get here")) render.Error(w, errs.InternalServer("it should hot get here"))
return return
} }
JSON(w, cfg) render.JSON(w, cfg)
} }
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest var body SSHCheckPrincipalRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHCheckPrincipalResponse{ render.JSON(w, &SSHCheckPrincipalResponse{
Exists: exists, Exists: exists,
}) })
} }
@ -453,10 +456,10 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHGetHostsResponse{ render.JSON(w, &SSHGetHostsResponse{
Hosts: hosts, Hosts: hosts,
}) })
} }
@ -464,22 +467,22 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
// SSHBastion provides returns the bastion configured if any. // SSHBastion provides returns the bastion configured if any.
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
JSON(w, &SSHBastionResponse{ render.JSON(w, &SSHBastionResponse{
Hostname: body.Hostname, Hostname: body.Hostname,
Bastion: bastion, Bastion: bastion,
}) })

View file

@ -4,9 +4,12 @@ import (
"net/http" "net/http"
"time" "time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
) )
// SSHRekeyRequest is the request body of an SSH certificate request. // SSHRekeyRequest is the request body of an SSH certificate request.
@ -38,37 +41,38 @@ type SSHRekeyResponse struct {
// the request. // the request.
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey")) render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return
} }
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
return return
} }
@ -78,11 +82,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
} }
JSONStatus(w, &SSHRekeyResponse{ render.JSONStatus(w, &SSHRekeyResponse{
Certificate: SSHCertificate{newCert}, Certificate: SSHCertificate{newCert},
IdentityCertificate: identity, IdentityCertificate: identity,
}, http.StatusCreated) }, http.StatusCreated)

View file

@ -6,6 +6,9 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -36,31 +39,32 @@ type SSHRenewResponse struct {
// the request. // the request.
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT) _, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT) oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return
} }
newCert, err := h.Authority.RenewSSH(ctx, oldCert) newCert, err := h.Authority.RenewSSH(ctx, oldCert)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
return return
} }
@ -70,11 +74,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
} }
JSONStatus(w, &SSHSignResponse{ render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{newCert}, Certificate: SSHCertificate{newCert},
IdentityCertificate: identity, IdentityCertificate: identity,
}, http.StatusCreated) }, http.StatusCreated)

View file

@ -3,11 +3,14 @@ package api
import ( import (
"net/http" "net/http"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
) )
// SSHRevokeResponse is the response object that returns the health of the server. // SSHRevokeResponse is the response object that returns the health of the server.
@ -47,13 +50,13 @@ func (r *SSHRevokeRequest) Validate() (err error) {
// NOTE: currently only Passive revocation is supported. // NOTE: currently only Passive revocation is supported.
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, err) render.Error(w, err)
return return
} }
@ -69,18 +72,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
// otherwise it is assumed that the certificate is revoking itself over mTLS. // otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
return return
} }
logSSHRevoke(w, opts) logSSHRevoke(w, opts)
JSON(w, &SSHRevokeResponse{Status: "ok"}) render.JSON(w, &SSHRevokeResponse{Status: "ok"})
} }
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {

View file

@ -18,12 +18,13 @@ import (
"testing" "testing"
"time" "time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh"
) )
var ( var (

View file

@ -1,109 +0,0 @@
package api
import (
"encoding/json"
"io"
"log"
"net/http"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
// EnableLogger is an interface that enables response logging for an object.
type EnableLogger interface {
ToLog() (interface{}, error)
}
// LogError adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func LogError(rw http.ResponseWriter, err error) {
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
} else {
log.Println(err)
}
}
// LogEnabledResponse log the response object if it implements the EnableLogger
// interface.
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
if el, ok := v.(EnableLogger); ok {
out, err := el.ToLog()
if err != nil {
LogError(rw, err)
return
}
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"response": out,
})
} else {
log.Println(out)
}
}
}
// JSON writes the passed value into the http.ResponseWriter.
func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK)
}
// JSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil {
LogError(w, err)
return
}
LogEnabledResponse(w, v)
}
// ProtoJSON writes the passed value into the http.ResponseWriter.
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
ProtoJSONStatus(w, m, http.StatusOK)
}
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
// given status is written as the status code of the response.
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
b, err := protojson.Marshal(m)
if err != nil {
LogError(w, err)
return
}
if _, err := w.Write(b); err != nil {
LogError(w, err)
return
}
//LogEnabledResponse(w, v)
}
// ReadJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
}
return nil
}
// ReadProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ReadProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
}

View file

@ -1,125 +0,0 @@
package api
import (
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
)
func TestLogError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct {
name string
args args
withFields bool
}{
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
LogError(tt.args.rw, tt.args.err)
if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else {
t.Error("ResponseWriter does not implement logging.ResponseLogger")
}
}
})
}
}
func TestJSON(t *testing.T) {
type args struct {
rw http.ResponseWriter
v interface{}
}
tests := []struct {
name string
args args
ok bool
}{
{"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true},
{"fail", args{httptest.NewRecorder(), make(chan int)}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rw := logging.NewResponseLogger(tt.args.rw)
JSON(rw, tt.args.v)
rr, ok := tt.args.rw.(*httptest.ResponseRecorder)
if !ok {
t.Error("ResponseWriter does not implement *httptest.ResponseRecorder")
return
}
fields := rw.Fields()
if tt.ok {
if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" {
t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body)
}
if len(fields) != 0 {
t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields)
}
} else {
if body := rr.Body.String(); body != "" {
t.Errorf("Unexpected body = %s, want empty string", body)
}
if len(fields) != 1 {
t.Errorf("ResponseLogger fields = %v, wants 1 element", fields)
}
}
})
}
}
func TestReadJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ReadJSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
})
}
}

View file

@ -6,10 +6,12 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/api"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
) )
const ( const (
@ -44,11 +46,11 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
provName := chi.URLParam(r, "provisionerName") provName := chi.URLParam(r, "provisionerName")
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if !eabEnabled { if !eabEnabled {
api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
return return
} }
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, provisionerContextKey, prov)
@ -101,15 +103,15 @@ func NewACMEAdminResponder() *ACMEAdminResponder {
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint // GetExternalAccountKeys writes the response for the EAB keys GET endpoint
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// CreateExternalAccountKey writes the response for the EAB key POST endpoint // CreateExternalAccountKey writes the response for the EAB key POST endpoint
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }

View file

@ -5,10 +5,14 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
) )
type adminAuthority interface { type adminAuthority interface {
@ -82,28 +86,28 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
adm, ok := h.auth.LoadAdminByID(id) adm, ok := h.auth.LoadAdminByID(id)
if !ok { if !ok {
api.WriteError(w, admin.NewError(admin.ErrorNotFoundType, render.Error(w, admin.NewError(admin.ErrorNotFoundType,
"admin %s not found", id)) "admin %s not found", id))
return return
} }
api.ProtoJSON(w, adm) render.ProtoJSON(w, adm)
} }
// GetAdmins returns a segment of admins associated with the authority. // GetAdmins returns a segment of admins associated with the authority.
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params")) "error parsing cursor and limit from query params"))
return return
} }
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
return return
} }
api.JSON(w, &GetAdminsResponse{ render.JSON(w, &GetAdminsResponse{
Admins: admins, Admins: admins,
NextCursor: nextCursor, NextCursor: nextCursor,
}) })
@ -112,19 +116,19 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
// CreateAdmin creates a new admin. // CreateAdmin creates a new admin.
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest var body CreateAdminRequest
if err := api.ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
p, err := h.auth.LoadProvisionerByName(body.Provisioner) p, err := h.auth.LoadProvisionerByName(body.Provisioner)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
return return
} }
adm := &linkedca.Admin{ adm := &linkedca.Admin{
@ -134,11 +138,11 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
} }
// Store to authority collection. // Store to authority collection.
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error storing admin")) render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
return return
} }
api.ProtoJSONStatus(w, adm, http.StatusCreated) render.ProtoJSONStatus(w, adm, http.StatusCreated)
} }
// DeleteAdmin deletes admin. // DeleteAdmin deletes admin.
@ -146,23 +150,23 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
return return
} }
api.JSON(w, &DeleteResponse{Status: "ok"}) render.JSON(w, &DeleteResponse{Status: "ok"})
} }
// UpdateAdmin updates an existing admin. // UpdateAdmin updates an existing admin.
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest var body UpdateAdminRequest
if err := api.ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
@ -170,9 +174,9 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id)) render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
return return
} }
api.ProtoJSON(w, adm) render.ProtoJSON(w, adm)
} }

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
) )
@ -15,7 +15,7 @@ type nextHTTP = func(http.ResponseWriter, *http.Request)
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !h.auth.IsAdminAPIEnabled() { if !h.auth.IsAdminAPIEnabled() {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
"administration API not enabled")) "administration API not enabled"))
return return
} }
@ -28,14 +28,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization") tok := r.Header.Get("Authorization")
if tok == "" { if tok == "" {
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
"missing authorization header token")) "missing authorization header token"))
return return
} }
adm, err := h.auth.AuthorizeAdminToken(r, tok) adm, err := h.auth.AuthorizeAdminToken(r, tok)
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }

View file

@ -4,12 +4,16 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/linkedca"
) )
// GetProvisionersResponse is the type for GET /admin/provisioners responses. // GetProvisionersResponse is the type for GET /admin/provisioners responses.
@ -31,39 +35,39 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
) )
if len(id) > 0 { if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil { if p, err = h.auth.LoadProvisionerByID(id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil { if p, err = h.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
if err != nil { if err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
api.ProtoJSON(w, prov) render.ProtoJSON(w, prov)
} }
// GetProvisioners returns the given segment of provisioners associated with the authority. // GetProvisioners returns the given segment of provisioners associated with the authority.
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
"error parsing cursor and limit from query params")) "error parsing cursor and limit from query params"))
return return
} }
p, next, err := h.auth.GetProvisioners(cursor, limit) p, next, err := h.auth.GetProvisioners(cursor, limit)
if err != nil { if err != nil {
api.WriteError(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
api.JSON(w, &GetProvisionersResponse{ render.JSON(w, &GetProvisionersResponse{
Provisioners: p, Provisioners: p,
NextCursor: next, NextCursor: next,
}) })
@ -72,22 +76,22 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
// CreateProvisioner creates a new prov. // CreateProvisioner creates a new prov.
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner) var prov = new(linkedca.Provisioner)
if err := api.ReadProtoJSON(r.Body, prov); err != nil { if err := read.ProtoJSON(r.Body, prov); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
// TODO: Validate inputs // TODO: Validate inputs
if err := authority.ValidateClaims(prov.Claims); err != nil { if err := authority.ValidateClaims(prov.Claims); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
return return
} }
api.ProtoJSONStatus(w, prov, http.StatusCreated) render.ProtoJSONStatus(w, prov, http.StatusCreated)
} }
// DeleteProvisioner deletes a provisioner. // DeleteProvisioner deletes a provisioner.
@ -101,75 +105,75 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
) )
if len(id) > 0 { if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil { if p, err = h.auth.LoadProvisionerByID(id); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil { if p, err = h.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
return return
} }
api.JSON(w, &DeleteResponse{Status: "ok"}) render.JSON(w, &DeleteResponse{Status: "ok"})
} }
// UpdateProvisioner updates an existing prov. // UpdateProvisioner updates an existing prov.
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner) var nu = new(linkedca.Provisioner)
if err := api.ReadProtoJSON(r.Body, nu); err != nil { if err := read.ProtoJSON(r.Body, nu); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
name := chi.URLParam(r, "name") name := chi.URLParam(r, "name")
_old, err := h.auth.LoadProvisionerByName(name) _old, err := h.auth.LoadProvisionerByName(name)
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
return return
} }
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
if err != nil { if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
return return
} }
if nu.Id != old.Id { if nu.Id != old.Id {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID")) render.Error(w, admin.NewErrorISE("cannot change provisioner ID"))
return return
} }
if nu.Type != old.Type { if nu.Type != old.Type {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner type")) render.Error(w, admin.NewErrorISE("cannot change provisioner type"))
return return
} }
if nu.AuthorityId != old.AuthorityId { if nu.AuthorityId != old.AuthorityId {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID")) render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID"))
return return
} }
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) { if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt")) render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt"))
return return
} }
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) { if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt")) render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
return return
} }
// TODO: Validate inputs // TODO: Validate inputs
if err := authority.ValidateClaims(nu.Claims); err != nil { if err := authority.ValidateClaims(nu.Claims); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
api.WriteError(w, err) render.Error(w, err)
return return
} }
api.ProtoJSON(w, nu) render.ProtoJSON(w, nu)
} }

View file

@ -3,13 +3,10 @@ package admin
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/logging"
) )
// ProblemType is the type of the Admin problem. // ProblemType is the type of the Admin problem.
@ -197,27 +194,9 @@ func (e *Error) ToLog() (interface{}, error) {
return string(b), nil return string(b), nil
} }
// WriteError writes to w a JSON representation of the given error. // Render implements render.RenderableError for Error.
func WriteError(w http.ResponseWriter, err *Error) { func (e *Error) Render(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json") e.Message = e.Err.Error()
w.WriteHeader(err.StatusCode())
err.Message = err.Err.Error() render.JSONStatus(w, e, e.StatusCode())
// Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err.Err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.Err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
}
}
}
if err := json.NewEncoder(w).Encode(err); err != nil {
log.Println(err)
}
} }

View file

@ -70,14 +70,24 @@ type Authority struct {
startTime time.Time startTime time.Time
// Custom functions // Custom functions
sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error)
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
getIdentityFunc provisioner.GetIdentityFunc getIdentityFunc provisioner.GetIdentityFunc
authorizeRenewFunc provisioner.AuthorizeRenewFunc
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
adminMutex sync.RWMutex adminMutex sync.RWMutex
} }
type Info struct {
StartTime time.Time
RootX509Certs []*x509.Certificate
SSHCAUserPublicKey []byte
SSHCAHostPublicKey []byte
DNSNames []string
}
// New creates and initiates a new Authority type. // New creates and initiates a new Authority type.
func New(cfg *config.Config, opts ...Option) (*Authority, error) { func New(cfg *config.Config, opts ...Option) (*Authority, error) {
err := cfg.Validate() err := cfg.Validate()
@ -175,7 +185,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
// Create provisioner collection. // Create provisioner collection.
provClxn := provisioner.NewCollection(provisionerConfig.Audiences) provClxn := provisioner.NewCollection(provisionerConfig.Audiences)
for _, p := range provList { for _, p := range provList {
if err := p.Init(*provisionerConfig); err != nil { if err := p.Init(provisionerConfig); err != nil {
return err return err
} }
if err := provClxn.Store(p); err != nil { if err := provClxn.Store(p); err != nil {
@ -251,6 +261,21 @@ func (a *Authority) init() error {
} }
} }
// Initialize linkedca client if necessary. On a linked RA, the issuer
// configuration might come from majordomo.
var linkedcaClient *linkedCaClient
if a.config.AuthorityConfig.EnableAdmin && a.linkedCAToken != "" && a.adminDB == nil {
linkedcaClient, err = newLinkedCAClient(a.linkedCAToken)
if err != nil {
return err
}
// If authorityId is configured make sure it matches the one in the token
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) {
return errors.New("error initializing linkedca: token authority and configured authority do not match")
}
linkedcaClient.Run()
}
// Initialize the X.509 CA Service if it has not been set in the options. // Initialize the X.509 CA Service if it has not been set in the options.
if a.x509CAService == nil { if a.x509CAService == nil {
var options casapi.Options var options casapi.Options
@ -258,6 +283,22 @@ func (a *Authority) init() error {
options = *a.config.AuthorityConfig.Options options = *a.config.AuthorityConfig.Options
} }
// Configure linked RA
if linkedcaClient != nil && options.CertificateAuthority == "" {
conf, err := linkedcaClient.GetConfiguration(context.Background())
if err != nil {
return err
}
if conf.RaConfig != nil {
options.CertificateAuthority = conf.RaConfig.CaUrl
options.CertificateAuthorityFingerprint = conf.RaConfig.Fingerprint
options.CertificateIssuer = &casapi.CertificateIssuer{
Type: conf.RaConfig.Provisioner.Type.String(),
Provisioner: conf.RaConfig.Provisioner.Name,
}
}
}
// Set the issuer password if passed in the flags. // Set the issuer password if passed in the flags.
if options.CertificateIssuer != nil && a.issuerPassword != nil { if options.CertificateIssuer != nil && a.issuerPassword != nil {
options.CertificateIssuer.Password = string(a.issuerPassword) options.CertificateIssuer.Password = string(a.issuerPassword)
@ -292,8 +333,6 @@ func (a *Authority) init() error {
return err return err
} }
a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate) a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate)
sum := sha256.Sum256(resp.RootCertificate.Raw)
log.Printf("Using root fingerprint '%s'", hex.EncodeToString(sum[:]))
} }
} }
@ -479,24 +518,13 @@ func (a *Authority) init() error {
// Initialize step-ca Admin Database if it's not already initialized using // Initialize step-ca Admin Database if it's not already initialized using
// WithAdminDB. // WithAdminDB.
if a.adminDB == nil { if a.adminDB == nil {
if a.linkedCAToken == "" { if linkedcaClient != nil {
// Check if AuthConfig already exists a.adminDB = linkedcaClient
} else {
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID) a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
if err != nil { if err != nil {
return err return err
} }
} else {
// Use the linkedca client as the admindb.
client, err := newLinkedCAClient(a.linkedCAToken)
if err != nil {
return err
}
// If authorityId is configured make sure it matches the one in the token
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) {
return errors.New("error initializing linkedca: token authority and configured authority do not match")
}
client.Run()
a.adminDB = client
} }
} }
@ -559,6 +587,21 @@ func (a *Authority) GetAdminDatabase() admin.DB {
return a.adminDB return a.adminDB
} }
func (a *Authority) GetInfo() Info {
ai := Info{
StartTime: a.startTime,
RootX509Certs: a.rootX509Certs,
DNSNames: a.config.DNSNames,
}
if a.sshCAUserCertSignKey != nil {
ai.SSHCAUserPublicKey = ssh.MarshalAuthorizedKey(a.sshCAUserCertSignKey.PublicKey())
}
if a.sshCAHostCertSignKey != nil {
ai.SSHCAHostPublicKey = ssh.MarshalAuthorizedKey(a.sshCAHostCertSignKey.PublicKey())
}
return ai
}
// IsAdminAPIEnabled returns a boolean indicating whether the Admin API has // IsAdminAPIEnabled returns a boolean indicating whether the Admin API has
// been enabled. // been enabled.
func (a *Authority) IsAdminAPIEnabled() bool { func (a *Authority) IsAdminAPIEnabled() bool {

View file

@ -6,6 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -276,6 +277,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
func (a *Authority) authorizeRenew(cert *x509.Certificate) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
serial := cert.SerialNumber.String() serial := cert.SerialNumber.String()
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
isRevoked, err := a.IsRevoked(serial) isRevoked, err := a.IsRevoked(serial)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
@ -283,7 +285,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
if isRevoked { if isRevoked {
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
} }
p, ok := a.provisioners.LoadByCertificate(cert) p, ok := a.provisioners.LoadByCertificate(cert)
if !ok { if !ok {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
@ -371,3 +372,80 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error
} }
return nil return nil
} }
// AuthorizeRenewToken validates the renew token and returns the leaf
// certificate in the x5cInsecure header.
func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
var claims jose.Claims
jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs)
if err != nil {
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
}
leaf := chain[0][0]
if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
}
p, ok := a.provisioners.LoadByCertificate(leaf)
if !ok {
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
}
if err := a.UseToken(ott, p); err != nil {
return nil, err
}
if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: p.GetName(),
Subject: leaf.Subject.CommonName,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
switch err {
case jose.ErrInvalidIssuer:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)"))
case jose.ErrInvalidSubject:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)"))
case jose.ErrNotValidYet:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)"))
case jose.ErrExpired:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)"))
case jose.ErrIssuedInTheFuture:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)"))
default:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
}
}
audiences := a.config.GetAudiences().Renew
if !matchesAudience(claims.Audience, audiences) {
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
}
return leaf, nil
}
// matchesAudience returns true if A and B share at least one element.
func matchesAudience(as, bs []string) bool {
if len(bs) == 0 || len(as) == 0 {
return false
}
for _, b := range bs {
for _, a := range as {
if b == a || stripPort(a) == stripPort(b) {
return true
}
}
}
return false
}
// stripPort attempts to strip the port from the given url. If parsing the url
// produces errors it will just return the passed argument.
func stripPort(rawurl string) string {
u, err := url.Parse(rawurl)
if err != nil {
return rawurl
}
u.Host = u.Hostname()
return u.String()
}

View file

@ -3,24 +3,32 @@ package authority
import ( import (
"context" "context"
"crypto" "crypto"
"crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"reflect"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"golang.org/x/crypto/ssh" "go.step.sm/crypto/x509util"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
) )
var testAudiences = provisioner.Audiences{ var testAudiences = provisioner.Audiences{
@ -305,8 +313,8 @@ func TestAuthority_authorizeToken(t *testing.T) {
p, err := tc.auth.authorizeToken(context.Background(), tc.token) p, err := tc.auth.authorizeToken(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -391,8 +399,8 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil { if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -476,14 +484,14 @@ func TestAuthority_authorizeSign(t *testing.T) {
got, err := tc.auth.authorizeSign(context.Background(), tc.token) got, err := tc.auth.authorizeSign(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Len(t, 7, got) assert.Len(t, 8, got)
} }
} }
}) })
@ -735,8 +743,8 @@ func TestAuthority_Authorize(t *testing.T) {
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, got) assert.Nil(t, got)
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -753,6 +761,7 @@ func TestAuthority_Authorize(t *testing.T) {
func TestAuthority_authorizeRenew(t *testing.T) { func TestAuthority_authorizeRenew(t *testing.T) {
fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt")
fooCrt.NotAfter = time.Now().Add(time.Hour)
assert.FatalError(t, err) assert.FatalError(t, err)
renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt")
@ -822,7 +831,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
cert: renewDisabledCrt, cert: renewDisabledCrt,
err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"), err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -847,7 +856,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
err := tc.auth.authorizeRenew(tc.cert) err := tc.auth.authorizeRenew(tc.cert)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -909,6 +918,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner.
} }
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
now := time.Now()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -917,6 +927,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now.Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
}
if err := cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -988,8 +1004,8 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token) got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -1003,6 +1019,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
} }
func TestAuthority_authorizeSSHRenew(t *testing.T) { func TestAuthority_authorizeSSHRenew(t *testing.T) {
now := time.Now().UTC()
sshpop := func(a *Authority) (*ssh.Certificate, string) {
p, ok := a.provisioners.Load("sshpop/sshpop")
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return cert, token
}
a := testAuthority(t) a := testAuthority(t)
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
@ -1012,8 +1045,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err) assert.FatalError(t, err)
now := time.Now().UTC()
validIssuer := "step-cli" validIssuer := "step-cli"
type authorizeTest struct { type authorizeTest struct {
@ -1050,27 +1081,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
"fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
return errs.Forbidden("forbidden")
}))
_, token := sshpop(aa)
return &authorizeTest{
auth: aa,
token: token,
err: errors.New("authority.authorizeSSHRenew: forbidden"),
code: http.StatusForbidden,
}
},
"ok": func(t *testing.T) *authorizeTest { "ok": func(t *testing.T) *authorizeTest {
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") cert, token := sshpop(a)
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
p, ok := a.provisioners.Load("sshpop/sshpop")
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop",
[]string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: tok, token: token,
cert: cert,
}
},
"ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
return nil
}))
cert, token := sshpop(aa)
return &authorizeTest{
auth: aa,
token: token,
cert: cert, cert: cert,
} }
}, },
@ -1083,8 +1121,8 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token) got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -1183,8 +1221,8 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) {
if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil { if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -1276,8 +1314,8 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token) cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -1290,3 +1328,283 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
}) })
} }
} }
func TestAuthority_AuthorizeRenewToken(t *testing.T) {
ctx := context.Background()
type stepProvisionerASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
_, signer, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer)
if err != nil {
t.Fatal(err)
}
_, otherSigner, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
if err != nil {
t.Fatal(err)
}
var x5c []string
for _, c := range chain {
x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw))
}
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("x5cInsecure", x5c)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so)
if err != nil {
t.Fatal(err)
}
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return s, chain[0]
}
now := time.Now()
a1 := testAuthority(t)
t1, c1 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
t2, c2 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now.Add(-time.Hour)
cert.NotAfter = now.Add(-time.Minute)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "bad-issuer",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "bad-subject",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
type args struct {
ctx context.Context
ott string
}
tests := []struct {
name string
authority *Authority
args args
want *x509.Certificate
wantErr bool
}{
{"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
{"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true},
{"fail token iss", a1, args{ctx, badIssuer}, nil, true},
{"fail token sub", a1, args{ctx, badSubject}, nil, true},
{"fail token iat", a1, args{ctx, badNotBefore}, nil, true},
{"fail token iat", a1, args{ctx, badExpiry}, nil, true},
{"fail token iat", a1, args{ctx, badIssuedAt}, nil, true},
{"fail token aud", a1, args{ctx, badAudience}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -26,23 +26,27 @@ var (
DefaultBackdate = time.Minute DefaultBackdate = time.Minute
// DefaultDisableRenewal disables renewals per provisioner. // DefaultDisableRenewal disables renewals per provisioner.
DefaultDisableRenewal = false DefaultDisableRenewal = false
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is
// expired.
DefaultAllowRenewAfterExpiry = false
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally // DefaultEnableSSHCA enable SSH CA features per provisioner or globally
// for all provisioners. // for all provisioners.
DefaultEnableSSHCA = false DefaultEnableSSHCA = false
// GlobalProvisionerClaims default claims for the Authority. Can be overridden // GlobalProvisionerClaims default claims for the Authority. Can be overridden
// by provisioner specific claims. // by provisioner specific claims.
GlobalProvisionerClaims = provisioner.Claims{ GlobalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &DefaultDisableRenewal, MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, EnableSSHCA: &DefaultEnableSSHCA,
EnableSSHCA: &DefaultEnableSSHCA, DisableRenewal: &DefaultDisableRenewal,
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry,
} }
) )
@ -273,28 +277,32 @@ func (c *Config) GetAudiences() provisioner.Audiences {
} }
for _, name := range c.DNSNames { for _, name := range c.DNSNames {
hostname := toHostname(name)
audiences.Sign = append(audiences.Sign, audiences.Sign = append(audiences.Sign,
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/sign", hostname),
fmt.Sprintf("https://%s/sign", toHostname(name)), fmt.Sprintf("https://%s/sign", hostname),
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
fmt.Sprintf("https://%s/ssh/sign", toHostname(name))) fmt.Sprintf("https://%s/ssh/sign", hostname))
audiences.Renew = append(audiences.Renew,
fmt.Sprintf("https://%s/1.0/renew", hostname),
fmt.Sprintf("https://%s/renew", hostname))
audiences.Revoke = append(audiences.Revoke, audiences.Revoke = append(audiences.Revoke,
fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)), fmt.Sprintf("https://%s/1.0/revoke", hostname),
fmt.Sprintf("https://%s/revoke", toHostname(name))) fmt.Sprintf("https://%s/revoke", hostname))
audiences.SSHSign = append(audiences.SSHSign, audiences.SSHSign = append(audiences.SSHSign,
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
fmt.Sprintf("https://%s/ssh/sign", toHostname(name)), fmt.Sprintf("https://%s/ssh/sign", hostname),
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/sign", hostname),
fmt.Sprintf("https://%s/sign", toHostname(name))) fmt.Sprintf("https://%s/sign", hostname))
audiences.SSHRevoke = append(audiences.SSHRevoke, audiences.SSHRevoke = append(audiences.SSHRevoke,
fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname),
fmt.Sprintf("https://%s/ssh/revoke", toHostname(name))) fmt.Sprintf("https://%s/ssh/revoke", hostname))
audiences.SSHRenew = append(audiences.SSHRenew, audiences.SSHRenew = append(audiences.SSHRenew,
fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/renew", hostname),
fmt.Sprintf("https://%s/ssh/renew", toHostname(name))) fmt.Sprintf("https://%s/ssh/renew", hostname))
audiences.SSHRekey = append(audiences.SSHRekey, audiences.SSHRekey = append(audiences.SSHRekey,
fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname),
fmt.Sprintf("https://%s/ssh/rekey", toHostname(name))) fmt.Sprintf("https://%s/ssh/rekey", hostname))
} }
return audiences return audiences

View file

@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
@ -151,13 +152,21 @@ func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linked
} }
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) { func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
resp, err := c.GetConfiguration(ctx)
if err != nil {
return nil, err
}
return resp.Provisioners, nil
}
func (c *linkedCaClient) GetConfiguration(ctx context.Context) (*linkedca.ConfigurationResponse, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
AuthorityId: c.authorityID, AuthorityId: c.authorityID,
}) })
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error getting provisioners") return nil, errors.Wrap(err, "error getting configuration")
} }
return resp.Provisioners, nil return resp, nil
} }
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
@ -204,11 +213,9 @@ func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Adm
} }
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) { func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ resp, err := c.GetConfiguration(ctx)
AuthorityId: c.authorityID,
})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error getting admins") return nil, err
} }
return resp.Admins, nil return resp.Admins, nil
} }
@ -228,12 +235,13 @@ func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
return errors.Wrap(err, "error deleting admin") return errors.Wrap(err, "error deleting admin")
} }
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error { func (c *linkedCaClient) StoreCertificateChain(prov provisioner.Interface, fullchain ...*x509.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel() defer cancel()
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
PemCertificate: serializeCertificateChain(fullchain[0]), PemCertificate: serializeCertificateChain(fullchain[0]),
PemCertificateChain: serializeCertificateChain(fullchain[1:]...), PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
Provisioner: createProvisionerIdentity(prov),
}) })
return errors.Wrap(err, "error posting certificate") return errors.Wrap(err, "error posting certificate")
} }
@ -310,6 +318,17 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
} }
func createProvisionerIdentity(prov provisioner.Interface) *linkedca.ProvisionerIdentity {
if prov == nil {
return nil
}
return &linkedca.ProvisionerIdentity{
Id: prov.GetID(),
Type: linkedca.Provisioner_Type(prov.GetType()),
Name: prov.GetName(),
}
}
func serializeCertificate(crt *x509.Certificate) string { func serializeCertificate(crt *x509.Certificate) string {
if crt == nil { if crt == nil {
return "" return ""

View file

@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e
} }
} }
// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of
// an X.509 certificate.
func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option {
return func(a *Authority) error {
a.authorizeRenewFunc = fn
return nil
}
}
// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal
// of a SSH certificate.
func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option {
return func(a *Authority) error {
a.authorizeSSHRenewFunc = fn
return nil
}
}
// WithSSHBastionFunc sets a custom function to get the bastion for a // WithSSHBastionFunc sets a custom function to get the bastion for a
// given user-host pair. // given user-host pair.
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option { func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option {
@ -145,6 +163,22 @@ func WithX509Signer(crt *x509.Certificate, s crypto.Signer) Option {
} }
} }
// WithX509SignerFunc defines the function used to get the chain of certificates
// and signer used when we sign X.509 certificates.
func WithX509SignerFunc(fn func() ([]*x509.Certificate, crypto.Signer, error)) Option {
return func(a *Authority) error {
srv, err := cas.New(context.Background(), casapi.Options{
Type: casapi.SoftCAS,
CertificateSigner: fn,
})
if err != nil {
return err
}
a.x509CAService = srv
return nil
}
}
// WithSSHUserSigner defines the signer used to sign SSH user certificates. // WithSSHUserSigner defines the signer used to sign SSH user certificates.
func WithSSHUserSigner(s crypto.Signer) Option { func WithSSHUserSigner(s crypto.Signer) Option {
return func(a *Authority) error { return func(a *Authority) error {

View file

@ -6,7 +6,6 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
) )
// ACME is the acme provisioner type, an entity that can authorize the ACME // ACME is the acme provisioner type, an entity that can authorize the ACME
@ -24,7 +23,7 @@ type ACME struct {
RequireEAB bool `json:"requireEAB,omitempty"` RequireEAB bool `json:"requireEAB,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer ctl *Controller
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -69,7 +68,7 @@ func (p *ACME) GetOptions() *Options {
// DefaultTLSCertDuration returns the default TLS cert duration enforced by // DefaultTLSCertDuration returns the default TLS cert duration enforced by
// the provisioner. // the provisioner.
func (p *ACME) DefaultTLSCertDuration() time.Duration { func (p *ACME) DefaultTLSCertDuration() time.Duration {
return p.claimer.DefaultTLSCertDuration() return p.ctl.Claimer.DefaultTLSCertDuration()
} }
// Init initializes and validates the fields of a JWK type. // Init initializes and validates the fields of a JWK type.
@ -81,12 +80,8 @@ func (p *ACME) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty") return errors.New("provisioner name cannot be empty")
} }
// Update claims with global ones p.ctl, err = NewController(p, p.Claims, config)
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { return
return err
}
return err
} }
// AuthorizeSign does not do any validation, because all validation is handled // AuthorizeSign does not do any validation, because all validation is handled
@ -94,13 +89,14 @@ func (p *ACME) Init(config Config) (err error) {
// on the resulting certificate. // on the resulting certificate.
func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{ return []SignOption{
p,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeACME, p.Name, ""), newProvisionerExtensionOption(TypeACME, p.Name, ""),
newForceCNOption(p.ForceCN), newForceCNOption(p.ForceCN),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil }, nil
} }
@ -118,8 +114,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName())
}
return nil
} }

View file

@ -3,13 +3,14 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"errors"
"fmt"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/api/render"
) )
func TestACME_Getters(t *testing.T) { func TestACME_Getters(t *testing.T) {
@ -91,6 +92,7 @@ func TestACME_Init(t *testing.T) {
} }
func TestACME_AuthorizeRenew(t *testing.T) { func TestACME_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct { type test struct {
p *ACME p *ACME
cert *x509.Certificate cert *x509.Certificate
@ -104,21 +106,27 @@ func TestACME_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()), err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateACME() p, err := generateACME()
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
} }
}, },
} }
@ -126,8 +134,8 @@ func TestACME_AuthorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -161,31 +169,32 @@ func TestACME_AuthorizeSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) { if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
assert.Len(t, 5, opts) assert.Len(t, 6, opts)
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case *ACME:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeACME)) assert.Equals(t, v.Type, TypeACME)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case *forceCNOption: case *forceCNOption:
assert.Equals(t, v.ForceCN, tc.p.ForceCN) assert.Equals(t, v.ForceCN, tc.p.ForceCN)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }

View file

@ -264,9 +264,8 @@ type AWS struct {
IIDRoots string `json:"iidRoots,omitempty"` IIDRoots string `json:"iidRoots,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
config *awsConfig config *awsConfig
audiences Audiences ctl *Controller
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -400,15 +399,11 @@ func (p *AWS) Init(config Config) (err error) {
case p.InstanceAge.Value() < 0: case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative") return errors.New("provisioner instanceAge cannot be negative")
} }
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Add default config // Add default config
if p.config, err = newAWSConfig(p.IIDRoots); err != nil { if p.config, err = newAWSConfig(p.IIDRoots); err != nil {
return err return err
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
// validate IMDS versions // validate IMDS versions
if len(p.IMDSVersions) == 0 { if len(p.IMDSVersions) == 0 {
@ -425,7 +420,9 @@ func (p *AWS) Init(config Config) (err error) {
} }
} }
return nil config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
} }
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
@ -470,14 +467,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
} }
return append(so, return append(so,
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID), newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
commonNameValidator(payload.Claims.Subject), commonNameValidator(payload.Claims.Subject),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
), nil ), nil
} }
@ -486,10 +484,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName())
}
return nil
} }
// assertConfig initializes the config if it has not been initialized // assertConfig initializes the config if it has not been initialized
@ -664,7 +659,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(payload.Audience, p.audiences.Sign) { if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) {
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)") return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
} }
@ -704,7 +699,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
@ -752,11 +747,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil

View file

@ -9,6 +9,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -17,10 +18,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestAWS_Getters(t *testing.T) { func TestAWS_Getters(t *testing.T) {
@ -521,8 +522,8 @@ func TestAWS_authorizeToken(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil { if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -641,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false}, {"ok", p1, args{t1, "foo.local"}, 7, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false}, {"ok", p2, args{t2, "instance-id"}, 11, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false}, {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 11, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false}, {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 11, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false}, {"ok", p1, args{t4, "instance-id"}, 7, http.StatusOK, false},
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true}, {"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true}, {"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
@ -668,27 +669,28 @@ func TestAWS_AuthorizeSign(t *testing.T) {
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
case err != nil: case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
default: default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *AWS:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAWS)) assert.Equals(t, v.Type, TypeAWS)
assert.Equals(t, v.Name, tt.aws.GetName()) assert.Equals(t, v.Name, tt.aws.GetName())
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
assert.Len(t, 2, v.KeyValuePairs) assert.Len(t, 2, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), tt.args.cn) assert.Equals(t, string(v), tt.args.cn)
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator: case ipAddressesValidator:
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
case emailAddressesValidator: case emailAddressesValidator:
@ -698,7 +700,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
case dnsNamesValidator: case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -726,7 +728,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p3.Claims = &Claims{EnableSSHCA: &disable} p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com") t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
@ -747,7 +749,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{ expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -802,8 +804,8 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
@ -824,6 +826,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
} }
func TestAWS_AuthorizeRenew(t *testing.T) { func TestAWS_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateAWS() p1, err := generateAWS()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateAWS() p2, err := generateAWS()
@ -832,7 +835,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -845,16 +848,22 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })

View file

@ -30,7 +30,7 @@ const azureDefaultAudience = "https://management.azure.com/"
// azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim. // azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim.
// Using case insensitive as resourceGroups appears as resourcegroups. // Using case insensitive as resourceGroups appears as resourcegroups.
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`) var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
type azureConfig struct { type azureConfig struct {
oidcDiscoveryURL string oidcDiscoveryURL string
@ -89,15 +89,17 @@ type Azure struct {
Name string `json:"name"` Name string `json:"name"`
TenantID string `json:"tenantID"` TenantID string `json:"tenantID"`
ResourceGroups []string `json:"resourceGroups"` ResourceGroups []string `json:"resourceGroups"`
SubscriptionIDs []string `json:"subscriptionIDs"`
ObjectIDs []string `json:"objectIDs"`
Audience string `json:"audience,omitempty"` Audience string `json:"audience,omitempty"`
DisableCustomSANs bool `json:"disableCustomSANs"` DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
config *azureConfig config *azureConfig
oidcConfig openIDConfiguration oidcConfig openIDConfiguration
keyStore *keyStore keyStore *keyStore
ctl *Controller
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -201,37 +203,34 @@ func (p *Azure) Init(config Config) (err error) {
case p.Audience == "": // use default audience case p.Audience == "": // use default audience
p.Audience = azureDefaultAudience p.Audience = azureDefaultAudience
} }
// Initialize config // Initialize config
p.assertConfig() p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint // Decode and validate openid-configuration endpoint
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
return err return
} }
if err := p.oidcConfig.Validate(); err != nil { if err := p.oidcConfig.Validate(); err != nil {
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL) return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
} }
// Get JWK key set // Get JWK key set
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil { if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
return err return
} }
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// authorizeToken returns the claims, name, group, error. // authorizeToken returns the claims, name, group, subscription, identityObjectID, error.
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) { func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, string, string, error) {
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token") return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
} }
if len(jwt.Headers) == 0 { if len(jwt.Headers) == 0 {
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header") return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
} }
var found bool var found bool
@ -244,7 +243,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
} }
} }
if !found { if !found {
return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token") return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
} }
if err := claims.ValidateWithLeeway(jose.Expected{ if err := claims.ValidateWithLeeway(jose.Expected{
@ -252,26 +251,30 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
Issuer: p.oidcConfig.Issuer, Issuer: p.oidcConfig.Issuer,
Time: time.Now(), Time: time.Now(),
}, 1*time.Minute); err != nil { }, 1*time.Minute); err != nil {
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload") return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload")
} }
// Validate TenantID // Validate TenantID
if claims.TenantID != p.TenantID { if claims.TenantID != p.TenantID {
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)") return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
} }
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 { if len(re) != 5 {
return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID) return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
} }
group, name := re[2], re[3]
return &claims, name, group, nil var subscription, group, name string
identityObjectID := claims.ObjectID
subscription, group, name = re[1], re[2], re[4]
return &claims, name, group, subscription, identityObjectID, nil
} }
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation. // will be used on certificate creation.
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
_, name, group, err := p.authorizeToken(token) _, name, group, subscription, identityObjectID, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign")
} }
@ -290,6 +293,34 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
} }
} }
// Filter by subscription id
if len(p.SubscriptionIDs) > 0 {
var found bool
for _, s := range p.SubscriptionIDs {
if s == subscription {
found = true
break
}
}
if !found {
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid subscription id")
}
}
// Filter by Azure AD identity object id
if len(p.ObjectIDs) > 0 {
var found bool
for _, i := range p.ObjectIDs {
if i == identityObjectID {
found = true
break
}
}
if !found {
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid identity object id")
}
}
// Template options // Template options
data := x509util.NewTemplateData() data := x509util.NewTemplateData()
data.SetCommonName(name) data.SetCommonName(name)
@ -321,13 +352,14 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
} }
return append(so, return append(so,
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID), newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
), nil ), nil
} }
@ -336,19 +368,16 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName())
} }
_, name, _, err := p.authorizeToken(token) _, name, _, _, _, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign")
} }
@ -389,11 +418,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil

View file

@ -8,6 +8,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -15,10 +16,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestAzure_Getters(t *testing.T) { func TestAzure_Getters(t *testing.T) {
@ -95,7 +96,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
@ -237,7 +238,7 @@ func TestAzure_authorizeToken(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), jwk) time.Now(), jwk)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -252,7 +253,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -267,7 +268,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
"foo", "subscriptionID", "resourceGroup", "virtualMachine", "foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -321,7 +322,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -333,10 +334,10 @@ func TestAzure_authorizeToken(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, name, group, err := tc.p.authorizeToken(tc.token); err != nil { if claims, name, group, subscriptionID, objectID, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -348,6 +349,8 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.Equals(t, name, "virtualMachine") assert.Equals(t, name, "virtualMachine")
assert.Equals(t, group, "resourceGroup") assert.Equals(t, group, "resourceGroup")
assert.Equals(t, subscriptionID, "subscriptionID")
assert.Equals(t, objectID, "the-oid")
} }
} }
}) })
@ -382,6 +385,38 @@ func TestAzure_AuthorizeSign(t *testing.T) {
p4.oidcConfig = p1.oidcConfig p4.oidcConfig = p1.oidcConfig
p4.keyStore = p1.keyStore p4.keyStore = p1.keyStore
p5, err := generateAzure()
assert.FatalError(t, err)
p5.TenantID = p1.TenantID
p5.SubscriptionIDs = []string{"subscriptionID"}
p5.config = p1.config
p5.oidcConfig = p1.oidcConfig
p5.keyStore = p1.keyStore
p6, err := generateAzure()
assert.FatalError(t, err)
p6.TenantID = p1.TenantID
p6.SubscriptionIDs = []string{"foobarzar"}
p6.config = p1.config
p6.oidcConfig = p1.oidcConfig
p6.keyStore = p1.keyStore
p7, err := generateAzure()
assert.FatalError(t, err)
p7.TenantID = p1.TenantID
p7.ObjectIDs = []string{"the-oid"}
p7.config = p1.config
p7.oidcConfig = p1.oidcConfig
p7.keyStore = p1.keyStore
p8, err := generateAzure()
assert.FatalError(t, err)
p8.TenantID = p1.TenantID
p8.ObjectIDs = []string{"foobarzar"}
p8.config = p1.config
p8.oidcConfig = p1.oidcConfig
p8.keyStore = p1.keyStore
badKey, err := generateJSONWebKey() badKey, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -393,30 +428,38 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t4, err := p4.GetIdentityToken("subject", "caURL") t4, err := p4.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err) assert.FatalError(t, err)
t5, err := p5.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
t6, err := p6.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
t7, err := p6.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
t8, err := p6.GetIdentityToken("subject", "caURL")
assert.FatalError(t, err)
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience", failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0]) time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0]) time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), badKey) time.Now(), badKey)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -431,11 +474,15 @@ func TestAzure_AuthorizeSign(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1}, 5, http.StatusOK, false}, {"ok", p1, args{t1}, 6, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false}, {"ok", p2, args{t2}, 11, http.StatusOK, false},
{"ok", p1, args{t11}, 5, http.StatusOK, false}, {"ok", p1, args{t11}, 6, http.StatusOK, false},
{"ok", p5, args{t5}, 6, http.StatusOK, false},
{"ok", p7, args{t7}, 6, http.StatusOK, false},
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true}, {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true}, {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
{"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true},
{"fail object id", p8, args{t8}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true}, {"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
{"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true}, {"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
@ -451,27 +498,28 @@ func TestAzure_AuthorizeSign(t *testing.T) {
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
case err != nil: case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
default: default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *Azure:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAzure)) assert.Equals(t, v.Type, TypeAzure)
assert.Equals(t, v.Name, tt.azure.GetName()) assert.Equals(t, v.Name, tt.azure.GetName())
assert.Equals(t, v.CredentialID, tt.azure.TenantID) assert.Equals(t, v.CredentialID, tt.azure.TenantID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), "virtualMachine") assert.Equals(t, string(v), "virtualMachine")
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator: case ipAddressesValidator:
assert.Equals(t, v, nil) assert.Equals(t, v, nil)
case emailAddressesValidator: case emailAddressesValidator:
@ -481,7 +529,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
case dnsNamesValidator: case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"virtualMachine"}) assert.Equals(t, []string(v), []string{"virtualMachine"})
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -490,6 +538,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
} }
func TestAzure_AuthorizeRenew(t *testing.T) { func TestAzure_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateAzure() p1, err := generateAzure()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateAzure() p2, err := generateAzure()
@ -498,7 +547,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -511,16 +560,22 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })
@ -549,7 +604,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p3.Claims = &Claims{EnableSSHCA: &disable} p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("subject", "caURL") t1, err := p1.GetIdentityToken("subject", "caURL")
@ -570,7 +625,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{ expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"virtualMachine"}, CertType: "host", Principals: []string{"virtualMachine"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -615,8 +670,8 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {

View file

@ -10,10 +10,10 @@ import (
// Claims so that individual provisioners can override global claims. // Claims so that individual provisioners can override global claims.
type Claims struct { type Claims struct {
// TLS CA properties // TLS CA properties
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"`
// SSH CA properties // SSH CA properties
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"` MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"` MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
@ -22,6 +22,10 @@ type Claims struct {
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"` MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"` DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
EnableSSHCA *bool `json:"enableSSHCA,omitempty"` EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
// Renewal properties
DisableRenewal *bool `json:"disableRenewal,omitempty"`
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"`
} }
// Claimer is the type that controls claims. It provides an interface around the // Claimer is the type that controls claims. It provides an interface around the
@ -40,19 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
// Claims returns the merge of the inner and global claims. // Claims returns the merge of the inner and global claims.
func (c *Claimer) Claims() Claims { func (c *Claimer) Claims() Claims {
disableRenewal := c.IsDisableRenewal() disableRenewal := c.IsDisableRenewal()
allowRenewAfterExpiry := c.AllowRenewAfterExpiry()
enableSSHCA := c.IsSSHCAEnabled() enableSSHCA := c.IsSSHCAEnabled()
return Claims{ return Claims{
MinTLSDur: &Duration{c.MinTLSCertDuration()}, MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
DisableRenewal: &disableRenewal, MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, EnableSSHCA: &enableSSHCA,
EnableSSHCA: &enableSSHCA, DisableRenewal: &disableRenewal,
AllowRenewAfterExpiry: &allowRenewAfterExpiry,
} }
} }
@ -102,6 +109,16 @@ func (c *Claimer) IsDisableRenewal() bool {
return *c.claims.DisableRenewal return *c.claims.DisableRenewal
} }
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the
// certificate is expired. If the property is not set within the provisioner
// then the global value from the authority configuration will be used.
func (c *Claimer) AllowRenewAfterExpiry() bool {
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil {
return *c.global.AllowRenewAfterExpiry
}
return *c.claims.AllowRenewAfterExpiry
}
// DefaultSSHCertDuration returns the default SSH certificate duration for the // DefaultSSHCertDuration returns the default SSH certificate duration for the
// given certificate type. // given certificate type.
func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) { func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) {

View file

@ -152,8 +152,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
// proper id to load the provisioner. // proper id to load the provisioner.
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) { func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
for _, e := range cert.Extensions { for _, e := range cert.Extensions {
if e.Id.Equal(stepOIDProvisioner) { if e.Id.Equal(StepOIDProvisioner) {
var provisioner stepProvisionerASN1 var provisioner extensionASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false return nil, false
} }

View file

@ -147,6 +147,17 @@ func TestCollection_LoadByToken(t *testing.T) {
} }
func TestCollection_LoadByCertificate(t *testing.T) { func TestCollection_LoadByCertificate(t *testing.T) {
mustExtension := func(typ Type, name, credentialID string) pkix.Extension {
e := Extension{
Type: typ, Name: name, CredentialID: credentialID,
}
ext, err := e.ToExtension()
if err != nil {
t.Fatal(err)
}
return ext
}
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateOIDC() p2, err := generateOIDC()
@ -159,30 +170,21 @@ func TestCollection_LoadByCertificate(t *testing.T) {
byName.Store(p2.GetName(), p2) byName.Store(p2.GetName(), p2)
byName.Store(p3.GetName(), p3) byName.Store(p3.GetName(), p3)
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
assert.FatalError(t, err)
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
assert.FatalError(t, err)
ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "")
assert.FatalError(t, err)
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
assert.FatalError(t, err)
ok1Cert := &x509.Certificate{ ok1Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok1Ext}, Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)},
} }
ok2Cert := &x509.Certificate{ ok2Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok2Ext}, Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)},
} }
ok3Cert := &x509.Certificate{ ok3Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok3Ext}, Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")},
} }
notFoundCert := &x509.Certificate{ notFoundCert := &x509.Certificate{
Extensions: []pkix.Extension{notFoundExt}, Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")},
} }
badCert := &x509.Certificate{ badCert := &x509.Certificate{
Extensions: []pkix.Extension{ Extensions: []pkix.Extension{
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")}, {Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")},
}, },
} }

View file

@ -0,0 +1,194 @@
package provisioner
import (
"context"
"crypto/x509"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
)
// Controller wraps a provisioner with other attributes useful in callback
// functions.
type Controller struct {
Interface
Audiences *Audiences
Claimer *Claimer
IdentityFunc GetIdentityFunc
AuthorizeRenewFunc AuthorizeRenewFunc
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
// NewController initializes a new provisioner controller.
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) {
claimer, err := NewClaimer(claims, config.Claims)
if err != nil {
return nil, err
}
return &Controller{
Interface: p,
Audiences: &config.Audiences,
Claimer: claimer,
IdentityFunc: config.GetIdentityFunc,
AuthorizeRenewFunc: config.AuthorizeRenewFunc,
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
}, nil
}
// GetIdentity returns the identity for a given email.
func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) {
if c.IdentityFunc != nil {
return c.IdentityFunc(ctx, c.Interface, email)
}
return DefaultIdentityFunc(ctx, c.Interface, email)
}
// AuthorizeRenew returns nil if the given cert can be renewed, returns an error
// otherwise.
func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if c.AuthorizeRenewFunc != nil {
return c.AuthorizeRenewFunc(ctx, c, cert)
}
return DefaultAuthorizeRenew(ctx, c, cert)
}
// AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an
// error otherwise.
func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error {
if c.AuthorizeSSHRenewFunc != nil {
return c.AuthorizeSSHRenewFunc(ctx, c, cert)
}
return DefaultAuthorizeSSHRenew(ctx, c, cert)
}
// Identity is the type representing an externally supplied identity that is used
// by provisioners to populate certificate fields.
type Identity struct {
Usernames []string `json:"usernames"`
Permissions `json:"permissions"`
}
// GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
// AuthorizeRenewFunc is a function that returns nil if the renewal of a
// certificate is enabled.
type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error
// AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the
// given SSH certificate is enabled.
type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error
// DefaultIdentityFunc return a default identity depending on the provisioner
// type. For OIDC email is always present and the usernames might
// contain empty strings.
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) {
case *OIDC:
// OIDC principals would be:
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
// 2. Sanitized local.
// 3. Raw local (if different).
// 4. Email address.
name := SanitizeSSHUserPrincipal(email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
}
usernames := []string{name}
if i := strings.LastIndex(email, "@"); i >= 0 {
usernames = append(usernames, email[:i])
}
usernames = append(usernames, email)
return &Identity{
Usernames: SanitizeStringSlices(usernames),
}, nil
default:
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
}
}
// DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It
// will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled.
func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error {
if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
}
now := time.Now().Truncate(time.Second)
if now.Before(cert.NotBefore) {
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
}
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
return nil
}
// DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It
// will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled.
func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
}
unixNow := time.Now().Unix()
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return errs.Unauthorized("certificate is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
return nil
}
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
// SanitizeStringSlices removes duplicated an empty strings.
func SanitizeStringSlices(original []string) []string {
output := []string{}
seen := make(map[string]struct{})
for _, entry := range original {
if entry == "" {
continue
}
if _, value := seen[entry]; !value {
seen[entry] = struct{}{}
output = append(output, entry)
}
}
return output
}
// SanitizeSSHUserPrincipal grabs an email or a string with the format
// local@domain and returns a sanitized version of the local, valid to be used
// as a user name. If the email starts with a letter between a and z, the
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
func SanitizeSSHUserPrincipal(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i]
}
return strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= '0' && r <= '9':
return r
case r == '-':
return '-'
case r == '.': // drop dots
return -1
default:
return '_'
}
}, strings.ToLower(email))
}

View file

@ -0,0 +1,391 @@
package provisioner
import (
"context"
"crypto/x509"
"fmt"
"reflect"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
var trueValue = true
func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer {
t.Helper()
c, err := NewClaimer(claims, global)
if err != nil {
t.Fatal(err)
}
return c
}
func mustDuration(t *testing.T, s string) *Duration {
t.Helper()
d, err := NewDuration(s)
if err != nil {
t.Fatal(err)
}
return d
}
func TestNewController(t *testing.T) {
type args struct {
p Interface
claims *Claims
config Config
}
tests := []struct {
name string
args args
want *Controller
wantErr bool
}{
{"ok", args{&JWK{}, nil, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, false},
{"ok with claims", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims),
}, false},
{"fail claimer", args{&JWK{}, &Claims{
MinTLSDur: mustDuration(t, "24h"),
MaxTLSDur: mustDuration(t, "2h"),
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
if (err != nil) != tt.wantErr {
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewController() = %v, want %v", got, tt.want)
}
})
}
}
func TestController_GetIdentity(t *testing.T) {
ctx := context.Background()
type fields struct {
Interface Interface
IdentityFunc GetIdentityFunc
}
type args struct {
ctx context.Context
email string
}
tests := []struct {
name string
fields fields
args args
want *Identity
wantErr bool
}{
{"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{
Usernames: []string{"jane", "jane@doe.org"},
}, false},
{"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
return &Identity{Usernames: []string{"jane"}}, nil
}}, args{ctx, "jane@doe.org"}, &Identity{
Usernames: []string{"jane"},
}, false},
{"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true},
{"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
return nil, fmt.Errorf("an error")
}}, args{ctx, "jane@doe.org"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
IdentityFunc: tt.fields.IdentityFunc,
}
got, err := c.GetIdentity(tt.args.ctx, tt.args.email)
if (err != nil) != tt.wantErr {
t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want)
}
})
}
}
func TestController_AuthorizeRenew(t *testing.T) {
ctx := context.Background()
now := time.Now().Truncate(time.Second)
type fields struct {
Interface Interface
Claimer *Claimer
AuthorizeRenewFunc AuthorizeRenewFunc
}
type args struct {
ctx context.Context
cert *x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(time.Hour),
NotAfter: now.Add(2 * time.Hour),
}}, true},
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, true},
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return fmt.Errorf("an error")
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
Claimer: tt.fields.Claimer,
AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc,
}
if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestController_AuthorizeSSHRenew(t *testing.T) {
ctx := context.Background()
now := time.Now()
type fields struct {
Interface Interface
Claimer *Claimer
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
type args struct {
ctx context.Context
cert *ssh.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(time.Hour).Unix()),
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
}}, true},
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, true},
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return fmt.Errorf("an error")
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
Claimer: tt.fields.Claimer,
AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc,
}
if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDefaultAuthorizeRenew(t *testing.T) {
ctx := context.Background()
now := time.Now().Truncate(time.Second)
type args struct {
ctx context.Context
p *Controller
cert *x509.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
{"fail disabled", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
{"fail not yet valid", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(time.Hour),
NotAfter: now.Add(2 * time.Hour),
}}, true},
{"fail expired", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDefaultAuthorizeSSHRenew(t *testing.T) {
ctx := context.Background()
now := time.Now()
type args struct {
ctx context.Context
p *Controller
cert *ssh.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
{"fail disabled", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
{"fail not yet valid", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(time.Hour).Unix()),
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
}}, true},
{"fail expired", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -0,0 +1,73 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
)
var (
// StepOIDRoot is the root OID for smallstep.
StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
// StepOIDProvisioner is the OID for the provisioner extension.
StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...)
)
// Extension is the Go representation of the provisioner extension.
type Extension struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
type extensionASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
// Marshal marshals the extension using encoding/asn1.
func (e *Extension) Marshal() ([]byte, error) {
return asn1.Marshal(extensionASN1{
Type: int(e.Type),
Name: []byte(e.Name),
CredentialID: []byte(e.CredentialID),
KeyValuePairs: e.KeyValuePairs,
})
}
// ToExtension returns the pkix.Extension representation of the provisioner
// extension.
func (e *Extension) ToExtension() (pkix.Extension, error) {
b, err := e.Marshal()
if err != nil {
return pkix.Extension{}, err
}
return pkix.Extension{
Id: StepOIDProvisioner,
Value: b,
}, nil
}
// GetProvisionerExtension goes through all the certificate extensions and
// returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1).
func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) {
for _, e := range cert.Extensions {
if e.Id.Equal(StepOIDProvisioner) {
var provisioner extensionASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false
}
return &Extension{
Type: Type(provisioner.Type),
Name: string(provisioner.Name),
CredentialID: string(provisioner.CredentialID),
KeyValuePairs: provisioner.KeyValuePairs,
}, true
}
}
return nil, false
}

View file

@ -0,0 +1,158 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"reflect"
"testing"
"go.step.sm/crypto/pemutil"
)
func TestExtension_Marshal(t *testing.T) {
type fields struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
tests := []struct {
name string
fields fields
want []byte
wantErr bool
}{
{"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
}, false},
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
0x13, 0x03, 0x62, 0x61, 0x72,
}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Extension{
Type: tt.fields.Type,
Name: tt.fields.Name,
CredentialID: tt.fields.CredentialID,
KeyValuePairs: tt.fields.KeyValuePairs,
}
got, err := e.Marshal()
if (err != nil) != tt.wantErr {
t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want)
}
})
}
}
func TestExtension_ToExtension(t *testing.T) {
type fields struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
tests := []struct {
name string
fields fields
want pkix.Extension
wantErr bool
}{
{"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
},
}, false},
{"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
},
}, false},
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
0x13, 0x03, 0x62, 0x61, 0x72,
},
}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Extension{
Type: tt.fields.Type,
Name: tt.fields.Name,
CredentialID: tt.fields.CredentialID,
KeyValuePairs: tt.fields.KeyValuePairs,
}
got, err := e.ToExtension()
if (err != nil) != tt.wantErr {
t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetProvisionerExtension(t *testing.T) {
mustCertificate := func(fn string) *x509.Certificate {
cert, err := pemutil.ReadCertificate(fn)
if err != nil {
t.Fatal(err)
}
return cert
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
args args
want *Extension
want1 bool
}{
{"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{
Type: TypeJWK,
Name: "mariano@smallstep.com",
CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ",
}, true},
{"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false},
{"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := GetProvisionerExtension(tt.args.cert)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1)
}
})
}
}

View file

@ -88,10 +88,9 @@ type GCP struct {
InstanceAge Duration `json:"instanceAge,omitempty"` InstanceAge Duration `json:"instanceAge,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
config *gcpConfig config *gcpConfig
keyStore *keyStore keyStore *keyStore
audiences Audiences ctl *Controller
} }
// GetID returns the provisioner unique identifier. The name should uniquely // GetID returns the provisioner unique identifier. The name should uniquely
@ -194,8 +193,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
} }
// Init validates and initializes the GCP provisioner. // Init validates and initializes the GCP provisioner.
func (p *GCP) Init(config Config) error { func (p *GCP) Init(config Config) (err error) {
var err error
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -204,20 +202,18 @@ func (p *GCP) Init(config Config) error {
case p.InstanceAge.Value() < 0: case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative") return errors.New("provisioner instanceAge cannot be negative")
} }
// Initialize config // Initialize config
p.assertConfig() p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize key store // Initialize key store
p.keyStore, err = newKeyStore(p.config.CertsURL) if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil {
if err != nil { return
return err
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
@ -266,22 +262,20 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
} }
return append(so, return append(so,
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName), newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
), nil ), nil
} }
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName())
}
return nil
} }
// assertConfig initializes the config if it has not been initialized. // assertConfig initializes the config if it has not been initialized.
@ -328,7 +322,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) { if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) {
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)") return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
} }
@ -383,7 +377,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
@ -431,11 +425,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil

View file

@ -8,6 +8,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -16,10 +17,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestGCP_Getters(t *testing.T) { func TestGCP_Getters(t *testing.T) {
@ -390,8 +391,8 @@ func TestGCP_authorizeToken(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token); err != nil { if claims, err := tc.p.authorizeToken(tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -515,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1}, 5, http.StatusOK, false}, {"ok", p1, args{t1}, 6, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false}, {"ok", p2, args{t2}, 11, http.StatusOK, false},
{"ok", p3, args{t3}, 5, http.StatusOK, false}, {"ok", p3, args{t3}, 6, http.StatusOK, false},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true}, {"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
@ -540,27 +541,28 @@ func TestGCP_AuthorizeSign(t *testing.T) {
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
case err != nil: case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
default: default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *GCP:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeGCP)) assert.Equals(t, v.Type, TypeGCP)
assert.Equals(t, v.Name, tt.gcp.GetName()) assert.Equals(t, v.Name, tt.gcp.GetName())
assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0])
assert.Len(t, 4, v.KeyValuePairs) assert.Len(t, 4, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration())
case commonNameSliceValidator: case commonNameSliceValidator:
assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator: case ipAddressesValidator:
assert.Equals(t, v, nil) assert.Equals(t, v, nil)
case emailAddressesValidator: case emailAddressesValidator:
@ -570,7 +572,7 @@ func TestGCP_AuthorizeSign(t *testing.T) {
case dnsNamesValidator: case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -595,7 +597,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p3.Claims = &Claims{EnableSSHCA: &disable} p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0], t1, err := generateGCPToken(p1.ServiceAccounts[0],
@ -622,7 +624,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{ expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -677,8 +679,8 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
@ -698,6 +700,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
} }
func TestGCP_AuthorizeRenew(t *testing.T) { func TestGCP_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateGCP() p1, err := generateGCP()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateGCP() p2, err := generateGCP()
@ -706,7 +709,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -719,15 +722,21 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renewal-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }

View file

@ -35,8 +35,7 @@ type JWK struct {
EncryptedKey string `json:"encryptedKey,omitempty"` EncryptedKey string `json:"encryptedKey,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer ctl *Controller
audiences Audiences
} }
// GetID returns the provisioner unique identifier. The name and credential id // GetID returns the provisioner unique identifier. The name and credential id
@ -98,13 +97,8 @@ func (p *JWK) Init(config Config) (err error) {
return errors.New("provisioner key cannot be empty") return errors.New("provisioner key cannot be empty")
} }
// Update claims with global ones p.ctl, err = NewController(p, p.Claims, config)
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { return
return err
}
p.audiences = config.Audiences
return err
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -146,13 +140,13 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
} }
@ -176,15 +170,16 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
} }
return []SignOption{ return []SignOption{
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
commonNameValidator(claims.Subject), commonNameValidator(claims.Subject),
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
defaultSANsValidator(claims.SANs), defaultSANsValidator(claims.SANs),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil }, nil
} }
@ -193,18 +188,15 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
} }
@ -261,11 +253,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
return append(signOptions, return append(signOptions,
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
@ -273,6 +265,6 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.SSHRevoke) _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
} }

View file

@ -6,15 +6,17 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"errors"
"fmt"
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestJWK_Getters(t *testing.T) { func TestJWK_Getters(t *testing.T) {
@ -76,13 +78,13 @@ func TestJWK_Init(t *testing.T) {
}, },
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
err: errors.New("claims: MinTLSCertDuration must be greater than 0"), err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
} }
}, },
"ok": func(t *testing.T) ProvisionerValidateTest { "ok": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences}, p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
} }
}, },
} }
@ -183,8 +185,8 @@ func TestJWK_authorizeToken(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil { if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
@ -223,8 +225,8 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil { if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
@ -288,34 +290,35 @@ func TestJWK_AuthorizeSign(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil { if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
} else { } else {
if assert.NotNil(t, got) { if assert.NotNil(t, got) {
assert.Len(t, 7, got) assert.Len(t, 8, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *JWK:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeJWK)) assert.Equals(t, v.Type, TypeJWK)
assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.Name, tt.prov.GetName())
assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID) assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), "subject") assert.Equals(t, string(v), "subject")
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case defaultSANsValidator: case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans) assert.Equals(t, []string(v), tt.sans)
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -325,6 +328,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
} }
func TestJWK_AuthorizeRenew(t *testing.T) { func TestJWK_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateJWK() p2, err := generateJWK()
@ -333,7 +337,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -346,16 +350,22 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr { if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })
@ -373,7 +383,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p2.Claims = &Claims{EnableSSHCA: &disable} p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p1.EncryptedKey) jwk, err := decryptJSONWebKey(p1.EncryptedKey)
@ -402,8 +412,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name"}, CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
@ -448,8 +458,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
@ -485,8 +495,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
signer, err := generateJSONWebKey() signer, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name"}, CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
@ -613,8 +623,8 @@ func TestJWK_AuthorizeSSHRevoke(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -42,16 +42,15 @@ type k8sSAPayload struct {
// entity trusted to make signature requests. // entity trusted to make signature requests.
type K8sSA struct { type K8sSA struct {
*base *base
ID string `json:"-"` ID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
PubKeys []byte `json:"publicKeys,omitempty"` PubKeys []byte `json:"publicKeys,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
audiences Audiences
//kauthn kauthn.AuthenticationV1Interface //kauthn kauthn.AuthenticationV1Interface
pubKeys []interface{} pubKeys []interface{}
ctl *Controller
} }
// GetID returns the provisioner unique identifier. The name and credential id // GetID returns the provisioner unique identifier. The name and credential id
@ -138,13 +137,8 @@ func (p *K8sSA) Init(config Config) (err error) {
p.kauthn = k8s.AuthenticationV1() p.kauthn = k8s.AuthenticationV1()
*/ */
// Update claims with global ones p.ctl, err = NewController(p, p.Claims, config)
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { return
return err
}
p.audiences = config.Audiences
return err
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -211,13 +205,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error { func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
} }
@ -237,30 +231,28 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
} }
return []SignOption{ return []SignOption{
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""), newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil }, nil
} }
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign validates an request for an SSH certificate. // AuthorizeSSHSign validates an request for an SSH certificate.
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
} }
@ -282,11 +274,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Require type, key-id and principals in the SignSSHOptions. // Require type, key-id and principals in the SignSSHOptions.
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil

View file

@ -3,14 +3,16 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"errors"
"fmt"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestK8sSA_Getters(t *testing.T) { func TestK8sSA_Getters(t *testing.T) {
@ -116,8 +118,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -165,8 +167,8 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -179,6 +181,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
} }
func TestK8sSA_AuthorizeRenew(t *testing.T) { func TestK8sSA_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct { type test struct {
p *K8sSA p *K8sSA
cert *x509.Certificate cert *x509.Certificate
@ -192,21 +195,27 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()), err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateK8sSA(nil) p, err := generateK8sSA(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
} }
}, },
} }
@ -214,8 +223,8 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil { if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -263,8 +272,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -274,24 +283,25 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
tot := 0 tot := 0
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case *K8sSA:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeK8sSA)) assert.Equals(t, v.Type, TypeK8sSA)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
tot++ tot++
} }
assert.Equals(t, tot, 5) assert.Equals(t, tot, 6)
} }
} }
} }
@ -313,13 +323,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p.Claims = &Claims{EnableSSHCA: &disable} p.Claims = &Claims{EnableSSHCA: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()), err: fmt.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()),
} }
}, },
"fail/invalid-token": func(t *testing.T) test { "fail/invalid-token": func(t *testing.T) test {
@ -350,8 +360,8 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -365,13 +375,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case *sshCertOptionsRequireValidator: case *sshCertOptionsRequireValidator:
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshDefaultPublicKeyValidator: case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator: case *sshCertDefaultValidator:
case *sshDefaultDuration: case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
tot++ tot++
} }

View file

@ -34,19 +34,18 @@ const (
// https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by // https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by
// go.step.sm/crypto/x25519. // go.step.sm/crypto/x25519.
type Nebula struct { type Nebula struct {
ID string `json:"-"` ID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Roots []byte `json:"roots"` Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer caPool *nebula.NebulaCAPool
caPool *nebula.NebulaCAPool ctl *Controller
audiences Audiences
} }
// Init verifies and initializes the Nebula provisioner. // Init verifies and initializes the Nebula provisioner.
func (p *Nebula) Init(config Config) error { func (p *Nebula) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -56,19 +55,14 @@ func (p *Nebula) Init(config Config) error {
return errors.New("provisioner root(s) cannot be empty") return errors.New("provisioner root(s) cannot be empty")
} }
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots) p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots)
if err != nil { if err != nil {
return errs.InternalServer("failed to create ca pool: %v", err) return errs.InternalServer("failed to create ca pool: %v", err)
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return nil return
} }
// GetID returns the provisioner id. // GetID returns the provisioner id.
@ -120,7 +114,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) {
// AuthorizeSign returns the list of SignOption for a Sign request. // AuthorizeSign returns the list of SignOption for a Sign request.
func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
crt, claims, err := p.authorizeToken(token, p.audiences.Sign) crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -139,8 +133,9 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
data.SetToken(v) data.SetToken(v)
} }
// The Nebula certificate will be available using the template variable Crt. // The Nebula certificate will be available using the template variable
// For example {{ .Crt.Details.Groups }} can be used to get all the groups. // AuthorizationCrt. For example {{ .AuthorizationCrt.Details.Groups }} can
// be used to get all the groups.
data.SetAuthorizationCertificate(crt) data.SetAuthorizationCertificate(crt)
templateOptions, err := TemplateOptions(p.Options, data) templateOptions, err := TemplateOptions(p.Options, data)
@ -149,11 +144,12 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
} }
return []SignOption{ return []SignOption{
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeNebula, p.Name, ""), newProvisionerExtensionOption(TypeNebula, p.Name, ""),
profileLimitDuration{ profileLimitDuration{
def: p.claimer.DefaultTLSCertDuration(), def: p.ctl.Claimer.DefaultTLSCertDuration(),
notBefore: crt.Details.NotBefore, notBefore: crt.Details.NotBefore,
notAfter: crt.Details.NotAfter, notAfter: crt.Details.NotAfter,
}, },
@ -164,18 +160,18 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
IPs: crt.Details.Ips, IPs: crt.Details.Ips,
}, },
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil }, nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
// Currently the Nebula provisioner only grants host SSH certificates. // Currently the Nebula provisioner only grants host SSH certificates.
func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
} }
crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign) crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -253,11 +249,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
return append(signOptions, return append(signOptions,
templateOptions, templateOptions,
// Checks the validity bounds, and set the validity if has not been set. // Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.claimer, crt.Details.NotAfter}, &sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter},
// Validate public key. // Validate public key.
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil
@ -265,23 +261,20 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error { func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, crt)
return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeRevoke returns an error if the token is not valid. // AuthorizeRevoke returns an error if the token is not valid.
func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error { func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error {
return p.validateToken(token, p.audiences.Revoke) return p.validateToken(token, p.ctl.Audiences.Revoke)
} }
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid. // AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
} }
if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil { if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil {
return err return err
} }
return nil return nil

View file

@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) {
func TestNebula_GetTokenID(t *testing.T) { func TestNebula_GetTokenID(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer) c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv)
_, claims, err := parseToken(t1) _, claims, err := parseToken(t1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv) okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv)
pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions, _, _ := mustNebulaProvisioner(t)
pBadOptions.caPool = p.caPool pBadOptions.caPool = p.caPool
@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
// Ok provisioner // Ok provisioner
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "host", CertType: "host",
KeyID: "test.lan", KeyID: "test.lan",
Principals: []string{"test.lan", "10.1.0.1"}, Principals: []string{"test.lan", "10.1.0.1"},
}, crt, priv) }, crt, priv)
okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv) okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv)
okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)), ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)),
ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)), ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)),
}, crt, priv) }, crt, priv)
failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "user", CertType: "user",
}, crt, priv) }, crt, priv)
failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "host", CertType: "host",
KeyID: "test.lan", KeyID: "test.lan",
Principals: []string{"test.lan", "10.1.0.1", "foo.bar"}, Principals: []string{"test.lan", "10.1.0.1", "foo.bar"},
@ -549,6 +549,8 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
func TestNebula_AuthorizeRenew(t *testing.T) { func TestNebula_AuthorizeRenew(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
now := time.Now().Truncate(time.Second)
// Ok provisioner // Ok provisioner
p, _, _ := mustNebulaProvisioner(t) p, _, _ := mustNebulaProvisioner(t)
@ -567,8 +569,14 @@ func TestNebula_AuthorizeRenew(t *testing.T) {
args args args args
wantErr bool wantErr bool
}{ }{
{"ok", p, args{ctx, &x509.Certificate{}}, false}, {"ok", p, args{ctx, &x509.Certificate{
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -584,12 +592,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) {
// Ok provisioner // Ok provisioner
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
// Fail different CA // Fail different CA
nc, signer := mustNebulaCA(t) nc, signer := mustNebulaCA(t)
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -618,12 +626,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
// Ok provisioner // Ok provisioner
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
// Fail different CA // Fail different CA
nc, signer := mustNebulaCA(t) nc, signer := mustNebulaCA(t)
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
// Provisioner with SSH disabled // Provisioner with SSH disabled
var bFalse bool var bFalse bool
@ -657,7 +665,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
func TestNebula_AuthorizeSSHRenew(t *testing.T) { func TestNebula_AuthorizeSSHRenew(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv) t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv)
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -689,7 +697,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) {
func TestNebula_AuthorizeSSHRekey(t *testing.T) { func TestNebula_AuthorizeSSHRekey(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv) t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv)
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -726,20 +734,20 @@ func TestNebula_authorizeToken(t *testing.T) {
t1 := now() t1 := now()
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv) okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv)
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{ okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{
CertType: "host", CertType: "host",
KeyID: "test.lan", KeyID: "test.lan",
Principals: []string{"test.lan"}, Principals: []string{"test.lan"},
}, crt, priv) }, crt, priv)
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv) okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv)
// Token with errors // Token with errors
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv) failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv)
failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
// Not a nebula token // Not a nebula token
jwk, err := generateJSONWebKey() jwk, err := generateJSONWebKey()
@ -761,7 +769,7 @@ func TestNebula_authorizeToken(t *testing.T) {
IssuedAt: jose.NewNumericDate(t1), IssuedAt: jose.NewNumericDate(t1),
NotBefore: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1),
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
Audience: []string{p.audiences.Sign[0]}, Audience: []string{p.ctl.Audiences.Sign[0]},
} }
sshClaims := jose.Claims{ sshClaims := jose.Claims{
ID: "[REPLACEME]", ID: "[REPLACEME]",
@ -770,7 +778,7 @@ func TestNebula_authorizeToken(t *testing.T) {
IssuedAt: jose.NewNumericDate(t1), IssuedAt: jose.NewNumericDate(t1),
NotBefore: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1),
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
Audience: []string{p.audiences.SSHSign[0]}, Audience: []string{p.ctl.Audiences.SSHSign[0]},
} }
type args struct { type args struct {
@ -785,14 +793,14 @@ func TestNebula_authorizeToken(t *testing.T) {
want1 *jwtPayload want1 *jwtPayload
wantErr bool wantErr bool
}{ }{
{"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{ {"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{
Claims: x509Claims, Claims: x509Claims,
SANs: []string{"10.1.0.1"}, SANs: []string{"10.1.0.1"},
}, false}, }, false},
{"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{ {"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{
Claims: x509Claims, Claims: x509Claims,
}, false}, }, false},
{"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{ {"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
Claims: sshClaims, Claims: sshClaims,
Step: &stepPayload{ Step: &stepPayload{
SSH: &SignSSHOptions{ SSH: &SignSSHOptions{
@ -802,16 +810,16 @@ func TestNebula_authorizeToken(t *testing.T) {
}, },
}, },
}, false}, }, false},
{"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{ {"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
Claims: sshClaims, Claims: sshClaims,
}, false}, }, false},
{"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true}, {"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true},
{"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true}, {"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true}, {"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true}, {"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true}, {"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true}, {"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true}, {"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -38,7 +38,7 @@ func (p *noop) Init(config Config) error {
} }
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{}, nil return []SignOption{p}, nil
} }
func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {

View file

@ -24,6 +24,6 @@ func Test_noop(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
sigOptions, err := p.AuthorizeSign(ctx, "foo") sigOptions, err := p.AuthorizeSign(ctx, "foo")
assert.Equals(t, []SignOption{}, sigOptions) assert.Equals(t, []SignOption{&p}, sigOptions)
assert.Equals(t, nil, err) assert.Equals(t, nil, err)
} }

View file

@ -92,8 +92,7 @@ type OIDC struct {
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
configuration openIDConfiguration configuration openIDConfiguration
keyStore *keyStore keyStore *keyStore
claimer *Claimer ctl *Controller
getIdentityFunc GetIdentityFunc
} }
func sanitizeEmail(email string) string { func sanitizeEmail(email string) string {
@ -172,11 +171,6 @@ func (o *OIDC) Init(config Config) (err error) {
} }
} }
// Update claims with global ones
if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint // Decode and validate openid-configuration endpoint
u, err := url.Parse(o.ConfigurationEndpoint) u, err := url.Parse(o.ConfigurationEndpoint)
if err != nil { if err != nil {
@ -201,13 +195,8 @@ func (o *OIDC) Init(config Config) (err error) {
return err return err
} }
// Set the identity getter if it exists, otherwise use the default. o.ctl, err = NewController(o, o.Claims, config)
if config.GetIdentityFunc == nil { return
o.getIdentityFunc = DefaultIdentityFunc
} else {
o.getIdentityFunc = config.GetIdentityFunc
}
return nil
} }
// ValidatePayload validates the given token payload. // ValidatePayload validates the given token payload.
@ -356,13 +345,14 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
} }
return []SignOption{ return []SignOption{
o,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()), profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()), newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()),
}, nil }, nil
} }
@ -371,15 +361,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if o.claimer.IsDisableRenewal() { return o.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !o.claimer.IsSSHCAEnabled() { if !o.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName()) return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName())
} }
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
@ -394,7 +381,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Get the identity using either the default identityFunc or one injected // Get the identity using either the default identityFunc or one injected
// externally. Note that the PreferredUsername might be empty. // externally. Note that the PreferredUsername might be empty.
// TBD: Would preferred_username present a safety issue here? // TBD: Would preferred_username present a safety issue here?
iden, err := o.getIdentityFunc(ctx, o, claims.Email) iden, err := o.ctl.GetIdentity(ctx, claims.Email)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
} }
@ -445,11 +432,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
return append(signOptions, return append(signOptions,
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{o.claimer}, &sshDefaultDuration{o.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{o.claimer}, &sshCertValidityValidator{o.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil

View file

@ -6,16 +6,17 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func Test_openIDConfiguration_Validate(t *testing.T) { func Test_openIDConfiguration_Validate(t *testing.T) {
@ -246,8 +247,8 @@ func TestOIDC_authorizeToken(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else { } else {
@ -317,30 +318,31 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
assert.Len(t, 5, got) assert.Len(t, 6, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *OIDC:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeOIDC)) assert.Equals(t, v.Type, TypeOIDC)
assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.Name, tt.prov.GetName())
assert.Equals(t, v.CredentialID, tt.prov.ClientID) assert.Equals(t, v.CredentialID, tt.prov.ClientID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case emailOnlyIdentity: case emailOnlyIdentity:
assert.Equals(t, string(v), "name@smallstep.com") assert.Equals(t, string(v), "name@smallstep.com")
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -402,8 +404,8 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })
@ -411,6 +413,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
} }
func TestOIDC_AuthorizeRenew(t *testing.T) { func TestOIDC_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateOIDC() p1, err := generateOIDC()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateOIDC() p2, err := generateOIDC()
@ -419,7 +422,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -432,8 +435,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -441,8 +450,8 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })
@ -478,7 +487,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p6.Claims = &Claims{EnableSSHCA: &disable} p6.Claims = &Claims{EnableSSHCA: &disable}
p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
// Update configuration endpoints and initialize // Update configuration endpoints and initialize
@ -494,10 +503,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p4.Init(config))
assert.FatalError(t, p5.Init(config)) assert.FatalError(t, p5.Init(config))
p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
return &Identity{Usernames: []string{"max", "mariano"}}, nil return &Identity{Usernames: []string{"max", "mariano"}}, nil
} }
p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
return nil, errors.New("force") return nil, errors.New("force")
} }
// Additional test needed for empty usernames and duplicate email and usernames // Additional test needed for empty usernames and duplicate email and usernames
@ -527,8 +536,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name", "name@smallstep.com"}, CertType: "user", Principals: []string{"name", "name@smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
@ -597,8 +606,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
return return
} }
if err != nil { if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
@ -665,8 +674,8 @@ func TestOIDC_AuthorizeSSHRevoke(t *testing.T) {
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr)
} else if err != nil { } else if err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} }
}) })

View file

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
stderrors "errors" stderrors "errors"
"net/url" "net/url"
"regexp"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -47,6 +46,7 @@ var ErrAllowTokenReuse = stderrors.New("allow token reuse")
// Audiences stores all supported audiences by request type. // Audiences stores all supported audiences by request type.
type Audiences struct { type Audiences struct {
Sign []string Sign []string
Renew []string
Revoke []string Revoke []string
SSHSign []string SSHSign []string
SSHRevoke []string SSHRevoke []string
@ -57,6 +57,7 @@ type Audiences struct {
// All returns all supported audiences across all request types in one list. // All returns all supported audiences across all request types in one list.
func (a Audiences) All() (auds []string) { func (a Audiences) All() (auds []string) {
auds = a.Sign auds = a.Sign
auds = append(auds, a.Renew...)
auds = append(auds, a.Revoke...) auds = append(auds, a.Revoke...)
auds = append(auds, a.SSHSign...) auds = append(auds, a.SSHSign...)
auds = append(auds, a.SSHRevoke...) auds = append(auds, a.SSHRevoke...)
@ -70,6 +71,7 @@ func (a Audiences) All() (auds []string) {
func (a Audiences) WithFragment(fragment string) Audiences { func (a Audiences) WithFragment(fragment string) Audiences {
ret := Audiences{ ret := Audiences{
Sign: make([]string, len(a.Sign)), Sign: make([]string, len(a.Sign)),
Renew: make([]string, len(a.Renew)),
Revoke: make([]string, len(a.Revoke)), Revoke: make([]string, len(a.Revoke)),
SSHSign: make([]string, len(a.SSHSign)), SSHSign: make([]string, len(a.SSHSign)),
SSHRevoke: make([]string, len(a.SSHRevoke)), SSHRevoke: make([]string, len(a.SSHRevoke)),
@ -83,6 +85,13 @@ func (a Audiences) WithFragment(fragment string) Audiences {
ret.Sign[i] = s ret.Sign[i] = s
} }
} }
for i, s := range a.Renew {
if u, err := url.Parse(s); err == nil {
ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
} else {
ret.Renew[i] = s
}
}
for i, s := range a.Revoke { for i, s := range a.Revoke {
if u, err := url.Parse(s); err == nil { if u, err := url.Parse(s); err == nil {
ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
@ -210,6 +219,12 @@ type Config struct {
// GetIdentityFunc is a function that returns an identity that will be // GetIdentityFunc is a function that returns an identity that will be
// used by the provisioner to populate certificate attributes. // used by the provisioner to populate certificate attributes.
GetIdentityFunc GetIdentityFunc GetIdentityFunc GetIdentityFunc
// AuthorizeRenewFunc is a function that returns nil if a given X.509
// certificate can be renewed.
AuthorizeRenewFunc AuthorizeRenewFunc
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
// certificate can be renewed.
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
} }
type provisioner struct { type provisioner struct {
@ -278,32 +293,6 @@ func (l *List) UnmarshalJSON(data []byte) error {
return nil return nil
} }
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
// SanitizeSSHUserPrincipal grabs an email or a string with the format
// local@domain and returns a sanitized version of the local, valid to be used
// as a user name. If the email starts with a letter between a and z, the
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
func SanitizeSSHUserPrincipal(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i]
}
return strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= '0' && r <= '9':
return r
case r == '-':
return '-'
case r == '.': // drop dots
return -1
default:
return '_'
}
}, strings.ToLower(email))
}
type base struct{} type base struct{}
// AuthorizeSign returns an unimplemented error. Provisioners should overwrite // AuthorizeSign returns an unimplemented error. Provisioners should overwrite
@ -348,66 +337,12 @@ func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certif
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented") return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
} }
// Identity is the type representing an externally supplied identity that is used
// by provisioners to populate certificate fields.
type Identity struct {
Usernames []string `json:"usernames"`
Permissions `json:"permissions"`
}
// Permissions defines extra extensions and critical options to grant to an SSH certificate. // Permissions defines extra extensions and critical options to grant to an SSH certificate.
type Permissions struct { type Permissions struct {
Extensions map[string]string `json:"extensions"` Extensions map[string]string `json:"extensions"`
CriticalOptions map[string]string `json:"criticalOptions"` CriticalOptions map[string]string `json:"criticalOptions"`
} }
// GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
// DefaultIdentityFunc return a default identity depending on the provisioner
// type. For OIDC email is always present and the usernames might
// contain empty strings.
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) {
case *OIDC:
// OIDC principals would be:
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
// 2. Sanitized local.
// 3. Raw local (if different).
// 4. Email address.
name := SanitizeSSHUserPrincipal(email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
}
usernames := []string{name}
if i := strings.LastIndex(email, "@"); i >= 0 {
usernames = append(usernames, email[:i])
}
usernames = append(usernames, email)
return &Identity{
Usernames: SanitizeStringSlices(usernames),
}, nil
default:
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
}
}
// SanitizeStringSlices removes duplicated an empty strings.
func SanitizeStringSlices(original []string) []string {
output := []string{}
seen := make(map[string]struct{})
for _, entry := range original {
if entry == "" {
continue
}
if _, value := seen[entry]; !value {
seen[entry] = struct{}{}
output = append(output, entry)
}
}
return output
}
// MockProvisioner for testing // MockProvisioner for testing
type MockProvisioner struct { type MockProvisioner struct {
Mret1, Mret2, Mret3 interface{} Mret1, Mret2, Mret3 interface{}

View file

@ -2,13 +2,14 @@ package provisioner
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"testing" "testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestType_String(t *testing.T) { func TestType_String(t *testing.T) {
@ -240,8 +241,8 @@ func TestUnimplementedMethods(t *testing.T) {
default: default:
t.Errorf("unexpected method %s", tt.method) t.Errorf("unexpected method %s", tt.method)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized) assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
assert.Equals(t, err.Error(), msg) assert.Equals(t, err.Error(), msg)
}) })

View file

@ -11,28 +11,30 @@ import (
// SCEP provisioning flow // SCEP provisioning flow
type SCEP struct { type SCEP struct {
*base *base
ID string `json:"-"` ID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
ForceCN bool `json:"forceCN,omitempty"` ForceCN bool `json:"forceCN,omitempty"`
ChallengePassword string `json:"challenge,omitempty"` ChallengePassword string `json:"challenge,omitempty"`
Capabilities []string `json:"capabilities,omitempty"` Capabilities []string `json:"capabilities,omitempty"`
// IncludeRoot makes the provisioner return the CA root in addition to the // IncludeRoot makes the provisioner return the CA root in addition to the
// intermediate in the GetCACerts response // intermediate in the GetCACerts response
IncludeRoot bool `json:"includeRoot,omitempty"` IncludeRoot bool `json:"includeRoot,omitempty"`
// MinimumPublicKeyLength is the minimum length for public keys in CSRs // MinimumPublicKeyLength is the minimum length for public keys in CSRs
MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"`
// Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7
// at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63
// Defaults to 0, being DES-CBC // Defaults to 0, being DES-CBC
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
Options *Options `json:"options,omitempty"`
Claims *Claims `json:"claims,omitempty"`
claimer *Claimer
Options *Options `json:"options,omitempty"`
Claims *Claims `json:"claims,omitempty"`
secretChallengePassword string secretChallengePassword string
encryptionAlgorithm int encryptionAlgorithm int
ctl *Controller
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -77,7 +79,7 @@ func (s *SCEP) GetOptions() *Options {
// DefaultTLSCertDuration returns the default TLS cert duration enforced by // DefaultTLSCertDuration returns the default TLS cert duration enforced by
// the provisioner. // the provisioner.
func (s *SCEP) DefaultTLSCertDuration() time.Duration { func (s *SCEP) DefaultTLSCertDuration() time.Duration {
return s.claimer.DefaultTLSCertDuration() return s.ctl.Claimer.DefaultTLSCertDuration()
} }
// Init initializes and validates the fields of a SCEP type. // Init initializes and validates the fields of a SCEP type.
@ -90,11 +92,6 @@ func (s *SCEP) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty") return errors.New("provisioner name cannot be empty")
} }
// Update claims with global ones
if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil {
return err
}
// Mask the actual challenge value, so it won't be marshaled // Mask the actual challenge value, so it won't be marshaled
s.secretChallengePassword = s.ChallengePassword s.secretChallengePassword = s.ChallengePassword
s.ChallengePassword = "*** redacted ***" s.ChallengePassword = "*** redacted ***"
@ -115,7 +112,8 @@ func (s *SCEP) Init(config Config) (err error) {
// TODO: add other, SCEP specific, options? // TODO: add other, SCEP specific, options?
return err s.ctl, err = NewController(s, s.Claims, config)
return
} }
// AuthorizeSign does not do any verification, because all verification is handled // AuthorizeSign does not do any verification, because all verification is handled
@ -123,13 +121,14 @@ func (s *SCEP) Init(config Config) (err error) {
// on the resulting certificate. // on the resulting certificate.
func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{ return []SignOption{
s,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeSCEP, s.Name, ""), newProvisionerExtensionOption(TypeSCEP, s.Name, ""),
newForceCNOption(s.ForceCN), newForceCNOption(s.ForceCN),
profileDefaultDuration(s.claimer.DefaultTLSCertDuration()), profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength),
newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()), newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()),
}, nil }, nil
} }

View file

@ -6,7 +6,6 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1"
"encoding/json" "encoding/json"
"net" "net"
"net/http" "net/http"
@ -14,7 +13,6 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
@ -404,17 +402,12 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
return nil return nil
} }
var ( // type stepProvisionerASN1 struct {
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} // Type int
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) // Name []byte
) // CredentialID []byte
// KeyValuePairs []string `asn1:"optional,omitempty"`
type stepProvisionerASN1 struct { // }
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
type forceCNOption struct { type forceCNOption struct {
ForceCN bool ForceCN bool
@ -441,23 +434,22 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
} }
type provisionerExtensionOption struct { type provisionerExtensionOption struct {
Type int Extension
Name string
CredentialID string
KeyValuePairs []string
} }
func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption { func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption {
return &provisionerExtensionOption{ return &provisionerExtensionOption{
Type: int(typ), Extension: Extension{
Name: name, Type: typ,
CredentialID: credentialID, Name: name,
KeyValuePairs: keyValuePairs, CredentialID: credentialID,
KeyValuePairs: keyValuePairs,
},
} }
} }
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...) ext, err := o.ToExtension()
if err != nil { if err != nil {
return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") return errs.NewError(http.StatusInternalServerError, err, "error creating certificate")
} }
@ -471,20 +463,3 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption
cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...) cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...)
return nil return nil
} }
func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) {
b, err := asn1.Marshal(stepProvisionerASN1{
Type: typ,
Name: []byte(name),
CredentialID: []byte(credentialID),
KeyValuePairs: keyValuePairs,
})
if err != nil {
return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension")
}
return pkix.Extension{
Id: stepOIDProvisioner,
Critical: false,
Value: b,
}, nil
}

View file

@ -636,18 +636,18 @@ func Test_newProvisionerExtension_Option(t *testing.T) {
valid: func(cert *x509.Certificate) { valid: func(cert *x509.Certificate) {
if assert.Len(t, 1, cert.ExtraExtensions) { if assert.Len(t, 1, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0] ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner) assert.Equals(t, ext.Id, StepOIDProvisioner)
} }
}, },
} }
}, },
"ok/prepend": func() test { "ok/prepend": func() test {
return test{ return test{
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
valid: func(cert *x509.Certificate) { valid: func(cert *x509.Certificate) {
if assert.Len(t, 3, cert.ExtraExtensions) { if assert.Len(t, 3, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0] ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner) assert.Equals(t, ext.Id, StepOIDProvisioner)
assert.False(t, ext.Critical) assert.False(t, ext.Critical)
} }
}, },

View file

@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) {
func Test_sshCertValidityValidator(t *testing.T) { func Test_sshCertValidityValidator(t *testing.T) {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
v := sshCertValidityValidator{p.claimer} v := sshCertValidityValidator{p.ctl.Claimer}
n := now() n := now()
tests := []struct { tests := []struct {
name string name string
@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) {
tests := map[string]func() test{ tests := map[string]func() test{
"fail/type-not-set": func() test { "fail/type-not-set": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) {
}, },
"fail/type-not-recognized": func() test { "fail/type-not-recognized": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 4, CertType: 4,
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) {
}, },
"fail/requested-validAfter-after-limit": func() test { "fail/requested-validAfter-after-limit": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: uint64(n.Add(2 * time.Hour).Unix()), ValidAfter: uint64(n.Add(2 * time.Hour).Unix()),
@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) {
}, },
"fail/requested-validBefore-after-limit": func() test { "fail/requested-validBefore-after-limit": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/no-limit": func() test { "ok/no-limit": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
}, },
@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/defaults": func() test { "ok/defaults": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
}, },
@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/valid-requested-validBefore": func() test { "ok/valid-requested-validBefore": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix()) va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: va, ValidAfter: va,
@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/empty-requested-validBefore-limit-after-default": func() test { "ok/empty-requested-validBefore-limit-after-default": func() test {
va := uint64(n.Unix()) va := uint64(n.Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: va, ValidAfter: va,
@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/empty-requested-validBefore-limit-before-default": func() test { "ok/empty-requested-validBefore-limit-before-default": func() test {
va := uint64(n.Unix()) va := uint64(n.Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: va, ValidAfter: va,

View file

@ -29,8 +29,7 @@ type SSHPOP struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
claimer *Claimer ctl *Controller
audiences Audiences
sshPubKeys *SSHKeys sshPubKeys *SSHKeys
} }
@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) {
} }
// Init initializes and validates the fields of a SSHPOP type. // Init initializes and validates the fields of a SSHPOP type.
func (p *SSHPOP) Init(config Config) error { func (p *SSHPOP) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -93,15 +92,11 @@ func (p *SSHPOP) Init(config Config) error {
return errors.New("provisioner public SSH validation keys cannot be empty") return errors.New("provisioner public SSH validation keys cannot be empty")
} }
// Update claims with global ones
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.sshPubKeys = config.SSHKeys p.sshPubKeys = config.SSHKeys
return nil
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -109,7 +104,7 @@ func (p *SSHPOP) Init(config Config) error {
// e.g. a Sign request will auth/validate different fields than a Revoke request. // e.g. a Sign request will auth/validate different fields than a Revoke request.
// //
// Checking for certificate revocation has been moved to the authority package. // Checking for certificate revocation has been moved to the authority package.
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) {
sshCert, jwt, err := ExtractSSHPOPCert(token) sshCert, jwt, err := ExtractSSHPOPCert(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, return nil, errs.Wrap(http.StatusUnauthorized, err,
@ -117,13 +112,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
} }
// Check validity period of the certificate. // Check validity period of the certificate.
n := time.Now() //
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { // Controller.AuthorizeSSHRenew will validate this on the renewal flow.
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") if checkValidity {
} unixNow := time.Now().Unix()
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
}
if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
}
} }
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
if !ok { if !ok {
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey") return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
@ -186,7 +186,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
// AuthorizeSSHRevoke validates the authorization token and extracts/validates // AuthorizeSSHRevoke validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := p.authorizeToken(token, p.audiences.SSHRevoke) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
} }
@ -199,22 +199,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
// AuthorizeSSHRenew validates the authorization token and extracts/validates // AuthorizeSSHRenew validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRenew) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert)
return claims.sshCert, nil
} }
// AuthorizeSSHRekey validates the authorization token and extracts/validates // AuthorizeSSHRekey validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRekey) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true)
if err != nil { if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
} }
@ -225,11 +223,10 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
}, nil }, nil
} }
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate

View file

@ -5,16 +5,19 @@ import (
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"fmt"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestSSHPOP_Getters(t *testing.T) { func TestSSHPOP_Getters(t *testing.T) {
@ -38,6 +41,7 @@ func TestSSHPOP_Getters(t *testing.T) {
} }
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
now := time.Now()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -46,6 +50,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now.Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
}
if err := cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -207,9 +217,9 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -279,8 +289,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil { if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -360,8 +370,8 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
tc := tt(t) tc := tt(t)
if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil { if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -442,8 +452,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
tc := tt(t) tc := tt(t)
if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil { if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -455,9 +465,9 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
case *sshDefaultPublicKeyValidator: case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator: case *sshCertDefaultValidator:
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
assert.Equals(t, tc.cert.Nonce, cert.Nonce) assert.Equals(t, tc.cert.Nonce, cert.Nonce)

View file

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy
MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc
pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49
BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY
wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g=
-----END CERTIFICATE-----

View file

@ -0,0 +1,22 @@
-----BEGIN CERTIFICATE-----
MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx
MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0
4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz
bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf
SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB
bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w==
-----END CERTIFICATE-----

View file

@ -24,20 +24,22 @@ import (
) )
var ( var (
defaultDisableRenewal = false defaultDisableRenewal = false
defaultEnableSSHCA = true defaultAllowRenewAfterExpiry = false
globalProvisionerClaims = Claims{ defaultEnableSSHCA = true
MinTLSDur: &Duration{5 * time.Minute}, globalProvisionerClaims = Claims{
MaxTLSDur: &Duration{24 * time.Hour}, MinTLSDur: &Duration{5 * time.Minute},
DefaultTLSDur: &Duration{24 * time.Hour}, MaxTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &defaultDisableRenewal, DefaultTLSDur: &Duration{24 * time.Hour},
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &defaultEnableSSHCA, EnableSSHCA: &defaultEnableSSHCA,
DisableRenewal: &defaultDisableRenewal,
AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry,
} }
testAudiences = Audiences{ testAudiences = Audiences{
Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"},
@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil { p := &JWK{
return nil, err
}
return &JWK{
Name: name, Name: name,
Type: "JWK", Type: "JWK",
Key: &public, Key: &public,
EncryptedKey: encrypted, EncryptedKey: encrypted,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, }
claimer: claimer, p.ctl, err = NewController(p, p.Claims, Config{
}, nil Audiences: testAudiences,
})
return p, err
} }
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
pubKeys := []interface{}{fooPub, barPub} pubKeys := []interface{}{fooPub, barPub}
if inputPubKey != nil { if inputPubKey != nil {
pubKeys = append(pubKeys, inputPubKey) pubKeys = append(pubKeys, inputPubKey)
} }
return &K8sSA{ p := &K8sSA{
Name: K8sSAName, Name: K8sSAName,
Type: "K8sSA", Type: "K8sSA",
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, pubKeys: pubKeys,
claimer: claimer, }
pubKeys: pubKeys, p.ctl, err = NewController(p, p.Claims, Config{
}, nil Audiences: testAudiences,
})
return p, err
} }
func generateSSHPOP() (*SSHPOP, error) { func generateSSHPOP() (*SSHPOP, error) {
@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub") userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
if err != nil { if err != nil {
return nil, err return nil, err
@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) {
return nil, err return nil, err
} }
return &SSHPOP{ p := &SSHPOP{
Name: name, Name: name,
Type: "SSHPOP", Type: "SSHPOP",
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences,
claimer: claimer,
sshPubKeys: &SSHKeys{ sshPubKeys: &SSHKeys{
UserKeys: []ssh.PublicKey{userKey}, UserKeys: []ssh.PublicKey{userKey},
HostKeys: []ssh.PublicKey{hostKey}, HostKeys: []ssh.PublicKey{hostKey},
}, },
}, nil }
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
} }
func generateX5C(root []byte) (*X5C, error) { func generateX5C(root []byte) (*X5C, error) {
@ -283,11 +279,6 @@ M46l92gdOozT
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
rootPool := x509.NewCertPool() rootPool := x509.NewCertPool()
var ( var (
@ -305,15 +296,17 @@ M46l92gdOozT
} }
rootPool.AddCert(cert) rootPool.AddCert(cert)
} }
return &X5C{ p := &X5C{
Name: name, Name: name,
Type: "X5C", Type: "X5C",
Roots: root, Roots: root,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, rootPool: rootPool,
claimer: claimer, }
rootPool: rootPool, p.ctl, err = NewController(p, p.Claims, Config{
}, nil Audiences: testAudiences,
})
return p, err
} }
func generateOIDC() (*OIDC, error) { func generateOIDC() (*OIDC, error) {
@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims) p := &OIDC{
if err != nil {
return nil, err
}
return &OIDC{
Name: name, Name: name,
Type: "OIDC", Type: "OIDC",
ClientID: clientID, ClientID: clientID,
@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
claimer: claimer, }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
} }
func generateGCP() (*GCP, error) { func generateGCP() (*GCP, error) {
@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims) p := &GCP{
if err != nil {
return nil, err
}
return &GCP{
Type: "GCP", Type: "GCP",
Name: name, Name: name,
ServiceAccounts: []string{serviceAccount}, ServiceAccounts: []string{serviceAccount},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
claimer: claimer,
config: newGCPConfig(), config: newGCPConfig(),
keyStore: &keyStore{ keyStore: &keyStore{
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
audiences: testAudiences.WithFragment("gcp/" + name), }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("gcp/" + name),
})
return p, err
} }
func generateAWS() (*AWS, error) { func generateAWS() (*AWS, error) {
@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate)) block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" { if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate") return nil, errors.New("error decoding AWS certificate")
@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate") return nil, errors.Wrap(err, "error parsing AWS certificate")
} }
return &AWS{ p := &AWS{
Type: "AWS", Type: "AWS",
Name: name, Name: name,
Accounts: []string{accountID}, Accounts: []string{accountID},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v2", "v1"}, IMDSVersions: []string{"v2", "v1"},
claimer: claimer,
config: &awsConfig{ config: &awsConfig{
identityURL: awsIdentityURL, identityURL: awsIdentityURL,
signatureURL: awsSignatureURL, signatureURL: awsSignatureURL,
@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) {
certificates: []*x509.Certificate{cert}, certificates: []*x509.Certificate{cert},
signatureAlgorithm: awsSignatureAlgorithm, signatureAlgorithm: awsSignatureAlgorithm,
}, },
audiences: testAudiences.WithFragment("aws/" + name), }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
return p, err
} }
func generateAWSWithServer() (*AWS, *httptest.Server, error) { func generateAWSWithServer() (*AWS, *httptest.Server, error) {
@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate)) block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" { if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate") return nil, errors.New("error decoding AWS certificate")
@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate") return nil, errors.Wrap(err, "error parsing AWS certificate")
} }
return &AWS{ p := &AWS{
Type: "AWS", Type: "AWS",
Name: name, Name: name,
Accounts: []string{accountID}, Accounts: []string{accountID},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v1"}, IMDSVersions: []string{"v1"},
claimer: claimer,
config: &awsConfig{ config: &awsConfig{
identityURL: awsIdentityURL, identityURL: awsIdentityURL,
signatureURL: awsSignatureURL, signatureURL: awsSignatureURL,
@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) {
certificates: []*x509.Certificate{cert}, certificates: []*x509.Certificate{cert},
signatureAlgorithm: awsSignatureAlgorithm, signatureAlgorithm: awsSignatureAlgorithm,
}, },
audiences: testAudiences.WithFragment("aws/" + name), }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
return p, err
} }
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
jwk, err := generateJSONWebKey() jwk, err := generateJSONWebKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Azure{ p := &Azure{
Type: "Azure", Type: "Azure",
Name: name, Name: name,
TenantID: tenantID, TenantID: tenantID,
Audience: azureDefaultAudience, Audience: azureDefaultAudience,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
claimer: claimer,
config: newAzureConfig(tenantID), config: newAzureConfig(tenantID),
oidcConfig: openIDConfiguration{ oidcConfig: openIDConfiguration{
Issuer: "https://sts.windows.net/" + tenantID + "/", Issuer: "https://sts.windows.net/" + tenantID + "/",
@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
}, nil }
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
} }
func generateAzureWithServer() (*Azure, *httptest.Server, error) { func generateAzureWithServer() (*Azure, *httptest.Server, error) {
@ -671,7 +656,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
w.Header().Add("Cache-Control", "max-age=5") w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(az.keyStore.keySet)) writeJSON(w, getPublic(az.keyStore.keySet))
case "/metadata/identity/oauth2/token": case "/metadata/identity/oauth2/token":
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0]) tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0])
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} else { } else {
@ -1009,7 +994,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner( sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
@ -1017,6 +1002,12 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
if err != nil { if err != nil {
return "", err return "", err
} }
var xmsMirID string
if resourceType == "vm" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName)
} else if resourceType == "uai" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName)
}
claims := azurePayload{ claims := azurePayload{
Claims: jose.Claims{ Claims: jose.Claims{
@ -1034,7 +1025,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
ObjectID: "the-oid", ObjectID: "the-oid",
TenantID: tenantID, TenantID: tenantID,
Version: "the-version", Version: "the-version",
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine), XMSMirID: xmsMirID,
} }
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }

View file

@ -26,15 +26,14 @@ type x5cPayload struct {
// signature requests. // signature requests.
type X5C struct { type X5C struct {
*base *base
ID string `json:"-"` ID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Roots []byte `json:"roots"` Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer ctl *Controller
audiences Audiences rootPool *x509.CertPool
rootPool *x509.CertPool
} }
// GetID returns the provisioner unique identifier. The name and credential id // GetID returns the provisioner unique identifier. The name and credential id
@ -86,7 +85,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) {
} }
// Init initializes and validates the fields of a X5C type. // Init initializes and validates the fields of a X5C type.
func (p *X5C) Init(config Config) error { func (p *X5C) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -101,6 +100,7 @@ func (p *X5C) Init(config Config) error {
var ( var (
block *pem.Block block *pem.Block
rest = p.Roots rest = p.Roots
count int
) )
for rest != nil { for rest != nil {
block, rest = pem.Decode(rest) block, rest = pem.Decode(rest)
@ -111,22 +111,18 @@ func (p *X5C) Init(config Config) error {
if err != nil { if err != nil {
return errors.Wrap(err, "error parsing x509 certificate from PEM block") return errors.Wrap(err, "error parsing x509 certificate from PEM block")
} }
count++
p.rootPool.AddCert(cert) p.rootPool.AddCert(cert)
} }
// Verify that at least one root was found. // Verify that at least one root was found.
if len(p.rootPool.Subjects()) == 0 { if count == 0 {
return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName())
} }
// Update claims with global ones config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
var err error p.ctl, err = NewController(p, p.Claims, config)
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil { return
return err
}
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
return nil
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -189,13 +185,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error { func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
} }
@ -213,40 +209,45 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
data.SetToken(v) data.SetToken(v)
} }
// The X509 certificate will be available using the template variable
// AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be
// used to get all the domains.
data.SetAuthorizationCertificate(claims.chains[0][0])
templateOptions, err := TemplateOptions(p.Options, data) templateOptions, err := TemplateOptions(p.Options, data)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
} }
return []SignOption{ return []SignOption{
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeX5C, p.Name, ""), newProvisionerExtensionOption(TypeX5C, p.Name, ""),
profileLimitDuration{p.claimer.DefaultTLSCertDuration(), profileLimitDuration{
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter}, p.ctl.Claimer.DefaultTLSCertDuration(),
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter,
},
// validators // validators
commonNameValidator(claims.Subject), commonNameValidator(claims.Subject),
defaultSANsValidator(claims.SANs), defaultSANsValidator(claims.SANs),
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil }, nil
} }
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
} }
@ -287,6 +288,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
data.SetToken(v) data.SetToken(v)
} }
// The X509 certificate will be available using the template variable
// AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be
// used to get all the domains.
data.SetAuthorizationCertificate(claims.chains[0][0])
templateOptions, err := TemplateSSHOptions(p.Options, data) templateOptions, err := TemplateSSHOptions(p.Options, data)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
@ -304,11 +310,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
return append(signOptions, return append(signOptions,
// Checks the validity bounds, and set the validity if has not been set. // Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter}, &sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter},
// Validate public key. // Validate public key.
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
), nil ), nil

View file

@ -2,16 +2,19 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509"
"errors"
"fmt"
"net/http" "net/http"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestX5C_Getters(t *testing.T) { func TestX5C_Getters(t *testing.T) {
@ -69,8 +72,8 @@ func TestX5C_Init(t *testing.T) {
}, },
"fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences}, p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")},
err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), err: errors.New("no x509 certificates found in roots attribute for provisioner 'foo'"),
} }
}, },
"fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest { "fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest {
@ -117,9 +120,11 @@ M46l92gdOozT
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: p, p: p,
extraValid: func(p *X5C) error { extraValid: func(p *X5C) error {
// nolint:staticcheck // We don't have a different way to
// check the number of certificates in the pool.
numCerts := len(p.rootPool.Subjects()) numCerts := len(p.rootPool.Subjects())
if numCerts != 2 { if numCerts != 2 {
return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts) return fmt.Errorf("unexpected number of certs: want 2, but got %d", numCerts)
} }
return nil return nil
}, },
@ -141,7 +146,7 @@ M46l92gdOozT
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID())) assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID()))
if tc.extraValid != nil { if tc.extraValid != nil {
assert.Nil(t, tc.extraValid(tc.p)) assert.Nil(t, tc.extraValid(tc.p))
} }
@ -384,8 +389,8 @@ lgsqsR63is+0YQ==
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -455,7 +460,7 @@ func TestX5C_AuthorizeSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -463,19 +468,20 @@ func TestX5C_AuthorizeSign(t *testing.T) {
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) { if assert.NotNil(t, opts) {
assert.Equals(t, len(opts), 7) assert.Equals(t, len(opts), 8)
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case *X5C:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeX5C)) assert.Equals(t, v.Type, TypeX5C)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileLimitDuration: case profileLimitDuration:
assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration()) assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration())
claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign) claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter)
case commonNameValidator: case commonNameValidator:
@ -484,10 +490,10 @@ func TestX5C_AuthorizeSign(t *testing.T) {
case defaultSANsValidator: case defaultSANsValidator:
assert.Equals(t, []string(v), tc.sans) assert.Equals(t, []string(v), tc.sans)
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
} }
} }
@ -538,8 +544,8 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil { if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -551,6 +557,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
} }
func TestX5C_AuthorizeRenew(t *testing.T) { func TestX5C_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct { type test struct {
p *X5C p *X5C
code int code int
@ -563,12 +570,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()), err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -582,10 +589,13 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil { if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -618,13 +628,13 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
enable := false enable := false
p.Claims = &Claims{EnableSSHCA: &enable} p.Claims = &Claims{EnableSSHCA: &enable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
token: "foo", token: "foo",
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()), err: fmt.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()),
} }
}, },
"fail/invalid-token": func(t *testing.T) test { "fail/invalid-token": func(t *testing.T) test {
@ -745,7 +755,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
tc := tt(t) tc := tt(t)
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil { if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -774,13 +784,13 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
case sshCertDefaultsModifier: case sshCertDefaultsModifier:
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert}) assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
case *sshLimitDuration: case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc:
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
tot++ tot++
} }

View file

@ -87,20 +87,20 @@ func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, e
return p, nil return p, nil
} }
func (a *Authority) generateProvisionerConfig(ctx context.Context) (*provisioner.Config, error) { func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.Config, error) {
// Merge global and configuration claims // Merge global and configuration claims
claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims) claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims)
if err != nil { if err != nil {
return nil, err return provisioner.Config{}, err
} }
// TODO: should we also be combining the ssh federated roots here? // TODO: should we also be combining the ssh federated roots here?
// If we rotate ssh roots keys, sshpop provisioner will lose ability to // If we rotate ssh roots keys, sshpop provisioner will lose ability to
// validate old SSH certificates, unless they are added as federated certs. // validate old SSH certificates, unless they are added as federated certs.
sshKeys, err := a.GetSSHRoots(ctx) sshKeys, err := a.GetSSHRoots(ctx)
if err != nil { if err != nil {
return nil, err return provisioner.Config{}, err
} }
return &provisioner.Config{ return provisioner.Config{
Claims: claimer.Claims(), Claims: claimer.Claims(),
Audiences: a.config.GetAudiences(), Audiences: a.config.GetAudiences(),
DB: a.db, DB: a.db,
@ -108,7 +108,9 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (*provisioner
UserKeys: sshKeys.UserKeys, UserKeys: sshKeys.UserKeys,
HostKeys: sshKeys.HostKeys, HostKeys: sshKeys.HostKeys,
}, },
GetIdentityFunc: a.getIdentityFunc, GetIdentityFunc: a.getIdentityFunc,
AuthorizeRenewFunc: a.authorizeRenewFunc,
AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc,
}, nil }, nil
} }
@ -133,9 +135,18 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi
"provisioner with token ID %s already exists", certProv.GetIDForToken()) "provisioner with token ID %s already exists", certProv.GetIDForToken())
} }
provisionerConfig, err := a.generateProvisionerConfig(ctx)
if err != nil {
return admin.WrapErrorISE(err, "error generating provisioner config")
}
if err := certProv.Init(provisionerConfig); err != nil {
return admin.WrapError(admin.ErrorBadRequestType, err, "error validating configuration for provisioner %s", prov.Name)
}
// Store to database -- this will set the ID. // Store to database -- this will set the ID.
if err := a.adminDB.CreateProvisioner(ctx, prov); err != nil { if err := a.adminDB.CreateProvisioner(ctx, prov); err != nil {
return admin.WrapErrorISE(err, "error creating admin") return admin.WrapErrorISE(err, "error creating provisioner")
} }
// We need a new conversion that has the newly set ID. // We need a new conversion that has the newly set ID.
@ -145,12 +156,7 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi
"error converting to certificates provisioner from linkedca provisioner") "error converting to certificates provisioner from linkedca provisioner")
} }
provisionerConfig, err := a.generateProvisionerConfig(ctx) if err := certProv.Init(provisionerConfig); err != nil {
if err != nil {
return admin.WrapErrorISE(err, "error generating provisioner config")
}
if err := certProv.Init(*provisionerConfig); err != nil {
return admin.WrapErrorISE(err, "error initializing provisioner %s", prov.Name) return admin.WrapErrorISE(err, "error initializing provisioner %s", prov.Name)
} }
@ -179,7 +185,7 @@ func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisio
return admin.WrapErrorISE(err, "error generating provisioner config") return admin.WrapErrorISE(err, "error generating provisioner config")
} }
if err := certProv.Init(*provisionerConfig); err != nil { if err := certProv.Init(provisionerConfig); err != nil {
return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name) return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name)
} }
@ -431,7 +437,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
} }
pc := &provisioner.Claims{ pc := &provisioner.Claims{
DisableRenewal: &c.DisableRenewal, DisableRenewal: &c.DisableRenewal,
AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry,
} }
var err error var err error
@ -469,12 +476,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims {
} }
disableRenewal := config.DefaultDisableRenewal disableRenewal := config.DefaultDisableRenewal
allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry
if c.DisableRenewal != nil { if c.DisableRenewal != nil {
disableRenewal = *c.DisableRenewal disableRenewal = *c.DisableRenewal
} }
if c.AllowRenewAfterExpiry != nil {
allowRenewAfterExpiry = *c.AllowRenewAfterExpiry
}
lc := &linkedca.Claims{ lc := &linkedca.Claims{
DisableRenewal: disableRenewal, DisableRenewal: disableRenewal,
AllowRenewAfterExpiry: allowRenewAfterExpiry,
} }
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil { if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {
@ -706,6 +719,8 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
Name: p.Name, Name: p.Name,
TenantID: cfg.TenantId, TenantID: cfg.TenantId,
ResourceGroups: cfg.ResourceGroups, ResourceGroups: cfg.ResourceGroups,
SubscriptionIDs: cfg.SubscriptionIds,
ObjectIDs: cfg.ObjectIds,
Audience: cfg.Audience, Audience: cfg.Audience,
DisableCustomSANs: cfg.DisableCustomSans, DisableCustomSANs: cfg.DisableCustomSans,
DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse, DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse,
@ -865,6 +880,8 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro
Azure: &linkedca.AzureProvisioner{ Azure: &linkedca.AzureProvisioner{
TenantId: p.TenantID, TenantId: p.TenantID,
ResourceGroups: p.ResourceGroups, ResourceGroups: p.ResourceGroups,
SubscriptionIds: p.SubscriptionIDs,
ObjectIds: p.ObjectIDs,
Audience: p.Audience, Audience: p.Audience,
DisableCustomSans: p.DisableCustomSANs, DisableCustomSans: p.DisableCustomSANs,
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse, DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,

View file

@ -1,13 +1,13 @@
package authority package authority
import ( import (
"errors"
"net/http" "net/http"
"testing" "testing"
"github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
) )
func TestGetEncryptedKey(t *testing.T) { func TestGetEncryptedKey(t *testing.T) {
@ -49,8 +49,8 @@ func TestGetEncryptedKey(t *testing.T) {
ek, err := tc.a.GetEncryptedKey(tc.kid) ek, err := tc.a.GetEncryptedKey(tc.kid)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -90,8 +90,8 @@ func TestGetProvisioners(t *testing.T) {
ps, next, err := tc.a.GetProvisioners("", 0) ps, next, err := tc.a.GetProvisioners("", 0)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -2,14 +2,15 @@ package authority
import ( import (
"crypto/x509" "crypto/x509"
"errors"
"net/http" "net/http"
"reflect" "reflect"
"testing" "testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
) )
func TestRoot(t *testing.T) { func TestRoot(t *testing.T) {
@ -31,7 +32,7 @@ func TestRoot(t *testing.T) {
crt, err := a.Root(tc.sum) crt, err := a.Root(tc.sum)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())

View file

@ -7,21 +7,22 @@ import (
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/templates"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/templates"
) )
type sshTestModifier ssh.Certificate type sshTestModifier ssh.Certificate
@ -716,8 +717,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { } else if err != nil {
_, ok := err.(errs.StatusCoder) _, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want) t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want)
@ -806,8 +807,8 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
hosts, err := auth.GetSSHHosts(context.Background(), tc.cert) hosts, err := auth.GetSSHHosts(context.Background(), tc.cert)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -1033,8 +1034,8 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...) cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -1,11 +0,0 @@
package status
// Type is the type for status.
type Type string
var (
// Active active
Active = Type("active")
// Deleted deleted
Deleted = Type("deleted")
)

View file

@ -89,8 +89,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
// Set backdate with the configured value // Set backdate with the configured value
signOpts.Backdate = a.config.AuthorityConfig.Backdate.Duration signOpts.Backdate = a.config.AuthorityConfig.Backdate.Duration
var prov provisioner.Interface
for _, op := range extraOpts { for _, op := range extraOpts {
switch k := op.(type) { switch k := op.(type) {
// Capture current provisioner
case provisioner.Interface:
prov = k
// Adds new options to NewCertificate // Adds new options to NewCertificate
case provisioner.CertificateOptions: case provisioner.CertificateOptions:
certOptions = append(certOptions, k.Options(signOpts)...) certOptions = append(certOptions, k.Options(signOpts)...)
@ -204,7 +209,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
} }
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
if err = a.storeCertificate(fullchain); err != nil { if err = a.storeCertificate(prov, fullchain); err != nil {
if err != db.ErrNotImplemented { if err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, return nil, errs.Wrap(http.StatusInternalServerError, err,
"authority.Sign; error storing certificate in db", opts...) "authority.Sign; error storing certificate in db", opts...)
@ -325,19 +330,28 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
// TODO: at some point we should replace the db.AuthDB interface to implement // TODO: at some point we should replace the db.AuthDB interface to implement
// `StoreCertificate(...*x509.Certificate) error` instead of just // `StoreCertificate(...*x509.Certificate) error` instead of just
// `StoreCertificate(*x509.Certificate) error`. // `StoreCertificate(*x509.Certificate) error`.
func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { func (a *Authority) storeCertificate(prov provisioner.Interface, fullchain []*x509.Certificate) error {
type linkedChainStorer interface {
StoreCertificateChain(provisioner.Interface, ...*x509.Certificate) error
}
type certificateChainStorer interface { type certificateChainStorer interface {
StoreCertificateChain(...*x509.Certificate) error StoreCertificateChain(...*x509.Certificate) error
} }
// Store certificate in linkedca // Store certificate in linkedca
if s, ok := a.adminDB.(certificateChainStorer); ok { switch s := a.adminDB.(type) {
case linkedChainStorer:
return s.StoreCertificateChain(prov, fullchain...)
case certificateChainStorer:
return s.StoreCertificateChain(fullchain...) return s.StoreCertificateChain(fullchain...)
} }
// Store certificate in local db // Store certificate in local db
if s, ok := a.db.(certificateChainStorer); ok { switch s := a.db.(type) {
case certificateChainStorer:
return s.StoreCertificateChain(fullchain...) return s.StoreCertificateChain(fullchain...)
default:
return a.db.StoreCertificate(fullchain[0])
} }
return a.db.StoreCertificate(fullchain[0])
} }
// storeRenewedCertificate allows to use an extension of the db.AuthDB interface // storeRenewedCertificate allows to use an extension of the db.AuthDB interface

View file

@ -11,24 +11,26 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/smallstep/certificates/cas/softcas" "gopkg.in/square/go-jose.v2/jwt"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas/softcas"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
) )
var ( var (
@ -187,14 +189,14 @@ func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) {
func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) {
b, err := x509.MarshalPKIXPublicKey(pub) b, err := x509.MarshalPKIXPublicKey(pub)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling public key") return nil, fmt.Errorf("error marshaling public key: %w", err)
} }
info := struct { info := struct {
Algorithm pkix.AlgorithmIdentifier Algorithm pkix.AlgorithmIdentifier
SubjectPublicKey asn1.BitString SubjectPublicKey asn1.BitString
}{} }{}
if _, err = asn1.Unmarshal(b, &info); err != nil { if _, err = asn1.Unmarshal(b, &info); err != nil {
return nil, errors.Wrap(err, "error unmarshaling public key") return nil, fmt.Errorf("error unmarshaling public key: %w", err)
} }
hash := sha1.Sum(info.SubjectPublicKey.Bytes) hash := sha1.Sum(info.SubjectPublicKey.Bytes)
return hash[:], nil return hash[:], nil
@ -661,8 +663,8 @@ ZYtQ9Ot36qc=
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -757,7 +759,7 @@ func TestAuthority_Renew(t *testing.T) {
now := time.Now().UTC() now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7) nb1 := now.Add(-time.Minute * 7)
na1 := now na1 := now.Add(time.Hour)
so := &provisioner.SignOptions{ so := &provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(nb1), NotBefore: provisioner.NewTimeDuration(nb1),
NotAfter: provisioner.NewTimeDuration(na1), NotAfter: provisioner.NewTimeDuration(na1),
@ -798,7 +800,20 @@ func TestAuthority_Renew(t *testing.T) {
"fail/unauthorized": func() (*renewTest, error) { "fail/unauthorized": func() (*renewTest, error) {
return &renewTest{ return &renewTest{
cert: certNoRenew, cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized,
}, nil
},
"fail/WithAuthorizeRenewFunc": func() (*renewTest, error) {
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
return errs.Unauthorized("not authorized")
}))
aa.x509CAService = a.x509CAService
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
return &renewTest{
auth: aa,
cert: cert,
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
}, nil }, nil
}, },
@ -820,6 +835,17 @@ func TestAuthority_Renew(t *testing.T) {
cert: cert, cert: cert,
}, nil }, nil
}, },
"ok/WithAuthorizeRenewFunc": func() (*renewTest, error) {
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
return nil
}))
aa.x509CAService = a.x509CAService
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
return &renewTest{
auth: aa,
cert: cert,
}, nil
},
} }
for name, genTestCase := range tests { for name, genTestCase := range tests {
@ -836,8 +862,8 @@ func TestAuthority_Renew(t *testing.T) {
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -856,7 +882,7 @@ func TestAuthority_Renew(t *testing.T) {
expiry := now.Add(time.Minute * 7) expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.Subject.String(), assert.Equals(t, leaf.Subject.String(),
@ -956,7 +982,7 @@ func TestAuthority_Rekey(t *testing.T) {
now := time.Now().UTC() now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7) nb1 := now.Add(-time.Minute * 7)
na1 := now na1 := now.Add(time.Hour)
so := &provisioner.SignOptions{ so := &provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(nb1), NotBefore: provisioner.NewTimeDuration(nb1),
NotAfter: provisioner.NewTimeDuration(na1), NotAfter: provisioner.NewTimeDuration(na1),
@ -998,7 +1024,7 @@ func TestAuthority_Rekey(t *testing.T) {
"fail/unauthorized": func() (*renewTest, error) { "fail/unauthorized": func() (*renewTest, error) {
return &renewTest{ return &renewTest{
cert: certNoRenew, cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
}, nil }, nil
}, },
@ -1043,8 +1069,8 @@ func TestAuthority_Rekey(t *testing.T) {
if err != nil { if err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
assert.Nil(t, certChain) assert.Nil(t, certChain)
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
@ -1063,7 +1089,7 @@ func TestAuthority_Rekey(t *testing.T) {
expiry := now.Add(time.Minute * 7) expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.Subject.String(), assert.Equals(t, leaf.Subject.String(),
@ -1432,8 +1458,8 @@ func TestAuthority_Revoke(t *testing.T) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
if err := tc.auth.Revoke(ctx, tc.opts); err != nil { if err := tc.auth.Revoke(ctx, tc.opts); err != nil {
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())

View file

@ -12,12 +12,13 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api" acmeAPI "github.com/smallstep/certificates/acme/api"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/render"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
) )
func TestNewACMEClient(t *testing.T) { func TestNewACMEClient(t *testing.T) {
@ -112,15 +113,15 @@ func TestNewACMEClient(t *testing.T) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
switch { switch {
case i == 0: case i == 0:
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
case i == 1: case i == 1:
w.Header().Set("Replay-Nonce", "abc123") w.Header().Set("Replay-Nonce", "abc123")
api.JSONStatus(w, []byte{}, 200) render.JSONStatus(w, []byte{}, 200)
i++ i++
default: default:
w.Header().Set("Location", accLocation) w.Header().Set("Location", accLocation)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
} }
}) })
@ -206,7 +207,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
}) })
if nonce, err := ac.GetNonce(); err != nil { if nonce, err := ac.GetNonce(); err != nil {
@ -315,7 +316,7 @@ func TestACMEClient_post(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -338,7 +339,7 @@ func TestACMEClient_post(t *testing.T) {
assert.Equals(t, hdr.KeyID, ac.kid) assert.Equals(t, hdr.KeyID, ac.kid)
} }
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil { if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil {
@ -455,7 +456,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -477,7 +478,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, payload, norb) assert.Equals(t, payload, norb)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := ac.NewOrder(norb); err != nil { if res, err := ac.NewOrder(norb); err != nil {
@ -577,7 +578,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -599,7 +600,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := ac.GetOrder(url); err != nil { if res, err := ac.GetOrder(url); err != nil {
@ -699,7 +700,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -721,7 +722,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := ac.GetAuthz(url); err != nil { if res, err := ac.GetAuthz(url); err != nil {
@ -821,7 +822,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -844,7 +845,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := ac.GetChallenge(url); err != nil { if res, err := ac.GetChallenge(url); err != nil {
@ -944,7 +945,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -967,7 +968,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
assert.Equals(t, payload, []byte("{}")) assert.Equals(t, payload, []byte("{}"))
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if err := ac.ValidateChallenge(url); err != nil { if err := ac.ValidateChallenge(url); err != nil {
@ -1071,7 +1072,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -1093,7 +1094,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, payload, frb) assert.Equals(t, payload, frb)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if err := ac.FinalizeOrder(url, csr); err != nil { if err := ac.FinalizeOrder(url, csr); err != nil {
@ -1200,7 +1201,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -1222,7 +1223,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, len(payload), 0) assert.Equals(t, len(payload), 0)
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
}) })
if res, err := tc.client.GetAccountOrders(); err != nil { if res, err := tc.client.GetAccountOrders(); err != nil {
@ -1331,7 +1332,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
w.Header().Set("Replay-Nonce", expectedNonce) w.Header().Set("Replay-Nonce", expectedNonce)
if i == 0 { if i == 0 {
api.JSONStatus(w, tc.r1, tc.rc1) render.JSONStatus(w, tc.r1, tc.rc1)
i++ i++
return return
} }
@ -1356,7 +1357,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
if tc.certBytes != nil { if tc.certBytes != nil {
w.Write(tc.certBytes) w.Write(tc.certBytes)
} else { } else {
api.JSONStatus(w, tc.r2, tc.rc2) render.JSONStatus(w, tc.r2, tc.rc2)
} }
}) })

View file

@ -14,11 +14,14 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs"
) )
func newLocalListener() net.Listener { func newLocalListener() net.Listener {
@ -79,7 +82,7 @@ func startCAServer(configFile string) (*CA, string, error) {
func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler { func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/version" { if r.URL.Path == "/version" {
api.JSON(w, api.VersionResponse{ render.JSON(w, api.VersionResponse{
Version: "test", Version: "test",
RequireClientAuthentication: true, RequireClientAuthentication: true,
}) })
@ -93,7 +96,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han
} }
isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0 isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0
if !isMTLS { if !isMTLS {
api.WriteError(w, errs.Unauthorized("missing peer certificate")) render.Error(w, errs.Unauthorized("missing peer certificate"))
} else { } else {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }
@ -408,6 +411,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
server.ServeTLS(listener, "", "") server.ServeTLS(listener, "", "")
}() }()
defer server.Close() defer server.Close()
time.Sleep(1 * time.Second)
// Create bootstrap client // Create bootstrap client
token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
@ -419,7 +423,6 @@ func TestBootstrapClientServerRotation(t *testing.T) {
// doTest does a request that requires mTLS // doTest does a request that requires mTLS
doTest := func(client *http.Client) error { doTest := func(client *http.Client) error {
time.Sleep(1 * time.Second)
// test with ca // test with ca
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
if err != nil { if err != nil {

View file

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
"strings"
"sync" "sync"
"github.com/go-chi/chi" "github.com/go-chi/chi"
@ -26,11 +27,14 @@ import (
scepAPI "github.com/smallstep/certificates/scep/api" scepAPI "github.com/smallstep/certificates/scep/api"
"github.com/smallstep/certificates/server" "github.com/smallstep/certificates/server"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/cli-utils/step"
"go.step.sm/crypto/x509util"
) )
type options struct { type options struct {
configFile string configFile string
linkedCAToken string linkedCAToken string
quiet bool
password []byte password []byte
issuerPassword []byte issuerPassword []byte
sshHostPassword []byte sshHostPassword []byte
@ -101,6 +105,13 @@ func WithLinkedCAToken(token string) Option {
} }
} }
// WithQuiet sets the quiet flag.
func WithQuiet(quiet bool) Option {
return func(o *options) {
o.quiet = quiet
}
}
// CA is the type used to build the complete certificate authority. It builds // CA is the type used to build the complete certificate authority. It builds
// the HTTP server, set ups the middlewares and the HTTP handlers. // the HTTP server, set ups the middlewares and the HTTP handlers.
type CA struct { type CA struct {
@ -288,6 +299,35 @@ func (ca *CA) Run() error {
var wg sync.WaitGroup var wg sync.WaitGroup
errs := make(chan error, 1) errs := make(chan error, 1)
if !ca.opts.quiet {
authorityInfo := ca.auth.GetInfo()
log.Printf("Starting %s", step.Version())
log.Printf("Documentation: https://u.step.sm/docs/ca")
log.Printf("Community Discord: https://u.step.sm/discord")
if step.Contexts().GetCurrent() != nil {
log.Printf("Current context: %s", step.Contexts().GetCurrent().Name)
}
log.Printf("Config file: %s", ca.opts.configFile)
baseURL := fmt.Sprintf("https://%s%s",
authorityInfo.DNSNames[0],
ca.config.Address[strings.LastIndex(ca.config.Address, ":"):])
log.Printf("The primary server URL is %s", baseURL)
log.Printf("Root certificates are available at %s/roots.pem", baseURL)
if len(authorityInfo.DNSNames) > 1 {
log.Printf("Additional configured hostnames: %s",
strings.Join(authorityInfo.DNSNames[1:], ", "))
}
for _, crt := range authorityInfo.RootX509Certs {
log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt))
}
if authorityInfo.SSHCAHostPublicKey != nil {
log.Printf("SSH Host CA Key is %s\n", authorityInfo.SSHCAHostPublicKey)
}
if authorityInfo.SSHCAUserPublicKey != nil {
log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey)
}
}
if ca.insecureSrv != nil { if ca.insecureSrv != nil {
wg.Add(1) wg.Add(1)
go func() { go func() {
@ -355,6 +395,7 @@ func (ca *CA) Reload() error {
WithSSHUserPassword(ca.opts.sshUserPassword), WithSSHUserPassword(ca.opts.sshUserPassword),
WithIssuerPassword(ca.opts.issuerPassword), WithIssuerPassword(ca.opts.issuerPassword),
WithLinkedCAToken(ca.opts.linkedCAToken), WithLinkedCAToken(ca.opts.linkedCAToken),
WithQuiet(ca.opts.quiet),
WithConfigFile(ca.opts.configFile), WithConfigFile(ca.opts.configFile),
WithDatabase(ca.auth.GetDatabase()), WithDatabase(ca.auth.GetDatabase()),
) )
@ -450,9 +491,6 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
tlsConfig.ClientCAs = certPool tlsConfig.ClientCAs = certPool
// Use server's most preferred ciphersuite
tlsConfig.PreferServerCipherSuites = true
return tlsConfig, nil return tlsConfig, nil
} }

View file

@ -563,6 +563,11 @@ func (c *Client) retryOnError(r *http.Response) bool {
return false return false
} }
// GetCaURL returns the configured CA url.
func (c *Client) GetCaURL() string {
return c.endpoint.String()
}
// GetRootCAs returns the RootCAs certificate pool from the configured // GetRootCAs returns the RootCAs certificate pool from the configured
// transport. // transport.
func (c *Client) GetRootCAs() *x509.CertPool { func (c *Client) GetRootCAs() *x509.CertPool {
@ -723,6 +728,36 @@ retry:
return &sign, nil return &sign, nil
} }
// RenewWithToken performs the renew request to the CA with the given
// authorization token and returns the api.SignResponse struct. This method is
// generally used to renew an expired certificate.
func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
req, err := http.NewRequest("POST", u.String(), http.NoBody)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request")
}
req.Header.Add("Authorization", "Bearer "+token)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u)
}
return &sign, nil
}
// Rekey performs the rekey request to the CA and returns the api.SignResponse // Rekey performs the rekey request to the CA and returns the api.SignResponse
// struct. // struct.
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {

View file

@ -8,6 +8,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -16,14 +17,16 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors" "go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
) )
const ( const (
@ -179,7 +182,7 @@ func TestClient_Version(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Version() got, err := c.Version()
@ -229,7 +232,7 @@ func TestClient_Health(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Health() got, err := c.Health()
@ -287,7 +290,7 @@ func TestClient_Root(t *testing.T) {
if req.RequestURI != expected { if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Root(tt.shasum) got, err := c.Root(tt.shasum)
@ -354,10 +357,10 @@ func TestClient_Sign(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest) body := new(api.SignRequest)
if err := api.ReadJSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
if tt.request == nil { if tt.request == nil {
@ -368,7 +371,7 @@ func TestClient_Sign(t *testing.T) {
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request) t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
} }
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Sign(tt.request) got, err := c.Sign(tt.request)
@ -426,10 +429,10 @@ func TestClient_Revoke(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest) body := new(api.RevokeRequest)
if err := api.ReadJSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e) render.Error(w, e)
return return
} else if !equalJSON(t, body, tt.request) { } else if !equalJSON(t, body, tt.request) {
if tt.request == nil { if tt.request == nil {
@ -440,7 +443,7 @@ func TestClient_Revoke(t *testing.T) {
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request) t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
} }
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Revoke(tt.request, nil) got, err := c.Revoke(tt.request, nil)
@ -500,7 +503,7 @@ func TestClient_Renew(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Renew(nil) got, err := c.Renew(nil)
@ -516,8 +519,8 @@ func TestClient_Renew(t *testing.T) {
t.Errorf("Client.Renew() = %v, want nil", got) t.Errorf("Client.Renew() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
@ -529,6 +532,74 @@ func TestClient_Renew(t *testing.T) {
} }
} }
func TestClient_RenewWithToken(t *testing.T) {
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)},
{Certificate: parseCertificate(rootPEM)},
},
}
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer token" {
render.JSONStatus(w, errs.InternalServer("force"), 500)
} else {
render.JSONStatus(w, tt.response, tt.responseCode)
}
})
got, err := c.RenewWithToken("token")
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.RenewWithToken() = %v, want nil", got)
}
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Rekey(t *testing.T) { func TestClient_Rekey(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
@ -569,7 +640,7 @@ func TestClient_Rekey(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Rekey(tt.request, nil) got, err := c.Rekey(tt.request, nil)
@ -585,8 +656,8 @@ func TestClient_Rekey(t *testing.T) {
t.Errorf("Client.Renew() = %v, want nil", got) t.Errorf("Client.Renew() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
@ -634,7 +705,7 @@ func TestClient_Provisioners(t *testing.T) {
if req.RequestURI != tt.expectedURI { if req.RequestURI != tt.expectedURI {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Provisioners(tt.args...) got, err := c.Provisioners(tt.args...)
@ -691,7 +762,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
if req.RequestURI != expected { if req.RequestURI != expected {
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
} }
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.ProvisionerKey(tt.kid) got, err := c.ProvisionerKey(tt.kid)
@ -706,8 +777,8 @@ func TestClient_ProvisionerKey(t *testing.T) {
t.Errorf("Client.ProvisionerKey() = %v, want nil", got) t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
@ -750,7 +821,7 @@ func TestClient_Roots(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Roots() got, err := c.Roots()
@ -765,8 +836,8 @@ func TestClient_Roots(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Roots() = %v, want nil", got) t.Errorf("Client.Roots() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
@ -808,7 +879,7 @@ func TestClient_Federation(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.Federation() got, err := c.Federation()
@ -823,8 +894,8 @@ func TestClient_Federation(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Federation() = %v, want nil", got) t.Errorf("Client.Federation() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
@ -870,7 +941,7 @@ func TestClient_SSHRoots(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.SSHRoots() got, err := c.SSHRoots()
@ -885,8 +956,8 @@ func TestClient_SSHRoots(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.SSHKeys() = %v, want nil", got) t.Errorf("Client.SSHKeys() = %v, want nil", got)
} }
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, tt.err.Error(), err.Error())
default: default:
@ -970,7 +1041,7 @@ func TestClient_RootFingerprint(t *testing.T) {
} }
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.RootFingerprint() got, err := c.RootFingerprint()
@ -1031,7 +1102,7 @@ func TestClient_SSHBastion(t *testing.T) {
} }
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
api.JSONStatus(w, tt.response, tt.responseCode) render.JSONStatus(w, tt.response, tt.responseCode)
}) })
got, err := c.SSHBastion(tt.request) got, err := c.SSHBastion(tt.request)
@ -1047,8 +1118,8 @@ func TestClient_SSHBastion(t *testing.T) {
t.Errorf("Client.SSHBastion() = %v, want nil", got) t.Errorf("Client.SSHBastion() = %v, want nil", got)
} }
if tt.responseCode != 200 { if tt.responseCode != 200 {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
@ -1060,3 +1131,28 @@ func TestClient_SSHBastion(t *testing.T) {
}) })
} }
} }
func TestClient_GetCaURL(t *testing.T) {
tests := []struct {
name string
caURL string
want string
}{
{"ok", "https://ca.com", "https://ca.com"},
{"ok no schema", "ca.com", "https://ca.com"},
{"ok with port", "https://ca.com:9000", "https://ca.com:9000"},
{"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
if got := c.GetCaURL(); got != tt.want {
t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
"sort"
"testing" "testing"
) )
@ -196,7 +197,7 @@ func TestLoadClient(t *testing.T) {
switch { switch {
case gotTransport.TLSClientConfig.GetClientCertificate == nil: case gotTransport.TLSClientConfig.GetClientCertificate == nil:
t.Error("LoadClient() transport does not define GetClientCertificate") t.Error("LoadClient() transport does not define GetClientCertificate")
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()): case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs):
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
default: default:
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
@ -238,3 +239,23 @@ func Test_defaultsConfig_Validate(t *testing.T) {
}) })
} }
} }
// nolint:staticcheck,gocritic
func equalPools(a, b *x509.CertPool) bool {
if reflect.DeepEqual(a, b) {
return true
}
subjects := a.Subjects()
sA := make([]string, len(subjects))
for i := range subjects {
sA[i] = string(subjects[i])
}
subjects = b.Subjects()
sB := make([]string, len(subjects))
for i := range subjects {
sB[i] = string(subjects[i])
}
sort.Strings(sA)
sort.Strings(sB)
return reflect.DeepEqual(sA, sB)
}

View file

@ -346,6 +346,8 @@ func TestIdentity_GetCertPool(t *testing.T) {
return return
} }
if got != nil { if got != nil {
// nolint:staticcheck // we don't have a different way to check
// the certificates in the pool.
subjects := got.Subjects() subjects := got.Subjects()
if !reflect.DeepEqual(subjects, tt.wantSubjects) { if !reflect.DeepEqual(subjects, tt.wantSubjects) {
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects) t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)

View file

@ -60,7 +60,10 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
} }
} }
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore) // Use the current time to calculate the initial period. Using a notBefore
// in the past might set a renewBefore too large, causing continuous
// renewals due to the negative values in nextRenewDuration.
period := cert.Leaf.NotAfter.Sub(time.Now().Truncate(time.Second))
if period < minCertDuration { if period < minCertDuration {
return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period) return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period)
} }
@ -181,7 +184,7 @@ func (r *TLSRenewer) renewCertificate() {
} }
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration { func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {
d := time.Until(notAfter) - r.renewBefore d := time.Until(notAfter).Truncate(time.Second) - r.renewBefore
n := rand.Int63n(int64(r.renewJitter)) n := rand.Int63n(int64(r.renewJitter))
d -= time.Duration(n) d -= time.Duration(n)
if d < 0 { if d < 0 {

2
ca/testdata/ca.json vendored
View file

@ -6,7 +6,7 @@
"password": "password", "password": "password",
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"], "dnsNames": ["127.0.0.1"],
"logger": {"format": "text"}, "_logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.3, "maxVersion": 1.3,

Some files were not shown because too many files have changed in this diff Show more