forked from TrueCloudLab/certificates
Merge branch 'master' into feat/vault
This commit is contained in:
commit
37b521ec6c
119 changed files with 4455 additions and 2003 deletions
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
|
@ -12,7 +12,7 @@ jobs:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: [ '1.15', '1.16', '1.17' ]
|
go: [ '1.17', '1.18' ]
|
||||||
outputs:
|
outputs:
|
||||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||||
steps:
|
steps:
|
||||||
|
@ -33,7 +33,7 @@ jobs:
|
||||||
uses: golangci/golangci-lint-action@v2
|
uses: golangci/golangci-lint-action@v2
|
||||||
with:
|
with:
|
||||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||||
version: 'v1.44.0'
|
version: 'v1.45.0'
|
||||||
|
|
||||||
# Optional: working directory, useful for monorepos
|
# Optional: working directory, useful for monorepos
|
||||||
# working-directory: somedir
|
# working-directory: somedir
|
||||||
|
@ -106,7 +106,7 @@ jobs:
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: 1.17
|
go-version: 1.18
|
||||||
-
|
-
|
||||||
name: APT Install
|
name: APT Install
|
||||||
id: aptInstall
|
id: aptInstall
|
||||||
|
@ -159,7 +159,7 @@ jobs:
|
||||||
name: Setup Go
|
name: Setup Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: '1.17'
|
go-version: '1.18'
|
||||||
-
|
-
|
||||||
name: Install cosign
|
name: Install cosign
|
||||||
uses: sigstore/cosign-installer@v1.1.0
|
uses: sigstore/cosign-installer@v1.1.0
|
||||||
|
|
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
|
@ -14,7 +14,7 @@ jobs:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: [ '1.16', '1.17' ]
|
go: [ '1.17', '1.18' ]
|
||||||
steps:
|
steps:
|
||||||
-
|
-
|
||||||
name: Checkout
|
name: Checkout
|
||||||
|
@ -33,7 +33,7 @@ jobs:
|
||||||
uses: golangci/golangci-lint-action@v2
|
uses: golangci/golangci-lint-action@v2
|
||||||
with:
|
with:
|
||||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||||
version: 'v1.44.0'
|
version: 'v1.45.0'
|
||||||
|
|
||||||
# Optional: working directory, useful for monorepos
|
# Optional: working directory, useful for monorepos
|
||||||
# working-directory: somedir
|
# working-directory: somedir
|
||||||
|
@ -58,7 +58,7 @@ jobs:
|
||||||
run: V=1 make ci
|
run: V=1 make ci
|
||||||
-
|
-
|
||||||
name: Codecov
|
name: Codecov
|
||||||
if: matrix.go == '1.17'
|
if: matrix.go == '1.18'
|
||||||
uses: codecov/codecov-action@v1.2.1
|
uses: codecov/codecov-action@v1.2.1
|
||||||
with:
|
with:
|
||||||
file: ./coverage.out # optional
|
file: ./coverage.out # optional
|
||||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -19,3 +19,4 @@ coverage.txt
|
||||||
output
|
output
|
||||||
vendor
|
vendor
|
||||||
.idea
|
.idea
|
||||||
|
.envrc
|
||||||
|
|
|
@ -19,6 +19,7 @@ builds:
|
||||||
- linux_386
|
- linux_386
|
||||||
- linux_amd64
|
- linux_amd64
|
||||||
- linux_arm64
|
- linux_arm64
|
||||||
|
- linux_arm_5
|
||||||
- linux_arm_6
|
- linux_arm_6
|
||||||
- linux_arm_7
|
- linux_arm_7
|
||||||
- windows_amd64
|
- windows_amd64
|
||||||
|
@ -39,6 +40,7 @@ builds:
|
||||||
- linux_386
|
- linux_386
|
||||||
- linux_amd64
|
- linux_amd64
|
||||||
- linux_arm64
|
- linux_arm64
|
||||||
|
- linux_arm_5
|
||||||
- linux_arm_6
|
- linux_arm_6
|
||||||
- linux_arm_7
|
- linux_arm_7
|
||||||
- windows_amd64
|
- windows_amd64
|
||||||
|
@ -59,6 +61,7 @@ builds:
|
||||||
- linux_386
|
- linux_386
|
||||||
- linux_amd64
|
- linux_amd64
|
||||||
- linux_arm64
|
- linux_arm64
|
||||||
|
- linux_arm_5
|
||||||
- linux_arm_6
|
- linux_arm_6
|
||||||
- linux_arm_7
|
- linux_arm_7
|
||||||
- windows_amd64
|
- windows_amd64
|
||||||
|
|
22
CHANGELOG.md
22
CHANGELOG.md
|
@ -4,16 +4,32 @@ All notable changes to this project will be documented in this file.
|
||||||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
||||||
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [Unreleased - 0.18.2] - DATE
|
## [Unreleased - 0.18.3] - DATE
|
||||||
### Added
|
### Added
|
||||||
|
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
|
||||||
|
- Added support for `extraNames` in X.509 templates.
|
||||||
### Changed
|
### Changed
|
||||||
- IPv6 addresses are normalized as IP addresses instead of hostnames.
|
- Made SCEP CA URL paths dynamic
|
||||||
- More descriptive JWK decryption error message.
|
- Support two latest versions of Go (1.17, 1.18)
|
||||||
### Deprecated
|
### Deprecated
|
||||||
### Removed
|
### Removed
|
||||||
### Fixed
|
### Fixed
|
||||||
### Security
|
### Security
|
||||||
|
|
||||||
|
## [0.18.2] - 2022-03-01
|
||||||
|
### Added
|
||||||
|
- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner.
|
||||||
|
- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering
|
||||||
|
out database drivers using Go tags. For example, using the Go flag
|
||||||
|
`--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx
|
||||||
|
driver for PostgreSQL.
|
||||||
|
### Changed
|
||||||
|
- IPv6 addresses are normalized as IP addresses instead of hostnames.
|
||||||
|
- More descriptive JWK decryption error message.
|
||||||
|
- Make the X5C leaf certificate available to the templates using `{{ .AuthorizationCrt }}`.
|
||||||
|
### Fixed
|
||||||
|
- During provisioner add - validate provisioner configuration before storing to DB.
|
||||||
|
|
||||||
## [0.18.1] - 2022-02-03
|
## [0.18.1] - 2022-02-03
|
||||||
### Added
|
### Added
|
||||||
- Support for ACME revocation.
|
- Support for ACME revocation.
|
||||||
|
|
|
@ -5,8 +5,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -70,23 +71,23 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
payload, err := payloadFromContext(ctx)
|
payload, err := payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var nar NewAccountRequest
|
var nar NewAccountRequest
|
||||||
if err := json.Unmarshal(payload.value, &nar); err != nil {
|
if err := json.Unmarshal(payload.value, &nar); err != nil {
|
||||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
|
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||||
"failed to unmarshal new-account request payload"))
|
"failed to unmarshal new-account request payload"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := nar.Validate(); err != nil {
|
if err := nar.Validate(); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prov, err := acmeProvisionerFromContext(ctx)
|
prov, err := acmeProvisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,26 +97,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
acmeErr, ok := err.(*acme.Error)
|
acmeErr, ok := err.(*acme.Error)
|
||||||
if !ok || acmeErr.Status != http.StatusBadRequest {
|
if !ok || acmeErr.Status != http.StatusBadRequest {
|
||||||
// Something went wrong ...
|
// Something went wrong ...
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Account does not exist //
|
// Account does not exist //
|
||||||
if nar.OnlyReturnExisting {
|
if nar.OnlyReturnExisting {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
|
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
|
||||||
"account does not exist"))
|
"account does not exist"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
jwk, err := jwkFromContext(ctx)
|
jwk, err := jwkFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
eak, err := h.validateExternalAccountBinding(ctx, &nar)
|
eak, err := h.validateExternalAccountBinding(ctx, &nar)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,18 +126,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
Status: acme.StatusValid,
|
Status: acme.StatusValid,
|
||||||
}
|
}
|
||||||
if err := h.db.CreateAccount(ctx, acc); err != nil {
|
if err := h.db.CreateAccount(ctx, acc); err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error creating account"))
|
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
|
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
|
||||||
err := eak.BindTo(acc)
|
err := eak.BindTo(acc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key"))
|
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acc.ExternalAccountBinding = nar.ExternalAccountBinding
|
acc.ExternalAccountBinding = nar.ExternalAccountBinding
|
||||||
|
@ -149,7 +150,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
h.linker.LinkAccount(ctx, acc)
|
h.linker.LinkAccount(ctx, acc)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
|
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
|
||||||
api.JSONStatus(w, acc, httpStatus)
|
render.JSONStatus(w, acc, httpStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOrUpdateAccount is the api for updating an ACME account.
|
// GetOrUpdateAccount is the api for updating an ACME account.
|
||||||
|
@ -157,12 +158,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payload, err := payloadFromContext(ctx)
|
payload, err := payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,12 +172,12 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
if !payload.isPostAsGet {
|
if !payload.isPostAsGet {
|
||||||
var uar UpdateAccountRequest
|
var uar UpdateAccountRequest
|
||||||
if err := json.Unmarshal(payload.value, &uar); err != nil {
|
if err := json.Unmarshal(payload.value, &uar); err != nil {
|
||||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
|
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||||
"failed to unmarshal new-account request payload"))
|
"failed to unmarshal new-account request payload"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := uar.Validate(); err != nil {
|
if err := uar.Validate(); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(uar.Status) > 0 || len(uar.Contact) > 0 {
|
if len(uar.Status) > 0 || len(uar.Contact) > 0 {
|
||||||
|
@ -187,7 +188,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.UpdateAccount(ctx, acc); err != nil {
|
if err := h.db.UpdateAccount(ctx, acc); err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error updating account"))
|
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -196,7 +197,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
h.linker.LinkAccount(ctx, acc)
|
h.linker.LinkAccount(ctx, acc)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
|
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
|
||||||
api.JSON(w, acc)
|
render.JSON(w, acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
||||||
|
@ -213,22 +214,22 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
accID := chi.URLParam(r, "accID")
|
accID := chi.URLParam(r, "accID")
|
||||||
if acc.ID != accID {
|
if acc.ID != accID {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
|
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkOrdersByAccountID(ctx, orders)
|
h.linker.LinkOrdersByAccountID(ctx, orders)
|
||||||
|
|
||||||
api.JSON(w, orders)
|
render.JSON(w, orders)
|
||||||
logOrdersByAccount(w, orders)
|
logOrdersByAccount(w, orders)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -11,8 +12,10 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,6 +46,7 @@ type Handler struct {
|
||||||
ca acme.CertificateAuthority
|
ca acme.CertificateAuthority
|
||||||
linker Linker
|
linker Linker
|
||||||
validateChallengeOptions *acme.ValidateChallengeOptions
|
validateChallengeOptions *acme.ValidateChallengeOptions
|
||||||
|
prerequisitesChecker func(ctx context.Context) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlerOptions required to create a new ACME API request handler.
|
// HandlerOptions required to create a new ACME API request handler.
|
||||||
|
@ -60,6 +64,9 @@ type HandlerOptions struct {
|
||||||
// "acme" is the prefix from which the ACME api is accessed.
|
// "acme" is the prefix from which the ACME api is accessed.
|
||||||
Prefix string
|
Prefix string
|
||||||
CA acme.CertificateAuthority
|
CA acme.CertificateAuthority
|
||||||
|
// PrerequisitesChecker checks if all prerequisites for serving ACME are
|
||||||
|
// met by the CA configuration.
|
||||||
|
PrerequisitesChecker func(ctx context.Context) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler returns a new ACME API handler.
|
// NewHandler returns a new ACME API handler.
|
||||||
|
@ -76,6 +83,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
prerequisitesChecker := func(ctx context.Context) (bool, error) {
|
||||||
|
// by default all prerequisites are met
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if ops.PrerequisitesChecker != nil {
|
||||||
|
prerequisitesChecker = ops.PrerequisitesChecker
|
||||||
|
}
|
||||||
return &Handler{
|
return &Handler{
|
||||||
ca: ops.CA,
|
ca: ops.CA,
|
||||||
db: ops.DB,
|
db: ops.DB,
|
||||||
|
@ -88,6 +102,7 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||||
return tls.DialWithDialer(dialer, network, addr, config)
|
return tls.DialWithDialer(dialer, network, addr, config)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
prerequisitesChecker: prerequisitesChecker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,13 +110,13 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||||
func (h *Handler) Route(r api.Router) {
|
func (h *Handler) Route(r api.Router) {
|
||||||
getPath := h.linker.GetUnescapedPathSuffix
|
getPath := h.linker.GetUnescapedPathSuffix
|
||||||
// Standard ACME API
|
// Standard ACME API
|
||||||
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
|
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||||
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
|
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce))))))
|
||||||
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
|
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||||
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
|
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
|
||||||
|
|
||||||
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
validatingMiddleware := func(next nextHTTP) nextHTTP {
|
||||||
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))
|
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))))
|
||||||
}
|
}
|
||||||
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
|
||||||
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
|
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
|
||||||
|
@ -168,11 +183,11 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
api.JSON(w, &Directory{
|
render.JSON(w, &Directory{
|
||||||
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
|
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
|
||||||
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
|
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
|
||||||
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
|
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
|
||||||
|
@ -187,7 +202,7 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||||
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
||||||
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
||||||
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthorization ACME api for retrieving an Authz.
|
// GetAuthorization ACME api for retrieving an Authz.
|
||||||
|
@ -195,28 +210,28 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if acc.ID != az.AccountID {
|
if acc.ID != az.AccountID {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = az.UpdateStatus(ctx, h.db); err != nil {
|
if err = az.UpdateStatus(ctx, h.db); err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkAuthorization(ctx, az)
|
h.linker.LinkAuthorization(ctx, az)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
|
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
|
||||||
api.JSON(w, az)
|
render.JSON(w, az)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetChallenge ACME api for retrieving a Challenge.
|
// GetChallenge ACME api for retrieving a Challenge.
|
||||||
|
@ -224,14 +239,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Just verify that the payload was set, since we're not strictly adhering
|
// Just verify that the payload was set, since we're not strictly adhering
|
||||||
// to ACME V2 spec for reasons specified below.
|
// to ACME V2 spec for reasons specified below.
|
||||||
_, err = payloadFromContext(ctx)
|
_, err = payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,22 +259,22 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
azID := chi.URLParam(r, "authzID")
|
azID := chi.URLParam(r, "authzID")
|
||||||
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ch.AuthorizationID = azID
|
ch.AuthorizationID = azID
|
||||||
if acc.ID != ch.AccountID {
|
if acc.ID != ch.AccountID {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
"account '%s' does not own challenge '%s'", acc.ID, ch.ID))
|
"account '%s' does not own challenge '%s'", acc.ID, ch.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwk, err := jwkFromContext(ctx)
|
jwk, err := jwkFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
|
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge"))
|
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,7 +282,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
|
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
|
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
|
||||||
api.JSON(w, ch)
|
render.JSON(w, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificate ACME api for retrieving a Certificate.
|
// GetCertificate ACME api for retrieving a Certificate.
|
||||||
|
@ -275,18 +290,18 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
certID := chi.URLParam(r, "certID")
|
certID := chi.URLParam(r, "certID")
|
||||||
|
|
||||||
cert, err := h.db.GetCertificate(ctx, certID)
|
cert, err := h.db.GetCertificate(ctx, certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if cert.AccountID != acc.ID {
|
if cert.AccountID != acc.ID {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
"account '%s' does not own certificate '%s'", acc.ID, certID))
|
"account '%s' does not own certificate '%s'", acc.ID, certID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,13 +10,14 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
"go.step.sm/crypto/keyutil"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
"go.step.sm/crypto/jose"
|
|
||||||
"go.step.sm/crypto/keyutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||||
|
@ -64,7 +65,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
nonce, err := h.db.CreateNonce(r.Context())
|
nonce, err := h.db.CreateNonce(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Replay-Nonce", string(nonce))
|
w.Header().Set("Replay-Nonce", string(nonce))
|
||||||
|
@ -90,7 +91,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||||
var expected []string
|
var expected []string
|
||||||
p, err := provisionerFromContext(r.Context())
|
p, err := provisionerFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +111,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
"expected content-type to be in %s, but got %s", expected, ct))
|
"expected content-type to be in %s, but got %s", expected, ct))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -120,12 +121,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
|
render.Error(w, acme.WrapErrorISE(err, "failed to read request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jws, err := jose.ParseJWS(string(body))
|
jws, err := jose.ParseJWS(string(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
|
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
|
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
|
||||||
|
@ -153,15 +154,15 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(r.Context())
|
jws, err := jwsFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(jws.Signatures) == 0 {
|
if len(jws.Signatures) == 0 {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(jws.Signatures) > 1 {
|
if len(jws.Signatures) > 1 {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +173,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
len(uh.Algorithm) > 0 ||
|
len(uh.Algorithm) > 0 ||
|
||||||
len(uh.Nonce) > 0 ||
|
len(uh.Nonce) > 0 ||
|
||||||
len(uh.ExtraHeaders) > 0 {
|
len(uh.ExtraHeaders) > 0 {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hdr := sig.Protected
|
hdr := sig.Protected
|
||||||
|
@ -182,13 +183,13 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
switch k := hdr.JSONWebKey.Key.(type) {
|
switch k := hdr.JSONWebKey.Key.(type) {
|
||||||
case *rsa.PublicKey:
|
case *rsa.PublicKey:
|
||||||
if k.Size() < keyutil.MinRSAKeyBytes {
|
if k.Size() < keyutil.MinRSAKeyBytes {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
"rsa keys must be at least %d bits (%d bytes) in size",
|
"rsa keys must be at least %d bits (%d bytes) in size",
|
||||||
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
|
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
"jws key type and algorithm do not match"))
|
"jws key type and algorithm do not match"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -196,35 +197,35 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
||||||
// we good
|
// we good
|
||||||
default:
|
default:
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
|
render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the validity/freshness of the Nonce.
|
// Check the validity/freshness of the Nonce.
|
||||||
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the JWS url matches the requested url.
|
// Check that the JWS url matches the requested url.
|
||||||
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
||||||
if jwsURL != reqURL.String() {
|
if jwsURL != reqURL.String() {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
|
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
|
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
|
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next(w, r)
|
next(w, r)
|
||||||
|
@ -239,23 +240,23 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(r.Context())
|
jws, err := jwsFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwk := jws.Signatures[0].Protected.JSONWebKey
|
jwk := jws.Signatures[0].Protected.JSONWebKey
|
||||||
if jwk == nil {
|
if jwk == nil {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !jwk.Valid() {
|
if !jwk.Valid() {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Overwrite KeyID with the JWK thumbprint.
|
// Overwrite KeyID with the JWK thumbprint.
|
||||||
jwk.KeyID, err = acme.KeyToID(jwk)
|
jwk.KeyID, err = acme.KeyToID(jwk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
|
render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -269,11 +270,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
// For NewAccount and Revoke requests ...
|
// For NewAccount and Revoke requests ...
|
||||||
break
|
break
|
||||||
case err != nil:
|
case err != nil:
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
if !acc.IsValid() {
|
if !acc.IsValid() {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
@ -283,25 +284,24 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupProvisioner loads the provisioner associated with the request.
|
// lookupProvisioner loads the provisioner associated with the request.
|
||||||
// Responsds 404 if the provisioner does not exist.
|
// Responds 404 if the provisioner does not exist.
|
||||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||||
name, err := url.PathUnescape(nameEscaped)
|
name, err := url.PathUnescape(nameEscaped)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p, err := h.ca.LoadProvisionerByName(name)
|
p, err := h.ca.LoadProvisionerByName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acmeProv, ok := p.(*provisioner.ACME)
|
acmeProv, ok := p.(*provisioner.ACME)
|
||||||
if !ok {
|
if !ok {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
|
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
|
||||||
|
@ -309,6 +309,24 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkPrerequisites checks if all prerequisites for serving ACME
|
||||||
|
// are met by the CA configuration.
|
||||||
|
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
ok, err := h.prerequisitesChecker(ctx)
|
||||||
|
if err != nil {
|
||||||
|
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// lookupJWK loads the JWK associated with the acme account referenced by the
|
// lookupJWK loads the JWK associated with the acme account referenced by the
|
||||||
// kid parameter of the signed payload.
|
// kid parameter of the signed payload.
|
||||||
// Make sure to parse and validate the JWS before running this middleware.
|
// Make sure to parse and validate the JWS before running this middleware.
|
||||||
|
@ -317,14 +335,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
|
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
|
||||||
kid := jws.Signatures[0].Protected.KeyID
|
kid := jws.Signatures[0].Protected.KeyID
|
||||||
if !strings.HasPrefix(kid, kidPrefix) {
|
if !strings.HasPrefix(kid, kidPrefix) {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
|
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
"kid does not have required prefix; expected %s, but got %s",
|
"kid does not have required prefix; expected %s, but got %s",
|
||||||
kidPrefix, kid))
|
kidPrefix, kid))
|
||||||
return
|
return
|
||||||
|
@ -334,14 +352,14 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
|
||||||
acc, err := h.db.GetAccount(ctx, accID)
|
acc, err := h.db.GetAccount(ctx, accID)
|
||||||
switch {
|
switch {
|
||||||
case nosql.IsErrNotFound(err):
|
case nosql.IsErrNotFound(err):
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||||
return
|
return
|
||||||
case err != nil:
|
case err != nil:
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
if !acc.IsValid() {
|
if !acc.IsValid() {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
|
@ -359,7 +377,7 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -395,21 +413,21 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwk, err := jwkFromContext(ctx)
|
jwk, err := jwkFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payload, err := jws.Verify(jwk)
|
payload, err := jws.Verify(jwk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
|
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
|
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
|
||||||
|
@ -426,11 +444,11 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
payload, err := payloadFromContext(r.Context())
|
payload, err := payloadFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !payload.isPostAsGet {
|
if !payload.isPostAsGet {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next(w, r)
|
next(w, r)
|
||||||
|
|
|
@ -1656,3 +1656,91 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandler_checkPrerequisites(t *testing.T) {
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||||
|
baseURL, provName)
|
||||||
|
type test struct {
|
||||||
|
linker Linker
|
||||||
|
ctx context.Context
|
||||||
|
prerequisitesChecker func(context.Context) (bool, error)
|
||||||
|
next func(http.ResponseWriter, *http.Request)
|
||||||
|
err *acme.Error
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/error": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("dns", "acme"),
|
||||||
|
ctx: ctx,
|
||||||
|
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
err: acme.WrapErrorISE(errors.New("force"), "error checking acme provisioner prerequisites"),
|
||||||
|
statusCode: 500,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/prerequisites-nok": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("dns", "acme"),
|
||||||
|
ctx: ctx,
|
||||||
|
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
err: acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"),
|
||||||
|
statusCode: 501,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("dns", "acme"),
|
||||||
|
ctx: ctx,
|
||||||
|
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
statusCode: 200,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, run := range tests {
|
||||||
|
tc := run(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
|
||||||
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.checkPrerequisites(tc.next)(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
|
var ae acme.Error
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), testBody)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -11,9 +11,11 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/smallstep/certificates/acme"
|
|
||||||
"github.com/smallstep/certificates/api"
|
|
||||||
"go.step.sm/crypto/randutil"
|
"go.step.sm/crypto/randutil"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/acme"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewOrderRequest represents the body for a NewOrder request.
|
// NewOrderRequest represents the body for a NewOrder request.
|
||||||
|
@ -70,28 +72,28 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
prov, err := provisionerFromContext(ctx)
|
prov, err := provisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payload, err := payloadFromContext(ctx)
|
payload, err := payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var nor NewOrderRequest
|
var nor NewOrderRequest
|
||||||
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
||||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
|
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||||
"failed to unmarshal new-order request payload"))
|
"failed to unmarshal new-order request payload"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := nor.Validate(); err != nil {
|
if err := nor.Validate(); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,7 +118,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
Status: acme.StatusPending,
|
Status: acme.StatusPending,
|
||||||
}
|
}
|
||||||
if err := h.newAuthorization(ctx, az); err != nil {
|
if err := h.newAuthorization(ctx, az); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
o.AuthorizationIDs[i] = az.ID
|
o.AuthorizationIDs[i] = az.ID
|
||||||
|
@ -135,14 +137,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.CreateOrder(ctx, o); err != nil {
|
if err := h.db.CreateOrder(ctx, o); err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error creating order"))
|
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkOrder(ctx, o)
|
h.linker.LinkOrder(ctx, o)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||||
api.JSONStatus(w, o, http.StatusCreated)
|
render.JSONStatus(w, o, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
|
||||||
|
@ -186,38 +188,38 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
prov, err := provisionerFromContext(ctx)
|
prov, err := provisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if acc.ID != o.AccountID {
|
if acc.ID != o.AccountID {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if prov.GetID() != o.ProvisionerID {
|
if prov.GetID() != o.ProvisionerID {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = o.UpdateStatus(ctx, h.db); err != nil {
|
if err = o.UpdateStatus(ctx, h.db); err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error updating order status"))
|
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkOrder(ctx, o)
|
h.linker.LinkOrder(ctx, o)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||||
api.JSON(w, o)
|
render.JSON(w, o)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
// FinalizeOrder attemptst to finalize an order and create a certificate.
|
||||||
|
@ -225,54 +227,54 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
acc, err := accountFromContext(ctx)
|
acc, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
prov, err := provisionerFromContext(ctx)
|
prov, err := provisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payload, err := payloadFromContext(ctx)
|
payload, err := payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var fr FinalizeRequest
|
var fr FinalizeRequest
|
||||||
if err := json.Unmarshal(payload.value, &fr); err != nil {
|
if err := json.Unmarshal(payload.value, &fr); err != nil {
|
||||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
|
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||||
"failed to unmarshal finalize-order request payload"))
|
"failed to unmarshal finalize-order request payload"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := fr.Validate(); err != nil {
|
if err := fr.Validate(); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if acc.ID != o.AccountID {
|
if acc.ID != o.AccountID {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if prov.GetID() != o.ProvisionerID {
|
if prov.GetID() != o.ProvisionerID {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
|
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order"))
|
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.linker.LinkOrder(ctx, o)
|
h.linker.LinkOrder(ctx, o)
|
||||||
|
|
||||||
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
|
||||||
api.JSON(w, o)
|
render.JSON(w, o)
|
||||||
}
|
}
|
||||||
|
|
||||||
// challengeTypes determines the types of challenges that should be used
|
// challengeTypes determines the types of challenges that should be used
|
||||||
|
|
|
@ -10,13 +10,14 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
"golang.org/x/crypto/ocsp"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"go.step.sm/crypto/jose"
|
|
||||||
"golang.org/x/crypto/ocsp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type revokePayload struct {
|
type revokePayload struct {
|
||||||
|
@ -30,65 +31,65 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
jws, err := jwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prov, err := provisionerFromContext(ctx)
|
prov, err := provisionerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
payload, err := payloadFromContext(ctx)
|
payload, err := payloadFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var p revokePayload
|
var p revokePayload
|
||||||
err = json.Unmarshal(payload.value, &p)
|
err = json.Unmarshal(payload.value, &p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
|
render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
|
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// in this case the most likely cause is a client that didn't properly encode the certificate
|
// in this case the most likely cause is a client that didn't properly encode the certificate
|
||||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
|
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certToBeRevoked, err := x509.ParseCertificate(certBytes)
|
certToBeRevoked, err := x509.ParseCertificate(certBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// in this case a client may have encoded something different than a certificate
|
// in this case a client may have encoded something different than a certificate
|
||||||
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
|
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
serial := certToBeRevoked.SerialNumber.String()
|
serial := certToBeRevoked.SerialNumber.String()
|
||||||
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
|
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
|
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
|
||||||
// this should never happen
|
// this should never happen
|
||||||
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal"))
|
render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldCheckAccountFrom(jws) {
|
if shouldCheckAccountFrom(jws) {
|
||||||
account, err := accountFromContext(ctx)
|
account, err := accountFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||||
if acmeErr != nil {
|
if acmeErr != nil {
|
||||||
api.WriteError(w, acmeErr)
|
render.Error(w, acmeErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -97,26 +98,26 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
_, err := jws.Verify(certToBeRevoked.PublicKey)
|
_, err := jws.Verify(certToBeRevoked.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
|
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
|
||||||
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
|
render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
|
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasBeenRevokedBefore {
|
if hasBeenRevokedBefore {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
|
render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
reasonCode := p.ReasonCode
|
reasonCode := p.ReasonCode
|
||||||
acmeErr := validateReasonCode(reasonCode)
|
acmeErr := validateReasonCode(reasonCode)
|
||||||
if acmeErr != nil {
|
if acmeErr != nil {
|
||||||
api.WriteError(w, acmeErr)
|
render.Error(w, acmeErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,14 +125,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
|
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
|
||||||
err = prov.AuthorizeRevoke(ctx, "")
|
err = prov.AuthorizeRevoke(ctx, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
|
render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
||||||
err = h.ca.Revoke(ctx, options)
|
err = h.ca.Revoke(ctx, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, wrapRevokeErr(err))
|
render.Error(w, wrapRevokeErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||||
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||||
|
|
||||||
resp, err := vo.HTTPGet(u.String())
|
resp, err := vo.HTTPGet(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// http01ChallengeHost checks if a Challenge value is an IPv6 address
|
||||||
|
// and adds square brackets if that's the case, so that it can be used
|
||||||
|
// as a hostname. Returns the original Challenge value as the host to
|
||||||
|
// use in other cases.
|
||||||
|
func http01ChallengeHost(value string) string {
|
||||||
|
if ip := net.ParseIP(value); ip != nil && ip.To4() == nil {
|
||||||
|
value = "[" + value + "]"
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
func tlsAlert(err error) uint8 {
|
func tlsAlert(err error) uint8 {
|
||||||
var opErr *net.OpError
|
var opErr *net.OpError
|
||||||
if errors.As(err, &opErr) {
|
if errors.As(err, &opErr) {
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
@ -23,9 +24,9 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_storeError(t *testing.T) {
|
func Test_storeError(t *testing.T) {
|
||||||
|
@ -2350,3 +2351,34 @@ func Test_serverName(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_http01ChallengeHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "dns",
|
||||||
|
value: "www.example.com",
|
||||||
|
want: "www.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv4",
|
||||||
|
value: "127.0.0.1",
|
||||||
|
want: "127.0.0.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6",
|
||||||
|
value: "::1",
|
||||||
|
want: "[::1]",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := http01ChallengeHost(tt.value); got != tt.want {
|
||||||
|
t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,13 +3,10 @@ package acme
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProblemType is the type of the ACME problem.
|
// ProblemType is the type of the ACME problem.
|
||||||
|
@ -353,26 +350,8 @@ func (e *Error) ToLog() (interface{}, error) {
|
||||||
return string(b), nil
|
return string(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteError writes to w a JSON representation of the given error.
|
// Render implements render.RenderableError for Error.
|
||||||
func WriteError(w http.ResponseWriter, err *Error) {
|
func (e *Error) Render(w http.ResponseWriter) {
|
||||||
w.Header().Set("Content-Type", "application/problem+json")
|
w.Header().Set("Content-Type", "application/problem+json")
|
||||||
w.WriteHeader(err.StatusCode())
|
render.JSONStatus(w, e, e.StatusCode())
|
||||||
|
|
||||||
// Write errors in the response writer
|
|
||||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
|
||||||
"error": err.Err,
|
|
||||||
})
|
|
||||||
if os.Getenv("STEPDEBUG") == "1" {
|
|
||||||
if e, ok := err.Err.(errs.StackTracer); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
|
||||||
"stack-trace": fmt.Sprintf("%+v", e),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
56
api/api.go
56
api/api.go
|
@ -20,6 +20,9 @@ import (
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/log"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/config"
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
@ -33,6 +36,7 @@ type Authority interface {
|
||||||
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||||
|
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||||
GetTLSOptions() *config.TLSOptions
|
GetTLSOptions() *config.TLSOptions
|
||||||
Root(shasum string) (*x509.Certificate, error)
|
Root(shasum string) (*x509.Certificate, error)
|
||||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||||
|
@ -43,7 +47,7 @@ type Authority interface {
|
||||||
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
||||||
Revoke(context.Context, *authority.RevokeOptions) error
|
Revoke(context.Context, *authority.RevokeOptions) error
|
||||||
GetEncryptedKey(kid string) (string, error)
|
GetEncryptedKey(kid string) (string, error)
|
||||||
GetRoots() (federation []*x509.Certificate, err error)
|
GetRoots() ([]*x509.Certificate, error)
|
||||||
GetFederation() ([]*x509.Certificate, error)
|
GetFederation() ([]*x509.Certificate, error)
|
||||||
Version() authority.Version
|
Version() authority.Version
|
||||||
}
|
}
|
||||||
|
@ -257,6 +261,7 @@ func (h *caHandler) Route(r Router) {
|
||||||
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
r.MethodFunc("GET", "/provisioners", h.Provisioners)
|
||||||
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
|
||||||
r.MethodFunc("GET", "/roots", h.Roots)
|
r.MethodFunc("GET", "/roots", h.Roots)
|
||||||
|
r.MethodFunc("GET", "/roots.pem", h.RootsPEM)
|
||||||
r.MethodFunc("GET", "/federation", h.Federation)
|
r.MethodFunc("GET", "/federation", h.Federation)
|
||||||
// SSH CA
|
// SSH CA
|
||||||
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
|
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
|
||||||
|
@ -280,7 +285,7 @@ func (h *caHandler) Route(r Router) {
|
||||||
// Version is an HTTP handler that returns the version of the server.
|
// Version is an HTTP handler that returns the version of the server.
|
||||||
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
||||||
v := h.Authority.Version()
|
v := h.Authority.Version()
|
||||||
JSON(w, VersionResponse{
|
render.JSON(w, VersionResponse{
|
||||||
Version: v.Version,
|
Version: v.Version,
|
||||||
RequireClientAuthentication: v.RequireClientAuthentication,
|
RequireClientAuthentication: v.RequireClientAuthentication,
|
||||||
})
|
})
|
||||||
|
@ -288,7 +293,7 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// Health is an HTTP handler that returns the status of the server.
|
// Health is an HTTP handler that returns the status of the server.
|
||||||
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||||
JSON(w, HealthResponse{Status: "ok"})
|
render.JSON(w, HealthResponse{Status: "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
||||||
|
@ -299,11 +304,11 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
||||||
// Load root certificate with the
|
// Load root certificate with the
|
||||||
cert, err := h.Authority.Root(sum)
|
cert, err := h.Authority.Root(sum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
JSON(w, &RootResponse{RootPEM: Certificate{cert}})
|
render.JSON(w, &RootResponse{RootPEM: Certificate{cert}})
|
||||||
}
|
}
|
||||||
|
|
||||||
func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
||||||
|
@ -318,16 +323,16 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
||||||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||||
cursor, limit, err := ParseCursor(r)
|
cursor, limit, err := ParseCursor(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p, next, err := h.Authority.GetProvisioners(cursor, limit)
|
p, next, err := h.Authority.GetProvisioners(cursor, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
JSON(w, &ProvisionersResponse{
|
render.JSON(w, &ProvisionersResponse{
|
||||||
Provisioners: p,
|
Provisioners: p,
|
||||||
NextCursor: next,
|
NextCursor: next,
|
||||||
})
|
})
|
||||||
|
@ -338,17 +343,17 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||||
kid := chi.URLParam(r, "kid")
|
kid := chi.URLParam(r, "kid")
|
||||||
key, err := h.Authority.GetEncryptedKey(kid)
|
key, err := h.Authority.GetEncryptedKey(kid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.NotFoundErr(err))
|
render.Error(w, errs.NotFoundErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
JSON(w, &ProvisionerKeyResponse{key})
|
render.JSON(w, &ProvisionerKeyResponse{key})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Roots returns all the root certificates for the CA.
|
// Roots returns all the root certificates for the CA.
|
||||||
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
roots, err := h.Authority.GetRoots()
|
roots, err := h.Authority.GetRoots()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error getting roots"))
|
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,16 +362,39 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
|
||||||
certs[i] = Certificate{roots[i]}
|
certs[i] = Certificate{roots[i]}
|
||||||
}
|
}
|
||||||
|
|
||||||
JSONStatus(w, &RootsResponse{
|
render.JSONStatus(w, &RootsResponse{
|
||||||
Certificates: certs,
|
Certificates: certs,
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RootsPEM returns all the root certificates for the CA in PEM format.
|
||||||
|
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||||
|
roots, err := h.Authority.GetRoots()
|
||||||
|
if err != nil {
|
||||||
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/x-pem-file")
|
||||||
|
|
||||||
|
for _, root := range roots {
|
||||||
|
block := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: root.Raw,
|
||||||
|
})
|
||||||
|
|
||||||
|
if _, err := w.Write(block); err != nil {
|
||||||
|
log.Error(w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Federation returns all the public certificates in the federation.
|
// Federation returns all the public certificates in the federation.
|
||||||
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||||
federated, err := h.Authority.GetFederation()
|
federated, err := h.Authority.GetFederation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,7 +403,7 @@ func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
|
||||||
certs[i] = Certificate{federated[i]}
|
certs[i] = Certificate{federated[i]}
|
||||||
}
|
}
|
||||||
|
|
||||||
JSONStatus(w, &FederationResponse{
|
render.JSONStatus(w, &FederationResponse{
|
||||||
Certificates: certs,
|
Certificates: certs,
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
173
api/api_test.go
173
api/api_test.go
|
@ -13,6 +13,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -27,14 +28,17 @@ import (
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
"go.step.sm/crypto/x509util"
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"github.com/smallstep/certificates/templates"
|
"github.com/smallstep/certificates/templates"
|
||||||
"go.step.sm/crypto/jose"
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -171,6 +175,7 @@ type mockAuthority struct {
|
||||||
ret1, ret2 interface{}
|
ret1, ret2 interface{}
|
||||||
err error
|
err error
|
||||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
||||||
|
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||||
getTLSOptions func() *authority.TLSOptions
|
getTLSOptions func() *authority.TLSOptions
|
||||||
root func(shasum string) (*x509.Certificate, error)
|
root func(shasum string) (*x509.Certificate, error)
|
||||||
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||||
|
@ -208,6 +213,13 @@ func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, err
|
||||||
return m.ret1.([]provisioner.SignOption), m.err
|
return m.ret1.([]provisioner.SignOption), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||||
|
if m.authorizeRenewToken != nil {
|
||||||
|
return m.authorizeRenewToken(ctx, ott)
|
||||||
|
}
|
||||||
|
return m.ret1.(*x509.Certificate), m.err
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions {
|
func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions {
|
||||||
if m.getTLSOptions != nil {
|
if m.getTLSOptions != nil {
|
||||||
return m.getTLSOptions()
|
return m.getTLSOptions()
|
||||||
|
@ -920,48 +932,141 @@ func Test_caHandler_Renew(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prepare root and leaf for renew after expiry test.
|
||||||
|
now := time.Now()
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
root := &x509.Certificate{
|
||||||
|
Subject: pkix.Name{CommonName: "Test Root CA"},
|
||||||
|
PublicKey: rootPub,
|
||||||
|
KeyUsage: x509.KeyUsageCertSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
NotBefore: now.Add(-2 * time.Hour),
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}
|
||||||
|
root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
expiredLeaf := &x509.Certificate{
|
||||||
|
Subject: pkix.Name{CommonName: "Leaf certificate"},
|
||||||
|
PublicKey: leafPub,
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||||
|
NotBefore: now.Add(-time.Hour),
|
||||||
|
NotAfter: now.Add(-time.Minute),
|
||||||
|
EmailAddresses: []string{"test@example.org"},
|
||||||
|
}
|
||||||
|
expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate renew after expiry token
|
||||||
|
so := new(jose.SignerOptions)
|
||||||
|
so.WithType("JWT")
|
||||||
|
so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)})
|
||||||
|
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
generateX5cToken := func(claims jose.Claims) string {
|
||||||
|
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
tls *tls.ConnectionState
|
tls *tls.ConnectionState
|
||||||
|
header http.Header
|
||||||
cert *x509.Certificate
|
cert *x509.Certificate
|
||||||
root *x509.Certificate
|
root *x509.Certificate
|
||||||
err error
|
err error
|
||||||
statusCode int
|
statusCode int
|
||||||
}{
|
}{
|
||||||
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
{"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||||
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
|
{"ok renew after expiry", &tls.ConnectionState{}, http.Header{
|
||||||
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
|
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
|
||||||
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
})},
|
||||||
|
}, expiredLeaf, root, nil, http.StatusCreated},
|
||||||
|
{"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest},
|
||||||
|
{"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
||||||
|
{"fail expired token", &tls.ConnectionState{}, http.Header{
|
||||||
|
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
|
||||||
|
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
|
||||||
|
})},
|
||||||
|
}, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized},
|
||||||
|
{"fail invalid root", &tls.ConnectionState{}, http.Header{
|
||||||
|
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
|
||||||
|
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
|
||||||
|
})},
|
||||||
|
}, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
h := New(&mockAuthority{
|
||||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||||
|
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||||
|
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errs.Unauthorized(err.Error())
|
||||||
|
}
|
||||||
|
var claims jose.Claims
|
||||||
|
if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil {
|
||||||
|
return nil, errs.Unauthorized(err.Error())
|
||||||
|
}
|
||||||
|
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||||
|
Time: now,
|
||||||
|
}, time.Minute); err != nil {
|
||||||
|
return nil, errs.Unauthorized(err.Error())
|
||||||
|
}
|
||||||
|
return chain[0][0], nil
|
||||||
|
},
|
||||||
getTLSOptions: func() *authority.TLSOptions {
|
getTLSOptions: func() *authority.TLSOptions {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
}).(*caHandler)
|
||||||
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
||||||
req.TLS = tt.tls
|
req.TLS = tt.tls
|
||||||
|
req.Header = tt.header
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Renew(logging.NewResponseLogger(w), req)
|
h.Renew(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
res := w.Result()
|
||||||
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
defer res.Body.Close()
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
res.Body.Close()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("caHandler.Renew unexpected error = %v", err)
|
t.Errorf("caHandler.Renew unexpected error = %v", err)
|
||||||
}
|
}
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
t.Errorf("%s", body)
|
||||||
|
}
|
||||||
|
|
||||||
if tt.statusCode < http.StatusBadRequest {
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
|
expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` +
|
||||||
|
`"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` +
|
||||||
|
`"certChain":["` +
|
||||||
|
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` +
|
||||||
|
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`)
|
||||||
|
|
||||||
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||||||
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
|
t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -1239,6 +1344,46 @@ func Test_caHandler_Roots(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_caHandler_RootsPEM(t *testing.T) {
|
||||||
|
parsedRoot := parseCertificate(rootPEM)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
roots []*x509.Certificate
|
||||||
|
err error
|
||||||
|
statusCode int
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{"one root", []*x509.Certificate{parsedRoot}, nil, http.StatusOK, rootPEM},
|
||||||
|
{"two roots", []*x509.Certificate{parsedRoot, parsedRoot}, nil, http.StatusOK, rootPEM + "\n" + rootPEM},
|
||||||
|
{"fail", nil, errors.New("an error"), http.StatusInternalServerError, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler)
|
||||||
|
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.RootsPEM(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
if res.StatusCode != tt.statusCode {
|
||||||
|
t.Errorf("caHandler.RootsPEM StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("caHandler.RootsPEM unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
if tt.statusCode < http.StatusBadRequest {
|
||||||
|
if !bytes.Equal(bytes.TrimSpace(body), []byte(tt.expect)) {
|
||||||
|
t.Errorf("caHandler.RootsPEM Body = %s, wants %s", body, tt.expect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_caHandler_Federation(t *testing.T) {
|
func Test_caHandler_Federation(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
|
|
|
@ -1,64 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/acme"
|
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"github.com/smallstep/certificates/logging"
|
|
||||||
"github.com/smallstep/certificates/scep"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WriteError writes to w a JSON representation of the given error.
|
|
||||||
func WriteError(w http.ResponseWriter, err error) {
|
|
||||||
switch k := err.(type) {
|
|
||||||
case *acme.Error:
|
|
||||||
acme.WriteError(w, k)
|
|
||||||
return
|
|
||||||
case *admin.Error:
|
|
||||||
admin.WriteError(w, k)
|
|
||||||
return
|
|
||||||
case *scep.Error:
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
default:
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
}
|
|
||||||
|
|
||||||
cause := errors.Cause(err)
|
|
||||||
if sc, ok := err.(errs.StatusCoder); ok {
|
|
||||||
w.WriteHeader(sc.StatusCode())
|
|
||||||
} else {
|
|
||||||
if sc, ok := cause.(errs.StatusCoder); ok {
|
|
||||||
w.WriteHeader(sc.StatusCode())
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write errors in the response writer
|
|
||||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
|
||||||
"error": err,
|
|
||||||
})
|
|
||||||
if os.Getenv("STEPDEBUG") == "1" {
|
|
||||||
if e, ok := err.(errs.StackTracer); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
|
||||||
"stack-trace": fmt.Sprintf("%+v", e),
|
|
||||||
})
|
|
||||||
} else if e, ok := cause.(errs.StackTracer); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
|
||||||
"stack-trace": fmt.Sprintf("%+v", e),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
|
||||||
LogError(w, err)
|
|
||||||
}
|
|
||||||
}
|
|
79
api/log/log.go
Normal file
79
api/log/log.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
// Package log implements API-related logging helpers.
|
||||||
|
package log
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StackTracedError is the set of errors implementing the StackTrace function.
|
||||||
|
//
|
||||||
|
// Errors implementing this interface have their stack traces logged when passed
|
||||||
|
// to the Error function of this package.
|
||||||
|
type StackTracedError interface {
|
||||||
|
error
|
||||||
|
|
||||||
|
StackTrace() errors.StackTrace
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error adds to the response writer the given error if it implements
|
||||||
|
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||||
|
// using the log package.
|
||||||
|
func Error(rw http.ResponseWriter, err error) {
|
||||||
|
rl, ok := rw.(logging.ResponseLogger)
|
||||||
|
if !ok {
|
||||||
|
log.Println(err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rl.WithFields(map[string]interface{}{
|
||||||
|
"error": err,
|
||||||
|
})
|
||||||
|
|
||||||
|
if os.Getenv("STEPDEBUG") != "1" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e, ok := err.(StackTracedError)
|
||||||
|
if !ok {
|
||||||
|
e, ok = errors.Cause(err).(StackTracedError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
rl.WithFields(map[string]interface{}{
|
||||||
|
"stack-trace": fmt.Sprintf("%+v", e.StackTrace()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnabledResponse log the response object if it implements the EnableLogger
|
||||||
|
// interface.
|
||||||
|
func EnabledResponse(rw http.ResponseWriter, v interface{}) {
|
||||||
|
type enableLogger interface {
|
||||||
|
ToLog() (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if el, ok := v.(enableLogger); ok {
|
||||||
|
out, err := el.ToLog()
|
||||||
|
if err != nil {
|
||||||
|
Error(rw, err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||||
|
rl.WithFields(map[string]interface{}{
|
||||||
|
"response": out,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
log.Println(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
44
api/log/log_test.go
Normal file
44
api/log/log_test.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package log
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestError(t *testing.T) {
|
||||||
|
theError := errors.New("the error")
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
rw http.ResponseWriter
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
withFields bool
|
||||||
|
}{
|
||||||
|
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
|
||||||
|
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
Error(tt.args.rw, tt.args.err)
|
||||||
|
if tt.withFields {
|
||||||
|
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
|
||||||
|
fields := rl.Fields()
|
||||||
|
if !reflect.DeepEqual(fields["error"], theError) {
|
||||||
|
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error("ResponseWriter does not implement logging.ResponseLogger")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
31
api/read/read.go
Normal file
31
api/read/read.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
// Package read implements request object readers.
|
||||||
|
package read
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/errs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JSON reads JSON from the request body and stores it in the value
|
||||||
|
// pointed by v.
|
||||||
|
func JSON(r io.Reader, v interface{}) error {
|
||||||
|
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||||
|
return errs.BadRequestErr(err, "error decoding json")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProtoJSON reads JSON from the request body and stores it in the value
|
||||||
|
// pointed by v.
|
||||||
|
func ProtoJSON(r io.Reader, m proto.Message) error {
|
||||||
|
data, err := io.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return errs.BadRequestErr(err, "error reading request body")
|
||||||
|
}
|
||||||
|
return protojson.Unmarshal(data, m)
|
||||||
|
}
|
46
api/read/read_test.go
Normal file
46
api/read/read_test.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package read
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/errs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJSON(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
r io.Reader
|
||||||
|
v interface{}
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
|
||||||
|
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := JSON(tt.args.r, &tt.args.v)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
e, ok := err.(*errs.Error)
|
||||||
|
if ok {
|
||||||
|
if code := e.StatusCode(); code != 400 {
|
||||||
|
t.Errorf("error.StatusCode() = %v, wants 400", code)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("error type = %T, wants *Error", err)
|
||||||
|
}
|
||||||
|
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
|
||||||
|
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
14
api/rekey.go
14
api/rekey.go
|
@ -3,6 +3,8 @@ package api
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,24 +29,24 @@ func (s *RekeyRequest) Validate() error {
|
||||||
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
||||||
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||||
WriteError(w, errs.BadRequest("missing client certificate"))
|
render.Error(w, errs.BadRequest("missing client certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var body RekeyRequest
|
var body RekeyRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
certChainPEM := certChainToPEM(certChain)
|
certChainPEM := certChainToPEM(certChain)
|
||||||
|
@ -54,7 +56,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogCertificate(w, certChain[0])
|
LogCertificate(w, certChain[0])
|
||||||
JSONStatus(w, &SignResponse{
|
render.JSONStatus(w, &SignResponse{
|
||||||
ServerPEM: certChainPEM[0],
|
ServerPEM: certChainPEM[0],
|
||||||
CaPEM: caPEM,
|
CaPEM: caPEM,
|
||||||
CertChainPEM: certChainPEM,
|
CertChainPEM: certChainPEM,
|
||||||
|
|
122
api/render/render.go
Normal file
122
api/render/render.go
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
// Package render implements functionality related to response rendering.
|
||||||
|
package render
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JSON is shorthand for JSONStatus(w, v, http.StatusOK).
|
||||||
|
func JSON(w http.ResponseWriter, v interface{}) {
|
||||||
|
JSONStatus(w, v, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSONStatus marshals v into w. It additionally sets the status code of
|
||||||
|
// w to the given one.
|
||||||
|
//
|
||||||
|
// JSONStatus sets the Content-Type of w to application/json unless one is
|
||||||
|
// specified.
|
||||||
|
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(v); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
setContentTypeUnlessPresent(w, "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_, _ = b.WriteTo(w)
|
||||||
|
|
||||||
|
log.EnabledResponse(w, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK).
|
||||||
|
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
|
||||||
|
ProtoJSONStatus(w, m, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
|
||||||
|
// given status is written as the status code of the response.
|
||||||
|
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
|
||||||
|
b, err := protojson.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
setContentTypeUnlessPresent(w, "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_, _ = w.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) {
|
||||||
|
const header = "Content-Type"
|
||||||
|
|
||||||
|
h := w.Header()
|
||||||
|
if _, ok := h[header]; !ok {
|
||||||
|
h.Set(header, contentType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenderableError is the set of errors that implement the basic Render method.
|
||||||
|
//
|
||||||
|
// Errors that implement this interface will use their own Render method when
|
||||||
|
// being rendered into responses.
|
||||||
|
type RenderableError interface {
|
||||||
|
error
|
||||||
|
|
||||||
|
Render(http.ResponseWriter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error marshals the JSON representation of err to w. In case err implements
|
||||||
|
// RenderableError its own Render method will be called instead.
|
||||||
|
func Error(w http.ResponseWriter, err error) {
|
||||||
|
log.Error(w, err)
|
||||||
|
|
||||||
|
if e, ok := err.(RenderableError); ok {
|
||||||
|
e.Render(w)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
JSONStatus(w, err, statusCodeFromError(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusCodedError is the set of errors that implement the basic StatusCode
|
||||||
|
// function.
|
||||||
|
//
|
||||||
|
// Errors that implement this interface will use the code reported by StatusCode
|
||||||
|
// as the HTTP response code when being rendered by this package.
|
||||||
|
type StatusCodedError interface {
|
||||||
|
error
|
||||||
|
|
||||||
|
StatusCode() int
|
||||||
|
}
|
||||||
|
|
||||||
|
func statusCodeFromError(err error) (code int) {
|
||||||
|
code = http.StatusInternalServerError
|
||||||
|
|
||||||
|
type causer interface {
|
||||||
|
Cause() error
|
||||||
|
}
|
||||||
|
|
||||||
|
for err != nil {
|
||||||
|
if sc, ok := err.(StatusCodedError); ok {
|
||||||
|
code = sc.StatusCode()
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
cause, ok := err.(causer)
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = cause.Cause()
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
115
api/render/render_test.go
Normal file
115
api/render/render_test.go
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
package render
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJSON(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rw := logging.NewResponseLogger(rec)
|
||||||
|
|
||||||
|
JSON(rw, map[string]interface{}{"foo": "bar"})
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Result().StatusCode)
|
||||||
|
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||||
|
assert.Equal(t, "{\"foo\":\"bar\"}\n", rec.Body.String())
|
||||||
|
|
||||||
|
assert.Empty(t, rw.Fields())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONPanics(t *testing.T) {
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
JSON(httptest.NewRecorder(), make(chan struct{}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type renderableError struct {
|
||||||
|
Code int `json:"-"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err renderableError) Error() string {
|
||||||
|
return err.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err renderableError) Render(w http.ResponseWriter) {
|
||||||
|
w.Header().Set("Content-Type", "something/custom")
|
||||||
|
|
||||||
|
JSONStatus(w, err, err.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
type statusedError struct {
|
||||||
|
Contents string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err statusedError) Error() string { return err.Contents }
|
||||||
|
|
||||||
|
func (statusedError) StatusCode() int { return 432 }
|
||||||
|
|
||||||
|
func TestError(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
err error
|
||||||
|
code int
|
||||||
|
body string
|
||||||
|
header string
|
||||||
|
}{
|
||||||
|
0: {
|
||||||
|
err: renderableError{532, "some string"},
|
||||||
|
code: 532,
|
||||||
|
body: "{\"message\":\"some string\"}\n",
|
||||||
|
header: "something/custom",
|
||||||
|
},
|
||||||
|
1: {
|
||||||
|
err: statusedError{"123"},
|
||||||
|
code: 432,
|
||||||
|
body: "{\"Contents\":\"123\"}\n",
|
||||||
|
header: "application/json",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for caseIndex := range cases {
|
||||||
|
kase := cases[caseIndex]
|
||||||
|
|
||||||
|
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
Error(rec, kase.err)
|
||||||
|
|
||||||
|
assert.Equal(t, kase.code, rec.Result().StatusCode)
|
||||||
|
assert.Equal(t, kase.body, rec.Body.String())
|
||||||
|
assert.Equal(t, kase.header, rec.Header().Get("Content-Type"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type causedError struct {
|
||||||
|
cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err causedError) Error() string { return fmt.Sprintf("cause: %s", err.cause) }
|
||||||
|
func (err causedError) Cause() error { return err.cause }
|
||||||
|
|
||||||
|
func TestStatusCodeFromError(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
err error
|
||||||
|
exp int
|
||||||
|
}{
|
||||||
|
0: {nil, http.StatusInternalServerError},
|
||||||
|
1: {io.EOF, http.StatusInternalServerError},
|
||||||
|
2: {statusedError{"123"}, 432},
|
||||||
|
3: {causedError{statusedError{"432"}}, 432},
|
||||||
|
}
|
||||||
|
|
||||||
|
for caseIndex, kase := range cases {
|
||||||
|
assert.Equal(t, kase.exp, statusCodeFromError(kase.err), "case: %d", caseIndex)
|
||||||
|
}
|
||||||
|
}
|
31
api/renew.go
31
api/renew.go
|
@ -1,22 +1,31 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/x509"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
authorizationHeader = "Authorization"
|
||||||
|
bearerScheme = "Bearer"
|
||||||
|
)
|
||||||
|
|
||||||
// Renew uses the information of certificate in the TLS connection to create a
|
// Renew uses the information of certificate in the TLS connection to create a
|
||||||
// new one.
|
// new one.
|
||||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
cert, err := h.getPeerCertificate(r)
|
||||||
WriteError(w, errs.BadRequest("missing client certificate"))
|
if err != nil {
|
||||||
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
|
certChain, err := h.Authority.Renew(cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
certChainPEM := certChainToPEM(certChain)
|
certChainPEM := certChainToPEM(certChain)
|
||||||
|
@ -26,10 +35,22 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogCertificate(w, certChain[0])
|
LogCertificate(w, certChain[0])
|
||||||
JSONStatus(w, &SignResponse{
|
render.JSONStatus(w, &SignResponse{
|
||||||
ServerPEM: certChainPEM[0],
|
ServerPEM: certChainPEM[0],
|
||||||
CaPEM: caPEM,
|
CaPEM: caPEM,
|
||||||
CertChainPEM: certChainPEM,
|
CertChainPEM: certChainPEM,
|
||||||
TLSOptions: h.Authority.GetTLSOptions(),
|
TLSOptions: h.Authority.GetTLSOptions(),
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||||
|
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||||
|
return r.TLS.PeerCertificates[0], nil
|
||||||
|
}
|
||||||
|
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||||
|
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||||
|
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errs.BadRequest("missing client certificate")
|
||||||
|
}
|
||||||
|
|
|
@ -4,11 +4,14 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ocsp"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"golang.org/x/crypto/ocsp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RevokeResponse is the response object that returns the health of the server.
|
// RevokeResponse is the response object that returns the health of the server.
|
||||||
|
@ -48,13 +51,13 @@ func (r *RevokeRequest) Validate() (err error) {
|
||||||
// TODO: Add CRL and OCSP support.
|
// TODO: Add CRL and OCSP support.
|
||||||
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||||
var body RevokeRequest
|
var body RevokeRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +74,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||||
if len(body.OTT) > 0 {
|
if len(body.OTT) > 0 {
|
||||||
logOtt(w, body.OTT)
|
logOtt(w, body.OTT)
|
||||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
opts.OTT = body.OTT
|
opts.OTT = body.OTT
|
||||||
|
@ -80,12 +83,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||||
// the client certificate Serial Number must match the serial number
|
// the client certificate Serial Number must match the serial number
|
||||||
// being revoked.
|
// being revoked.
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||||
WriteError(w, errs.BadRequest("missing ott or client certificate"))
|
render.Error(w, errs.BadRequest("missing ott or client certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
opts.Crt = r.TLS.PeerCertificates[0]
|
opts.Crt = r.TLS.PeerCertificates[0]
|
||||||
if opts.Crt.SerialNumber.String() != opts.Serial {
|
if opts.Crt.SerialNumber.String() != opts.Serial {
|
||||||
WriteError(w, errs.BadRequest("serial number in client certificate different than body"))
|
render.Error(w, errs.BadRequest("serial number in client certificate different than body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// TODO: should probably be checking if the certificate was revoked here.
|
// TODO: should probably be checking if the certificate was revoked here.
|
||||||
|
@ -96,12 +99,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logRevoke(w, opts)
|
logRevoke(w, opts)
|
||||||
JSON(w, &RevokeResponse{Status: "ok"})
|
render.JSON(w, &RevokeResponse{Status: "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
|
14
api/sign.go
14
api/sign.go
|
@ -5,6 +5,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/config"
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
|
@ -49,14 +51,14 @@ type SignResponse struct {
|
||||||
// information in the certificate request.
|
// information in the certificate request.
|
||||||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SignRequest
|
var body SignRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logOtt(w, body.OTT)
|
logOtt(w, body.OTT)
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,13 +70,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error signing certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
certChainPEM := certChainToPEM(certChain)
|
certChainPEM := certChainToPEM(certChain)
|
||||||
|
@ -83,7 +85,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||||
caPEM = certChainPEM[1]
|
caPEM = certChainPEM[1]
|
||||||
}
|
}
|
||||||
LogCertificate(w, certChain[0])
|
LogCertificate(w, certChain[0])
|
||||||
JSONStatus(w, &SignResponse{
|
render.JSONStatus(w, &SignResponse{
|
||||||
ServerPEM: certChainPEM[0],
|
ServerPEM: certChainPEM[0],
|
||||||
CaPEM: caPEM,
|
CaPEM: caPEM,
|
||||||
CertChainPEM: certChainPEM,
|
CertChainPEM: certChainPEM,
|
||||||
|
|
75
api/ssh.go
75
api/ssh.go
|
@ -9,12 +9,15 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/config"
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"github.com/smallstep/certificates/templates"
|
"github.com/smallstep/certificates/templates"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSHAuthority is the interface implemented by a SSH CA authority.
|
// SSHAuthority is the interface implemented by a SSH CA authority.
|
||||||
|
@ -249,20 +252,20 @@ type SSHBastionResponse struct {
|
||||||
// the request.
|
// the request.
|
||||||
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHSignRequest
|
var body SSHSignRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logOtt(w, body.OTT)
|
logOtt(w, body.OTT)
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,7 +273,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
if body.AddUserPublicKey != nil {
|
if body.AddUserPublicKey != nil {
|
||||||
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
|
render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -287,13 +290,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
|
||||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
|
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -301,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
||||||
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
addUserCertificate = &SSHCertificate{addUserCert}
|
addUserCertificate = &SSHCertificate{addUserCert}
|
||||||
|
@ -314,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,13 +329,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
identityCertificate = certChainToPEM(certChain)
|
identityCertificate = certChainToPEM(certChain)
|
||||||
}
|
}
|
||||||
|
|
||||||
JSONStatus(w, &SSHSignResponse{
|
render.JSONStatus(w, &SSHSignResponse{
|
||||||
Certificate: SSHCertificate{cert},
|
Certificate: SSHCertificate{cert},
|
||||||
AddUserCertificate: addUserCertificate,
|
AddUserCertificate: addUserCertificate,
|
||||||
IdentityCertificate: identityCertificate,
|
IdentityCertificate: identityCertificate,
|
||||||
|
@ -344,12 +347,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||||
keys, err := h.Authority.GetSSHRoots(r.Context())
|
keys, err := h.Authority.GetSSHRoots(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||||
WriteError(w, errs.NotFound("no keys found"))
|
render.Error(w, errs.NotFound("no keys found"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -361,7 +364,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||||
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
||||||
}
|
}
|
||||||
|
|
||||||
JSON(w, resp)
|
render.JSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
||||||
|
@ -369,12 +372,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||||
keys, err := h.Authority.GetSSHFederation(r.Context())
|
keys, err := h.Authority.GetSSHFederation(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||||
WriteError(w, errs.NotFound("no keys found"))
|
render.Error(w, errs.NotFound("no keys found"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -386,25 +389,25 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||||
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
||||||
}
|
}
|
||||||
|
|
||||||
JSON(w, resp)
|
render.JSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
||||||
// and servers.
|
// and servers.
|
||||||
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHConfigRequest
|
var body SSHConfigRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
|
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -415,31 +418,31 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
case provisioner.SSHHostCert:
|
case provisioner.SSHHostCert:
|
||||||
cfg.HostTemplates = ts
|
cfg.HostTemplates = ts
|
||||||
default:
|
default:
|
||||||
WriteError(w, errs.InternalServer("it should hot get here"))
|
render.Error(w, errs.InternalServer("it should hot get here"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
JSON(w, cfg)
|
render.JSON(w, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||||
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHCheckPrincipalRequest
|
var body SSHCheckPrincipalRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
|
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
JSON(w, &SSHCheckPrincipalResponse{
|
render.JSON(w, &SSHCheckPrincipalResponse{
|
||||||
Exists: exists,
|
Exists: exists,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -453,10 +456,10 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
|
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
JSON(w, &SSHGetHostsResponse{
|
render.JSON(w, &SSHGetHostsResponse{
|
||||||
Hosts: hosts,
|
Hosts: hosts,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -464,22 +467,22 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||||
// SSHBastion provides returns the bastion configured if any.
|
// SSHBastion provides returns the bastion configured if any.
|
||||||
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHBastionRequest
|
var body SSHBastionRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
|
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
JSON(w, &SSHBastionResponse{
|
render.JSON(w, &SSHBastionResponse{
|
||||||
Hostname: body.Hostname,
|
Hostname: body.Hostname,
|
||||||
Bastion: bastion,
|
Bastion: bastion,
|
||||||
})
|
})
|
||||||
|
|
|
@ -4,9 +4,12 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSHRekeyRequest is the request body of an SSH certificate request.
|
// SSHRekeyRequest is the request body of an SSH certificate request.
|
||||||
|
@ -38,37 +41,38 @@ type SSHRekeyResponse struct {
|
||||||
// the request.
|
// the request.
|
||||||
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHRekeyRequest
|
var body SSHRekeyRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logOtt(w, body.OTT)
|
logOtt(w, body.OTT)
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
|
||||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,11 +82,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
JSONStatus(w, &SSHRekeyResponse{
|
render.JSONStatus(w, &SSHRekeyResponse{
|
||||||
Certificate: SSHCertificate{newCert},
|
Certificate: SSHCertificate{newCert},
|
||||||
IdentityCertificate: identity,
|
IdentityCertificate: identity,
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
|
|
|
@ -6,6 +6,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
)
|
)
|
||||||
|
@ -36,31 +39,32 @@ type SSHRenewResponse struct {
|
||||||
// the request.
|
// the request.
|
||||||
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHRenewRequest
|
var body SSHRenewRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logOtt(w, body.OTT)
|
logOtt(w, body.OTT)
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
|
||||||
_, err := h.Authority.Authorize(ctx, body.OTT)
|
_, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
|
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,11 +74,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
JSONStatus(w, &SSHSignResponse{
|
render.JSONStatus(w, &SSHSignResponse{
|
||||||
Certificate: SSHCertificate{newCert},
|
Certificate: SSHCertificate{newCert},
|
||||||
IdentityCertificate: identity,
|
IdentityCertificate: identity,
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
|
|
|
@ -3,11 +3,14 @@ package api
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ocsp"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"golang.org/x/crypto/ocsp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSHRevokeResponse is the response object that returns the health of the server.
|
// SSHRevokeResponse is the response object that returns the health of the server.
|
||||||
|
@ -47,13 +50,13 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
||||||
// NOTE: currently only Passive revocation is supported.
|
// NOTE: currently only Passive revocation is supported.
|
||||||
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||||
var body SSHRevokeRequest
|
var body SSHRevokeRequest
|
||||||
if err := ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,18 +72,18 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||||
logOtt(w, body.OTT)
|
logOtt(w, body.OTT)
|
||||||
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
render.Error(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
opts.OTT = body.OTT
|
opts.OTT = body.OTT
|
||||||
|
|
||||||
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
if err := h.Authority.Revoke(ctx, opts); err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logSSHRevoke(w, opts)
|
logSSHRevoke(w, opts)
|
||||||
JSON(w, &SSHRevokeResponse{Status: "ok"})
|
render.JSON(w, &SSHRevokeResponse{Status: "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||||
|
|
|
@ -18,12 +18,13 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"github.com/smallstep/certificates/templates"
|
"github.com/smallstep/certificates/templates"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
109
api/utils.go
109
api/utils.go
|
@ -1,109 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"github.com/smallstep/certificates/logging"
|
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// EnableLogger is an interface that enables response logging for an object.
|
|
||||||
type EnableLogger interface {
|
|
||||||
ToLog() (interface{}, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogError adds to the response writer the given error if it implements
|
|
||||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
|
||||||
// using the log package.
|
|
||||||
func LogError(rw http.ResponseWriter, err error) {
|
|
||||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
|
||||||
"error": err,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogEnabledResponse log the response object if it implements the EnableLogger
|
|
||||||
// interface.
|
|
||||||
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
|
|
||||||
if el, ok := v.(EnableLogger); ok {
|
|
||||||
out, err := el.ToLog()
|
|
||||||
if err != nil {
|
|
||||||
LogError(rw, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
|
||||||
"response": out,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
log.Println(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// JSON writes the passed value into the http.ResponseWriter.
|
|
||||||
func JSON(w http.ResponseWriter, v interface{}) {
|
|
||||||
JSONStatus(w, v, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// JSONStatus writes the given value into the http.ResponseWriter and the
|
|
||||||
// given status is written as the status code of the response.
|
|
||||||
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
|
||||||
LogError(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
LogEnabledResponse(w, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProtoJSON writes the passed value into the http.ResponseWriter.
|
|
||||||
func ProtoJSON(w http.ResponseWriter, m proto.Message) {
|
|
||||||
ProtoJSONStatus(w, m, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProtoJSONStatus writes the given value into the http.ResponseWriter and the
|
|
||||||
// given status is written as the status code of the response.
|
|
||||||
func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
|
|
||||||
b, err := protojson.Marshal(m)
|
|
||||||
if err != nil {
|
|
||||||
LogError(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := w.Write(b); err != nil {
|
|
||||||
LogError(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//LogEnabledResponse(w, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadJSON reads JSON from the request body and stores it in the value
|
|
||||||
// pointed by v.
|
|
||||||
func ReadJSON(r io.Reader, v interface{}) error {
|
|
||||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
|
||||||
return errs.BadRequestErr(err, "error decoding json")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadProtoJSON reads JSON from the request body and stores it in the value
|
|
||||||
// pointed by v.
|
|
||||||
func ReadProtoJSON(r io.Reader, m proto.Message) error {
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
return errs.BadRequestErr(err, "error reading request body")
|
|
||||||
}
|
|
||||||
return protojson.Unmarshal(data, m)
|
|
||||||
}
|
|
|
@ -1,125 +0,0 @@
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"github.com/smallstep/certificates/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLogError(t *testing.T) {
|
|
||||||
theError := errors.New("the error")
|
|
||||||
type args struct {
|
|
||||||
rw http.ResponseWriter
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
withFields bool
|
|
||||||
}{
|
|
||||||
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
|
|
||||||
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
LogError(tt.args.rw, tt.args.err)
|
|
||||||
if tt.withFields {
|
|
||||||
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
|
|
||||||
fields := rl.Fields()
|
|
||||||
if !reflect.DeepEqual(fields["error"], theError) {
|
|
||||||
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.Error("ResponseWriter does not implement logging.ResponseLogger")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJSON(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
rw http.ResponseWriter
|
|
||||||
v interface{}
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
ok bool
|
|
||||||
}{
|
|
||||||
{"ok", args{httptest.NewRecorder(), map[string]interface{}{"foo": "bar"}}, true},
|
|
||||||
{"fail", args{httptest.NewRecorder(), make(chan int)}, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
rw := logging.NewResponseLogger(tt.args.rw)
|
|
||||||
JSON(rw, tt.args.v)
|
|
||||||
|
|
||||||
rr, ok := tt.args.rw.(*httptest.ResponseRecorder)
|
|
||||||
if !ok {
|
|
||||||
t.Error("ResponseWriter does not implement *httptest.ResponseRecorder")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := rw.Fields()
|
|
||||||
if tt.ok {
|
|
||||||
if body := rr.Body.String(); body != "{\"foo\":\"bar\"}\n" {
|
|
||||||
t.Errorf(`Unexpected body = %v, want {"foo":"bar"}`, body)
|
|
||||||
}
|
|
||||||
if len(fields) != 0 {
|
|
||||||
t.Errorf("ResponseLogger fields = %v, wants 0 elements", fields)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if body := rr.Body.String(); body != "" {
|
|
||||||
t.Errorf("Unexpected body = %s, want empty string", body)
|
|
||||||
}
|
|
||||||
if len(fields) != 1 {
|
|
||||||
t.Errorf("ResponseLogger fields = %v, wants 1 element", fields)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadJSON(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
r io.Reader
|
|
||||||
v interface{}
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
|
|
||||||
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
err := ReadJSON(tt.args.r, &tt.args.v)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
if tt.wantErr {
|
|
||||||
e, ok := err.(*errs.Error)
|
|
||||||
if ok {
|
|
||||||
if code := e.StatusCode(); code != 400 {
|
|
||||||
t.Errorf("error.StatusCode() = %v, wants 400", code)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.Errorf("error type = %T, wants *Error", err)
|
|
||||||
}
|
|
||||||
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
|
|
||||||
t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -6,10 +6,12 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/smallstep/certificates/api"
|
|
||||||
|
"go.step.sm/linkedca"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"go.step.sm/linkedca"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -44,11 +46,11 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
|
||||||
provName := chi.URLParam(r, "provisionerName")
|
provName := chi.URLParam(r, "provisionerName")
|
||||||
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
|
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !eabEnabled {
|
if !eabEnabled {
|
||||||
api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
|
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||||
|
@ -101,15 +103,15 @@ func NewACMEAdminResponder() *ACMEAdminResponder {
|
||||||
|
|
||||||
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
|
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
|
||||||
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
||||||
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
|
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
|
||||||
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||||
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
|
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
|
||||||
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||||
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,10 +5,14 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
|
|
||||||
|
"go.step.sm/linkedca"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"go.step.sm/linkedca"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type adminAuthority interface {
|
type adminAuthority interface {
|
||||||
|
@ -82,28 +86,28 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
adm, ok := h.auth.LoadAdminByID(id)
|
adm, ok := h.auth.LoadAdminByID(id)
|
||||||
if !ok {
|
if !ok {
|
||||||
api.WriteError(w, admin.NewError(admin.ErrorNotFoundType,
|
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
|
||||||
"admin %s not found", id))
|
"admin %s not found", id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
api.ProtoJSON(w, adm)
|
render.ProtoJSON(w, adm)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdmins returns a segment of admins associated with the authority.
|
// GetAdmins returns a segment of admins associated with the authority.
|
||||||
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||||
cursor, limit, err := api.ParseCursor(r)
|
cursor, limit, err := api.ParseCursor(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||||
"error parsing cursor and limit from query params"))
|
"error parsing cursor and limit from query params"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
|
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
api.JSON(w, &GetAdminsResponse{
|
render.JSON(w, &GetAdminsResponse{
|
||||||
Admins: admins,
|
Admins: admins,
|
||||||
NextCursor: nextCursor,
|
NextCursor: nextCursor,
|
||||||
})
|
})
|
||||||
|
@ -112,19 +116,19 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||||
// CreateAdmin creates a new admin.
|
// CreateAdmin creates a new admin.
|
||||||
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
var body CreateAdminRequest
|
var body CreateAdminRequest
|
||||||
if err := api.ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
|
p, err := h.auth.LoadProvisionerByName(body.Provisioner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
adm := &linkedca.Admin{
|
adm := &linkedca.Admin{
|
||||||
|
@ -134,11 +138,11 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
// Store to authority collection.
|
// Store to authority collection.
|
||||||
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error storing admin"))
|
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
api.ProtoJSONStatus(w, adm, http.StatusCreated)
|
render.ProtoJSONStatus(w, adm, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAdmin deletes admin.
|
// DeleteAdmin deletes admin.
|
||||||
|
@ -146,23 +150,23 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
id := chi.URLParam(r, "id")
|
id := chi.URLParam(r, "id")
|
||||||
|
|
||||||
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
|
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
api.JSON(w, &DeleteResponse{Status: "ok"})
|
render.JSON(w, &DeleteResponse{Status: "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAdmin updates an existing admin.
|
// UpdateAdmin updates an existing admin.
|
||||||
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
var body UpdateAdminRequest
|
var body UpdateAdminRequest
|
||||||
if err := api.ReadJSON(r.Body, &body); err != nil {
|
if err := read.JSON(r.Body, &body); err != nil {
|
||||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := body.Validate(); err != nil {
|
if err := body.Validate(); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,9 +174,9 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
api.ProtoJSON(w, adm)
|
render.ProtoJSON(w, adm)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||||
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !h.auth.IsAdminAPIEnabled() {
|
if !h.auth.IsAdminAPIEnabled() {
|
||||||
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType,
|
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||||
"administration API not enabled"))
|
"administration API not enabled"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -28,14 +28,14 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
tok := r.Header.Get("Authorization")
|
tok := r.Header.Get("Authorization")
|
||||||
if tok == "" {
|
if tok == "" {
|
||||||
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
|
render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
|
||||||
"missing authorization header token"))
|
"missing authorization header token"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
adm, err := h.auth.AuthorizeAdminToken(r, tok)
|
adm, err := h.auth.AuthorizeAdminToken(r, tok)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,16 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
|
|
||||||
|
"go.step.sm/linkedca"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"go.step.sm/linkedca"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetProvisionersResponse is the type for GET /admin/provisioners responses.
|
// GetProvisionersResponse is the type for GET /admin/provisioners responses.
|
||||||
|
@ -31,39 +35,39 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
)
|
)
|
||||||
if len(id) > 0 {
|
if len(id) > 0 {
|
||||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
api.ProtoJSON(w, prov)
|
render.ProtoJSON(w, prov)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProvisioners returns the given segment of provisioners associated with the authority.
|
// GetProvisioners returns the given segment of provisioners associated with the authority.
|
||||||
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||||
cursor, limit, err := api.ParseCursor(r)
|
cursor, limit, err := api.ParseCursor(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||||
"error parsing cursor and limit from query params"))
|
"error parsing cursor and limit from query params"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p, next, err := h.auth.GetProvisioners(cursor, limit)
|
p, next, err := h.auth.GetProvisioners(cursor, limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, errs.InternalServerErr(err))
|
render.Error(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
api.JSON(w, &GetProvisionersResponse{
|
render.JSON(w, &GetProvisionersResponse{
|
||||||
Provisioners: p,
|
Provisioners: p,
|
||||||
NextCursor: next,
|
NextCursor: next,
|
||||||
})
|
})
|
||||||
|
@ -72,22 +76,22 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||||
// CreateProvisioner creates a new prov.
|
// CreateProvisioner creates a new prov.
|
||||||
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
var prov = new(linkedca.Provisioner)
|
var prov = new(linkedca.Provisioner)
|
||||||
if err := api.ReadProtoJSON(r.Body, prov); err != nil {
|
if err := read.ProtoJSON(r.Body, prov); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Validate inputs
|
// TODO: Validate inputs
|
||||||
if err := authority.ValidateClaims(prov.Claims); err != nil {
|
if err := authority.ValidateClaims(prov.Claims); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
|
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
api.ProtoJSONStatus(w, prov, http.StatusCreated)
|
render.ProtoJSONStatus(w, prov, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteProvisioner deletes a provisioner.
|
// DeleteProvisioner deletes a provisioner.
|
||||||
|
@ -101,75 +105,75 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
)
|
)
|
||||||
if len(id) > 0 {
|
if len(id) > 0 {
|
||||||
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
if p, err = h.auth.LoadProvisionerByID(id); err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
api.JSON(w, &DeleteResponse{Status: "ok"})
|
render.JSON(w, &DeleteResponse{Status: "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProvisioner updates an existing prov.
|
// UpdateProvisioner updates an existing prov.
|
||||||
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||||
var nu = new(linkedca.Provisioner)
|
var nu = new(linkedca.Provisioner)
|
||||||
if err := api.ReadProtoJSON(r.Body, nu); err != nil {
|
if err := read.ProtoJSON(r.Body, nu); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := chi.URLParam(r, "name")
|
name := chi.URLParam(r, "name")
|
||||||
_old, err := h.auth.LoadProvisionerByName(name)
|
_old, err := h.auth.LoadProvisionerByName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
|
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if nu.Id != old.Id {
|
if nu.Id != old.Id {
|
||||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner ID"))
|
render.Error(w, admin.NewErrorISE("cannot change provisioner ID"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if nu.Type != old.Type {
|
if nu.Type != old.Type {
|
||||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner type"))
|
render.Error(w, admin.NewErrorISE("cannot change provisioner type"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if nu.AuthorityId != old.AuthorityId {
|
if nu.AuthorityId != old.AuthorityId {
|
||||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner authorityID"))
|
render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
|
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
|
||||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner createdAt"))
|
render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
|
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
|
||||||
api.WriteError(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
|
render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Validate inputs
|
// TODO: Validate inputs
|
||||||
if err := authority.ValidateClaims(nu.Claims); err != nil {
|
if err := authority.ValidateClaims(nu.Claims); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||||
api.WriteError(w, err)
|
render.Error(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
api.ProtoJSON(w, nu)
|
render.ProtoJSON(w, nu)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,13 +3,10 @@ package admin
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProblemType is the type of the Admin problem.
|
// ProblemType is the type of the Admin problem.
|
||||||
|
@ -197,27 +194,9 @@ func (e *Error) ToLog() (interface{}, error) {
|
||||||
return string(b), nil
|
return string(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteError writes to w a JSON representation of the given error.
|
// Render implements render.RenderableError for Error.
|
||||||
func WriteError(w http.ResponseWriter, err *Error) {
|
func (e *Error) Render(w http.ResponseWriter) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
e.Message = e.Err.Error()
|
||||||
w.WriteHeader(err.StatusCode())
|
|
||||||
|
|
||||||
err.Message = err.Err.Error()
|
render.JSONStatus(w, e, e.StatusCode())
|
||||||
// Write errors in the response writer
|
|
||||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
|
||||||
"error": err.Err,
|
|
||||||
})
|
|
||||||
if os.Getenv("STEPDEBUG") == "1" {
|
|
||||||
if e, ok := err.Err.(errs.StackTracer); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
|
||||||
"stack-trace": fmt.Sprintf("%+v", e),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,14 +70,24 @@ type Authority struct {
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
|
|
||||||
// Custom functions
|
// Custom functions
|
||||||
sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error)
|
sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error)
|
||||||
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
|
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
|
||||||
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
|
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
|
||||||
getIdentityFunc provisioner.GetIdentityFunc
|
getIdentityFunc provisioner.GetIdentityFunc
|
||||||
|
authorizeRenewFunc provisioner.AuthorizeRenewFunc
|
||||||
|
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
|
||||||
|
|
||||||
adminMutex sync.RWMutex
|
adminMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Info struct {
|
||||||
|
StartTime time.Time
|
||||||
|
RootX509Certs []*x509.Certificate
|
||||||
|
SSHCAUserPublicKey []byte
|
||||||
|
SSHCAHostPublicKey []byte
|
||||||
|
DNSNames []string
|
||||||
|
}
|
||||||
|
|
||||||
// New creates and initiates a new Authority type.
|
// New creates and initiates a new Authority type.
|
||||||
func New(cfg *config.Config, opts ...Option) (*Authority, error) {
|
func New(cfg *config.Config, opts ...Option) (*Authority, error) {
|
||||||
err := cfg.Validate()
|
err := cfg.Validate()
|
||||||
|
@ -175,7 +185,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
|
||||||
// Create provisioner collection.
|
// Create provisioner collection.
|
||||||
provClxn := provisioner.NewCollection(provisionerConfig.Audiences)
|
provClxn := provisioner.NewCollection(provisionerConfig.Audiences)
|
||||||
for _, p := range provList {
|
for _, p := range provList {
|
||||||
if err := p.Init(*provisionerConfig); err != nil {
|
if err := p.Init(provisionerConfig); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := provClxn.Store(p); err != nil {
|
if err := provClxn.Store(p); err != nil {
|
||||||
|
@ -251,6 +261,21 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize linkedca client if necessary. On a linked RA, the issuer
|
||||||
|
// configuration might come from majordomo.
|
||||||
|
var linkedcaClient *linkedCaClient
|
||||||
|
if a.config.AuthorityConfig.EnableAdmin && a.linkedCAToken != "" && a.adminDB == nil {
|
||||||
|
linkedcaClient, err = newLinkedCAClient(a.linkedCAToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// If authorityId is configured make sure it matches the one in the token
|
||||||
|
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) {
|
||||||
|
return errors.New("error initializing linkedca: token authority and configured authority do not match")
|
||||||
|
}
|
||||||
|
linkedcaClient.Run()
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize the X.509 CA Service if it has not been set in the options.
|
// Initialize the X.509 CA Service if it has not been set in the options.
|
||||||
if a.x509CAService == nil {
|
if a.x509CAService == nil {
|
||||||
var options casapi.Options
|
var options casapi.Options
|
||||||
|
@ -258,6 +283,22 @@ func (a *Authority) init() error {
|
||||||
options = *a.config.AuthorityConfig.Options
|
options = *a.config.AuthorityConfig.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure linked RA
|
||||||
|
if linkedcaClient != nil && options.CertificateAuthority == "" {
|
||||||
|
conf, err := linkedcaClient.GetConfiguration(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if conf.RaConfig != nil {
|
||||||
|
options.CertificateAuthority = conf.RaConfig.CaUrl
|
||||||
|
options.CertificateAuthorityFingerprint = conf.RaConfig.Fingerprint
|
||||||
|
options.CertificateIssuer = &casapi.CertificateIssuer{
|
||||||
|
Type: conf.RaConfig.Provisioner.Type.String(),
|
||||||
|
Provisioner: conf.RaConfig.Provisioner.Name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Set the issuer password if passed in the flags.
|
// Set the issuer password if passed in the flags.
|
||||||
if options.CertificateIssuer != nil && a.issuerPassword != nil {
|
if options.CertificateIssuer != nil && a.issuerPassword != nil {
|
||||||
options.CertificateIssuer.Password = string(a.issuerPassword)
|
options.CertificateIssuer.Password = string(a.issuerPassword)
|
||||||
|
@ -292,8 +333,6 @@ func (a *Authority) init() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate)
|
a.rootX509Certs = append(a.rootX509Certs, resp.RootCertificate)
|
||||||
sum := sha256.Sum256(resp.RootCertificate.Raw)
|
|
||||||
log.Printf("Using root fingerprint '%s'", hex.EncodeToString(sum[:]))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -479,24 +518,13 @@ func (a *Authority) init() error {
|
||||||
// Initialize step-ca Admin Database if it's not already initialized using
|
// Initialize step-ca Admin Database if it's not already initialized using
|
||||||
// WithAdminDB.
|
// WithAdminDB.
|
||||||
if a.adminDB == nil {
|
if a.adminDB == nil {
|
||||||
if a.linkedCAToken == "" {
|
if linkedcaClient != nil {
|
||||||
// Check if AuthConfig already exists
|
a.adminDB = linkedcaClient
|
||||||
|
} else {
|
||||||
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
|
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Use the linkedca client as the admindb.
|
|
||||||
client, err := newLinkedCAClient(a.linkedCAToken)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// If authorityId is configured make sure it matches the one in the token
|
|
||||||
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) {
|
|
||||||
return errors.New("error initializing linkedca: token authority and configured authority do not match")
|
|
||||||
}
|
|
||||||
client.Run()
|
|
||||||
a.adminDB = client
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -559,6 +587,21 @@ func (a *Authority) GetAdminDatabase() admin.DB {
|
||||||
return a.adminDB
|
return a.adminDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Authority) GetInfo() Info {
|
||||||
|
ai := Info{
|
||||||
|
StartTime: a.startTime,
|
||||||
|
RootX509Certs: a.rootX509Certs,
|
||||||
|
DNSNames: a.config.DNSNames,
|
||||||
|
}
|
||||||
|
if a.sshCAUserCertSignKey != nil {
|
||||||
|
ai.SSHCAUserPublicKey = ssh.MarshalAuthorizedKey(a.sshCAUserCertSignKey.PublicKey())
|
||||||
|
}
|
||||||
|
if a.sshCAHostCertSignKey != nil {
|
||||||
|
ai.SSHCAHostPublicKey = ssh.MarshalAuthorizedKey(a.sshCAHostCertSignKey.PublicKey())
|
||||||
|
}
|
||||||
|
return ai
|
||||||
|
}
|
||||||
|
|
||||||
// IsAdminAPIEnabled returns a boolean indicating whether the Admin API has
|
// IsAdminAPIEnabled returns a boolean indicating whether the Admin API has
|
||||||
// been enabled.
|
// been enabled.
|
||||||
func (a *Authority) IsAdminAPIEnabled() bool {
|
func (a *Authority) IsAdminAPIEnabled() bool {
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -276,6 +277,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
||||||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||||
serial := cert.SerialNumber.String()
|
serial := cert.SerialNumber.String()
|
||||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
|
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
|
||||||
|
|
||||||
isRevoked, err := a.IsRevoked(serial)
|
isRevoked, err := a.IsRevoked(serial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||||
|
@ -283,7 +285,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||||
if isRevoked {
|
if isRevoked {
|
||||||
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
|
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, ok := a.provisioners.LoadByCertificate(cert)
|
p, ok := a.provisioners.LoadByCertificate(cert)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
||||||
|
@ -371,3 +372,80 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthorizeRenewToken validates the renew token and returns the leaf
|
||||||
|
// certificate in the x5cInsecure header.
|
||||||
|
func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||||
|
var claims jose.Claims
|
||||||
|
jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
|
||||||
|
}
|
||||||
|
leaf := chain[0][0]
|
||||||
|
if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
|
||||||
|
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
|
||||||
|
}
|
||||||
|
|
||||||
|
p, ok := a.provisioners.LoadByCertificate(leaf)
|
||||||
|
if !ok {
|
||||||
|
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
|
||||||
|
}
|
||||||
|
if err := a.UseToken(ott, p); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||||
|
Issuer: p.GetName(),
|
||||||
|
Subject: leaf.Subject.CommonName,
|
||||||
|
Time: time.Now().UTC(),
|
||||||
|
}, time.Minute); err != nil {
|
||||||
|
switch err {
|
||||||
|
case jose.ErrInvalidIssuer:
|
||||||
|
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)"))
|
||||||
|
case jose.ErrInvalidSubject:
|
||||||
|
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)"))
|
||||||
|
case jose.ErrNotValidYet:
|
||||||
|
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)"))
|
||||||
|
case jose.ErrExpired:
|
||||||
|
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)"))
|
||||||
|
case jose.ErrIssuedInTheFuture:
|
||||||
|
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)"))
|
||||||
|
default:
|
||||||
|
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
audiences := a.config.GetAudiences().Renew
|
||||||
|
if !matchesAudience(claims.Audience, audiences) {
|
||||||
|
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return leaf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesAudience returns true if A and B share at least one element.
|
||||||
|
func matchesAudience(as, bs []string) bool {
|
||||||
|
if len(bs) == 0 || len(as) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range bs {
|
||||||
|
for _, a := range as {
|
||||||
|
if b == a || stripPort(a) == stripPort(b) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripPort attempts to strip the port from the given url. If parsing the url
|
||||||
|
// produces errors it will just return the passed argument.
|
||||||
|
func stripPort(rawurl string) string {
|
||||||
|
u, err := url.Parse(rawurl)
|
||||||
|
if err != nil {
|
||||||
|
return rawurl
|
||||||
|
}
|
||||||
|
u.Host = u.Hostname()
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
|
@ -3,24 +3,32 @@ package authority
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"golang.org/x/crypto/ssh"
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
"go.step.sm/crypto/randutil"
|
"go.step.sm/crypto/randutil"
|
||||||
"golang.org/x/crypto/ssh"
|
"go.step.sm/crypto/x509util"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/certificates/errs"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testAudiences = provisioner.Audiences{
|
var testAudiences = provisioner.Audiences{
|
||||||
|
@ -305,8 +313,8 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
||||||
p, err := tc.auth.authorizeToken(context.Background(), tc.token)
|
p, err := tc.auth.authorizeToken(context.Background(), tc.token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -391,8 +399,8 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
|
||||||
|
|
||||||
if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil {
|
if err := tc.auth.authorizeRevoke(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -476,14 +484,14 @@ func TestAuthority_authorizeSign(t *testing.T) {
|
||||||
got, err := tc.auth.authorizeSign(context.Background(), tc.token)
|
got, err := tc.auth.authorizeSign(context.Background(), tc.token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if assert.Nil(t, tc.err) {
|
if assert.Nil(t, tc.err) {
|
||||||
assert.Len(t, 7, got)
|
assert.Len(t, 8, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -735,8 +743,8 @@ func TestAuthority_Authorize(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
|
||||||
|
@ -753,6 +761,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
||||||
|
|
||||||
func TestAuthority_authorizeRenew(t *testing.T) {
|
func TestAuthority_authorizeRenew(t *testing.T) {
|
||||||
fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt")
|
fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt")
|
||||||
|
fooCrt.NotAfter = time.Now().Add(time.Hour)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt")
|
renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt")
|
||||||
|
@ -822,7 +831,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
|
||||||
return &authorizeTest{
|
return &authorizeTest{
|
||||||
auth: a,
|
auth: a,
|
||||||
cert: renewDisabledCrt,
|
cert: renewDisabledCrt,
|
||||||
err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"),
|
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -847,7 +856,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
|
||||||
err := tc.auth.authorizeRenew(tc.cert)
|
err := tc.auth.authorizeRenew(tc.cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
@ -909,6 +918,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner.
|
||||||
}
|
}
|
||||||
|
|
||||||
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
|
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
|
||||||
|
now := time.Now()
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -917,6 +927,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
if cert.ValidAfter == 0 {
|
||||||
|
cert.ValidAfter = uint64(now.Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidBefore == 0 {
|
||||||
|
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
|
||||||
|
}
|
||||||
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -988,8 +1004,8 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
|
||||||
got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token)
|
got, err := tc.auth.authorizeSSHSign(context.Background(), tc.token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -1003,6 +1019,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthority_authorizeSSHRenew(t *testing.T) {
|
func TestAuthority_authorizeSSHRenew(t *testing.T) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
sshpop := func(a *Authority) (*ssh.Certificate, string) {
|
||||||
|
p, ok := a.provisioners.Load("sshpop/sshpop")
|
||||||
|
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
|
||||||
|
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
signer, ok := key.(crypto.Signer)
|
||||||
|
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
|
||||||
|
sshSigner, err := ssh.NewSignerFromSigner(signer)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert))
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return cert, token
|
||||||
|
}
|
||||||
|
|
||||||
a := testAuthority(t)
|
a := testAuthority(t)
|
||||||
|
|
||||||
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||||
|
@ -1012,8 +1045,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
|
||||||
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
|
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
now := time.Now().UTC()
|
|
||||||
|
|
||||||
validIssuer := "step-cli"
|
validIssuer := "step-cli"
|
||||||
|
|
||||||
type authorizeTest struct {
|
type authorizeTest struct {
|
||||||
|
@ -1050,27 +1081,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
|
||||||
|
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
|
||||||
|
return errs.Forbidden("forbidden")
|
||||||
|
}))
|
||||||
|
_, token := sshpop(aa)
|
||||||
|
return &authorizeTest{
|
||||||
|
auth: aa,
|
||||||
|
token: token,
|
||||||
|
err: errors.New("authority.authorizeSSHRenew: forbidden"),
|
||||||
|
code: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
},
|
||||||
"ok": func(t *testing.T) *authorizeTest {
|
"ok": func(t *testing.T) *authorizeTest {
|
||||||
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
|
cert, token := sshpop(a)
|
||||||
assert.FatalError(t, err)
|
|
||||||
signer, ok := key.(crypto.Signer)
|
|
||||||
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
|
|
||||||
sshSigner, err := ssh.NewSignerFromSigner(signer)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
p, ok := a.provisioners.Load("sshpop/sshpop")
|
|
||||||
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
|
|
||||||
|
|
||||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop",
|
|
||||||
[]string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert))
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
return &authorizeTest{
|
return &authorizeTest{
|
||||||
auth: a,
|
auth: a,
|
||||||
token: tok,
|
token: token,
|
||||||
|
cert: cert,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
|
||||||
|
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
cert, token := sshpop(aa)
|
||||||
|
return &authorizeTest{
|
||||||
|
auth: aa,
|
||||||
|
token: token,
|
||||||
cert: cert,
|
cert: cert,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1083,8 +1121,8 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
|
||||||
got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token)
|
got, err := tc.auth.authorizeSSHRenew(context.Background(), tc.token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -1183,8 +1221,8 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) {
|
||||||
|
|
||||||
if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil {
|
if err := tc.auth.authorizeSSHRevoke(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -1276,8 +1314,8 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
|
||||||
cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token)
|
cert, signOpts, err := tc.auth.authorizeSSHRekey(context.Background(), tc.token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -1290,3 +1328,283 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
type stepProvisionerASN1 struct {
|
||||||
|
Type int
|
||||||
|
Name []byte
|
||||||
|
CredentialID []byte
|
||||||
|
KeyValuePairs []string `asn1:"optional,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
_, signer, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, otherSigner, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
|
||||||
|
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var x5c []string
|
||||||
|
for _, c := range chain {
|
||||||
|
x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
so := new(jose.SignerOptions)
|
||||||
|
so.WithType("JWT")
|
||||||
|
so.WithHeader("x5cInsecure", x5c)
|
||||||
|
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return s, chain[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
a1 := testAuthority(t)
|
||||||
|
t1, c1 := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/renew"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "step-cli",
|
||||||
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
t2, c2 := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/renew"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "step-cli",
|
||||||
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
IssuedAt: jose.NewNumericDate(now),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now.Add(-time.Hour)
|
||||||
|
cert.NotAfter = now.Add(-time.Minute)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/renew"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "step-cli",
|
||||||
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/renew"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "step-cli",
|
||||||
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/renew"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "bad-issuer",
|
||||||
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/renew"},
|
||||||
|
Subject: "bad-subject",
|
||||||
|
Issuer: "step-cli",
|
||||||
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/sign"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "step-cli",
|
||||||
|
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/sign"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "step-cli",
|
||||||
|
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/sign"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "step-cli",
|
||||||
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||||
|
Audience: []string{"https://example.com/1.0/sign"},
|
||||||
|
Subject: "test.example.com",
|
||||||
|
Issuer: "step-cli",
|
||||||
|
NotBefore: jose.NewNumericDate(now),
|
||||||
|
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
cert.NotBefore = now
|
||||||
|
cert.NotAfter = now.Add(time.Hour)
|
||||||
|
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||||
|
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||||
|
Value: b,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
ott string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
authority *Authority
|
||||||
|
args args
|
||||||
|
want *x509.Certificate
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", a1, args{ctx, t1}, c1, false},
|
||||||
|
{"ok expired cert", a1, args{ctx, t2}, c2, false},
|
||||||
|
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
|
||||||
|
{"fail token reuse", a1, args{ctx, t1}, nil, true},
|
||||||
|
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
|
||||||
|
{"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true},
|
||||||
|
{"fail token iss", a1, args{ctx, badIssuer}, nil, true},
|
||||||
|
{"fail token sub", a1, args{ctx, badSubject}, nil, true},
|
||||||
|
{"fail token iat", a1, args{ctx, badNotBefore}, nil, true},
|
||||||
|
{"fail token iat", a1, args{ctx, badExpiry}, nil, true},
|
||||||
|
{"fail token iat", a1, args{ctx, badIssuedAt}, nil, true},
|
||||||
|
{"fail token aud", a1, args{ctx, badAudience}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -26,23 +26,27 @@ var (
|
||||||
DefaultBackdate = time.Minute
|
DefaultBackdate = time.Minute
|
||||||
// DefaultDisableRenewal disables renewals per provisioner.
|
// DefaultDisableRenewal disables renewals per provisioner.
|
||||||
DefaultDisableRenewal = false
|
DefaultDisableRenewal = false
|
||||||
|
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is
|
||||||
|
// expired.
|
||||||
|
DefaultAllowRenewAfterExpiry = false
|
||||||
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally
|
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally
|
||||||
// for all provisioners.
|
// for all provisioners.
|
||||||
DefaultEnableSSHCA = false
|
DefaultEnableSSHCA = false
|
||||||
// GlobalProvisionerClaims default claims for the Authority. Can be overridden
|
// GlobalProvisionerClaims default claims for the Authority. Can be overridden
|
||||||
// by provisioner specific claims.
|
// by provisioner specific claims.
|
||||||
GlobalProvisionerClaims = provisioner.Claims{
|
GlobalProvisionerClaims = provisioner.Claims{
|
||||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
DisableRenewal: &DefaultDisableRenewal,
|
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||||
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||||
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
|
||||||
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
|
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||||
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||||
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||||
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
EnableSSHCA: &DefaultEnableSSHCA,
|
||||||
EnableSSHCA: &DefaultEnableSSHCA,
|
DisableRenewal: &DefaultDisableRenewal,
|
||||||
|
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -273,28 +277,32 @@ func (c *Config) GetAudiences() provisioner.Audiences {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, name := range c.DNSNames {
|
for _, name := range c.DNSNames {
|
||||||
|
hostname := toHostname(name)
|
||||||
audiences.Sign = append(audiences.Sign,
|
audiences.Sign = append(audiences.Sign,
|
||||||
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)),
|
fmt.Sprintf("https://%s/1.0/sign", hostname),
|
||||||
fmt.Sprintf("https://%s/sign", toHostname(name)),
|
fmt.Sprintf("https://%s/sign", hostname),
|
||||||
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)),
|
fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
|
||||||
fmt.Sprintf("https://%s/ssh/sign", toHostname(name)))
|
fmt.Sprintf("https://%s/ssh/sign", hostname))
|
||||||
|
audiences.Renew = append(audiences.Renew,
|
||||||
|
fmt.Sprintf("https://%s/1.0/renew", hostname),
|
||||||
|
fmt.Sprintf("https://%s/renew", hostname))
|
||||||
audiences.Revoke = append(audiences.Revoke,
|
audiences.Revoke = append(audiences.Revoke,
|
||||||
fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)),
|
fmt.Sprintf("https://%s/1.0/revoke", hostname),
|
||||||
fmt.Sprintf("https://%s/revoke", toHostname(name)))
|
fmt.Sprintf("https://%s/revoke", hostname))
|
||||||
audiences.SSHSign = append(audiences.SSHSign,
|
audiences.SSHSign = append(audiences.SSHSign,
|
||||||
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)),
|
fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
|
||||||
fmt.Sprintf("https://%s/ssh/sign", toHostname(name)),
|
fmt.Sprintf("https://%s/ssh/sign", hostname),
|
||||||
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)),
|
fmt.Sprintf("https://%s/1.0/sign", hostname),
|
||||||
fmt.Sprintf("https://%s/sign", toHostname(name)))
|
fmt.Sprintf("https://%s/sign", hostname))
|
||||||
audiences.SSHRevoke = append(audiences.SSHRevoke,
|
audiences.SSHRevoke = append(audiences.SSHRevoke,
|
||||||
fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)),
|
fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname),
|
||||||
fmt.Sprintf("https://%s/ssh/revoke", toHostname(name)))
|
fmt.Sprintf("https://%s/ssh/revoke", hostname))
|
||||||
audiences.SSHRenew = append(audiences.SSHRenew,
|
audiences.SSHRenew = append(audiences.SSHRenew,
|
||||||
fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)),
|
fmt.Sprintf("https://%s/1.0/ssh/renew", hostname),
|
||||||
fmt.Sprintf("https://%s/ssh/renew", toHostname(name)))
|
fmt.Sprintf("https://%s/ssh/renew", hostname))
|
||||||
audiences.SSHRekey = append(audiences.SSHRekey,
|
audiences.SSHRekey = append(audiences.SSHRekey,
|
||||||
fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)),
|
fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname),
|
||||||
fmt.Sprintf("https://%s/ssh/rekey", toHostname(name)))
|
fmt.Sprintf("https://%s/ssh/rekey", hostname))
|
||||||
}
|
}
|
||||||
|
|
||||||
return audiences
|
return audiences
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/keyutil"
|
"go.step.sm/crypto/keyutil"
|
||||||
|
@ -151,13 +152,21 @@ func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linked
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
||||||
|
resp, err := c.GetConfiguration(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp.Provisioners, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) GetConfiguration(ctx context.Context) (*linkedca.ConfigurationResponse, error) {
|
||||||
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
||||||
AuthorityId: c.authorityID,
|
AuthorityId: c.authorityID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error getting provisioners")
|
return nil, errors.Wrap(err, "error getting configuration")
|
||||||
}
|
}
|
||||||
return resp.Provisioners, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||||
|
@ -204,11 +213,9 @@ func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Adm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||||
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
resp, err := c.GetConfiguration(ctx)
|
||||||
AuthorityId: c.authorityID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error getting admins")
|
return nil, err
|
||||||
}
|
}
|
||||||
return resp.Admins, nil
|
return resp.Admins, nil
|
||||||
}
|
}
|
||||||
|
@ -228,12 +235,13 @@ func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
|
||||||
return errors.Wrap(err, "error deleting admin")
|
return errors.Wrap(err, "error deleting admin")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error {
|
func (c *linkedCaClient) StoreCertificateChain(prov provisioner.Interface, fullchain ...*x509.Certificate) error {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
|
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
|
||||||
PemCertificate: serializeCertificateChain(fullchain[0]),
|
PemCertificate: serializeCertificateChain(fullchain[0]),
|
||||||
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
||||||
|
Provisioner: createProvisionerIdentity(prov),
|
||||||
})
|
})
|
||||||
return errors.Wrap(err, "error posting certificate")
|
return errors.Wrap(err, "error posting certificate")
|
||||||
}
|
}
|
||||||
|
@ -310,6 +318,17 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
|
||||||
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createProvisionerIdentity(prov provisioner.Interface) *linkedca.ProvisionerIdentity {
|
||||||
|
if prov == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &linkedca.ProvisionerIdentity{
|
||||||
|
Id: prov.GetID(),
|
||||||
|
Type: linkedca.Provisioner_Type(prov.GetType()),
|
||||||
|
Name: prov.GetName(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func serializeCertificate(crt *x509.Certificate) string {
|
func serializeCertificate(crt *x509.Certificate) string {
|
||||||
if crt == nil {
|
if crt == nil {
|
||||||
return ""
|
return ""
|
||||||
|
|
|
@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of
|
||||||
|
// an X.509 certificate.
|
||||||
|
func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option {
|
||||||
|
return func(a *Authority) error {
|
||||||
|
a.authorizeRenewFunc = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal
|
||||||
|
// of a SSH certificate.
|
||||||
|
func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option {
|
||||||
|
return func(a *Authority) error {
|
||||||
|
a.authorizeSSHRenewFunc = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithSSHBastionFunc sets a custom function to get the bastion for a
|
// WithSSHBastionFunc sets a custom function to get the bastion for a
|
||||||
// given user-host pair.
|
// given user-host pair.
|
||||||
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option {
|
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option {
|
||||||
|
@ -145,6 +163,22 @@ func WithX509Signer(crt *x509.Certificate, s crypto.Signer) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithX509SignerFunc defines the function used to get the chain of certificates
|
||||||
|
// and signer used when we sign X.509 certificates.
|
||||||
|
func WithX509SignerFunc(fn func() ([]*x509.Certificate, crypto.Signer, error)) Option {
|
||||||
|
return func(a *Authority) error {
|
||||||
|
srv, err := cas.New(context.Background(), casapi.Options{
|
||||||
|
Type: casapi.SoftCAS,
|
||||||
|
CertificateSigner: fn,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
a.x509CAService = srv
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithSSHUserSigner defines the signer used to sign SSH user certificates.
|
// WithSSHUserSigner defines the signer used to sign SSH user certificates.
|
||||||
func WithSSHUserSigner(s crypto.Signer) Option {
|
func WithSSHUserSigner(s crypto.Signer) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ACME is the acme provisioner type, an entity that can authorize the ACME
|
// ACME is the acme provisioner type, an entity that can authorize the ACME
|
||||||
|
@ -24,7 +23,7 @@ type ACME struct {
|
||||||
RequireEAB bool `json:"requireEAB,omitempty"`
|
RequireEAB bool `json:"requireEAB,omitempty"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
Options *Options `json:"options,omitempty"`
|
Options *Options `json:"options,omitempty"`
|
||||||
claimer *Claimer
|
ctl *Controller
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier.
|
// GetID returns the provisioner unique identifier.
|
||||||
|
@ -69,7 +68,7 @@ func (p *ACME) GetOptions() *Options {
|
||||||
// DefaultTLSCertDuration returns the default TLS cert duration enforced by
|
// DefaultTLSCertDuration returns the default TLS cert duration enforced by
|
||||||
// the provisioner.
|
// the provisioner.
|
||||||
func (p *ACME) DefaultTLSCertDuration() time.Duration {
|
func (p *ACME) DefaultTLSCertDuration() time.Duration {
|
||||||
return p.claimer.DefaultTLSCertDuration()
|
return p.ctl.Claimer.DefaultTLSCertDuration()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes and validates the fields of a JWK type.
|
// Init initializes and validates the fields of a JWK type.
|
||||||
|
@ -81,12 +80,8 @@ func (p *ACME) Init(config Config) (err error) {
|
||||||
return errors.New("provisioner name cannot be empty")
|
return errors.New("provisioner name cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update claims with global ones
|
p.ctl, err = NewController(p, p.Claims, config)
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
return
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign does not do any validation, because all validation is handled
|
// AuthorizeSign does not do any validation, because all validation is handled
|
||||||
|
@ -94,13 +89,14 @@ func (p *ACME) Init(config Config) (err error) {
|
||||||
// on the resulting certificate.
|
// on the resulting certificate.
|
||||||
func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
return []SignOption{
|
return []SignOption{
|
||||||
|
p,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeACME, p.Name, ""),
|
newProvisionerExtensionOption(TypeACME, p.Name, ""),
|
||||||
newForceCNOption(p.ForceCN),
|
newForceCNOption(p.ForceCN),
|
||||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||||
// validators
|
// validators
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,8 +114,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||||
// revocation status. Just confirms that the provisioner that created the
|
// revocation status. Just confirms that the provisioner that created the
|
||||||
// certificate was configured to allow renewals.
|
// certificate was configured to allow renewals.
|
||||||
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
if p.claimer.IsDisableRenewal() {
|
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||||
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,13 +3,14 @@ package provisioner
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestACME_Getters(t *testing.T) {
|
func TestACME_Getters(t *testing.T) {
|
||||||
|
@ -91,6 +92,7 @@ func TestACME_Init(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestACME_AuthorizeRenew(t *testing.T) {
|
func TestACME_AuthorizeRenew(t *testing.T) {
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
type test struct {
|
type test struct {
|
||||||
p *ACME
|
p *ACME
|
||||||
cert *x509.Certificate
|
cert *x509.Certificate
|
||||||
|
@ -104,21 +106,27 @@ func TestACME_AuthorizeRenew(t *testing.T) {
|
||||||
// disable renewal
|
// disable renewal
|
||||||
disable := true
|
disable := true
|
||||||
p.Claims = &Claims{DisableRenewal: &disable}
|
p.Claims = &Claims{DisableRenewal: &disable}
|
||||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
p: p,
|
p: p,
|
||||||
cert: &x509.Certificate{},
|
cert: &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
},
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()),
|
err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
p, err := generateACME()
|
p, err := generateACME()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
p: p,
|
p: p,
|
||||||
cert: &x509.Certificate{},
|
cert: &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -126,8 +134,8 @@ func TestACME_AuthorizeRenew(t *testing.T) {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
|
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
@ -161,31 +169,32 @@ func TestACME_AuthorizeSign(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
|
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
|
||||||
assert.Len(t, 5, opts)
|
assert.Len(t, 6, opts)
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
case *ACME:
|
||||||
case *provisionerExtensionOption:
|
case *provisionerExtensionOption:
|
||||||
assert.Equals(t, v.Type, int(TypeACME))
|
assert.Equals(t, v.Type, TypeACME)
|
||||||
assert.Equals(t, v.Name, tc.p.GetName())
|
assert.Equals(t, v.Name, tc.p.GetName())
|
||||||
assert.Equals(t, v.CredentialID, "")
|
assert.Equals(t, v.CredentialID, "")
|
||||||
assert.Len(t, 0, v.KeyValuePairs)
|
assert.Len(t, 0, v.KeyValuePairs)
|
||||||
case *forceCNOption:
|
case *forceCNOption:
|
||||||
assert.Equals(t, v.ForceCN, tc.p.ForceCN)
|
assert.Equals(t, v.ForceCN, tc.p.ForceCN)
|
||||||
case profileDefaultDuration:
|
case profileDefaultDuration:
|
||||||
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
|
assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
|
||||||
case defaultPublicKeyValidator:
|
case defaultPublicKeyValidator:
|
||||||
case *validityValidator:
|
case *validityValidator:
|
||||||
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
|
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
||||||
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
|
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -264,9 +264,8 @@ type AWS struct {
|
||||||
IIDRoots string `json:"iidRoots,omitempty"`
|
IIDRoots string `json:"iidRoots,omitempty"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
Options *Options `json:"options,omitempty"`
|
Options *Options `json:"options,omitempty"`
|
||||||
claimer *Claimer
|
|
||||||
config *awsConfig
|
config *awsConfig
|
||||||
audiences Audiences
|
ctl *Controller
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier.
|
// GetID returns the provisioner unique identifier.
|
||||||
|
@ -400,15 +399,11 @@ func (p *AWS) Init(config Config) (err error) {
|
||||||
case p.InstanceAge.Value() < 0:
|
case p.InstanceAge.Value() < 0:
|
||||||
return errors.New("provisioner instanceAge cannot be negative")
|
return errors.New("provisioner instanceAge cannot be negative")
|
||||||
}
|
}
|
||||||
// Update claims with global ones
|
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Add default config
|
// Add default config
|
||||||
if p.config, err = newAWSConfig(p.IIDRoots); err != nil {
|
if p.config, err = newAWSConfig(p.IIDRoots); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
|
||||||
|
|
||||||
// validate IMDS versions
|
// validate IMDS versions
|
||||||
if len(p.IMDSVersions) == 0 {
|
if len(p.IMDSVersions) == 0 {
|
||||||
|
@ -425,7 +420,9 @@ func (p *AWS) Init(config Config) (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||||
|
p.ctl, err = NewController(p, p.Claims, config)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign validates the given token and returns the sign options that
|
// AuthorizeSign validates the given token and returns the sign options that
|
||||||
|
@ -470,14 +467,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(so,
|
return append(so,
|
||||||
|
p,
|
||||||
templateOptions,
|
templateOptions,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
|
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
|
||||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||||
// validators
|
// validators
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
commonNameValidator(payload.Claims.Subject),
|
commonNameValidator(payload.Claims.Subject),
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -486,10 +484,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
// revocation status. Just confirms that the provisioner that created the
|
// revocation status. Just confirms that the provisioner that created the
|
||||||
// certificate was configured to allow renewals.
|
// certificate was configured to allow renewals.
|
||||||
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
if p.claimer.IsDisableRenewal() {
|
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||||
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// assertConfig initializes the config if it has not been initialized
|
// assertConfig initializes the config if it has not been initialized
|
||||||
|
@ -664,7 +659,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate audiences with the defaults
|
// validate audiences with the defaults
|
||||||
if !matchesAudience(payload.Audience, p.audiences.Sign) {
|
if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) {
|
||||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
|
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -704,7 +699,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
||||||
|
|
||||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
if !p.claimer.IsSSHCAEnabled() {
|
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||||
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
|
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
|
||||||
}
|
}
|
||||||
claims, err := p.authorizeToken(token)
|
claims, err := p.authorizeToken(token)
|
||||||
|
@ -752,11 +747,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
||||||
// Validate user SignSSHOptions.
|
// Validate user SignSSHOptions.
|
||||||
sshCertOptionsValidator(defaults),
|
sshCertOptionsValidator(defaults),
|
||||||
// Set the validity bounds if not set.
|
// Set the validity bounds if not set.
|
||||||
&sshDefaultDuration{p.claimer},
|
&sshDefaultDuration{p.ctl.Claimer},
|
||||||
// Validate public key
|
// Validate public key
|
||||||
&sshDefaultPublicKeyValidator{},
|
&sshDefaultPublicKeyValidator{},
|
||||||
// Validate the validity period.
|
// Validate the validity period.
|
||||||
&sshCertValidityValidator{p.claimer},
|
&sshCertValidityValidator{p.ctl.Claimer},
|
||||||
// Require all the fields in the SSH certificate
|
// Require all the fields in the SSH certificate
|
||||||
&sshCertDefaultValidator{},
|
&sshCertDefaultValidator{},
|
||||||
), nil
|
), nil
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -17,10 +18,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAWS_Getters(t *testing.T) {
|
func TestAWS_Getters(t *testing.T) {
|
||||||
|
@ -521,8 +522,8 @@ func TestAWS_authorizeToken(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
|
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -641,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
code int
|
code int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false},
|
{"ok", p1, args{t1, "foo.local"}, 7, http.StatusOK, false},
|
||||||
{"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false},
|
{"ok", p2, args{t2, "instance-id"}, 11, http.StatusOK, false},
|
||||||
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false},
|
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 11, http.StatusOK, false},
|
||||||
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false},
|
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 11, http.StatusOK, false},
|
||||||
{"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false},
|
{"ok", p1, args{t4, "instance-id"}, 7, http.StatusOK, false},
|
||||||
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
|
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
|
||||||
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
|
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
|
||||||
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
|
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
|
||||||
|
@ -668,27 +669,28 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
case err != nil:
|
case err != nil:
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
default:
|
default:
|
||||||
assert.Len(t, tt.wantLen, got)
|
assert.Len(t, tt.wantLen, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
case *AWS:
|
||||||
case certificateOptionsFunc:
|
case certificateOptionsFunc:
|
||||||
case *provisionerExtensionOption:
|
case *provisionerExtensionOption:
|
||||||
assert.Equals(t, v.Type, int(TypeAWS))
|
assert.Equals(t, v.Type, TypeAWS)
|
||||||
assert.Equals(t, v.Name, tt.aws.GetName())
|
assert.Equals(t, v.Name, tt.aws.GetName())
|
||||||
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
|
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
|
||||||
assert.Len(t, 2, v.KeyValuePairs)
|
assert.Len(t, 2, v.KeyValuePairs)
|
||||||
case profileDefaultDuration:
|
case profileDefaultDuration:
|
||||||
assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration())
|
assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration())
|
||||||
case commonNameValidator:
|
case commonNameValidator:
|
||||||
assert.Equals(t, string(v), tt.args.cn)
|
assert.Equals(t, string(v), tt.args.cn)
|
||||||
case defaultPublicKeyValidator:
|
case defaultPublicKeyValidator:
|
||||||
case *validityValidator:
|
case *validityValidator:
|
||||||
assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration())
|
assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration())
|
||||||
assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration())
|
assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration())
|
||||||
case ipAddressesValidator:
|
case ipAddressesValidator:
|
||||||
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
|
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
|
||||||
case emailAddressesValidator:
|
case emailAddressesValidator:
|
||||||
|
@ -698,7 +700,7 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
case dnsNamesValidator:
|
case dnsNamesValidator:
|
||||||
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
|
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -726,7 +728,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
||||||
// disable sshCA
|
// disable sshCA
|
||||||
disable := false
|
disable := false
|
||||||
p3.Claims = &Claims{EnableSSHCA: &disable}
|
p3.Claims = &Claims{EnableSSHCA: &disable}
|
||||||
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
|
t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
|
||||||
|
@ -747,7 +749,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
||||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||||
expectedHostOptions := &SignSSHOptions{
|
expectedHostOptions := &SignSSHOptions{
|
||||||
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
|
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
|
||||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
@ -802,8 +804,8 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
} else if assert.NotNil(t, got) {
|
} else if assert.NotNil(t, got) {
|
||||||
|
@ -824,6 +826,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAWS_AuthorizeRenew(t *testing.T) {
|
func TestAWS_AuthorizeRenew(t *testing.T) {
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
p1, err := generateAWS()
|
p1, err := generateAWS()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p2, err := generateAWS()
|
p2, err := generateAWS()
|
||||||
|
@ -832,7 +835,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
|
||||||
// disable renewal
|
// disable renewal
|
||||||
disable := true
|
disable := true
|
||||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
|
@ -845,16 +848,22 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
|
||||||
code int
|
code int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
{"ok", p1, args{&x509.Certificate{
|
||||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusOK, false},
|
||||||
|
{"fail/renew-disabled", p2, args{&x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusUnauthorized, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
if err := tt.aws.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AWS.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -30,7 +30,7 @@ const azureDefaultAudience = "https://management.azure.com/"
|
||||||
|
|
||||||
// azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim.
|
// azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim.
|
||||||
// Using case insensitive as resourceGroups appears as resourcegroups.
|
// Using case insensitive as resourceGroups appears as resourcegroups.
|
||||||
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`)
|
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
|
||||||
|
|
||||||
type azureConfig struct {
|
type azureConfig struct {
|
||||||
oidcDiscoveryURL string
|
oidcDiscoveryURL string
|
||||||
|
@ -89,15 +89,17 @@ type Azure struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
TenantID string `json:"tenantID"`
|
TenantID string `json:"tenantID"`
|
||||||
ResourceGroups []string `json:"resourceGroups"`
|
ResourceGroups []string `json:"resourceGroups"`
|
||||||
|
SubscriptionIDs []string `json:"subscriptionIDs"`
|
||||||
|
ObjectIDs []string `json:"objectIDs"`
|
||||||
Audience string `json:"audience,omitempty"`
|
Audience string `json:"audience,omitempty"`
|
||||||
DisableCustomSANs bool `json:"disableCustomSANs"`
|
DisableCustomSANs bool `json:"disableCustomSANs"`
|
||||||
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
|
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
Options *Options `json:"options,omitempty"`
|
Options *Options `json:"options,omitempty"`
|
||||||
claimer *Claimer
|
|
||||||
config *azureConfig
|
config *azureConfig
|
||||||
oidcConfig openIDConfiguration
|
oidcConfig openIDConfiguration
|
||||||
keyStore *keyStore
|
keyStore *keyStore
|
||||||
|
ctl *Controller
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier.
|
// GetID returns the provisioner unique identifier.
|
||||||
|
@ -201,37 +203,34 @@ func (p *Azure) Init(config Config) (err error) {
|
||||||
case p.Audience == "": // use default audience
|
case p.Audience == "": // use default audience
|
||||||
p.Audience = azureDefaultAudience
|
p.Audience = azureDefaultAudience
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize config
|
// Initialize config
|
||||||
p.assertConfig()
|
p.assertConfig()
|
||||||
|
|
||||||
// Update claims with global ones
|
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode and validate openid-configuration endpoint
|
// Decode and validate openid-configuration endpoint
|
||||||
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
|
if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if err := p.oidcConfig.Validate(); err != nil {
|
if err := p.oidcConfig.Validate(); err != nil {
|
||||||
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
|
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
|
||||||
}
|
}
|
||||||
// Get JWK key set
|
// Get JWK key set
|
||||||
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
|
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
p.ctl, err = NewController(p, p.Claims, config)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// authorizeToken returns the claims, name, group, error.
|
// authorizeToken returns the claims, name, group, subscription, identityObjectID, error.
|
||||||
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, error) {
|
func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, string, string, error) {
|
||||||
jwt, err := jose.ParseSigned(token)
|
jwt, err := jose.ParseSigned(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
|
return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; error parsing azure token")
|
||||||
}
|
}
|
||||||
if len(jwt.Headers) == 0 {
|
if len(jwt.Headers) == 0 {
|
||||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
|
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token missing header")
|
||||||
}
|
}
|
||||||
|
|
||||||
var found bool
|
var found bool
|
||||||
|
@ -244,7 +243,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
|
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; cannot validate azure token")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||||
|
@ -252,26 +251,30 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, err
|
||||||
Issuer: p.oidcConfig.Issuer,
|
Issuer: p.oidcConfig.Issuer,
|
||||||
Time: time.Now(),
|
Time: time.Now(),
|
||||||
}, 1*time.Minute); err != nil {
|
}, 1*time.Minute); err != nil {
|
||||||
return nil, "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload")
|
return nil, "", "", "", "", errs.Wrap(http.StatusUnauthorized, err, "azure.authorizeToken; failed to validate azure token payload")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate TenantID
|
// Validate TenantID
|
||||||
if claims.TenantID != p.TenantID {
|
if claims.TenantID != p.TenantID {
|
||||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
|
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; azure token validation failed - invalid tenant id claim (tid)")
|
||||||
}
|
}
|
||||||
|
|
||||||
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
|
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
|
||||||
if len(re) != 4 {
|
if len(re) != 5 {
|
||||||
return nil, "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
|
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
|
||||||
}
|
}
|
||||||
group, name := re[2], re[3]
|
|
||||||
return &claims, name, group, nil
|
var subscription, group, name string
|
||||||
|
identityObjectID := claims.ObjectID
|
||||||
|
subscription, group, name = re[1], re[2], re[4]
|
||||||
|
|
||||||
|
return &claims, name, group, subscription, identityObjectID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign validates the given token and returns the sign options that
|
// AuthorizeSign validates the given token and returns the sign options that
|
||||||
// will be used on certificate creation.
|
// will be used on certificate creation.
|
||||||
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
_, name, group, err := p.authorizeToken(token)
|
_, name, group, subscription, identityObjectID, err := p.authorizeToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign")
|
||||||
}
|
}
|
||||||
|
@ -290,6 +293,34 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Filter by subscription id
|
||||||
|
if len(p.SubscriptionIDs) > 0 {
|
||||||
|
var found bool
|
||||||
|
for _, s := range p.SubscriptionIDs {
|
||||||
|
if s == subscription {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid subscription id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter by Azure AD identity object id
|
||||||
|
if len(p.ObjectIDs) > 0 {
|
||||||
|
var found bool
|
||||||
|
for _, i := range p.ObjectIDs {
|
||||||
|
if i == identityObjectID {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, errs.Unauthorized("azure.AuthorizeSign; azure token validation failed - invalid identity object id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Template options
|
// Template options
|
||||||
data := x509util.NewTemplateData()
|
data := x509util.NewTemplateData()
|
||||||
data.SetCommonName(name)
|
data.SetCommonName(name)
|
||||||
|
@ -321,13 +352,14 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(so,
|
return append(so,
|
||||||
|
p,
|
||||||
templateOptions,
|
templateOptions,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
|
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
|
||||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||||
// validators
|
// validators
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -336,19 +368,16 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
||||||
// revocation status. Just confirms that the provisioner that created the
|
// revocation status. Just confirms that the provisioner that created the
|
||||||
// certificate was configured to allow renewals.
|
// certificate was configured to allow renewals.
|
||||||
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
if p.claimer.IsDisableRenewal() {
|
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||||
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
if !p.claimer.IsSSHCAEnabled() {
|
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||||
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName())
|
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName())
|
||||||
}
|
}
|
||||||
|
|
||||||
_, name, _, err := p.authorizeToken(token)
|
_, name, _, _, _, err := p.authorizeToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign")
|
||||||
}
|
}
|
||||||
|
@ -389,11 +418,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
||||||
// Validate user SignSSHOptions.
|
// Validate user SignSSHOptions.
|
||||||
sshCertOptionsValidator(defaults),
|
sshCertOptionsValidator(defaults),
|
||||||
// Set the validity bounds if not set.
|
// Set the validity bounds if not set.
|
||||||
&sshDefaultDuration{p.claimer},
|
&sshDefaultDuration{p.ctl.Claimer},
|
||||||
// Validate public key
|
// Validate public key
|
||||||
&sshDefaultPublicKeyValidator{},
|
&sshDefaultPublicKeyValidator{},
|
||||||
// Validate the validity period.
|
// Validate the validity period.
|
||||||
&sshCertValidityValidator{p.claimer},
|
&sshCertValidityValidator{p.ctl.Claimer},
|
||||||
// Require all the fields in the SSH certificate
|
// Require all the fields in the SSH certificate
|
||||||
&sshCertDefaultValidator{},
|
&sshCertDefaultValidator{},
|
||||||
), nil
|
), nil
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -15,10 +16,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAzure_Getters(t *testing.T) {
|
func TestAzure_Getters(t *testing.T) {
|
||||||
|
@ -95,7 +96,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
@ -237,7 +238,7 @@ func TestAzure_authorizeToken(t *testing.T) {
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
||||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now(), jwk)
|
time.Now(), jwk)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
|
@ -252,7 +253,7 @@ func TestAzure_authorizeToken(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
|
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
|
||||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
|
@ -267,7 +268,7 @@ func TestAzure_authorizeToken(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
||||||
"foo", "subscriptionID", "resourceGroup", "virtualMachine",
|
"foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
|
@ -321,7 +322,7 @@ func TestAzure_authorizeToken(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
||||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
|
@ -333,10 +334,10 @@ func TestAzure_authorizeToken(t *testing.T) {
|
||||||
for name, tt := range tests {
|
for name, tt := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if claims, name, group, err := tc.p.authorizeToken(tc.token); err != nil {
|
if claims, name, group, subscriptionID, objectID, err := tc.p.authorizeToken(tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -348,6 +349,8 @@ func TestAzure_authorizeToken(t *testing.T) {
|
||||||
|
|
||||||
assert.Equals(t, name, "virtualMachine")
|
assert.Equals(t, name, "virtualMachine")
|
||||||
assert.Equals(t, group, "resourceGroup")
|
assert.Equals(t, group, "resourceGroup")
|
||||||
|
assert.Equals(t, subscriptionID, "subscriptionID")
|
||||||
|
assert.Equals(t, objectID, "the-oid")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -382,6 +385,38 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
p4.oidcConfig = p1.oidcConfig
|
p4.oidcConfig = p1.oidcConfig
|
||||||
p4.keyStore = p1.keyStore
|
p4.keyStore = p1.keyStore
|
||||||
|
|
||||||
|
p5, err := generateAzure()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p5.TenantID = p1.TenantID
|
||||||
|
p5.SubscriptionIDs = []string{"subscriptionID"}
|
||||||
|
p5.config = p1.config
|
||||||
|
p5.oidcConfig = p1.oidcConfig
|
||||||
|
p5.keyStore = p1.keyStore
|
||||||
|
|
||||||
|
p6, err := generateAzure()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p6.TenantID = p1.TenantID
|
||||||
|
p6.SubscriptionIDs = []string{"foobarzar"}
|
||||||
|
p6.config = p1.config
|
||||||
|
p6.oidcConfig = p1.oidcConfig
|
||||||
|
p6.keyStore = p1.keyStore
|
||||||
|
|
||||||
|
p7, err := generateAzure()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p7.TenantID = p1.TenantID
|
||||||
|
p7.ObjectIDs = []string{"the-oid"}
|
||||||
|
p7.config = p1.config
|
||||||
|
p7.oidcConfig = p1.oidcConfig
|
||||||
|
p7.keyStore = p1.keyStore
|
||||||
|
|
||||||
|
p8, err := generateAzure()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p8.TenantID = p1.TenantID
|
||||||
|
p8.ObjectIDs = []string{"foobarzar"}
|
||||||
|
p8.config = p1.config
|
||||||
|
p8.oidcConfig = p1.oidcConfig
|
||||||
|
p8.keyStore = p1.keyStore
|
||||||
|
|
||||||
badKey, err := generateJSONWebKey()
|
badKey, err := generateJSONWebKey()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
@ -393,30 +428,38 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
t4, err := p4.GetIdentityToken("subject", "caURL")
|
t4, err := p4.GetIdentityToken("subject", "caURL")
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
t5, err := p5.GetIdentityToken("subject", "caURL")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t6, err := p6.GetIdentityToken("subject", "caURL")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t7, err := p6.GetIdentityToken("subject", "caURL")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
t8, err := p6.GetIdentityToken("subject", "caURL")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
|
failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
|
||||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience",
|
failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience",
|
||||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
|
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
|
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||||
time.Now(), badKey)
|
time.Now(), badKey)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
@ -431,11 +474,15 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
code int
|
code int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{t1}, 5, http.StatusOK, false},
|
{"ok", p1, args{t1}, 6, http.StatusOK, false},
|
||||||
{"ok", p2, args{t2}, 10, http.StatusOK, false},
|
{"ok", p2, args{t2}, 11, http.StatusOK, false},
|
||||||
{"ok", p1, args{t11}, 5, http.StatusOK, false},
|
{"ok", p1, args{t11}, 6, http.StatusOK, false},
|
||||||
|
{"ok", p5, args{t5}, 6, http.StatusOK, false},
|
||||||
|
{"ok", p7, args{t7}, 6, http.StatusOK, false},
|
||||||
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
|
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
|
||||||
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
|
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
|
||||||
|
{"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true},
|
||||||
|
{"fail object id", p8, args{t8}, 0, http.StatusUnauthorized, true},
|
||||||
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
|
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
|
||||||
{"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
|
{"fail issuer", p1, args{failIssuer}, 0, http.StatusUnauthorized, true},
|
||||||
{"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
|
{"fail audience", p1, args{failAudience}, 0, http.StatusUnauthorized, true},
|
||||||
|
@ -451,27 +498,28 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
case err != nil:
|
case err != nil:
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
default:
|
default:
|
||||||
assert.Len(t, tt.wantLen, got)
|
assert.Len(t, tt.wantLen, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
case *Azure:
|
||||||
case certificateOptionsFunc:
|
case certificateOptionsFunc:
|
||||||
case *provisionerExtensionOption:
|
case *provisionerExtensionOption:
|
||||||
assert.Equals(t, v.Type, int(TypeAzure))
|
assert.Equals(t, v.Type, TypeAzure)
|
||||||
assert.Equals(t, v.Name, tt.azure.GetName())
|
assert.Equals(t, v.Name, tt.azure.GetName())
|
||||||
assert.Equals(t, v.CredentialID, tt.azure.TenantID)
|
assert.Equals(t, v.CredentialID, tt.azure.TenantID)
|
||||||
assert.Len(t, 0, v.KeyValuePairs)
|
assert.Len(t, 0, v.KeyValuePairs)
|
||||||
case profileDefaultDuration:
|
case profileDefaultDuration:
|
||||||
assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration())
|
assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration())
|
||||||
case commonNameValidator:
|
case commonNameValidator:
|
||||||
assert.Equals(t, string(v), "virtualMachine")
|
assert.Equals(t, string(v), "virtualMachine")
|
||||||
case defaultPublicKeyValidator:
|
case defaultPublicKeyValidator:
|
||||||
case *validityValidator:
|
case *validityValidator:
|
||||||
assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration())
|
assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration())
|
||||||
assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration())
|
assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration())
|
||||||
case ipAddressesValidator:
|
case ipAddressesValidator:
|
||||||
assert.Equals(t, v, nil)
|
assert.Equals(t, v, nil)
|
||||||
case emailAddressesValidator:
|
case emailAddressesValidator:
|
||||||
|
@ -481,7 +529,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
case dnsNamesValidator:
|
case dnsNamesValidator:
|
||||||
assert.Equals(t, []string(v), []string{"virtualMachine"})
|
assert.Equals(t, []string(v), []string{"virtualMachine"})
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -490,6 +538,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzure_AuthorizeRenew(t *testing.T) {
|
func TestAzure_AuthorizeRenew(t *testing.T) {
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
p1, err := generateAzure()
|
p1, err := generateAzure()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p2, err := generateAzure()
|
p2, err := generateAzure()
|
||||||
|
@ -498,7 +547,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
|
||||||
// disable renewal
|
// disable renewal
|
||||||
disable := true
|
disable := true
|
||||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
|
@ -511,16 +560,22 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
|
||||||
code int
|
code int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
{"ok", p1, args{&x509.Certificate{
|
||||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusOK, false},
|
||||||
|
{"fail/renew-disabled", p2, args{&x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusUnauthorized, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
if err := tt.azure.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Azure.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -549,7 +604,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
|
||||||
// disable sshCA
|
// disable sshCA
|
||||||
disable := false
|
disable := false
|
||||||
p3.Claims = &Claims{EnableSSHCA: &disable}
|
p3.Claims = &Claims{EnableSSHCA: &disable}
|
||||||
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
t1, err := p1.GetIdentityToken("subject", "caURL")
|
t1, err := p1.GetIdentityToken("subject", "caURL")
|
||||||
|
@ -570,7 +625,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
|
||||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||||
expectedHostOptions := &SignSSHOptions{
|
expectedHostOptions := &SignSSHOptions{
|
||||||
CertType: "host", Principals: []string{"virtualMachine"},
|
CertType: "host", Principals: []string{"virtualMachine"},
|
||||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
@ -615,8 +670,8 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
} else if assert.NotNil(t, got) {
|
} else if assert.NotNil(t, got) {
|
||||||
|
|
|
@ -10,10 +10,10 @@ import (
|
||||||
// Claims so that individual provisioners can override global claims.
|
// Claims so that individual provisioners can override global claims.
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
// TLS CA properties
|
// TLS CA properties
|
||||||
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
||||||
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
||||||
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
||||||
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
|
||||||
// SSH CA properties
|
// SSH CA properties
|
||||||
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
|
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
|
||||||
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
|
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
|
||||||
|
@ -22,6 +22,10 @@ type Claims struct {
|
||||||
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
|
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
|
||||||
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
|
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
|
||||||
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
|
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
|
||||||
|
|
||||||
|
// Renewal properties
|
||||||
|
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||||
|
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claimer is the type that controls claims. It provides an interface around the
|
// Claimer is the type that controls claims. It provides an interface around the
|
||||||
|
@ -40,19 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
|
||||||
// Claims returns the merge of the inner and global claims.
|
// Claims returns the merge of the inner and global claims.
|
||||||
func (c *Claimer) Claims() Claims {
|
func (c *Claimer) Claims() Claims {
|
||||||
disableRenewal := c.IsDisableRenewal()
|
disableRenewal := c.IsDisableRenewal()
|
||||||
|
allowRenewAfterExpiry := c.AllowRenewAfterExpiry()
|
||||||
enableSSHCA := c.IsSSHCAEnabled()
|
enableSSHCA := c.IsSSHCAEnabled()
|
||||||
|
|
||||||
return Claims{
|
return Claims{
|
||||||
MinTLSDur: &Duration{c.MinTLSCertDuration()},
|
MinTLSDur: &Duration{c.MinTLSCertDuration()},
|
||||||
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
||||||
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
||||||
DisableRenewal: &disableRenewal,
|
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
|
||||||
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
|
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
|
||||||
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
|
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
|
||||||
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
|
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
|
||||||
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
|
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
|
||||||
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
|
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
|
||||||
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
|
EnableSSHCA: &enableSSHCA,
|
||||||
EnableSSHCA: &enableSSHCA,
|
DisableRenewal: &disableRenewal,
|
||||||
|
AllowRenewAfterExpiry: &allowRenewAfterExpiry,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,6 +109,16 @@ func (c *Claimer) IsDisableRenewal() bool {
|
||||||
return *c.claims.DisableRenewal
|
return *c.claims.DisableRenewal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the
|
||||||
|
// certificate is expired. If the property is not set within the provisioner
|
||||||
|
// then the global value from the authority configuration will be used.
|
||||||
|
func (c *Claimer) AllowRenewAfterExpiry() bool {
|
||||||
|
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil {
|
||||||
|
return *c.global.AllowRenewAfterExpiry
|
||||||
|
}
|
||||||
|
return *c.claims.AllowRenewAfterExpiry
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultSSHCertDuration returns the default SSH certificate duration for the
|
// DefaultSSHCertDuration returns the default SSH certificate duration for the
|
||||||
// given certificate type.
|
// given certificate type.
|
||||||
func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) {
|
func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) {
|
||||||
|
|
|
@ -152,8 +152,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
|
||||||
// proper id to load the provisioner.
|
// proper id to load the provisioner.
|
||||||
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
|
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
|
||||||
for _, e := range cert.Extensions {
|
for _, e := range cert.Extensions {
|
||||||
if e.Id.Equal(stepOIDProvisioner) {
|
if e.Id.Equal(StepOIDProvisioner) {
|
||||||
var provisioner stepProvisionerASN1
|
var provisioner extensionASN1
|
||||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
|
@ -147,6 +147,17 @@ func TestCollection_LoadByToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCollection_LoadByCertificate(t *testing.T) {
|
func TestCollection_LoadByCertificate(t *testing.T) {
|
||||||
|
mustExtension := func(typ Type, name, credentialID string) pkix.Extension {
|
||||||
|
e := Extension{
|
||||||
|
Type: typ, Name: name, CredentialID: credentialID,
|
||||||
|
}
|
||||||
|
ext, err := e.ToExtension()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return ext
|
||||||
|
}
|
||||||
|
|
||||||
p1, err := generateJWK()
|
p1, err := generateJWK()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p2, err := generateOIDC()
|
p2, err := generateOIDC()
|
||||||
|
@ -159,30 +170,21 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
||||||
byName.Store(p2.GetName(), p2)
|
byName.Store(p2.GetName(), p2)
|
||||||
byName.Store(p3.GetName(), p3)
|
byName.Store(p3.GetName(), p3)
|
||||||
|
|
||||||
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "")
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
|
|
||||||
ok1Cert := &x509.Certificate{
|
ok1Cert := &x509.Certificate{
|
||||||
Extensions: []pkix.Extension{ok1Ext},
|
Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)},
|
||||||
}
|
}
|
||||||
ok2Cert := &x509.Certificate{
|
ok2Cert := &x509.Certificate{
|
||||||
Extensions: []pkix.Extension{ok2Ext},
|
Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)},
|
||||||
}
|
}
|
||||||
ok3Cert := &x509.Certificate{
|
ok3Cert := &x509.Certificate{
|
||||||
Extensions: []pkix.Extension{ok3Ext},
|
Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")},
|
||||||
}
|
}
|
||||||
notFoundCert := &x509.Certificate{
|
notFoundCert := &x509.Certificate{
|
||||||
Extensions: []pkix.Extension{notFoundExt},
|
Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")},
|
||||||
}
|
}
|
||||||
badCert := &x509.Certificate{
|
badCert := &x509.Certificate{
|
||||||
Extensions: []pkix.Extension{
|
Extensions: []pkix.Extension{
|
||||||
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")},
|
{Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
194
authority/provisioner/controller.go
Normal file
194
authority/provisioner/controller.go
Normal file
|
@ -0,0 +1,194 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/errs"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Controller wraps a provisioner with other attributes useful in callback
|
||||||
|
// functions.
|
||||||
|
type Controller struct {
|
||||||
|
Interface
|
||||||
|
Audiences *Audiences
|
||||||
|
Claimer *Claimer
|
||||||
|
IdentityFunc GetIdentityFunc
|
||||||
|
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||||
|
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewController initializes a new provisioner controller.
|
||||||
|
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) {
|
||||||
|
claimer, err := NewClaimer(claims, config.Claims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Controller{
|
||||||
|
Interface: p,
|
||||||
|
Audiences: &config.Audiences,
|
||||||
|
Claimer: claimer,
|
||||||
|
IdentityFunc: config.GetIdentityFunc,
|
||||||
|
AuthorizeRenewFunc: config.AuthorizeRenewFunc,
|
||||||
|
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIdentity returns the identity for a given email.
|
||||||
|
func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) {
|
||||||
|
if c.IdentityFunc != nil {
|
||||||
|
return c.IdentityFunc(ctx, c.Interface, email)
|
||||||
|
}
|
||||||
|
return DefaultIdentityFunc(ctx, c.Interface, email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRenew returns nil if the given cert can be renewed, returns an error
|
||||||
|
// otherwise.
|
||||||
|
func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
|
if c.AuthorizeRenewFunc != nil {
|
||||||
|
return c.AuthorizeRenewFunc(ctx, c, cert)
|
||||||
|
}
|
||||||
|
return DefaultAuthorizeRenew(ctx, c, cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an
|
||||||
|
// error otherwise.
|
||||||
|
func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error {
|
||||||
|
if c.AuthorizeSSHRenewFunc != nil {
|
||||||
|
return c.AuthorizeSSHRenewFunc(ctx, c, cert)
|
||||||
|
}
|
||||||
|
return DefaultAuthorizeSSHRenew(ctx, c, cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identity is the type representing an externally supplied identity that is used
|
||||||
|
// by provisioners to populate certificate fields.
|
||||||
|
type Identity struct {
|
||||||
|
Usernames []string `json:"usernames"`
|
||||||
|
Permissions `json:"permissions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIdentityFunc is a function that returns an identity.
|
||||||
|
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
|
||||||
|
|
||||||
|
// AuthorizeRenewFunc is a function that returns nil if the renewal of a
|
||||||
|
// certificate is enabled.
|
||||||
|
type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error
|
||||||
|
|
||||||
|
// AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the
|
||||||
|
// given SSH certificate is enabled.
|
||||||
|
type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error
|
||||||
|
|
||||||
|
// DefaultIdentityFunc return a default identity depending on the provisioner
|
||||||
|
// type. For OIDC email is always present and the usernames might
|
||||||
|
// contain empty strings.
|
||||||
|
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||||
|
switch k := p.(type) {
|
||||||
|
case *OIDC:
|
||||||
|
// OIDC principals would be:
|
||||||
|
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
|
||||||
|
// 2. Sanitized local.
|
||||||
|
// 3. Raw local (if different).
|
||||||
|
// 4. Email address.
|
||||||
|
name := SanitizeSSHUserPrincipal(email)
|
||||||
|
if !sshUserRegex.MatchString(name) {
|
||||||
|
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
|
||||||
|
}
|
||||||
|
usernames := []string{name}
|
||||||
|
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||||
|
usernames = append(usernames, email[:i])
|
||||||
|
}
|
||||||
|
usernames = append(usernames, email)
|
||||||
|
return &Identity{
|
||||||
|
Usernames: SanitizeStringSlices(usernames),
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It
|
||||||
|
// will return an error if the provisioner has the renewal disabled, if the
|
||||||
|
// certificate is not yet valid or if the certificate is expired and renew after
|
||||||
|
// expiry is disabled.
|
||||||
|
func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||||
|
if p.Claimer.IsDisableRenewal() {
|
||||||
|
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
|
if now.Before(cert.NotBefore) {
|
||||||
|
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
|
||||||
|
}
|
||||||
|
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() {
|
||||||
|
return errs.Unauthorized("certificate has expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It
|
||||||
|
// will return an error if the provisioner has the renewal disabled, if the
|
||||||
|
// certificate is not yet valid or if the certificate is expired and renew after
|
||||||
|
// expiry is disabled.
|
||||||
|
func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||||
|
if p.Claimer.IsDisableRenewal() {
|
||||||
|
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
unixNow := time.Now().Unix()
|
||||||
|
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
|
||||||
|
return errs.Unauthorized("certificate is not yet valid")
|
||||||
|
}
|
||||||
|
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() {
|
||||||
|
return errs.Unauthorized("certificate has expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
|
||||||
|
|
||||||
|
// SanitizeStringSlices removes duplicated an empty strings.
|
||||||
|
func SanitizeStringSlices(original []string) []string {
|
||||||
|
output := []string{}
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
for _, entry := range original {
|
||||||
|
if entry == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, value := seen[entry]; !value {
|
||||||
|
seen[entry] = struct{}{}
|
||||||
|
output = append(output, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeSSHUserPrincipal grabs an email or a string with the format
|
||||||
|
// local@domain and returns a sanitized version of the local, valid to be used
|
||||||
|
// as a user name. If the email starts with a letter between a and z, the
|
||||||
|
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
|
||||||
|
func SanitizeSSHUserPrincipal(email string) string {
|
||||||
|
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||||
|
email = email[:i]
|
||||||
|
}
|
||||||
|
return strings.Map(func(r rune) rune {
|
||||||
|
switch {
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
return r
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
return r
|
||||||
|
case r == '-':
|
||||||
|
return '-'
|
||||||
|
case r == '.': // drop dots
|
||||||
|
return -1
|
||||||
|
default:
|
||||||
|
return '_'
|
||||||
|
}
|
||||||
|
}, strings.ToLower(email))
|
||||||
|
}
|
391
authority/provisioner/controller_test.go
Normal file
391
authority/provisioner/controller_test.go
Normal file
|
@ -0,0 +1,391 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
var trueValue = true
|
||||||
|
|
||||||
|
func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer {
|
||||||
|
t.Helper()
|
||||||
|
c, err := NewClaimer(claims, global)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
func mustDuration(t *testing.T, s string) *Duration {
|
||||||
|
t.Helper()
|
||||||
|
d, err := NewDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewController(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
p Interface
|
||||||
|
claims *Claims
|
||||||
|
config Config
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *Controller
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{&JWK{}, nil, Config{
|
||||||
|
Claims: globalProvisionerClaims,
|
||||||
|
Audiences: testAudiences,
|
||||||
|
}}, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Audiences: &testAudiences,
|
||||||
|
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||||
|
}, false},
|
||||||
|
{"ok with claims", args{&JWK{}, &Claims{
|
||||||
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
}, Config{
|
||||||
|
Claims: globalProvisionerClaims,
|
||||||
|
Audiences: testAudiences,
|
||||||
|
}}, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Audiences: &testAudiences,
|
||||||
|
Claimer: mustClaimer(t, &Claims{
|
||||||
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
}, globalProvisionerClaims),
|
||||||
|
}, false},
|
||||||
|
{"fail claimer", args{&JWK{}, &Claims{
|
||||||
|
MinTLSDur: mustDuration(t, "24h"),
|
||||||
|
MaxTLSDur: mustDuration(t, "2h"),
|
||||||
|
}, Config{
|
||||||
|
Claims: globalProvisionerClaims,
|
||||||
|
Audiences: testAudiences,
|
||||||
|
}}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("NewController() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestController_GetIdentity(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
type fields struct {
|
||||||
|
Interface Interface
|
||||||
|
IdentityFunc GetIdentityFunc
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
email string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want *Identity
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{
|
||||||
|
Usernames: []string{"jane", "jane@doe.org"},
|
||||||
|
}, false},
|
||||||
|
{"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||||
|
return &Identity{Usernames: []string{"jane"}}, nil
|
||||||
|
}}, args{ctx, "jane@doe.org"}, &Identity{
|
||||||
|
Usernames: []string{"jane"},
|
||||||
|
}, false},
|
||||||
|
{"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true},
|
||||||
|
{"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||||
|
return nil, fmt.Errorf("an error")
|
||||||
|
}}, args{ctx, "jane@doe.org"}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Controller{
|
||||||
|
Interface: tt.fields.Interface,
|
||||||
|
IdentityFunc: tt.fields.IdentityFunc,
|
||||||
|
}
|
||||||
|
got, err := c.GetIdentity(tt.args.ctx, tt.args.email)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestController_AuthorizeRenew(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
|
type fields struct {
|
||||||
|
Interface Interface
|
||||||
|
Claimer *Claimer
|
||||||
|
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, false},
|
||||||
|
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||||
|
return nil
|
||||||
|
}}, args{ctx, &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, false},
|
||||||
|
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||||
|
return nil
|
||||||
|
}}, args{ctx, &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, false},
|
||||||
|
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||||
|
NotBefore: now.Add(-time.Hour),
|
||||||
|
NotAfter: now.Add(-time.Minute),
|
||||||
|
}}, false},
|
||||||
|
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, true},
|
||||||
|
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||||
|
NotBefore: now.Add(time.Hour),
|
||||||
|
NotAfter: now.Add(2 * time.Hour),
|
||||||
|
}}, true},
|
||||||
|
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||||
|
NotBefore: now.Add(-time.Hour),
|
||||||
|
NotAfter: now.Add(-time.Minute),
|
||||||
|
}}, true},
|
||||||
|
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||||
|
return fmt.Errorf("an error")
|
||||||
|
}}, args{ctx, &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Controller{
|
||||||
|
Interface: tt.fields.Interface,
|
||||||
|
Claimer: tt.fields.Claimer,
|
||||||
|
AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc,
|
||||||
|
}
|
||||||
|
if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestController_AuthorizeSSHRenew(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now()
|
||||||
|
type fields struct {
|
||||||
|
Interface Interface
|
||||||
|
Claimer *Claimer
|
||||||
|
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
cert *ssh.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
}}, false},
|
||||||
|
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||||
|
return nil
|
||||||
|
}}, args{ctx, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
}}, false},
|
||||||
|
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||||
|
return nil
|
||||||
|
}}, args{ctx, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
}}, false},
|
||||||
|
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||||
|
}}, false},
|
||||||
|
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
}}, true},
|
||||||
|
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
|
||||||
|
}}, true},
|
||||||
|
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||||
|
}}, true},
|
||||||
|
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||||
|
return fmt.Errorf("an error")
|
||||||
|
}}, args{ctx, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Controller{
|
||||||
|
Interface: tt.fields.Interface,
|
||||||
|
Claimer: tt.fields.Claimer,
|
||||||
|
AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc,
|
||||||
|
}
|
||||||
|
if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultAuthorizeRenew(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
p *Controller
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||||
|
}, &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, false},
|
||||||
|
{"ok renew after expiry", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||||
|
}, &x509.Certificate{
|
||||||
|
NotBefore: now.Add(-time.Hour),
|
||||||
|
NotAfter: now.Add(-time.Minute),
|
||||||
|
}}, false},
|
||||||
|
{"fail disabled", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||||
|
}, &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, true},
|
||||||
|
{"fail not yet valid", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||||
|
}, &x509.Certificate{
|
||||||
|
NotBefore: now.Add(time.Hour),
|
||||||
|
NotAfter: now.Add(2 * time.Hour),
|
||||||
|
}}, true},
|
||||||
|
{"fail expired", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||||
|
}, &x509.Certificate{
|
||||||
|
NotBefore: now.Add(-time.Hour),
|
||||||
|
NotAfter: now.Add(-time.Minute),
|
||||||
|
}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultAuthorizeSSHRenew(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now()
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
p *Controller
|
||||||
|
cert *ssh.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||||
|
}, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
}}, false},
|
||||||
|
{"ok renew after expiry", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||||
|
}, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||||
|
}}, false},
|
||||||
|
{"fail disabled", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||||
|
}, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
}}, true},
|
||||||
|
{"fail not yet valid", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||||
|
}, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
|
||||||
|
}}, true},
|
||||||
|
{"fail expired", args{ctx, &Controller{
|
||||||
|
Interface: &JWK{},
|
||||||
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||||
|
}, &ssh.Certificate{
|
||||||
|
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||||
|
}}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
73
authority/provisioner/extension.go
Normal file
73
authority/provisioner/extension.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// StepOIDRoot is the root OID for smallstep.
|
||||||
|
StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||||
|
|
||||||
|
// StepOIDProvisioner is the OID for the provisioner extension.
|
||||||
|
StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Extension is the Go representation of the provisioner extension.
|
||||||
|
type Extension struct {
|
||||||
|
Type Type
|
||||||
|
Name string
|
||||||
|
CredentialID string
|
||||||
|
KeyValuePairs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type extensionASN1 struct {
|
||||||
|
Type int
|
||||||
|
Name []byte
|
||||||
|
CredentialID []byte
|
||||||
|
KeyValuePairs []string `asn1:"optional,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal marshals the extension using encoding/asn1.
|
||||||
|
func (e *Extension) Marshal() ([]byte, error) {
|
||||||
|
return asn1.Marshal(extensionASN1{
|
||||||
|
Type: int(e.Type),
|
||||||
|
Name: []byte(e.Name),
|
||||||
|
CredentialID: []byte(e.CredentialID),
|
||||||
|
KeyValuePairs: e.KeyValuePairs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToExtension returns the pkix.Extension representation of the provisioner
|
||||||
|
// extension.
|
||||||
|
func (e *Extension) ToExtension() (pkix.Extension, error) {
|
||||||
|
b, err := e.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return pkix.Extension{}, err
|
||||||
|
}
|
||||||
|
return pkix.Extension{
|
||||||
|
Id: StepOIDProvisioner,
|
||||||
|
Value: b,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProvisionerExtension goes through all the certificate extensions and
|
||||||
|
// returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1).
|
||||||
|
func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) {
|
||||||
|
for _, e := range cert.Extensions {
|
||||||
|
if e.Id.Equal(StepOIDProvisioner) {
|
||||||
|
var provisioner extensionASN1
|
||||||
|
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return &Extension{
|
||||||
|
Type: Type(provisioner.Type),
|
||||||
|
Name: string(provisioner.Name),
|
||||||
|
CredentialID: string(provisioner.CredentialID),
|
||||||
|
KeyValuePairs: provisioner.KeyValuePairs,
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
158
authority/provisioner/extension_test.go
Normal file
158
authority/provisioner/extension_test.go
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.step.sm/crypto/pemutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtension_Marshal(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Type Type
|
||||||
|
Name string
|
||||||
|
CredentialID string
|
||||||
|
KeyValuePairs []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{
|
||||||
|
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||||
|
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||||
|
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||||
|
0x44,
|
||||||
|
}, false},
|
||||||
|
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{
|
||||||
|
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||||
|
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||||
|
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||||
|
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
|
||||||
|
0x13, 0x03, 0x62, 0x61, 0x72,
|
||||||
|
}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
e := &Extension{
|
||||||
|
Type: tt.fields.Type,
|
||||||
|
Name: tt.fields.Name,
|
||||||
|
CredentialID: tt.fields.CredentialID,
|
||||||
|
KeyValuePairs: tt.fields.KeyValuePairs,
|
||||||
|
}
|
||||||
|
got, err := e.Marshal()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtension_ToExtension(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Type Type
|
||||||
|
Name string
|
||||||
|
CredentialID string
|
||||||
|
KeyValuePairs []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want pkix.Extension
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{
|
||||||
|
Id: StepOIDProvisioner,
|
||||||
|
Value: []byte{
|
||||||
|
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||||
|
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||||
|
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||||
|
0x44,
|
||||||
|
},
|
||||||
|
}, false},
|
||||||
|
{"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{
|
||||||
|
Id: StepOIDProvisioner,
|
||||||
|
Value: []byte{
|
||||||
|
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||||
|
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||||
|
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||||
|
0x44,
|
||||||
|
},
|
||||||
|
}, false},
|
||||||
|
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{
|
||||||
|
Id: StepOIDProvisioner,
|
||||||
|
Value: []byte{
|
||||||
|
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||||
|
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||||
|
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||||
|
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
|
||||||
|
0x13, 0x03, 0x62, 0x61, 0x72,
|
||||||
|
},
|
||||||
|
}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
e := &Extension{
|
||||||
|
Type: tt.fields.Type,
|
||||||
|
Name: tt.fields.Name,
|
||||||
|
CredentialID: tt.fields.CredentialID,
|
||||||
|
KeyValuePairs: tt.fields.KeyValuePairs,
|
||||||
|
}
|
||||||
|
got, err := e.ToExtension()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetProvisionerExtension(t *testing.T) {
|
||||||
|
mustCertificate := func(fn string) *x509.Certificate {
|
||||||
|
cert, err := pemutil.ReadCertificate(fn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *Extension
|
||||||
|
want1 bool
|
||||||
|
}{
|
||||||
|
{"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{
|
||||||
|
Type: TypeJWK,
|
||||||
|
Name: "mariano@smallstep.com",
|
||||||
|
CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ",
|
||||||
|
}, true},
|
||||||
|
{"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false},
|
||||||
|
{"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, got1 := GetProvisionerExtension(tt.args.cert)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got1 != tt.want1 {
|
||||||
|
t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -88,10 +88,9 @@ type GCP struct {
|
||||||
InstanceAge Duration `json:"instanceAge,omitempty"`
|
InstanceAge Duration `json:"instanceAge,omitempty"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
Options *Options `json:"options,omitempty"`
|
Options *Options `json:"options,omitempty"`
|
||||||
claimer *Claimer
|
|
||||||
config *gcpConfig
|
config *gcpConfig
|
||||||
keyStore *keyStore
|
keyStore *keyStore
|
||||||
audiences Audiences
|
ctl *Controller
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier. The name should uniquely
|
// GetID returns the provisioner unique identifier. The name should uniquely
|
||||||
|
@ -194,8 +193,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init validates and initializes the GCP provisioner.
|
// Init validates and initializes the GCP provisioner.
|
||||||
func (p *GCP) Init(config Config) error {
|
func (p *GCP) Init(config Config) (err error) {
|
||||||
var err error
|
|
||||||
switch {
|
switch {
|
||||||
case p.Type == "":
|
case p.Type == "":
|
||||||
return errors.New("provisioner type cannot be empty")
|
return errors.New("provisioner type cannot be empty")
|
||||||
|
@ -204,20 +202,18 @@ func (p *GCP) Init(config Config) error {
|
||||||
case p.InstanceAge.Value() < 0:
|
case p.InstanceAge.Value() < 0:
|
||||||
return errors.New("provisioner instanceAge cannot be negative")
|
return errors.New("provisioner instanceAge cannot be negative")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize config
|
// Initialize config
|
||||||
p.assertConfig()
|
p.assertConfig()
|
||||||
// Update claims with global ones
|
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Initialize key store
|
// Initialize key store
|
||||||
p.keyStore, err = newKeyStore(p.config.CertsURL)
|
if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil {
|
||||||
if err != nil {
|
return
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||||
return nil
|
p.ctl, err = NewController(p, p.Claims, config)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign validates the given token and returns the sign options that
|
// AuthorizeSign validates the given token and returns the sign options that
|
||||||
|
@ -266,22 +262,20 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(so,
|
return append(so,
|
||||||
|
p,
|
||||||
templateOptions,
|
templateOptions,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
|
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
|
||||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||||
// validators
|
// validators
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||||
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
if p.claimer.IsDisableRenewal() {
|
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||||
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// assertConfig initializes the config if it has not been initialized.
|
// assertConfig initializes the config if it has not been initialized.
|
||||||
|
@ -328,7 +322,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate audiences with the defaults
|
// validate audiences with the defaults
|
||||||
if !matchesAudience(claims.Audience, p.audiences.Sign) {
|
if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) {
|
||||||
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
|
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -383,7 +377,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
||||||
|
|
||||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
if !p.claimer.IsSSHCAEnabled() {
|
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||||
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName())
|
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName())
|
||||||
}
|
}
|
||||||
claims, err := p.authorizeToken(token)
|
claims, err := p.authorizeToken(token)
|
||||||
|
@ -431,11 +425,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
||||||
// Validate user SignSSHOptions.
|
// Validate user SignSSHOptions.
|
||||||
sshCertOptionsValidator(defaults),
|
sshCertOptionsValidator(defaults),
|
||||||
// Set the validity bounds if not set.
|
// Set the validity bounds if not set.
|
||||||
&sshDefaultDuration{p.claimer},
|
&sshDefaultDuration{p.ctl.Claimer},
|
||||||
// Validate public key
|
// Validate public key
|
||||||
&sshDefaultPublicKeyValidator{},
|
&sshDefaultPublicKeyValidator{},
|
||||||
// Validate the validity period.
|
// Validate the validity period.
|
||||||
&sshCertValidityValidator{p.claimer},
|
&sshCertValidityValidator{p.ctl.Claimer},
|
||||||
// Require all the fields in the SSH certificate
|
// Require all the fields in the SSH certificate
|
||||||
&sshCertDefaultValidator{},
|
&sshCertDefaultValidator{},
|
||||||
), nil
|
), nil
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -16,10 +17,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGCP_Getters(t *testing.T) {
|
func TestGCP_Getters(t *testing.T) {
|
||||||
|
@ -390,8 +391,8 @@ func TestGCP_authorizeToken(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
|
if claims, err := tc.p.authorizeToken(tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -515,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
||||||
code int
|
code int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{t1}, 5, http.StatusOK, false},
|
{"ok", p1, args{t1}, 6, http.StatusOK, false},
|
||||||
{"ok", p2, args{t2}, 10, http.StatusOK, false},
|
{"ok", p2, args{t2}, 11, http.StatusOK, false},
|
||||||
{"ok", p3, args{t3}, 5, http.StatusOK, false},
|
{"ok", p3, args{t3}, 6, http.StatusOK, false},
|
||||||
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
|
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
|
||||||
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
|
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
|
||||||
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
|
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
|
||||||
|
@ -540,27 +541,28 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
||||||
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
case err != nil:
|
case err != nil:
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
default:
|
default:
|
||||||
assert.Len(t, tt.wantLen, got)
|
assert.Len(t, tt.wantLen, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
case *GCP:
|
||||||
case certificateOptionsFunc:
|
case certificateOptionsFunc:
|
||||||
case *provisionerExtensionOption:
|
case *provisionerExtensionOption:
|
||||||
assert.Equals(t, v.Type, int(TypeGCP))
|
assert.Equals(t, v.Type, TypeGCP)
|
||||||
assert.Equals(t, v.Name, tt.gcp.GetName())
|
assert.Equals(t, v.Name, tt.gcp.GetName())
|
||||||
assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0])
|
assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0])
|
||||||
assert.Len(t, 4, v.KeyValuePairs)
|
assert.Len(t, 4, v.KeyValuePairs)
|
||||||
case profileDefaultDuration:
|
case profileDefaultDuration:
|
||||||
assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration())
|
assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration())
|
||||||
case commonNameSliceValidator:
|
case commonNameSliceValidator:
|
||||||
assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
|
assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
|
||||||
case defaultPublicKeyValidator:
|
case defaultPublicKeyValidator:
|
||||||
case *validityValidator:
|
case *validityValidator:
|
||||||
assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration())
|
assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration())
|
||||||
assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration())
|
assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration())
|
||||||
case ipAddressesValidator:
|
case ipAddressesValidator:
|
||||||
assert.Equals(t, v, nil)
|
assert.Equals(t, v, nil)
|
||||||
case emailAddressesValidator:
|
case emailAddressesValidator:
|
||||||
|
@ -570,7 +572,7 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
||||||
case dnsNamesValidator:
|
case dnsNamesValidator:
|
||||||
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
|
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -595,7 +597,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
||||||
// disable sshCA
|
// disable sshCA
|
||||||
disable := false
|
disable := false
|
||||||
p3.Claims = &Claims{EnableSSHCA: &disable}
|
p3.Claims = &Claims{EnableSSHCA: &disable}
|
||||||
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
t1, err := generateGCPToken(p1.ServiceAccounts[0],
|
t1, err := generateGCPToken(p1.ServiceAccounts[0],
|
||||||
|
@ -622,7 +624,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
||||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||||
expectedHostOptions := &SignSSHOptions{
|
expectedHostOptions := &SignSSHOptions{
|
||||||
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
|
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
|
||||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||||
|
@ -677,8 +679,8 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
} else if assert.NotNil(t, got) {
|
} else if assert.NotNil(t, got) {
|
||||||
|
@ -698,6 +700,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGCP_AuthorizeRenew(t *testing.T) {
|
func TestGCP_AuthorizeRenew(t *testing.T) {
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
p1, err := generateGCP()
|
p1, err := generateGCP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p2, err := generateGCP()
|
p2, err := generateGCP()
|
||||||
|
@ -706,7 +709,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
|
||||||
// disable renewal
|
// disable renewal
|
||||||
disable := true
|
disable := true
|
||||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
|
@ -719,15 +722,21 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
|
||||||
code int
|
code int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
{"ok", p1, args{&x509.Certificate{
|
||||||
{"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusOK, false},
|
||||||
|
{"fail/renewal-disabled", p2, args{&x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusUnauthorized, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("GCP.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,8 +35,7 @@ type JWK struct {
|
||||||
EncryptedKey string `json:"encryptedKey,omitempty"`
|
EncryptedKey string `json:"encryptedKey,omitempty"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
Options *Options `json:"options,omitempty"`
|
Options *Options `json:"options,omitempty"`
|
||||||
claimer *Claimer
|
ctl *Controller
|
||||||
audiences Audiences
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier. The name and credential id
|
// GetID returns the provisioner unique identifier. The name and credential id
|
||||||
|
@ -98,13 +97,8 @@ func (p *JWK) Init(config Config) (err error) {
|
||||||
return errors.New("provisioner key cannot be empty")
|
return errors.New("provisioner key cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update claims with global ones
|
p.ctl, err = NewController(p, p.Claims, config)
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
return
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
p.audiences = config.Audiences
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// authorizeToken performs common jwt authorization actions and returns the
|
// authorizeToken performs common jwt authorization actions and returns the
|
||||||
|
@ -146,13 +140,13 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
|
||||||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||||
// revoke the certificate with serial number in the `sub` property.
|
// revoke the certificate with serial number in the `sub` property.
|
||||||
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
|
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||||
_, err := p.authorizeToken(token, p.audiences.Revoke)
|
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
|
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign validates the given token.
|
// AuthorizeSign validates the given token.
|
||||||
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
|
||||||
}
|
}
|
||||||
|
@ -176,15 +170,16 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
}
|
}
|
||||||
|
|
||||||
return []SignOption{
|
return []SignOption{
|
||||||
|
p,
|
||||||
templateOptions,
|
templateOptions,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
|
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
|
||||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||||
// validators
|
// validators
|
||||||
commonNameValidator(claims.Subject),
|
commonNameValidator(claims.Subject),
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
defaultSANsValidator(claims.SANs),
|
defaultSANsValidator(claims.SANs),
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,18 +188,15 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
// revocation status. Just confirms that the provisioner that created the
|
// revocation status. Just confirms that the provisioner that created the
|
||||||
// certificate was configured to allow renewals.
|
// certificate was configured to allow renewals.
|
||||||
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
if p.claimer.IsDisableRenewal() {
|
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||||
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
if !p.claimer.IsSSHCAEnabled() {
|
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||||
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName())
|
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName())
|
||||||
}
|
}
|
||||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
|
||||||
}
|
}
|
||||||
|
@ -261,11 +253,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
||||||
|
|
||||||
return append(signOptions,
|
return append(signOptions,
|
||||||
// Set the validity bounds if not set.
|
// Set the validity bounds if not set.
|
||||||
&sshDefaultDuration{p.claimer},
|
&sshDefaultDuration{p.ctl.Claimer},
|
||||||
// Validate public key
|
// Validate public key
|
||||||
&sshDefaultPublicKeyValidator{},
|
&sshDefaultPublicKeyValidator{},
|
||||||
// Validate the validity period.
|
// Validate the validity period.
|
||||||
&sshCertValidityValidator{p.claimer},
|
&sshCertValidityValidator{p.ctl.Claimer},
|
||||||
// Require and validate all the default fields in the SSH certificate.
|
// Require and validate all the default fields in the SSH certificate.
|
||||||
&sshCertDefaultValidator{},
|
&sshCertDefaultValidator{},
|
||||||
), nil
|
), nil
|
||||||
|
@ -273,6 +265,6 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
||||||
|
|
||||||
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
|
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
|
||||||
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||||
_, err := p.authorizeToken(token, p.audiences.SSHRevoke)
|
_, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
|
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,15 +6,17 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJWK_Getters(t *testing.T) {
|
func TestJWK_Getters(t *testing.T) {
|
||||||
|
@ -76,13 +78,13 @@ func TestJWK_Init(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
|
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
|
||||||
return ProvisionerValidateTest{
|
return ProvisionerValidateTest{
|
||||||
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
|
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
|
||||||
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
|
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||||
return ProvisionerValidateTest{
|
return ProvisionerValidateTest{
|
||||||
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences},
|
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -183,8 +185,8 @@ func TestJWK_authorizeToken(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil {
|
if got, err := tt.prov.authorizeToken(tt.args.token, testAudiences.Sign); err != nil {
|
||||||
if assert.NotNil(t, tt.err) {
|
if assert.NotNil(t, tt.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -223,8 +225,8 @@ func TestJWK_AuthorizeRevoke(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
|
if err := tt.prov.AuthorizeRevoke(context.Background(), tt.args.token); err != nil {
|
||||||
if assert.NotNil(t, tt.err) {
|
if assert.NotNil(t, tt.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -288,34 +290,35 @@ func TestJWK_AuthorizeSign(t *testing.T) {
|
||||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
|
if got, err := tt.prov.AuthorizeSign(ctx, tt.args.token); err != nil {
|
||||||
if assert.NotNil(t, tt.err) {
|
if assert.NotNil(t, tt.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if assert.NotNil(t, got) {
|
if assert.NotNil(t, got) {
|
||||||
assert.Len(t, 7, got)
|
assert.Len(t, 8, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
case *JWK:
|
||||||
case certificateOptionsFunc:
|
case certificateOptionsFunc:
|
||||||
case *provisionerExtensionOption:
|
case *provisionerExtensionOption:
|
||||||
assert.Equals(t, v.Type, int(TypeJWK))
|
assert.Equals(t, v.Type, TypeJWK)
|
||||||
assert.Equals(t, v.Name, tt.prov.GetName())
|
assert.Equals(t, v.Name, tt.prov.GetName())
|
||||||
assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID)
|
assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID)
|
||||||
assert.Len(t, 0, v.KeyValuePairs)
|
assert.Len(t, 0, v.KeyValuePairs)
|
||||||
case profileDefaultDuration:
|
case profileDefaultDuration:
|
||||||
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
|
assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
|
||||||
case commonNameValidator:
|
case commonNameValidator:
|
||||||
assert.Equals(t, string(v), "subject")
|
assert.Equals(t, string(v), "subject")
|
||||||
case defaultPublicKeyValidator:
|
case defaultPublicKeyValidator:
|
||||||
case *validityValidator:
|
case *validityValidator:
|
||||||
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
|
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
|
||||||
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
|
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
|
||||||
case defaultSANsValidator:
|
case defaultSANsValidator:
|
||||||
assert.Equals(t, []string(v), tt.sans)
|
assert.Equals(t, []string(v), tt.sans)
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -325,6 +328,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJWK_AuthorizeRenew(t *testing.T) {
|
func TestJWK_AuthorizeRenew(t *testing.T) {
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
p1, err := generateJWK()
|
p1, err := generateJWK()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p2, err := generateJWK()
|
p2, err := generateJWK()
|
||||||
|
@ -333,7 +337,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
|
||||||
// disable renewal
|
// disable renewal
|
||||||
disable := true
|
disable := true
|
||||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
|
@ -346,16 +350,22 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
|
||||||
code int
|
code int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
{"ok", p1, args{&x509.Certificate{
|
||||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusOK, false},
|
||||||
|
{"fail/renew-disabled", p2, args{&x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusUnauthorized, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
if err := tt.prov.AuthorizeRenew(context.Background(), tt.args.cert); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("JWK.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -373,7 +383,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
|
||||||
// disable sshCA
|
// disable sshCA
|
||||||
disable := false
|
disable := false
|
||||||
p2.Claims = &Claims{EnableSSHCA: &disable}
|
p2.Claims = &Claims{EnableSSHCA: &disable}
|
||||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||||
|
@ -402,8 +412,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
|
||||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
|
||||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||||
expectedUserOptions := &SignSSHOptions{
|
expectedUserOptions := &SignSSHOptions{
|
||||||
CertType: "user", Principals: []string{"name"},
|
CertType: "user", Principals: []string{"name"},
|
||||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||||
|
@ -448,8 +458,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
} else if assert.NotNil(t, got) {
|
} else if assert.NotNil(t, got) {
|
||||||
|
@ -485,8 +495,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
|
||||||
signer, err := generateJSONWebKey()
|
signer, err := generateJSONWebKey()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
|
||||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||||
expectedUserOptions := &SignSSHOptions{
|
expectedUserOptions := &SignSSHOptions{
|
||||||
CertType: "user", Principals: []string{"name"},
|
CertType: "user", Principals: []string{"name"},
|
||||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||||
|
@ -613,8 +623,8 @@ func TestJWK_AuthorizeSSHRevoke(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
|
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,16 +42,15 @@ type k8sSAPayload struct {
|
||||||
// entity trusted to make signature requests.
|
// entity trusted to make signature requests.
|
||||||
type K8sSA struct {
|
type K8sSA struct {
|
||||||
*base
|
*base
|
||||||
ID string `json:"-"`
|
ID string `json:"-"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
PubKeys []byte `json:"publicKeys,omitempty"`
|
PubKeys []byte `json:"publicKeys,omitempty"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
Options *Options `json:"options,omitempty"`
|
Options *Options `json:"options,omitempty"`
|
||||||
claimer *Claimer
|
|
||||||
audiences Audiences
|
|
||||||
//kauthn kauthn.AuthenticationV1Interface
|
//kauthn kauthn.AuthenticationV1Interface
|
||||||
pubKeys []interface{}
|
pubKeys []interface{}
|
||||||
|
ctl *Controller
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier. The name and credential id
|
// GetID returns the provisioner unique identifier. The name and credential id
|
||||||
|
@ -138,13 +137,8 @@ func (p *K8sSA) Init(config Config) (err error) {
|
||||||
p.kauthn = k8s.AuthenticationV1()
|
p.kauthn = k8s.AuthenticationV1()
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Update claims with global ones
|
p.ctl, err = NewController(p, p.Claims, config)
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
return
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
p.audiences = config.Audiences
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// authorizeToken performs common jwt authorization actions and returns the
|
// authorizeToken performs common jwt authorization actions and returns the
|
||||||
|
@ -211,13 +205,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
|
||||||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||||
// revoke the certificate with serial number in the `sub` property.
|
// revoke the certificate with serial number in the `sub` property.
|
||||||
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
|
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||||
_, err := p.authorizeToken(token, p.audiences.Revoke)
|
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
|
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign validates the given token.
|
// AuthorizeSign validates the given token.
|
||||||
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
|
||||||
}
|
}
|
||||||
|
@ -237,30 +231,28 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
||||||
}
|
}
|
||||||
|
|
||||||
return []SignOption{
|
return []SignOption{
|
||||||
|
p,
|
||||||
templateOptions,
|
templateOptions,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
|
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
|
||||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||||
// validators
|
// validators
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||||
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
if p.claimer.IsDisableRenewal() {
|
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||||
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSSHSign validates an request for an SSH certificate.
|
// AuthorizeSSHSign validates an request for an SSH certificate.
|
||||||
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
if !p.claimer.IsSSHCAEnabled() {
|
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||||
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName())
|
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName())
|
||||||
}
|
}
|
||||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
|
||||||
}
|
}
|
||||||
|
@ -282,11 +274,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
||||||
// Require type, key-id and principals in the SignSSHOptions.
|
// Require type, key-id and principals in the SignSSHOptions.
|
||||||
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
|
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
|
||||||
// Set the validity bounds if not set.
|
// Set the validity bounds if not set.
|
||||||
&sshDefaultDuration{p.claimer},
|
&sshDefaultDuration{p.ctl.Claimer},
|
||||||
// Validate public key
|
// Validate public key
|
||||||
&sshDefaultPublicKeyValidator{},
|
&sshDefaultPublicKeyValidator{},
|
||||||
// Validate the validity period.
|
// Validate the validity period.
|
||||||
&sshCertValidityValidator{p.claimer},
|
&sshCertValidityValidator{p.ctl.Claimer},
|
||||||
// Require and validate all the default fields in the SSH certificate.
|
// Require and validate all the default fields in the SSH certificate.
|
||||||
&sshCertDefaultValidator{},
|
&sshCertDefaultValidator{},
|
||||||
), nil
|
), nil
|
||||||
|
|
|
@ -3,14 +3,16 @@ package provisioner
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestK8sSA_Getters(t *testing.T) {
|
func TestK8sSA_Getters(t *testing.T) {
|
||||||
|
@ -116,8 +118,8 @@ func TestK8sSA_authorizeToken(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
|
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -165,8 +167,8 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
|
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
@ -179,6 +181,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
type test struct {
|
type test struct {
|
||||||
p *K8sSA
|
p *K8sSA
|
||||||
cert *x509.Certificate
|
cert *x509.Certificate
|
||||||
|
@ -192,21 +195,27 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
||||||
// disable renewal
|
// disable renewal
|
||||||
disable := true
|
disable := true
|
||||||
p.Claims = &Claims{DisableRenewal: &disable}
|
p.Claims = &Claims{DisableRenewal: &disable}
|
||||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
p: p,
|
p: p,
|
||||||
cert: &x509.Certificate{},
|
cert: &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
},
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()),
|
err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
p, err := generateK8sSA(nil)
|
p, err := generateK8sSA(nil)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
p: p,
|
p: p,
|
||||||
cert: &x509.Certificate{},
|
cert: &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -214,8 +223,8 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
|
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
@ -263,8 +272,8 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -274,24 +283,25 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
|
||||||
tot := 0
|
tot := 0
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
case *K8sSA:
|
||||||
case certificateOptionsFunc:
|
case certificateOptionsFunc:
|
||||||
case *provisionerExtensionOption:
|
case *provisionerExtensionOption:
|
||||||
assert.Equals(t, v.Type, int(TypeK8sSA))
|
assert.Equals(t, v.Type, TypeK8sSA)
|
||||||
assert.Equals(t, v.Name, tc.p.GetName())
|
assert.Equals(t, v.Name, tc.p.GetName())
|
||||||
assert.Equals(t, v.CredentialID, "")
|
assert.Equals(t, v.CredentialID, "")
|
||||||
assert.Len(t, 0, v.KeyValuePairs)
|
assert.Len(t, 0, v.KeyValuePairs)
|
||||||
case profileDefaultDuration:
|
case profileDefaultDuration:
|
||||||
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
|
assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
|
||||||
case defaultPublicKeyValidator:
|
case defaultPublicKeyValidator:
|
||||||
case *validityValidator:
|
case *validityValidator:
|
||||||
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
|
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
||||||
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
|
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
tot++
|
tot++
|
||||||
}
|
}
|
||||||
assert.Equals(t, tot, 5)
|
assert.Equals(t, tot, 6)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -313,13 +323,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
||||||
// disable sshCA
|
// disable sshCA
|
||||||
disable := false
|
disable := false
|
||||||
p.Claims = &Claims{EnableSSHCA: &disable}
|
p.Claims = &Claims{EnableSSHCA: &disable}
|
||||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
p: p,
|
p: p,
|
||||||
token: "foo",
|
token: "foo",
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
err: errors.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()),
|
err: fmt.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/invalid-token": func(t *testing.T) test {
|
"fail/invalid-token": func(t *testing.T) test {
|
||||||
|
@ -350,8 +360,8 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
|
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -365,13 +375,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
||||||
case *sshCertOptionsRequireValidator:
|
case *sshCertOptionsRequireValidator:
|
||||||
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
|
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
|
||||||
case *sshCertValidityValidator:
|
case *sshCertValidityValidator:
|
||||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||||
case *sshDefaultPublicKeyValidator:
|
case *sshDefaultPublicKeyValidator:
|
||||||
case *sshCertDefaultValidator:
|
case *sshCertDefaultValidator:
|
||||||
case *sshDefaultDuration:
|
case *sshDefaultDuration:
|
||||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
tot++
|
tot++
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,19 +34,18 @@ const (
|
||||||
// https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by
|
// https://signal.org/docs/specifications/xeddsa/#xeddsa and implemented by
|
||||||
// go.step.sm/crypto/x25519.
|
// go.step.sm/crypto/x25519.
|
||||||
type Nebula struct {
|
type Nebula struct {
|
||||||
ID string `json:"-"`
|
ID string `json:"-"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Roots []byte `json:"roots"`
|
Roots []byte `json:"roots"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
Options *Options `json:"options,omitempty"`
|
Options *Options `json:"options,omitempty"`
|
||||||
claimer *Claimer
|
caPool *nebula.NebulaCAPool
|
||||||
caPool *nebula.NebulaCAPool
|
ctl *Controller
|
||||||
audiences Audiences
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init verifies and initializes the Nebula provisioner.
|
// Init verifies and initializes the Nebula provisioner.
|
||||||
func (p *Nebula) Init(config Config) error {
|
func (p *Nebula) Init(config Config) (err error) {
|
||||||
switch {
|
switch {
|
||||||
case p.Type == "":
|
case p.Type == "":
|
||||||
return errors.New("provisioner type cannot be empty")
|
return errors.New("provisioner type cannot be empty")
|
||||||
|
@ -56,19 +55,14 @@ func (p *Nebula) Init(config Config) error {
|
||||||
return errors.New("provisioner root(s) cannot be empty")
|
return errors.New("provisioner root(s) cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots)
|
p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.InternalServer("failed to create ca pool: %v", err)
|
return errs.InternalServer("failed to create ca pool: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||||
|
p.ctl, err = NewController(p, p.Claims, config)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner id.
|
// GetID returns the provisioner id.
|
||||||
|
@ -120,7 +114,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
|
|
||||||
// AuthorizeSign returns the list of SignOption for a Sign request.
|
// AuthorizeSign returns the list of SignOption for a Sign request.
|
||||||
func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
crt, claims, err := p.authorizeToken(token, p.audiences.Sign)
|
crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -139,8 +133,9 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
||||||
data.SetToken(v)
|
data.SetToken(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The Nebula certificate will be available using the template variable Crt.
|
// The Nebula certificate will be available using the template variable
|
||||||
// For example {{ .Crt.Details.Groups }} can be used to get all the groups.
|
// AuthorizationCrt. For example {{ .AuthorizationCrt.Details.Groups }} can
|
||||||
|
// be used to get all the groups.
|
||||||
data.SetAuthorizationCertificate(crt)
|
data.SetAuthorizationCertificate(crt)
|
||||||
|
|
||||||
templateOptions, err := TemplateOptions(p.Options, data)
|
templateOptions, err := TemplateOptions(p.Options, data)
|
||||||
|
@ -149,11 +144,12 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
||||||
}
|
}
|
||||||
|
|
||||||
return []SignOption{
|
return []SignOption{
|
||||||
|
p,
|
||||||
templateOptions,
|
templateOptions,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeNebula, p.Name, ""),
|
newProvisionerExtensionOption(TypeNebula, p.Name, ""),
|
||||||
profileLimitDuration{
|
profileLimitDuration{
|
||||||
def: p.claimer.DefaultTLSCertDuration(),
|
def: p.ctl.Claimer.DefaultTLSCertDuration(),
|
||||||
notBefore: crt.Details.NotBefore,
|
notBefore: crt.Details.NotBefore,
|
||||||
notAfter: crt.Details.NotAfter,
|
notAfter: crt.Details.NotAfter,
|
||||||
},
|
},
|
||||||
|
@ -164,18 +160,18 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
||||||
IPs: crt.Details.Ips,
|
IPs: crt.Details.Ips,
|
||||||
},
|
},
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
// Currently the Nebula provisioner only grants host SSH certificates.
|
// Currently the Nebula provisioner only grants host SSH certificates.
|
||||||
func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
if !p.claimer.IsSSHCAEnabled() {
|
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||||
return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
|
return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -253,11 +249,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
|
||||||
return append(signOptions,
|
return append(signOptions,
|
||||||
templateOptions,
|
templateOptions,
|
||||||
// Checks the validity bounds, and set the validity if has not been set.
|
// Checks the validity bounds, and set the validity if has not been set.
|
||||||
&sshLimitDuration{p.claimer, crt.Details.NotAfter},
|
&sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter},
|
||||||
// Validate public key.
|
// Validate public key.
|
||||||
&sshDefaultPublicKeyValidator{},
|
&sshDefaultPublicKeyValidator{},
|
||||||
// Validate the validity period.
|
// Validate the validity period.
|
||||||
&sshCertValidityValidator{p.claimer},
|
&sshCertValidityValidator{p.ctl.Claimer},
|
||||||
// Require all the fields in the SSH certificate
|
// Require all the fields in the SSH certificate
|
||||||
&sshCertDefaultValidator{},
|
&sshCertDefaultValidator{},
|
||||||
), nil
|
), nil
|
||||||
|
@ -265,23 +261,20 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
|
||||||
|
|
||||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||||
func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error {
|
func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error {
|
||||||
if p.claimer.IsDisableRenewal() {
|
return p.ctl.AuthorizeRenew(ctx, crt)
|
||||||
return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeRevoke returns an error if the token is not valid.
|
// AuthorizeRevoke returns an error if the token is not valid.
|
||||||
func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error {
|
func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||||
return p.validateToken(token, p.audiences.Revoke)
|
return p.validateToken(token, p.ctl.Audiences.Revoke)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
|
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
|
||||||
func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||||
if !p.claimer.IsSSHCAEnabled() {
|
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||||
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
|
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
|
||||||
}
|
}
|
||||||
if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil {
|
if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) {
|
||||||
func TestNebula_GetTokenID(t *testing.T) {
|
func TestNebula_GetTokenID(t *testing.T) {
|
||||||
p, ca, signer := mustNebulaProvisioner(t)
|
p, ca, signer := mustNebulaProvisioner(t)
|
||||||
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
|
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
|
||||||
t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv)
|
t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv)
|
||||||
_, claims, err := parseToken(t1)
|
_, claims, err := parseToken(t1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
p, ca, signer := mustNebulaProvisioner(t)
|
p, ca, signer := mustNebulaProvisioner(t)
|
||||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv)
|
ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv)
|
||||||
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv)
|
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv)
|
||||||
|
|
||||||
pBadOptions, _, _ := mustNebulaProvisioner(t)
|
pBadOptions, _, _ := mustNebulaProvisioner(t)
|
||||||
pBadOptions.caPool = p.caPool
|
pBadOptions.caPool = p.caPool
|
||||||
|
@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
|
||||||
// Ok provisioner
|
// Ok provisioner
|
||||||
p, ca, signer := mustNebulaProvisioner(t)
|
p, ca, signer := mustNebulaProvisioner(t)
|
||||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||||
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
|
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||||
CertType: "host",
|
CertType: "host",
|
||||||
KeyID: "test.lan",
|
KeyID: "test.lan",
|
||||||
Principals: []string{"test.lan", "10.1.0.1"},
|
Principals: []string{"test.lan", "10.1.0.1"},
|
||||||
}, crt, priv)
|
}, crt, priv)
|
||||||
okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv)
|
okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv)
|
||||||
okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
|
okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||||
ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)),
|
ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)),
|
||||||
ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)),
|
ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)),
|
||||||
}, crt, priv)
|
}, crt, priv)
|
||||||
failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
|
failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||||
CertType: "user",
|
CertType: "user",
|
||||||
}, crt, priv)
|
}, crt, priv)
|
||||||
failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
|
failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||||
CertType: "host",
|
CertType: "host",
|
||||||
KeyID: "test.lan",
|
KeyID: "test.lan",
|
||||||
Principals: []string{"test.lan", "10.1.0.1", "foo.bar"},
|
Principals: []string{"test.lan", "10.1.0.1", "foo.bar"},
|
||||||
|
@ -549,6 +549,8 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
|
||||||
|
|
||||||
func TestNebula_AuthorizeRenew(t *testing.T) {
|
func TestNebula_AuthorizeRenew(t *testing.T) {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
|
|
||||||
// Ok provisioner
|
// Ok provisioner
|
||||||
p, _, _ := mustNebulaProvisioner(t)
|
p, _, _ := mustNebulaProvisioner(t)
|
||||||
|
|
||||||
|
@ -567,8 +569,14 @@ func TestNebula_AuthorizeRenew(t *testing.T) {
|
||||||
args args
|
args args
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p, args{ctx, &x509.Certificate{}}, false},
|
{"ok", p, args{ctx, &x509.Certificate{
|
||||||
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true},
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, false},
|
||||||
|
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -584,12 +592,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) {
|
||||||
// Ok provisioner
|
// Ok provisioner
|
||||||
p, ca, signer := mustNebulaProvisioner(t)
|
p, ca, signer := mustNebulaProvisioner(t)
|
||||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv)
|
ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
|
||||||
|
|
||||||
// Fail different CA
|
// Fail different CA
|
||||||
nc, signer := mustNebulaCA(t)
|
nc, signer := mustNebulaCA(t)
|
||||||
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
|
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
|
||||||
failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv)
|
failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
@ -618,12 +626,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
|
||||||
// Ok provisioner
|
// Ok provisioner
|
||||||
p, ca, signer := mustNebulaProvisioner(t)
|
p, ca, signer := mustNebulaProvisioner(t)
|
||||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||||
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv)
|
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
|
||||||
|
|
||||||
// Fail different CA
|
// Fail different CA
|
||||||
nc, signer := mustNebulaCA(t)
|
nc, signer := mustNebulaCA(t)
|
||||||
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
|
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
|
||||||
failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv)
|
failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
|
||||||
|
|
||||||
// Provisioner with SSH disabled
|
// Provisioner with SSH disabled
|
||||||
var bFalse bool
|
var bFalse bool
|
||||||
|
@ -657,7 +665,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
|
||||||
func TestNebula_AuthorizeSSHRenew(t *testing.T) {
|
func TestNebula_AuthorizeSSHRenew(t *testing.T) {
|
||||||
p, ca, signer := mustNebulaProvisioner(t)
|
p, ca, signer := mustNebulaProvisioner(t)
|
||||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||||
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv)
|
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
@ -689,7 +697,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) {
|
||||||
func TestNebula_AuthorizeSSHRekey(t *testing.T) {
|
func TestNebula_AuthorizeSSHRekey(t *testing.T) {
|
||||||
p, ca, signer := mustNebulaProvisioner(t)
|
p, ca, signer := mustNebulaProvisioner(t)
|
||||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||||
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv)
|
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
@ -726,20 +734,20 @@ func TestNebula_authorizeToken(t *testing.T) {
|
||||||
t1 := now()
|
t1 := now()
|
||||||
p, ca, signer := mustNebulaProvisioner(t)
|
p, ca, signer := mustNebulaProvisioner(t)
|
||||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||||
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv)
|
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv)
|
||||||
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{
|
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{
|
||||||
CertType: "host",
|
CertType: "host",
|
||||||
KeyID: "test.lan",
|
KeyID: "test.lan",
|
||||||
Principals: []string{"test.lan"},
|
Principals: []string{"test.lan"},
|
||||||
}, crt, priv)
|
}, crt, priv)
|
||||||
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv)
|
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv)
|
||||||
|
|
||||||
// Token with errors
|
// Token with errors
|
||||||
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
|
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
|
||||||
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||||
failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv)
|
failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv)
|
||||||
failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||||
|
|
||||||
// Not a nebula token
|
// Not a nebula token
|
||||||
jwk, err := generateJSONWebKey()
|
jwk, err := generateJSONWebKey()
|
||||||
|
@ -761,7 +769,7 @@ func TestNebula_authorizeToken(t *testing.T) {
|
||||||
IssuedAt: jose.NewNumericDate(t1),
|
IssuedAt: jose.NewNumericDate(t1),
|
||||||
NotBefore: jose.NewNumericDate(t1),
|
NotBefore: jose.NewNumericDate(t1),
|
||||||
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
|
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
|
||||||
Audience: []string{p.audiences.Sign[0]},
|
Audience: []string{p.ctl.Audiences.Sign[0]},
|
||||||
}
|
}
|
||||||
sshClaims := jose.Claims{
|
sshClaims := jose.Claims{
|
||||||
ID: "[REPLACEME]",
|
ID: "[REPLACEME]",
|
||||||
|
@ -770,7 +778,7 @@ func TestNebula_authorizeToken(t *testing.T) {
|
||||||
IssuedAt: jose.NewNumericDate(t1),
|
IssuedAt: jose.NewNumericDate(t1),
|
||||||
NotBefore: jose.NewNumericDate(t1),
|
NotBefore: jose.NewNumericDate(t1),
|
||||||
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
|
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
|
||||||
Audience: []string{p.audiences.SSHSign[0]},
|
Audience: []string{p.ctl.Audiences.SSHSign[0]},
|
||||||
}
|
}
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
|
@ -785,14 +793,14 @@ func TestNebula_authorizeToken(t *testing.T) {
|
||||||
want1 *jwtPayload
|
want1 *jwtPayload
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{
|
{"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{
|
||||||
Claims: x509Claims,
|
Claims: x509Claims,
|
||||||
SANs: []string{"10.1.0.1"},
|
SANs: []string{"10.1.0.1"},
|
||||||
}, false},
|
}, false},
|
||||||
{"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{
|
{"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{
|
||||||
Claims: x509Claims,
|
Claims: x509Claims,
|
||||||
}, false},
|
}, false},
|
||||||
{"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{
|
{"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
|
||||||
Claims: sshClaims,
|
Claims: sshClaims,
|
||||||
Step: &stepPayload{
|
Step: &stepPayload{
|
||||||
SSH: &SignSSHOptions{
|
SSH: &SignSSHOptions{
|
||||||
|
@ -802,16 +810,16 @@ func TestNebula_authorizeToken(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, false},
|
}, false},
|
||||||
{"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{
|
{"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
|
||||||
Claims: sshClaims,
|
Claims: sshClaims,
|
||||||
}, false},
|
}, false},
|
||||||
{"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true},
|
{"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true},
|
||||||
{"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true},
|
{"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||||
{"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true},
|
{"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||||
{"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true},
|
{"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||||
{"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true},
|
{"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||||
{"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true},
|
{"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||||
{"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true},
|
{"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
|
@ -38,7 +38,7 @@ func (p *noop) Init(config Config) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
return []SignOption{}, nil
|
return []SignOption{p}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
|
|
|
@ -24,6 +24,6 @@ func Test_noop(t *testing.T) {
|
||||||
|
|
||||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
sigOptions, err := p.AuthorizeSign(ctx, "foo")
|
sigOptions, err := p.AuthorizeSign(ctx, "foo")
|
||||||
assert.Equals(t, []SignOption{}, sigOptions)
|
assert.Equals(t, []SignOption{&p}, sigOptions)
|
||||||
assert.Equals(t, nil, err)
|
assert.Equals(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,8 +92,7 @@ type OIDC struct {
|
||||||
Options *Options `json:"options,omitempty"`
|
Options *Options `json:"options,omitempty"`
|
||||||
configuration openIDConfiguration
|
configuration openIDConfiguration
|
||||||
keyStore *keyStore
|
keyStore *keyStore
|
||||||
claimer *Claimer
|
ctl *Controller
|
||||||
getIdentityFunc GetIdentityFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeEmail(email string) string {
|
func sanitizeEmail(email string) string {
|
||||||
|
@ -172,11 +171,6 @@ func (o *OIDC) Init(config Config) (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update claims with global ones
|
|
||||||
if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode and validate openid-configuration endpoint
|
// Decode and validate openid-configuration endpoint
|
||||||
u, err := url.Parse(o.ConfigurationEndpoint)
|
u, err := url.Parse(o.ConfigurationEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -201,13 +195,8 @@ func (o *OIDC) Init(config Config) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the identity getter if it exists, otherwise use the default.
|
o.ctl, err = NewController(o, o.Claims, config)
|
||||||
if config.GetIdentityFunc == nil {
|
return
|
||||||
o.getIdentityFunc = DefaultIdentityFunc
|
|
||||||
} else {
|
|
||||||
o.getIdentityFunc = config.GetIdentityFunc
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidatePayload validates the given token payload.
|
// ValidatePayload validates the given token payload.
|
||||||
|
@ -356,13 +345,14 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
||||||
}
|
}
|
||||||
|
|
||||||
return []SignOption{
|
return []SignOption{
|
||||||
|
o,
|
||||||
templateOptions,
|
templateOptions,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||||
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()),
|
||||||
// validators
|
// validators
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
|
newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -371,15 +361,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
||||||
// revocation status. Just confirms that the provisioner that created the
|
// revocation status. Just confirms that the provisioner that created the
|
||||||
// certificate was configured to allow renewals.
|
// certificate was configured to allow renewals.
|
||||||
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
if o.claimer.IsDisableRenewal() {
|
return o.ctl.AuthorizeRenew(ctx, cert)
|
||||||
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
if !o.claimer.IsSSHCAEnabled() {
|
if !o.ctl.Claimer.IsSSHCAEnabled() {
|
||||||
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName())
|
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName())
|
||||||
}
|
}
|
||||||
claims, err := o.authorizeToken(token)
|
claims, err := o.authorizeToken(token)
|
||||||
|
@ -394,7 +381,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
||||||
// Get the identity using either the default identityFunc or one injected
|
// Get the identity using either the default identityFunc or one injected
|
||||||
// externally. Note that the PreferredUsername might be empty.
|
// externally. Note that the PreferredUsername might be empty.
|
||||||
// TBD: Would preferred_username present a safety issue here?
|
// TBD: Would preferred_username present a safety issue here?
|
||||||
iden, err := o.getIdentityFunc(ctx, o, claims.Email)
|
iden, err := o.ctl.GetIdentity(ctx, claims.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
|
||||||
}
|
}
|
||||||
|
@ -445,11 +432,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
||||||
|
|
||||||
return append(signOptions,
|
return append(signOptions,
|
||||||
// Set the validity bounds if not set.
|
// Set the validity bounds if not set.
|
||||||
&sshDefaultDuration{o.claimer},
|
&sshDefaultDuration{o.ctl.Claimer},
|
||||||
// Validate public key
|
// Validate public key
|
||||||
&sshDefaultPublicKeyValidator{},
|
&sshDefaultPublicKeyValidator{},
|
||||||
// Validate the validity period.
|
// Validate the validity period.
|
||||||
&sshCertValidityValidator{o.claimer},
|
&sshCertValidityValidator{o.ctl.Claimer},
|
||||||
// Require all the fields in the SSH certificate
|
// Require all the fields in the SSH certificate
|
||||||
&sshCertDefaultValidator{},
|
&sshCertDefaultValidator{},
|
||||||
), nil
|
), nil
|
||||||
|
|
|
@ -6,16 +6,17 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_openIDConfiguration_Validate(t *testing.T) {
|
func Test_openIDConfiguration_Validate(t *testing.T) {
|
||||||
|
@ -246,8 +247,8 @@ func TestOIDC_authorizeToken(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
} else {
|
} else {
|
||||||
|
@ -317,30 +318,31 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
} else if assert.NotNil(t, got) {
|
} else if assert.NotNil(t, got) {
|
||||||
assert.Len(t, 5, got)
|
assert.Len(t, 6, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
case *OIDC:
|
||||||
case certificateOptionsFunc:
|
case certificateOptionsFunc:
|
||||||
case *provisionerExtensionOption:
|
case *provisionerExtensionOption:
|
||||||
assert.Equals(t, v.Type, int(TypeOIDC))
|
assert.Equals(t, v.Type, TypeOIDC)
|
||||||
assert.Equals(t, v.Name, tt.prov.GetName())
|
assert.Equals(t, v.Name, tt.prov.GetName())
|
||||||
assert.Equals(t, v.CredentialID, tt.prov.ClientID)
|
assert.Equals(t, v.CredentialID, tt.prov.ClientID)
|
||||||
assert.Len(t, 0, v.KeyValuePairs)
|
assert.Len(t, 0, v.KeyValuePairs)
|
||||||
case profileDefaultDuration:
|
case profileDefaultDuration:
|
||||||
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
|
assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
|
||||||
case defaultPublicKeyValidator:
|
case defaultPublicKeyValidator:
|
||||||
case *validityValidator:
|
case *validityValidator:
|
||||||
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
|
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
|
||||||
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
|
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
|
||||||
case emailOnlyIdentity:
|
case emailOnlyIdentity:
|
||||||
assert.Equals(t, string(v), "name@smallstep.com")
|
assert.Equals(t, string(v), "name@smallstep.com")
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -402,8 +404,8 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
||||||
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -411,6 +413,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOIDC_AuthorizeRenew(t *testing.T) {
|
func TestOIDC_AuthorizeRenew(t *testing.T) {
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
p1, err := generateOIDC()
|
p1, err := generateOIDC()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p2, err := generateOIDC()
|
p2, err := generateOIDC()
|
||||||
|
@ -419,7 +422,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
|
||||||
// disable renewal
|
// disable renewal
|
||||||
disable := true
|
disable := true
|
||||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
|
@ -432,8 +435,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
|
||||||
code int
|
code int
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
{"ok", p1, args{&x509.Certificate{
|
||||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusOK, false},
|
||||||
|
{"fail/renew-disabled", p2, args{&x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}}, http.StatusUnauthorized, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -441,8 +450,8 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("OIDC.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -478,7 +487,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
||||||
// disable sshCA
|
// disable sshCA
|
||||||
disable := false
|
disable := false
|
||||||
p6.Claims = &Claims{EnableSSHCA: &disable}
|
p6.Claims = &Claims{EnableSSHCA: &disable}
|
||||||
p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
|
p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
// Update configuration endpoints and initialize
|
// Update configuration endpoints and initialize
|
||||||
|
@ -494,10 +503,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
||||||
assert.FatalError(t, p4.Init(config))
|
assert.FatalError(t, p4.Init(config))
|
||||||
assert.FatalError(t, p5.Init(config))
|
assert.FatalError(t, p5.Init(config))
|
||||||
|
|
||||||
p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||||
return &Identity{Usernames: []string{"max", "mariano"}}, nil
|
return &Identity{Usernames: []string{"max", "mariano"}}, nil
|
||||||
}
|
}
|
||||||
p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
}
|
}
|
||||||
// Additional test needed for empty usernames and duplicate email and usernames
|
// Additional test needed for empty usernames and duplicate email and usernames
|
||||||
|
@ -527,8 +536,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
||||||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
|
||||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||||
expectedUserOptions := &SignSSHOptions{
|
expectedUserOptions := &SignSSHOptions{
|
||||||
CertType: "user", Principals: []string{"name", "name@smallstep.com"},
|
CertType: "user", Principals: []string{"name", "name@smallstep.com"},
|
||||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||||
|
@ -597,8 +606,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
} else if assert.NotNil(t, got) {
|
} else if assert.NotNil(t, got) {
|
||||||
|
@ -665,8 +674,8 @@ func TestOIDC_AuthorizeSSHRevoke(t *testing.T) {
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("OIDC.AuthorizeSSHRevoke() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
stderrors "errors"
|
stderrors "errors"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -47,6 +46,7 @@ var ErrAllowTokenReuse = stderrors.New("allow token reuse")
|
||||||
// Audiences stores all supported audiences by request type.
|
// Audiences stores all supported audiences by request type.
|
||||||
type Audiences struct {
|
type Audiences struct {
|
||||||
Sign []string
|
Sign []string
|
||||||
|
Renew []string
|
||||||
Revoke []string
|
Revoke []string
|
||||||
SSHSign []string
|
SSHSign []string
|
||||||
SSHRevoke []string
|
SSHRevoke []string
|
||||||
|
@ -57,6 +57,7 @@ type Audiences struct {
|
||||||
// All returns all supported audiences across all request types in one list.
|
// All returns all supported audiences across all request types in one list.
|
||||||
func (a Audiences) All() (auds []string) {
|
func (a Audiences) All() (auds []string) {
|
||||||
auds = a.Sign
|
auds = a.Sign
|
||||||
|
auds = append(auds, a.Renew...)
|
||||||
auds = append(auds, a.Revoke...)
|
auds = append(auds, a.Revoke...)
|
||||||
auds = append(auds, a.SSHSign...)
|
auds = append(auds, a.SSHSign...)
|
||||||
auds = append(auds, a.SSHRevoke...)
|
auds = append(auds, a.SSHRevoke...)
|
||||||
|
@ -70,6 +71,7 @@ func (a Audiences) All() (auds []string) {
|
||||||
func (a Audiences) WithFragment(fragment string) Audiences {
|
func (a Audiences) WithFragment(fragment string) Audiences {
|
||||||
ret := Audiences{
|
ret := Audiences{
|
||||||
Sign: make([]string, len(a.Sign)),
|
Sign: make([]string, len(a.Sign)),
|
||||||
|
Renew: make([]string, len(a.Renew)),
|
||||||
Revoke: make([]string, len(a.Revoke)),
|
Revoke: make([]string, len(a.Revoke)),
|
||||||
SSHSign: make([]string, len(a.SSHSign)),
|
SSHSign: make([]string, len(a.SSHSign)),
|
||||||
SSHRevoke: make([]string, len(a.SSHRevoke)),
|
SSHRevoke: make([]string, len(a.SSHRevoke)),
|
||||||
|
@ -83,6 +85,13 @@ func (a Audiences) WithFragment(fragment string) Audiences {
|
||||||
ret.Sign[i] = s
|
ret.Sign[i] = s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for i, s := range a.Renew {
|
||||||
|
if u, err := url.Parse(s); err == nil {
|
||||||
|
ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
|
||||||
|
} else {
|
||||||
|
ret.Renew[i] = s
|
||||||
|
}
|
||||||
|
}
|
||||||
for i, s := range a.Revoke {
|
for i, s := range a.Revoke {
|
||||||
if u, err := url.Parse(s); err == nil {
|
if u, err := url.Parse(s); err == nil {
|
||||||
ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
|
ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
|
||||||
|
@ -210,6 +219,12 @@ type Config struct {
|
||||||
// GetIdentityFunc is a function that returns an identity that will be
|
// GetIdentityFunc is a function that returns an identity that will be
|
||||||
// used by the provisioner to populate certificate attributes.
|
// used by the provisioner to populate certificate attributes.
|
||||||
GetIdentityFunc GetIdentityFunc
|
GetIdentityFunc GetIdentityFunc
|
||||||
|
// AuthorizeRenewFunc is a function that returns nil if a given X.509
|
||||||
|
// certificate can be renewed.
|
||||||
|
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||||
|
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
|
||||||
|
// certificate can be renewed.
|
||||||
|
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
type provisioner struct {
|
type provisioner struct {
|
||||||
|
@ -278,32 +293,6 @@ func (l *List) UnmarshalJSON(data []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
|
|
||||||
|
|
||||||
// SanitizeSSHUserPrincipal grabs an email or a string with the format
|
|
||||||
// local@domain and returns a sanitized version of the local, valid to be used
|
|
||||||
// as a user name. If the email starts with a letter between a and z, the
|
|
||||||
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
|
|
||||||
func SanitizeSSHUserPrincipal(email string) string {
|
|
||||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
|
||||||
email = email[:i]
|
|
||||||
}
|
|
||||||
return strings.Map(func(r rune) rune {
|
|
||||||
switch {
|
|
||||||
case r >= 'a' && r <= 'z':
|
|
||||||
return r
|
|
||||||
case r >= '0' && r <= '9':
|
|
||||||
return r
|
|
||||||
case r == '-':
|
|
||||||
return '-'
|
|
||||||
case r == '.': // drop dots
|
|
||||||
return -1
|
|
||||||
default:
|
|
||||||
return '_'
|
|
||||||
}
|
|
||||||
}, strings.ToLower(email))
|
|
||||||
}
|
|
||||||
|
|
||||||
type base struct{}
|
type base struct{}
|
||||||
|
|
||||||
// AuthorizeSign returns an unimplemented error. Provisioners should overwrite
|
// AuthorizeSign returns an unimplemented error. Provisioners should overwrite
|
||||||
|
@ -348,66 +337,12 @@ func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certif
|
||||||
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
|
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Identity is the type representing an externally supplied identity that is used
|
|
||||||
// by provisioners to populate certificate fields.
|
|
||||||
type Identity struct {
|
|
||||||
Usernames []string `json:"usernames"`
|
|
||||||
Permissions `json:"permissions"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Permissions defines extra extensions and critical options to grant to an SSH certificate.
|
// Permissions defines extra extensions and critical options to grant to an SSH certificate.
|
||||||
type Permissions struct {
|
type Permissions struct {
|
||||||
Extensions map[string]string `json:"extensions"`
|
Extensions map[string]string `json:"extensions"`
|
||||||
CriticalOptions map[string]string `json:"criticalOptions"`
|
CriticalOptions map[string]string `json:"criticalOptions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIdentityFunc is a function that returns an identity.
|
|
||||||
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
|
|
||||||
|
|
||||||
// DefaultIdentityFunc return a default identity depending on the provisioner
|
|
||||||
// type. For OIDC email is always present and the usernames might
|
|
||||||
// contain empty strings.
|
|
||||||
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
|
|
||||||
switch k := p.(type) {
|
|
||||||
case *OIDC:
|
|
||||||
// OIDC principals would be:
|
|
||||||
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
|
|
||||||
// 2. Sanitized local.
|
|
||||||
// 3. Raw local (if different).
|
|
||||||
// 4. Email address.
|
|
||||||
name := SanitizeSSHUserPrincipal(email)
|
|
||||||
if !sshUserRegex.MatchString(name) {
|
|
||||||
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
|
|
||||||
}
|
|
||||||
usernames := []string{name}
|
|
||||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
|
||||||
usernames = append(usernames, email[:i])
|
|
||||||
}
|
|
||||||
usernames = append(usernames, email)
|
|
||||||
return &Identity{
|
|
||||||
Usernames: SanitizeStringSlices(usernames),
|
|
||||||
}, nil
|
|
||||||
default:
|
|
||||||
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SanitizeStringSlices removes duplicated an empty strings.
|
|
||||||
func SanitizeStringSlices(original []string) []string {
|
|
||||||
output := []string{}
|
|
||||||
seen := make(map[string]struct{})
|
|
||||||
for _, entry := range original {
|
|
||||||
if entry == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, value := seen[entry]; !value {
|
|
||||||
seen[entry] = struct{}{}
|
|
||||||
output = append(output, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return output
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockProvisioner for testing
|
// MockProvisioner for testing
|
||||||
type MockProvisioner struct {
|
type MockProvisioner struct {
|
||||||
Mret1, Mret2, Mret3 interface{}
|
Mret1, Mret2, Mret3 interface{}
|
||||||
|
|
|
@ -2,13 +2,14 @@ package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestType_String(t *testing.T) {
|
func TestType_String(t *testing.T) {
|
||||||
|
@ -240,8 +241,8 @@ func TestUnimplementedMethods(t *testing.T) {
|
||||||
default:
|
default:
|
||||||
t.Errorf("unexpected method %s", tt.method)
|
t.Errorf("unexpected method %s", tt.method)
|
||||||
}
|
}
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
|
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
|
||||||
assert.Equals(t, err.Error(), msg)
|
assert.Equals(t, err.Error(), msg)
|
||||||
})
|
})
|
||||||
|
|
|
@ -11,28 +11,30 @@ import (
|
||||||
// SCEP provisioning flow
|
// SCEP provisioning flow
|
||||||
type SCEP struct {
|
type SCEP struct {
|
||||||
*base
|
*base
|
||||||
ID string `json:"-"`
|
ID string `json:"-"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
ForceCN bool `json:"forceCN,omitempty"`
|
ForceCN bool `json:"forceCN,omitempty"`
|
||||||
ChallengePassword string `json:"challenge,omitempty"`
|
ChallengePassword string `json:"challenge,omitempty"`
|
||||||
Capabilities []string `json:"capabilities,omitempty"`
|
Capabilities []string `json:"capabilities,omitempty"`
|
||||||
|
|
||||||
// IncludeRoot makes the provisioner return the CA root in addition to the
|
// IncludeRoot makes the provisioner return the CA root in addition to the
|
||||||
// intermediate in the GetCACerts response
|
// intermediate in the GetCACerts response
|
||||||
IncludeRoot bool `json:"includeRoot,omitempty"`
|
IncludeRoot bool `json:"includeRoot,omitempty"`
|
||||||
|
|
||||||
// MinimumPublicKeyLength is the minimum length for public keys in CSRs
|
// MinimumPublicKeyLength is the minimum length for public keys in CSRs
|
||||||
MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"`
|
MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"`
|
||||||
|
|
||||||
// Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7
|
// Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7
|
||||||
// at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63
|
// at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63
|
||||||
// Defaults to 0, being DES-CBC
|
// Defaults to 0, being DES-CBC
|
||||||
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
|
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
|
||||||
Options *Options `json:"options,omitempty"`
|
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
|
||||||
claimer *Claimer
|
|
||||||
|
|
||||||
|
Options *Options `json:"options,omitempty"`
|
||||||
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
secretChallengePassword string
|
secretChallengePassword string
|
||||||
encryptionAlgorithm int
|
encryptionAlgorithm int
|
||||||
|
ctl *Controller
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier.
|
// GetID returns the provisioner unique identifier.
|
||||||
|
@ -77,7 +79,7 @@ func (s *SCEP) GetOptions() *Options {
|
||||||
// DefaultTLSCertDuration returns the default TLS cert duration enforced by
|
// DefaultTLSCertDuration returns the default TLS cert duration enforced by
|
||||||
// the provisioner.
|
// the provisioner.
|
||||||
func (s *SCEP) DefaultTLSCertDuration() time.Duration {
|
func (s *SCEP) DefaultTLSCertDuration() time.Duration {
|
||||||
return s.claimer.DefaultTLSCertDuration()
|
return s.ctl.Claimer.DefaultTLSCertDuration()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes and validates the fields of a SCEP type.
|
// Init initializes and validates the fields of a SCEP type.
|
||||||
|
@ -90,11 +92,6 @@ func (s *SCEP) Init(config Config) (err error) {
|
||||||
return errors.New("provisioner name cannot be empty")
|
return errors.New("provisioner name cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update claims with global ones
|
|
||||||
if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mask the actual challenge value, so it won't be marshaled
|
// Mask the actual challenge value, so it won't be marshaled
|
||||||
s.secretChallengePassword = s.ChallengePassword
|
s.secretChallengePassword = s.ChallengePassword
|
||||||
s.ChallengePassword = "*** redacted ***"
|
s.ChallengePassword = "*** redacted ***"
|
||||||
|
@ -115,7 +112,8 @@ func (s *SCEP) Init(config Config) (err error) {
|
||||||
|
|
||||||
// TODO: add other, SCEP specific, options?
|
// TODO: add other, SCEP specific, options?
|
||||||
|
|
||||||
return err
|
s.ctl, err = NewController(s, s.Claims, config)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign does not do any verification, because all verification is handled
|
// AuthorizeSign does not do any verification, because all verification is handled
|
||||||
|
@ -123,13 +121,14 @@ func (s *SCEP) Init(config Config) (err error) {
|
||||||
// on the resulting certificate.
|
// on the resulting certificate.
|
||||||
func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
return []SignOption{
|
return []SignOption{
|
||||||
|
s,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeSCEP, s.Name, ""),
|
newProvisionerExtensionOption(TypeSCEP, s.Name, ""),
|
||||||
newForceCNOption(s.ForceCN),
|
newForceCNOption(s.ForceCN),
|
||||||
profileDefaultDuration(s.claimer.DefaultTLSCertDuration()),
|
profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()),
|
||||||
// validators
|
// validators
|
||||||
newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength),
|
newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength),
|
||||||
newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()),
|
newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/asn1"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -14,7 +13,6 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"go.step.sm/crypto/keyutil"
|
"go.step.sm/crypto/keyutil"
|
||||||
"go.step.sm/crypto/x509util"
|
"go.step.sm/crypto/x509util"
|
||||||
|
@ -404,17 +402,12 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
// type stepProvisionerASN1 struct {
|
||||||
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
// Type int
|
||||||
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
// Name []byte
|
||||||
)
|
// CredentialID []byte
|
||||||
|
// KeyValuePairs []string `asn1:"optional,omitempty"`
|
||||||
type stepProvisionerASN1 struct {
|
// }
|
||||||
Type int
|
|
||||||
Name []byte
|
|
||||||
CredentialID []byte
|
|
||||||
KeyValuePairs []string `asn1:"optional,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type forceCNOption struct {
|
type forceCNOption struct {
|
||||||
ForceCN bool
|
ForceCN bool
|
||||||
|
@ -441,23 +434,22 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type provisionerExtensionOption struct {
|
type provisionerExtensionOption struct {
|
||||||
Type int
|
Extension
|
||||||
Name string
|
|
||||||
CredentialID string
|
|
||||||
KeyValuePairs []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption {
|
func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption {
|
||||||
return &provisionerExtensionOption{
|
return &provisionerExtensionOption{
|
||||||
Type: int(typ),
|
Extension: Extension{
|
||||||
Name: name,
|
Type: typ,
|
||||||
CredentialID: credentialID,
|
Name: name,
|
||||||
KeyValuePairs: keyValuePairs,
|
CredentialID: credentialID,
|
||||||
|
KeyValuePairs: keyValuePairs,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
|
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
|
||||||
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...)
|
ext, err := o.ToExtension()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.NewError(http.StatusInternalServerError, err, "error creating certificate")
|
return errs.NewError(http.StatusInternalServerError, err, "error creating certificate")
|
||||||
}
|
}
|
||||||
|
@ -471,20 +463,3 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption
|
||||||
cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...)
|
cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) {
|
|
||||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
|
||||||
Type: typ,
|
|
||||||
Name: []byte(name),
|
|
||||||
CredentialID: []byte(credentialID),
|
|
||||||
KeyValuePairs: keyValuePairs,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension")
|
|
||||||
}
|
|
||||||
return pkix.Extension{
|
|
||||||
Id: stepOIDProvisioner,
|
|
||||||
Critical: false,
|
|
||||||
Value: b,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -636,18 +636,18 @@ func Test_newProvisionerExtension_Option(t *testing.T) {
|
||||||
valid: func(cert *x509.Certificate) {
|
valid: func(cert *x509.Certificate) {
|
||||||
if assert.Len(t, 1, cert.ExtraExtensions) {
|
if assert.Len(t, 1, cert.ExtraExtensions) {
|
||||||
ext := cert.ExtraExtensions[0]
|
ext := cert.ExtraExtensions[0]
|
||||||
assert.Equals(t, ext.Id, stepOIDProvisioner)
|
assert.Equals(t, ext.Id, StepOIDProvisioner)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/prepend": func() test {
|
"ok/prepend": func() test {
|
||||||
return test{
|
return test{
|
||||||
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
|
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
|
||||||
valid: func(cert *x509.Certificate) {
|
valid: func(cert *x509.Certificate) {
|
||||||
if assert.Len(t, 3, cert.ExtraExtensions) {
|
if assert.Len(t, 3, cert.ExtraExtensions) {
|
||||||
ext := cert.ExtraExtensions[0]
|
ext := cert.ExtraExtensions[0]
|
||||||
assert.Equals(t, ext.Id, stepOIDProvisioner)
|
assert.Equals(t, ext.Id, StepOIDProvisioner)
|
||||||
assert.False(t, ext.Critical)
|
assert.False(t, ext.Critical)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) {
|
||||||
func Test_sshCertValidityValidator(t *testing.T) {
|
func Test_sshCertValidityValidator(t *testing.T) {
|
||||||
p, err := generateX5C(nil)
|
p, err := generateX5C(nil)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
v := sshCertValidityValidator{p.claimer}
|
v := sshCertValidityValidator{p.ctl.Claimer}
|
||||||
n := now()
|
n := now()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
||||||
tests := map[string]func() test{
|
tests := map[string]func() test{
|
||||||
"fail/type-not-set": func() test {
|
"fail/type-not-set": func() test {
|
||||||
return test{
|
return test{
|
||||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
|
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
|
||||||
cert: &ssh.Certificate{
|
cert: &ssh.Certificate{
|
||||||
ValidAfter: uint64(n.Unix()),
|
ValidAfter: uint64(n.Unix()),
|
||||||
ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
|
ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
|
||||||
|
@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/type-not-recognized": func() test {
|
"fail/type-not-recognized": func() test {
|
||||||
return test{
|
return test{
|
||||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
|
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
|
||||||
cert: &ssh.Certificate{
|
cert: &ssh.Certificate{
|
||||||
CertType: 4,
|
CertType: 4,
|
||||||
ValidAfter: uint64(n.Unix()),
|
ValidAfter: uint64(n.Unix()),
|
||||||
|
@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/requested-validAfter-after-limit": func() test {
|
"fail/requested-validAfter-after-limit": func() test {
|
||||||
return test{
|
return test{
|
||||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
|
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
|
||||||
cert: &ssh.Certificate{
|
cert: &ssh.Certificate{
|
||||||
CertType: 1,
|
CertType: 1,
|
||||||
ValidAfter: uint64(n.Add(2 * time.Hour).Unix()),
|
ValidAfter: uint64(n.Add(2 * time.Hour).Unix()),
|
||||||
|
@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/requested-validBefore-after-limit": func() test {
|
"fail/requested-validBefore-after-limit": func() test {
|
||||||
return test{
|
return test{
|
||||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
|
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
|
||||||
cert: &ssh.Certificate{
|
cert: &ssh.Certificate{
|
||||||
CertType: 1,
|
CertType: 1,
|
||||||
ValidAfter: uint64(n.Unix()),
|
ValidAfter: uint64(n.Unix()),
|
||||||
|
@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
||||||
"ok/no-limit": func() test {
|
"ok/no-limit": func() test {
|
||||||
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
|
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
|
||||||
return test{
|
return test{
|
||||||
svm: &sshLimitDuration{Claimer: p.claimer},
|
svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
|
||||||
cert: &ssh.Certificate{
|
cert: &ssh.Certificate{
|
||||||
CertType: 1,
|
CertType: 1,
|
||||||
},
|
},
|
||||||
|
@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
||||||
"ok/defaults": func() test {
|
"ok/defaults": func() test {
|
||||||
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
|
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
|
||||||
return test{
|
return test{
|
||||||
svm: &sshLimitDuration{Claimer: p.claimer},
|
svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
|
||||||
cert: &ssh.Certificate{
|
cert: &ssh.Certificate{
|
||||||
CertType: 1,
|
CertType: 1,
|
||||||
},
|
},
|
||||||
|
@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
||||||
"ok/valid-requested-validBefore": func() test {
|
"ok/valid-requested-validBefore": func() test {
|
||||||
va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix())
|
va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix())
|
||||||
return test{
|
return test{
|
||||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
|
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
|
||||||
cert: &ssh.Certificate{
|
cert: &ssh.Certificate{
|
||||||
CertType: 1,
|
CertType: 1,
|
||||||
ValidAfter: va,
|
ValidAfter: va,
|
||||||
|
@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
||||||
"ok/empty-requested-validBefore-limit-after-default": func() test {
|
"ok/empty-requested-validBefore-limit-after-default": func() test {
|
||||||
va := uint64(n.Unix())
|
va := uint64(n.Unix())
|
||||||
return test{
|
return test{
|
||||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)},
|
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)},
|
||||||
cert: &ssh.Certificate{
|
cert: &ssh.Certificate{
|
||||||
CertType: 1,
|
CertType: 1,
|
||||||
ValidAfter: va,
|
ValidAfter: va,
|
||||||
|
@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
||||||
"ok/empty-requested-validBefore-limit-before-default": func() test {
|
"ok/empty-requested-validBefore-limit-before-default": func() test {
|
||||||
va := uint64(n.Unix())
|
va := uint64(n.Unix())
|
||||||
return test{
|
return test{
|
||||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
|
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
|
||||||
cert: &ssh.Certificate{
|
cert: &ssh.Certificate{
|
||||||
CertType: 1,
|
CertType: 1,
|
||||||
ValidAfter: va,
|
ValidAfter: va,
|
||||||
|
|
|
@ -29,8 +29,7 @@ type SSHPOP struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
claimer *Claimer
|
ctl *Controller
|
||||||
audiences Audiences
|
|
||||||
sshPubKeys *SSHKeys
|
sshPubKeys *SSHKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes and validates the fields of a SSHPOP type.
|
// Init initializes and validates the fields of a SSHPOP type.
|
||||||
func (p *SSHPOP) Init(config Config) error {
|
func (p *SSHPOP) Init(config Config) (err error) {
|
||||||
switch {
|
switch {
|
||||||
case p.Type == "":
|
case p.Type == "":
|
||||||
return errors.New("provisioner type cannot be empty")
|
return errors.New("provisioner type cannot be empty")
|
||||||
|
@ -93,15 +92,11 @@ func (p *SSHPOP) Init(config Config) error {
|
||||||
return errors.New("provisioner public SSH validation keys cannot be empty")
|
return errors.New("provisioner public SSH validation keys cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update claims with global ones
|
|
||||||
var err error
|
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
|
||||||
p.sshPubKeys = config.SSHKeys
|
p.sshPubKeys = config.SSHKeys
|
||||||
return nil
|
|
||||||
|
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||||
|
p.ctl, err = NewController(p, p.Claims, config)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// authorizeToken performs common jwt authorization actions and returns the
|
// authorizeToken performs common jwt authorization actions and returns the
|
||||||
|
@ -109,7 +104,7 @@ func (p *SSHPOP) Init(config Config) error {
|
||||||
// e.g. a Sign request will auth/validate different fields than a Revoke request.
|
// e.g. a Sign request will auth/validate different fields than a Revoke request.
|
||||||
//
|
//
|
||||||
// Checking for certificate revocation has been moved to the authority package.
|
// Checking for certificate revocation has been moved to the authority package.
|
||||||
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
|
func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) {
|
||||||
sshCert, jwt, err := ExtractSSHPOPCert(token)
|
sshCert, jwt, err := ExtractSSHPOPCert(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusUnauthorized, err,
|
return nil, errs.Wrap(http.StatusUnauthorized, err,
|
||||||
|
@ -117,13 +112,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check validity period of the certificate.
|
// Check validity period of the certificate.
|
||||||
n := time.Now()
|
//
|
||||||
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
|
// Controller.AuthorizeSSHRenew will validate this on the renewal flow.
|
||||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
|
if checkValidity {
|
||||||
}
|
unixNow := time.Now().Unix()
|
||||||
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) {
|
if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) {
|
||||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
|
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
|
||||||
|
}
|
||||||
|
if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
|
||||||
|
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
|
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
|
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
|
||||||
|
@ -186,7 +186,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
||||||
// AuthorizeSSHRevoke validates the authorization token and extracts/validates
|
// AuthorizeSSHRevoke validates the authorization token and extracts/validates
|
||||||
// the SSH certificate from the ssh-pop header.
|
// the SSH certificate from the ssh-pop header.
|
||||||
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||||
claims, err := p.authorizeToken(token, p.audiences.SSHRevoke)
|
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
|
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
|
||||||
}
|
}
|
||||||
|
@ -199,22 +199,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||||
// AuthorizeSSHRenew validates the authorization token and extracts/validates
|
// AuthorizeSSHRenew validates the authorization token and extracts/validates
|
||||||
// the SSH certificate from the ssh-pop header.
|
// the SSH certificate from the ssh-pop header.
|
||||||
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
|
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
|
||||||
claims, err := p.authorizeToken(token, p.audiences.SSHRenew)
|
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
|
||||||
}
|
}
|
||||||
if claims.sshCert.CertType != ssh.HostCert {
|
if claims.sshCert.CertType != ssh.HostCert {
|
||||||
return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
|
return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
|
||||||
}
|
}
|
||||||
|
return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert)
|
||||||
return claims.sshCert, nil
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSSHRekey validates the authorization token and extracts/validates
|
// AuthorizeSSHRekey validates the authorization token and extracts/validates
|
||||||
// the SSH certificate from the ssh-pop header.
|
// the SSH certificate from the ssh-pop header.
|
||||||
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
|
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
|
||||||
claims, err := p.authorizeToken(token, p.audiences.SSHRekey)
|
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
|
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
|
||||||
}
|
}
|
||||||
|
@ -225,11 +223,10 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
|
||||||
// Validate public key
|
// Validate public key
|
||||||
&sshDefaultPublicKeyValidator{},
|
&sshDefaultPublicKeyValidator{},
|
||||||
// Validate the validity period.
|
// Validate the validity period.
|
||||||
&sshCertValidityValidator{p.claimer},
|
&sshCertValidityValidator{p.ctl.Claimer},
|
||||||
// Require and validate all the default fields in the SSH certificate.
|
// Require and validate all the default fields in the SSH certificate.
|
||||||
&sshCertDefaultValidator{},
|
&sshCertDefaultValidator{},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate
|
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate
|
||||||
|
|
|
@ -5,16 +5,19 @@ import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"golang.org/x/crypto/ssh"
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSSHPOP_Getters(t *testing.T) {
|
func TestSSHPOP_Getters(t *testing.T) {
|
||||||
|
@ -38,6 +41,7 @@ func TestSSHPOP_Getters(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
|
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
|
||||||
|
now := time.Now()
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -46,6 +50,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
if cert.ValidAfter == 0 {
|
||||||
|
cert.ValidAfter = uint64(now.Unix())
|
||||||
|
}
|
||||||
|
if cert.ValidBefore == 0 {
|
||||||
|
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
|
||||||
|
}
|
||||||
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -207,9 +217,9 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
for name, tt := range tests {
|
for name, tt := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
|
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
@ -279,8 +289,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
|
if err := tc.p.AuthorizeSSHRevoke(context.Background(), tc.token); err != nil {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
@ -360,8 +370,8 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil {
|
if cert, err := tc.p.AuthorizeSSHRenew(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -442,8 +452,8 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil {
|
if cert, opts, err := tc.p.AuthorizeSSHRekey(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -455,9 +465,9 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
||||||
case *sshDefaultPublicKeyValidator:
|
case *sshDefaultPublicKeyValidator:
|
||||||
case *sshCertDefaultValidator:
|
case *sshCertDefaultValidator:
|
||||||
case *sshCertValidityValidator:
|
case *sshCertValidityValidator:
|
||||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
|
assert.Equals(t, tc.cert.Nonce, cert.Nonce)
|
||||||
|
|
21
authority/provisioner/testdata/certs/bad-extension.crt
vendored
Normal file
21
authority/provisioner/testdata/certs/bad-extension.crt
vendored
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi
|
||||||
|
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy
|
||||||
|
MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
|
||||||
|
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
|
||||||
|
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
|
||||||
|
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
|
||||||
|
KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc
|
||||||
|
pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE
|
||||||
|
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
|
||||||
|
ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg
|
||||||
|
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
|
||||||
|
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
|
||||||
|
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
|
||||||
|
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
|
||||||
|
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
|
||||||
|
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
|
||||||
|
LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49
|
||||||
|
BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY
|
||||||
|
wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g=
|
||||||
|
-----END CERTIFICATE-----
|
22
authority/provisioner/testdata/certs/good-extension.crt
vendored
Normal file
22
authority/provisioner/testdata/certs/good-extension.crt
vendored
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi
|
||||||
|
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx
|
||||||
|
MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
|
||||||
|
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
|
||||||
|
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
|
||||||
|
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
|
||||||
|
KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0
|
||||||
|
4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE
|
||||||
|
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
|
||||||
|
ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg
|
||||||
|
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
|
||||||
|
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
|
||||||
|
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
|
||||||
|
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
|
||||||
|
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
|
||||||
|
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
|
||||||
|
LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz
|
||||||
|
bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf
|
||||||
|
SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB
|
||||||
|
bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w==
|
||||||
|
-----END CERTIFICATE-----
|
|
@ -24,20 +24,22 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultDisableRenewal = false
|
defaultDisableRenewal = false
|
||||||
defaultEnableSSHCA = true
|
defaultAllowRenewAfterExpiry = false
|
||||||
globalProvisionerClaims = Claims{
|
defaultEnableSSHCA = true
|
||||||
MinTLSDur: &Duration{5 * time.Minute},
|
globalProvisionerClaims = Claims{
|
||||||
MaxTLSDur: &Duration{24 * time.Hour},
|
MinTLSDur: &Duration{5 * time.Minute},
|
||||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
MaxTLSDur: &Duration{24 * time.Hour},
|
||||||
DisableRenewal: &defaultDisableRenewal,
|
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||||
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
|
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||||
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
|
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
|
||||||
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
|
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
|
||||||
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
|
||||||
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||||
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||||
EnableSSHCA: &defaultEnableSSHCA,
|
EnableSSHCA: &defaultEnableSSHCA,
|
||||||
|
DisableRenewal: &defaultDisableRenewal,
|
||||||
|
AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry,
|
||||||
}
|
}
|
||||||
testAudiences = Audiences{
|
testAudiences = Audiences{
|
||||||
Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"},
|
Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"},
|
||||||
|
@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
|
||||||
if err != nil {
|
p := &JWK{
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &JWK{
|
|
||||||
Name: name,
|
Name: name,
|
||||||
Type: "JWK",
|
Type: "JWK",
|
||||||
Key: &public,
|
Key: &public,
|
||||||
EncryptedKey: encrypted,
|
EncryptedKey: encrypted,
|
||||||
Claims: &globalProvisionerClaims,
|
Claims: &globalProvisionerClaims,
|
||||||
audiences: testAudiences,
|
}
|
||||||
claimer: claimer,
|
p.ctl, err = NewController(p, p.Claims, Config{
|
||||||
}, nil
|
Audiences: testAudiences,
|
||||||
|
})
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
|
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
|
||||||
|
@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pubKeys := []interface{}{fooPub, barPub}
|
pubKeys := []interface{}{fooPub, barPub}
|
||||||
if inputPubKey != nil {
|
if inputPubKey != nil {
|
||||||
pubKeys = append(pubKeys, inputPubKey)
|
pubKeys = append(pubKeys, inputPubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &K8sSA{
|
p := &K8sSA{
|
||||||
Name: K8sSAName,
|
Name: K8sSAName,
|
||||||
Type: "K8sSA",
|
Type: "K8sSA",
|
||||||
Claims: &globalProvisionerClaims,
|
Claims: &globalProvisionerClaims,
|
||||||
audiences: testAudiences,
|
pubKeys: pubKeys,
|
||||||
claimer: claimer,
|
}
|
||||||
pubKeys: pubKeys,
|
p.ctl, err = NewController(p, p.Claims, Config{
|
||||||
}, nil
|
Audiences: testAudiences,
|
||||||
|
})
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateSSHPOP() (*SSHPOP, error) {
|
func generateSSHPOP() (*SSHPOP, error) {
|
||||||
|
@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
|
userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &SSHPOP{
|
p := &SSHPOP{
|
||||||
Name: name,
|
Name: name,
|
||||||
Type: "SSHPOP",
|
Type: "SSHPOP",
|
||||||
Claims: &globalProvisionerClaims,
|
Claims: &globalProvisionerClaims,
|
||||||
audiences: testAudiences,
|
|
||||||
claimer: claimer,
|
|
||||||
sshPubKeys: &SSHKeys{
|
sshPubKeys: &SSHKeys{
|
||||||
UserKeys: []ssh.PublicKey{userKey},
|
UserKeys: []ssh.PublicKey{userKey},
|
||||||
HostKeys: []ssh.PublicKey{hostKey},
|
HostKeys: []ssh.PublicKey{hostKey},
|
||||||
},
|
},
|
||||||
}, nil
|
}
|
||||||
|
p.ctl, err = NewController(p, p.Claims, Config{
|
||||||
|
Audiences: testAudiences,
|
||||||
|
})
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateX5C(root []byte) (*X5C, error) {
|
func generateX5C(root []byte) (*X5C, error) {
|
||||||
|
@ -283,11 +279,6 @@ M46l92gdOozT
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rootPool := x509.NewCertPool()
|
rootPool := x509.NewCertPool()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -305,15 +296,17 @@ M46l92gdOozT
|
||||||
}
|
}
|
||||||
rootPool.AddCert(cert)
|
rootPool.AddCert(cert)
|
||||||
}
|
}
|
||||||
return &X5C{
|
p := &X5C{
|
||||||
Name: name,
|
Name: name,
|
||||||
Type: "X5C",
|
Type: "X5C",
|
||||||
Roots: root,
|
Roots: root,
|
||||||
Claims: &globalProvisionerClaims,
|
Claims: &globalProvisionerClaims,
|
||||||
audiences: testAudiences,
|
rootPool: rootPool,
|
||||||
claimer: claimer,
|
}
|
||||||
rootPool: rootPool,
|
p.ctl, err = NewController(p, p.Claims, Config{
|
||||||
}, nil
|
Audiences: testAudiences,
|
||||||
|
})
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateOIDC() (*OIDC, error) {
|
func generateOIDC() (*OIDC, error) {
|
||||||
|
@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
p := &OIDC{
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &OIDC{
|
|
||||||
Name: name,
|
Name: name,
|
||||||
Type: "OIDC",
|
Type: "OIDC",
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
|
@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) {
|
||||||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||||
expiry: time.Now().Add(24 * time.Hour),
|
expiry: time.Now().Add(24 * time.Hour),
|
||||||
},
|
},
|
||||||
claimer: claimer,
|
}
|
||||||
}, nil
|
p.ctl, err = NewController(p, p.Claims, Config{
|
||||||
|
Audiences: testAudiences,
|
||||||
|
})
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateGCP() (*GCP, error) {
|
func generateGCP() (*GCP, error) {
|
||||||
|
@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
p := &GCP{
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &GCP{
|
|
||||||
Type: "GCP",
|
Type: "GCP",
|
||||||
Name: name,
|
Name: name,
|
||||||
ServiceAccounts: []string{serviceAccount},
|
ServiceAccounts: []string{serviceAccount},
|
||||||
Claims: &globalProvisionerClaims,
|
Claims: &globalProvisionerClaims,
|
||||||
claimer: claimer,
|
|
||||||
config: newGCPConfig(),
|
config: newGCPConfig(),
|
||||||
keyStore: &keyStore{
|
keyStore: &keyStore{
|
||||||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||||
expiry: time.Now().Add(24 * time.Hour),
|
expiry: time.Now().Add(24 * time.Hour),
|
||||||
},
|
},
|
||||||
audiences: testAudiences.WithFragment("gcp/" + name),
|
}
|
||||||
}, nil
|
p.ctl, err = NewController(p, p.Claims, Config{
|
||||||
|
Audiences: testAudiences.WithFragment("gcp/" + name),
|
||||||
|
})
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateAWS() (*AWS, error) {
|
func generateAWS() (*AWS, error) {
|
||||||
|
@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
block, _ := pem.Decode([]byte(awsTestCertificate))
|
block, _ := pem.Decode([]byte(awsTestCertificate))
|
||||||
if block == nil || block.Type != "CERTIFICATE" {
|
if block == nil || block.Type != "CERTIFICATE" {
|
||||||
return nil, errors.New("error decoding AWS certificate")
|
return nil, errors.New("error decoding AWS certificate")
|
||||||
|
@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error parsing AWS certificate")
|
return nil, errors.Wrap(err, "error parsing AWS certificate")
|
||||||
}
|
}
|
||||||
return &AWS{
|
p := &AWS{
|
||||||
Type: "AWS",
|
Type: "AWS",
|
||||||
Name: name,
|
Name: name,
|
||||||
Accounts: []string{accountID},
|
Accounts: []string{accountID},
|
||||||
Claims: &globalProvisionerClaims,
|
Claims: &globalProvisionerClaims,
|
||||||
IMDSVersions: []string{"v2", "v1"},
|
IMDSVersions: []string{"v2", "v1"},
|
||||||
claimer: claimer,
|
|
||||||
config: &awsConfig{
|
config: &awsConfig{
|
||||||
identityURL: awsIdentityURL,
|
identityURL: awsIdentityURL,
|
||||||
signatureURL: awsSignatureURL,
|
signatureURL: awsSignatureURL,
|
||||||
|
@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) {
|
||||||
certificates: []*x509.Certificate{cert},
|
certificates: []*x509.Certificate{cert},
|
||||||
signatureAlgorithm: awsSignatureAlgorithm,
|
signatureAlgorithm: awsSignatureAlgorithm,
|
||||||
},
|
},
|
||||||
audiences: testAudiences.WithFragment("aws/" + name),
|
}
|
||||||
}, nil
|
p.ctl, err = NewController(p, p.Claims, Config{
|
||||||
|
Audiences: testAudiences.WithFragment("aws/" + name),
|
||||||
|
})
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
||||||
|
@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
block, _ := pem.Decode([]byte(awsTestCertificate))
|
block, _ := pem.Decode([]byte(awsTestCertificate))
|
||||||
if block == nil || block.Type != "CERTIFICATE" {
|
if block == nil || block.Type != "CERTIFICATE" {
|
||||||
return nil, errors.New("error decoding AWS certificate")
|
return nil, errors.New("error decoding AWS certificate")
|
||||||
|
@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error parsing AWS certificate")
|
return nil, errors.Wrap(err, "error parsing AWS certificate")
|
||||||
}
|
}
|
||||||
return &AWS{
|
p := &AWS{
|
||||||
Type: "AWS",
|
Type: "AWS",
|
||||||
Name: name,
|
Name: name,
|
||||||
Accounts: []string{accountID},
|
Accounts: []string{accountID},
|
||||||
Claims: &globalProvisionerClaims,
|
Claims: &globalProvisionerClaims,
|
||||||
IMDSVersions: []string{"v1"},
|
IMDSVersions: []string{"v1"},
|
||||||
claimer: claimer,
|
|
||||||
config: &awsConfig{
|
config: &awsConfig{
|
||||||
identityURL: awsIdentityURL,
|
identityURL: awsIdentityURL,
|
||||||
signatureURL: awsSignatureURL,
|
signatureURL: awsSignatureURL,
|
||||||
|
@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) {
|
||||||
certificates: []*x509.Certificate{cert},
|
certificates: []*x509.Certificate{cert},
|
||||||
signatureAlgorithm: awsSignatureAlgorithm,
|
signatureAlgorithm: awsSignatureAlgorithm,
|
||||||
},
|
},
|
||||||
audiences: testAudiences.WithFragment("aws/" + name),
|
}
|
||||||
}, nil
|
p.ctl, err = NewController(p, p.Claims, Config{
|
||||||
|
Audiences: testAudiences.WithFragment("aws/" + name),
|
||||||
|
})
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
|
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
|
||||||
|
@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
jwk, err := generateJSONWebKey()
|
jwk, err := generateJSONWebKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Azure{
|
p := &Azure{
|
||||||
Type: "Azure",
|
Type: "Azure",
|
||||||
Name: name,
|
Name: name,
|
||||||
TenantID: tenantID,
|
TenantID: tenantID,
|
||||||
Audience: azureDefaultAudience,
|
Audience: azureDefaultAudience,
|
||||||
Claims: &globalProvisionerClaims,
|
Claims: &globalProvisionerClaims,
|
||||||
claimer: claimer,
|
|
||||||
config: newAzureConfig(tenantID),
|
config: newAzureConfig(tenantID),
|
||||||
oidcConfig: openIDConfiguration{
|
oidcConfig: openIDConfiguration{
|
||||||
Issuer: "https://sts.windows.net/" + tenantID + "/",
|
Issuer: "https://sts.windows.net/" + tenantID + "/",
|
||||||
|
@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) {
|
||||||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||||
expiry: time.Now().Add(24 * time.Hour),
|
expiry: time.Now().Add(24 * time.Hour),
|
||||||
},
|
},
|
||||||
}, nil
|
}
|
||||||
|
p.ctl, err = NewController(p, p.Claims, Config{
|
||||||
|
Audiences: testAudiences,
|
||||||
|
})
|
||||||
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
||||||
|
@ -671,7 +656,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
||||||
w.Header().Add("Cache-Control", "max-age=5")
|
w.Header().Add("Cache-Control", "max-age=5")
|
||||||
writeJSON(w, getPublic(az.keyStore.keySet))
|
writeJSON(w, getPublic(az.keyStore.keySet))
|
||||||
case "/metadata/identity/oauth2/token":
|
case "/metadata/identity/oauth2/token":
|
||||||
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0])
|
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
} else {
|
} else {
|
||||||
|
@ -1009,7 +994,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r
|
||||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||||
sig, err := jose.NewSigner(
|
sig, err := jose.NewSigner(
|
||||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||||
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||||
|
@ -1017,6 +1002,12 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
var xmsMirID string
|
||||||
|
if resourceType == "vm" {
|
||||||
|
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName)
|
||||||
|
} else if resourceType == "uai" {
|
||||||
|
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName)
|
||||||
|
}
|
||||||
|
|
||||||
claims := azurePayload{
|
claims := azurePayload{
|
||||||
Claims: jose.Claims{
|
Claims: jose.Claims{
|
||||||
|
@ -1034,7 +1025,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
|
||||||
ObjectID: "the-oid",
|
ObjectID: "the-oid",
|
||||||
TenantID: tenantID,
|
TenantID: tenantID,
|
||||||
Version: "the-version",
|
Version: "the-version",
|
||||||
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine),
|
XMSMirID: xmsMirID,
|
||||||
}
|
}
|
||||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,15 +26,14 @@ type x5cPayload struct {
|
||||||
// signature requests.
|
// signature requests.
|
||||||
type X5C struct {
|
type X5C struct {
|
||||||
*base
|
*base
|
||||||
ID string `json:"-"`
|
ID string `json:"-"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Roots []byte `json:"roots"`
|
Roots []byte `json:"roots"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
Options *Options `json:"options,omitempty"`
|
Options *Options `json:"options,omitempty"`
|
||||||
claimer *Claimer
|
ctl *Controller
|
||||||
audiences Audiences
|
rootPool *x509.CertPool
|
||||||
rootPool *x509.CertPool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetID returns the provisioner unique identifier. The name and credential id
|
// GetID returns the provisioner unique identifier. The name and credential id
|
||||||
|
@ -86,7 +85,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes and validates the fields of a X5C type.
|
// Init initializes and validates the fields of a X5C type.
|
||||||
func (p *X5C) Init(config Config) error {
|
func (p *X5C) Init(config Config) (err error) {
|
||||||
switch {
|
switch {
|
||||||
case p.Type == "":
|
case p.Type == "":
|
||||||
return errors.New("provisioner type cannot be empty")
|
return errors.New("provisioner type cannot be empty")
|
||||||
|
@ -101,6 +100,7 @@ func (p *X5C) Init(config Config) error {
|
||||||
var (
|
var (
|
||||||
block *pem.Block
|
block *pem.Block
|
||||||
rest = p.Roots
|
rest = p.Roots
|
||||||
|
count int
|
||||||
)
|
)
|
||||||
for rest != nil {
|
for rest != nil {
|
||||||
block, rest = pem.Decode(rest)
|
block, rest = pem.Decode(rest)
|
||||||
|
@ -111,22 +111,18 @@ func (p *X5C) Init(config Config) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "error parsing x509 certificate from PEM block")
|
return errors.Wrap(err, "error parsing x509 certificate from PEM block")
|
||||||
}
|
}
|
||||||
|
count++
|
||||||
p.rootPool.AddCert(cert)
|
p.rootPool.AddCert(cert)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that at least one root was found.
|
// Verify that at least one root was found.
|
||||||
if len(p.rootPool.Subjects()) == 0 {
|
if count == 0 {
|
||||||
return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName())
|
return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update claims with global ones
|
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||||
var err error
|
p.ctl, err = NewController(p, p.Claims, config)
|
||||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
return
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// authorizeToken performs common jwt authorization actions and returns the
|
// authorizeToken performs common jwt authorization actions and returns the
|
||||||
|
@ -189,13 +185,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
|
||||||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||||
// revoke the certificate with serial number in the `sub` property.
|
// revoke the certificate with serial number in the `sub` property.
|
||||||
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
|
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||||
_, err := p.authorizeToken(token, p.audiences.Revoke)
|
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
|
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSign validates the given token.
|
// AuthorizeSign validates the given token.
|
||||||
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
|
||||||
}
|
}
|
||||||
|
@ -213,40 +209,45 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
data.SetToken(v)
|
data.SetToken(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The X509 certificate will be available using the template variable
|
||||||
|
// AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be
|
||||||
|
// used to get all the domains.
|
||||||
|
data.SetAuthorizationCertificate(claims.chains[0][0])
|
||||||
|
|
||||||
templateOptions, err := TemplateOptions(p.Options, data)
|
templateOptions, err := TemplateOptions(p.Options, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
|
||||||
}
|
}
|
||||||
|
|
||||||
return []SignOption{
|
return []SignOption{
|
||||||
|
p,
|
||||||
templateOptions,
|
templateOptions,
|
||||||
// modifiers / withOptions
|
// modifiers / withOptions
|
||||||
newProvisionerExtensionOption(TypeX5C, p.Name, ""),
|
newProvisionerExtensionOption(TypeX5C, p.Name, ""),
|
||||||
profileLimitDuration{p.claimer.DefaultTLSCertDuration(),
|
profileLimitDuration{
|
||||||
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter},
|
p.ctl.Claimer.DefaultTLSCertDuration(),
|
||||||
|
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter,
|
||||||
|
},
|
||||||
// validators
|
// validators
|
||||||
commonNameValidator(claims.Subject),
|
commonNameValidator(claims.Subject),
|
||||||
defaultSANsValidator(claims.SANs),
|
defaultSANsValidator(claims.SANs),
|
||||||
defaultPublicKeyValidator{},
|
defaultPublicKeyValidator{},
|
||||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||||
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||||
if p.claimer.IsDisableRenewal() {
|
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||||
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||||
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||||
if !p.claimer.IsSSHCAEnabled() {
|
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||||
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName())
|
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName())
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
|
||||||
}
|
}
|
||||||
|
@ -287,6 +288,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
||||||
data.SetToken(v)
|
data.SetToken(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The X509 certificate will be available using the template variable
|
||||||
|
// AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be
|
||||||
|
// used to get all the domains.
|
||||||
|
data.SetAuthorizationCertificate(claims.chains[0][0])
|
||||||
|
|
||||||
templateOptions, err := TemplateSSHOptions(p.Options, data)
|
templateOptions, err := TemplateSSHOptions(p.Options, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
|
||||||
|
@ -304,11 +310,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
||||||
|
|
||||||
return append(signOptions,
|
return append(signOptions,
|
||||||
// Checks the validity bounds, and set the validity if has not been set.
|
// Checks the validity bounds, and set the validity if has not been set.
|
||||||
&sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter},
|
&sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter},
|
||||||
// Validate public key.
|
// Validate public key.
|
||||||
&sshDefaultPublicKeyValidator{},
|
&sshDefaultPublicKeyValidator{},
|
||||||
// Validate the validity period.
|
// Validate the validity period.
|
||||||
&sshCertValidityValidator{p.claimer},
|
&sshCertValidityValidator{p.ctl.Claimer},
|
||||||
// Require all the fields in the SSH certificate
|
// Require all the fields in the SSH certificate
|
||||||
&sshCertDefaultValidator{},
|
&sshCertDefaultValidator{},
|
||||||
), nil
|
), nil
|
||||||
|
|
|
@ -2,16 +2,19 @@ package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
"go.step.sm/crypto/randutil"
|
"go.step.sm/crypto/randutil"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestX5C_Getters(t *testing.T) {
|
func TestX5C_Getters(t *testing.T) {
|
||||||
|
@ -69,8 +72,8 @@ func TestX5C_Init(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest {
|
"fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest {
|
||||||
return ProvisionerValidateTest{
|
return ProvisionerValidateTest{
|
||||||
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences},
|
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")},
|
||||||
err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"),
|
err: errors.New("no x509 certificates found in roots attribute for provisioner 'foo'"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest {
|
"fail/invalid-duration": func(t *testing.T) ProvisionerValidateTest {
|
||||||
|
@ -117,9 +120,11 @@ M46l92gdOozT
|
||||||
return ProvisionerValidateTest{
|
return ProvisionerValidateTest{
|
||||||
p: p,
|
p: p,
|
||||||
extraValid: func(p *X5C) error {
|
extraValid: func(p *X5C) error {
|
||||||
|
// nolint:staticcheck // We don't have a different way to
|
||||||
|
// check the number of certificates in the pool.
|
||||||
numCerts := len(p.rootPool.Subjects())
|
numCerts := len(p.rootPool.Subjects())
|
||||||
if numCerts != 2 {
|
if numCerts != 2 {
|
||||||
return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts)
|
return fmt.Errorf("unexpected number of certs: want 2, but got %d", numCerts)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -141,7 +146,7 @@ M46l92gdOozT
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if assert.Nil(t, tc.err) {
|
if assert.Nil(t, tc.err) {
|
||||||
assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID()))
|
assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID()))
|
||||||
if tc.extraValid != nil {
|
if tc.extraValid != nil {
|
||||||
assert.Nil(t, tc.extraValid(tc.p))
|
assert.Nil(t, tc.extraValid(tc.p))
|
||||||
}
|
}
|
||||||
|
@ -384,8 +389,8 @@ lgsqsR63is+0YQ==
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
|
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -455,7 +460,7 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
@ -463,19 +468,20 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
||||||
} else {
|
} else {
|
||||||
if assert.Nil(t, tc.err) {
|
if assert.Nil(t, tc.err) {
|
||||||
if assert.NotNil(t, opts) {
|
if assert.NotNil(t, opts) {
|
||||||
assert.Equals(t, len(opts), 7)
|
assert.Equals(t, len(opts), 8)
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
case *X5C:
|
||||||
case certificateOptionsFunc:
|
case certificateOptionsFunc:
|
||||||
case *provisionerExtensionOption:
|
case *provisionerExtensionOption:
|
||||||
assert.Equals(t, v.Type, int(TypeX5C))
|
assert.Equals(t, v.Type, TypeX5C)
|
||||||
assert.Equals(t, v.Name, tc.p.GetName())
|
assert.Equals(t, v.Name, tc.p.GetName())
|
||||||
assert.Equals(t, v.CredentialID, "")
|
assert.Equals(t, v.CredentialID, "")
|
||||||
assert.Len(t, 0, v.KeyValuePairs)
|
assert.Len(t, 0, v.KeyValuePairs)
|
||||||
case profileLimitDuration:
|
case profileLimitDuration:
|
||||||
assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration())
|
assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration())
|
||||||
|
|
||||||
claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign)
|
claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter)
|
assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter)
|
||||||
case commonNameValidator:
|
case commonNameValidator:
|
||||||
|
@ -484,10 +490,10 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
||||||
case defaultSANsValidator:
|
case defaultSANsValidator:
|
||||||
assert.Equals(t, []string(v), tc.sans)
|
assert.Equals(t, []string(v), tc.sans)
|
||||||
case *validityValidator:
|
case *validityValidator:
|
||||||
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
|
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
||||||
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
|
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -538,8 +544,8 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
|
if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -551,6 +557,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestX5C_AuthorizeRenew(t *testing.T) {
|
func TestX5C_AuthorizeRenew(t *testing.T) {
|
||||||
|
now := time.Now().Truncate(time.Second)
|
||||||
type test struct {
|
type test struct {
|
||||||
p *X5C
|
p *X5C
|
||||||
code int
|
code int
|
||||||
|
@ -563,12 +570,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
|
||||||
// disable renewal
|
// disable renewal
|
||||||
disable := true
|
disable := true
|
||||||
p.Claims = &Claims{DisableRenewal: &disable}
|
p.Claims = &Claims{DisableRenewal: &disable}
|
||||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
p: p,
|
p: p,
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()),
|
err: fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
|
@ -582,10 +589,13 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
|
||||||
for name, tt := range tests {
|
for name, tt := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil {
|
if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(time.Hour),
|
||||||
|
}); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -618,13 +628,13 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
||||||
// disable sshCA
|
// disable sshCA
|
||||||
enable := false
|
enable := false
|
||||||
p.Claims = &Claims{EnableSSHCA: &enable}
|
p.Claims = &Claims{EnableSSHCA: &enable}
|
||||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
return test{
|
return test{
|
||||||
p: p,
|
p: p,
|
||||||
token: "foo",
|
token: "foo",
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
err: errors.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()),
|
err: fmt.Errorf("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/invalid-token": func(t *testing.T) test {
|
"fail/invalid-token": func(t *testing.T) test {
|
||||||
|
@ -745,7 +755,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
||||||
tc := tt(t)
|
tc := tt(t)
|
||||||
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
|
if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
@ -774,13 +784,13 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
||||||
case sshCertDefaultsModifier:
|
case sshCertDefaultsModifier:
|
||||||
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
|
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
|
||||||
case *sshLimitDuration:
|
case *sshLimitDuration:
|
||||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||||
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
|
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
|
||||||
case *sshCertValidityValidator:
|
case *sshCertValidityValidator:
|
||||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||||
case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc:
|
case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc:
|
||||||
default:
|
default:
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
|
||||||
}
|
}
|
||||||
tot++
|
tot++
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,20 +87,20 @@ func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, e
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authority) generateProvisionerConfig(ctx context.Context) (*provisioner.Config, error) {
|
func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.Config, error) {
|
||||||
// Merge global and configuration claims
|
// Merge global and configuration claims
|
||||||
claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims)
|
claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return provisioner.Config{}, err
|
||||||
}
|
}
|
||||||
// TODO: should we also be combining the ssh federated roots here?
|
// TODO: should we also be combining the ssh federated roots here?
|
||||||
// If we rotate ssh roots keys, sshpop provisioner will lose ability to
|
// If we rotate ssh roots keys, sshpop provisioner will lose ability to
|
||||||
// validate old SSH certificates, unless they are added as federated certs.
|
// validate old SSH certificates, unless they are added as federated certs.
|
||||||
sshKeys, err := a.GetSSHRoots(ctx)
|
sshKeys, err := a.GetSSHRoots(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return provisioner.Config{}, err
|
||||||
}
|
}
|
||||||
return &provisioner.Config{
|
return provisioner.Config{
|
||||||
Claims: claimer.Claims(),
|
Claims: claimer.Claims(),
|
||||||
Audiences: a.config.GetAudiences(),
|
Audiences: a.config.GetAudiences(),
|
||||||
DB: a.db,
|
DB: a.db,
|
||||||
|
@ -108,7 +108,9 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (*provisioner
|
||||||
UserKeys: sshKeys.UserKeys,
|
UserKeys: sshKeys.UserKeys,
|
||||||
HostKeys: sshKeys.HostKeys,
|
HostKeys: sshKeys.HostKeys,
|
||||||
},
|
},
|
||||||
GetIdentityFunc: a.getIdentityFunc,
|
GetIdentityFunc: a.getIdentityFunc,
|
||||||
|
AuthorizeRenewFunc: a.authorizeRenewFunc,
|
||||||
|
AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc,
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -133,9 +135,18 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi
|
||||||
"provisioner with token ID %s already exists", certProv.GetIDForToken())
|
"provisioner with token ID %s already exists", certProv.GetIDForToken())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
provisionerConfig, err := a.generateProvisionerConfig(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return admin.WrapErrorISE(err, "error generating provisioner config")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := certProv.Init(provisionerConfig); err != nil {
|
||||||
|
return admin.WrapError(admin.ErrorBadRequestType, err, "error validating configuration for provisioner %s", prov.Name)
|
||||||
|
}
|
||||||
|
|
||||||
// Store to database -- this will set the ID.
|
// Store to database -- this will set the ID.
|
||||||
if err := a.adminDB.CreateProvisioner(ctx, prov); err != nil {
|
if err := a.adminDB.CreateProvisioner(ctx, prov); err != nil {
|
||||||
return admin.WrapErrorISE(err, "error creating admin")
|
return admin.WrapErrorISE(err, "error creating provisioner")
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need a new conversion that has the newly set ID.
|
// We need a new conversion that has the newly set ID.
|
||||||
|
@ -145,12 +156,7 @@ func (a *Authority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisi
|
||||||
"error converting to certificates provisioner from linkedca provisioner")
|
"error converting to certificates provisioner from linkedca provisioner")
|
||||||
}
|
}
|
||||||
|
|
||||||
provisionerConfig, err := a.generateProvisionerConfig(ctx)
|
if err := certProv.Init(provisionerConfig); err != nil {
|
||||||
if err != nil {
|
|
||||||
return admin.WrapErrorISE(err, "error generating provisioner config")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := certProv.Init(*provisionerConfig); err != nil {
|
|
||||||
return admin.WrapErrorISE(err, "error initializing provisioner %s", prov.Name)
|
return admin.WrapErrorISE(err, "error initializing provisioner %s", prov.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,7 +185,7 @@ func (a *Authority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisio
|
||||||
return admin.WrapErrorISE(err, "error generating provisioner config")
|
return admin.WrapErrorISE(err, "error generating provisioner config")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := certProv.Init(*provisionerConfig); err != nil {
|
if err := certProv.Init(provisionerConfig); err != nil {
|
||||||
return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name)
|
return admin.WrapErrorISE(err, "error initializing provisioner %s", nu.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -431,7 +437,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
pc := &provisioner.Claims{
|
pc := &provisioner.Claims{
|
||||||
DisableRenewal: &c.DisableRenewal,
|
DisableRenewal: &c.DisableRenewal,
|
||||||
|
AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry,
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
@ -469,12 +476,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims {
|
||||||
}
|
}
|
||||||
|
|
||||||
disableRenewal := config.DefaultDisableRenewal
|
disableRenewal := config.DefaultDisableRenewal
|
||||||
|
allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry
|
||||||
|
|
||||||
if c.DisableRenewal != nil {
|
if c.DisableRenewal != nil {
|
||||||
disableRenewal = *c.DisableRenewal
|
disableRenewal = *c.DisableRenewal
|
||||||
}
|
}
|
||||||
|
if c.AllowRenewAfterExpiry != nil {
|
||||||
|
allowRenewAfterExpiry = *c.AllowRenewAfterExpiry
|
||||||
|
}
|
||||||
|
|
||||||
lc := &linkedca.Claims{
|
lc := &linkedca.Claims{
|
||||||
DisableRenewal: disableRenewal,
|
DisableRenewal: disableRenewal,
|
||||||
|
AllowRenewAfterExpiry: allowRenewAfterExpiry,
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {
|
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {
|
||||||
|
@ -706,6 +719,8 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
|
||||||
Name: p.Name,
|
Name: p.Name,
|
||||||
TenantID: cfg.TenantId,
|
TenantID: cfg.TenantId,
|
||||||
ResourceGroups: cfg.ResourceGroups,
|
ResourceGroups: cfg.ResourceGroups,
|
||||||
|
SubscriptionIDs: cfg.SubscriptionIds,
|
||||||
|
ObjectIDs: cfg.ObjectIds,
|
||||||
Audience: cfg.Audience,
|
Audience: cfg.Audience,
|
||||||
DisableCustomSANs: cfg.DisableCustomSans,
|
DisableCustomSANs: cfg.DisableCustomSans,
|
||||||
DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse,
|
DisableTrustOnFirstUse: cfg.DisableTrustOnFirstUse,
|
||||||
|
@ -865,6 +880,8 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro
|
||||||
Azure: &linkedca.AzureProvisioner{
|
Azure: &linkedca.AzureProvisioner{
|
||||||
TenantId: p.TenantID,
|
TenantId: p.TenantID,
|
||||||
ResourceGroups: p.ResourceGroups,
|
ResourceGroups: p.ResourceGroups,
|
||||||
|
SubscriptionIds: p.SubscriptionIDs,
|
||||||
|
ObjectIds: p.ObjectIDs,
|
||||||
Audience: p.Audience,
|
Audience: p.Audience,
|
||||||
DisableCustomSans: p.DisableCustomSANs,
|
DisableCustomSans: p.DisableCustomSANs,
|
||||||
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,
|
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetEncryptedKey(t *testing.T) {
|
func TestGetEncryptedKey(t *testing.T) {
|
||||||
|
@ -49,8 +49,8 @@ func TestGetEncryptedKey(t *testing.T) {
|
||||||
ek, err := tc.a.GetEncryptedKey(tc.kid)
|
ek, err := tc.a.GetEncryptedKey(tc.kid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -90,8 +90,8 @@ func TestGetProvisioners(t *testing.T) {
|
||||||
ps, next, err := tc.a.GetProvisioners("", 0)
|
ps, next, err := tc.a.GetProvisioners("", 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,14 +2,15 @@ package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRoot(t *testing.T) {
|
func TestRoot(t *testing.T) {
|
||||||
|
@ -31,7 +32,7 @@ func TestRoot(t *testing.T) {
|
||||||
crt, err := a.Root(tc.sum)
|
crt, err := a.Root(tc.sum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
|
|
@ -7,21 +7,22 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"github.com/smallstep/certificates/templates"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/sshutil"
|
"go.step.sm/crypto/sshutil"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/certificates/templates"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sshTestModifier ssh.Certificate
|
type sshTestModifier ssh.Certificate
|
||||||
|
@ -716,8 +717,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
|
||||||
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
_, ok := err.(errs.StatusCoder)
|
_, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want)
|
t.Errorf("Authority.GetSSHBastion() = %v, want %v", got, tt.want)
|
||||||
|
@ -806,8 +807,8 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
|
||||||
hosts, err := auth.GetSSHHosts(context.Background(), tc.cert)
|
hosts, err := auth.GetSSHHosts(context.Background(), tc.cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -1033,8 +1034,8 @@ func TestAuthority_RekeySSH(t *testing.T) {
|
||||||
cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...)
|
cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
package status
|
|
||||||
|
|
||||||
// Type is the type for status.
|
|
||||||
type Type string
|
|
||||||
|
|
||||||
var (
|
|
||||||
// Active active
|
|
||||||
Active = Type("active")
|
|
||||||
// Deleted deleted
|
|
||||||
Deleted = Type("deleted")
|
|
||||||
)
|
|
|
@ -89,8 +89,13 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
|
||||||
// Set backdate with the configured value
|
// Set backdate with the configured value
|
||||||
signOpts.Backdate = a.config.AuthorityConfig.Backdate.Duration
|
signOpts.Backdate = a.config.AuthorityConfig.Backdate.Duration
|
||||||
|
|
||||||
|
var prov provisioner.Interface
|
||||||
for _, op := range extraOpts {
|
for _, op := range extraOpts {
|
||||||
switch k := op.(type) {
|
switch k := op.(type) {
|
||||||
|
// Capture current provisioner
|
||||||
|
case provisioner.Interface:
|
||||||
|
prov = k
|
||||||
|
|
||||||
// Adds new options to NewCertificate
|
// Adds new options to NewCertificate
|
||||||
case provisioner.CertificateOptions:
|
case provisioner.CertificateOptions:
|
||||||
certOptions = append(certOptions, k.Options(signOpts)...)
|
certOptions = append(certOptions, k.Options(signOpts)...)
|
||||||
|
@ -204,7 +209,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
|
||||||
}
|
}
|
||||||
|
|
||||||
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
|
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
|
||||||
if err = a.storeCertificate(fullchain); err != nil {
|
if err = a.storeCertificate(prov, fullchain); err != nil {
|
||||||
if err != db.ErrNotImplemented {
|
if err != db.ErrNotImplemented {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
||||||
"authority.Sign; error storing certificate in db", opts...)
|
"authority.Sign; error storing certificate in db", opts...)
|
||||||
|
@ -325,19 +330,28 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
|
||||||
// TODO: at some point we should replace the db.AuthDB interface to implement
|
// TODO: at some point we should replace the db.AuthDB interface to implement
|
||||||
// `StoreCertificate(...*x509.Certificate) error` instead of just
|
// `StoreCertificate(...*x509.Certificate) error` instead of just
|
||||||
// `StoreCertificate(*x509.Certificate) error`.
|
// `StoreCertificate(*x509.Certificate) error`.
|
||||||
func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
|
func (a *Authority) storeCertificate(prov provisioner.Interface, fullchain []*x509.Certificate) error {
|
||||||
|
type linkedChainStorer interface {
|
||||||
|
StoreCertificateChain(provisioner.Interface, ...*x509.Certificate) error
|
||||||
|
}
|
||||||
type certificateChainStorer interface {
|
type certificateChainStorer interface {
|
||||||
StoreCertificateChain(...*x509.Certificate) error
|
StoreCertificateChain(...*x509.Certificate) error
|
||||||
}
|
}
|
||||||
// Store certificate in linkedca
|
// Store certificate in linkedca
|
||||||
if s, ok := a.adminDB.(certificateChainStorer); ok {
|
switch s := a.adminDB.(type) {
|
||||||
|
case linkedChainStorer:
|
||||||
|
return s.StoreCertificateChain(prov, fullchain...)
|
||||||
|
case certificateChainStorer:
|
||||||
return s.StoreCertificateChain(fullchain...)
|
return s.StoreCertificateChain(fullchain...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store certificate in local db
|
// Store certificate in local db
|
||||||
if s, ok := a.db.(certificateChainStorer); ok {
|
switch s := a.db.(type) {
|
||||||
|
case certificateChainStorer:
|
||||||
return s.StoreCertificateChain(fullchain...)
|
return s.StoreCertificateChain(fullchain...)
|
||||||
|
default:
|
||||||
|
return a.db.StoreCertificate(fullchain[0])
|
||||||
}
|
}
|
||||||
return a.db.StoreCertificate(fullchain[0])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// storeRenewedCertificate allows to use an extension of the db.AuthDB interface
|
// storeRenewedCertificate allows to use an extension of the db.AuthDB interface
|
||||||
|
|
|
@ -11,24 +11,26 @@ import (
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/cas/softcas"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/smallstep/assert"
|
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/keyutil"
|
"go.step.sm/crypto/keyutil"
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
"go.step.sm/crypto/x509util"
|
"go.step.sm/crypto/x509util"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/cas/softcas"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"github.com/smallstep/certificates/errs"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -187,14 +189,14 @@ func setExtraExtsCSR(exts []pkix.Extension) func(*x509.CertificateRequest) {
|
||||||
func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) {
|
func generateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) {
|
||||||
b, err := x509.MarshalPKIXPublicKey(pub)
|
b, err := x509.MarshalPKIXPublicKey(pub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "error marshaling public key")
|
return nil, fmt.Errorf("error marshaling public key: %w", err)
|
||||||
}
|
}
|
||||||
info := struct {
|
info := struct {
|
||||||
Algorithm pkix.AlgorithmIdentifier
|
Algorithm pkix.AlgorithmIdentifier
|
||||||
SubjectPublicKey asn1.BitString
|
SubjectPublicKey asn1.BitString
|
||||||
}{}
|
}{}
|
||||||
if _, err = asn1.Unmarshal(b, &info); err != nil {
|
if _, err = asn1.Unmarshal(b, &info); err != nil {
|
||||||
return nil, errors.Wrap(err, "error unmarshaling public key")
|
return nil, fmt.Errorf("error unmarshaling public key: %w", err)
|
||||||
}
|
}
|
||||||
hash := sha1.Sum(info.SubjectPublicKey.Bytes)
|
hash := sha1.Sum(info.SubjectPublicKey.Bytes)
|
||||||
return hash[:], nil
|
return hash[:], nil
|
||||||
|
@ -661,8 +663,8 @@ ZYtQ9Ot36qc=
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
||||||
assert.Nil(t, certChain)
|
assert.Nil(t, certChain)
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
|
||||||
|
@ -757,7 +759,7 @@ func TestAuthority_Renew(t *testing.T) {
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
nb1 := now.Add(-time.Minute * 7)
|
nb1 := now.Add(-time.Minute * 7)
|
||||||
na1 := now
|
na1 := now.Add(time.Hour)
|
||||||
so := &provisioner.SignOptions{
|
so := &provisioner.SignOptions{
|
||||||
NotBefore: provisioner.NewTimeDuration(nb1),
|
NotBefore: provisioner.NewTimeDuration(nb1),
|
||||||
NotAfter: provisioner.NewTimeDuration(na1),
|
NotAfter: provisioner.NewTimeDuration(na1),
|
||||||
|
@ -798,7 +800,20 @@ func TestAuthority_Renew(t *testing.T) {
|
||||||
"fail/unauthorized": func() (*renewTest, error) {
|
"fail/unauthorized": func() (*renewTest, error) {
|
||||||
return &renewTest{
|
return &renewTest{
|
||||||
cert: certNoRenew,
|
cert: certNoRenew,
|
||||||
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"),
|
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||||
|
code: http.StatusUnauthorized,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
"fail/WithAuthorizeRenewFunc": func() (*renewTest, error) {
|
||||||
|
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
|
||||||
|
return errs.Unauthorized("not authorized")
|
||||||
|
}))
|
||||||
|
aa.x509CAService = a.x509CAService
|
||||||
|
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
|
||||||
|
return &renewTest{
|
||||||
|
auth: aa,
|
||||||
|
cert: cert,
|
||||||
|
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
@ -820,6 +835,17 @@ func TestAuthority_Renew(t *testing.T) {
|
||||||
cert: cert,
|
cert: cert,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
"ok/WithAuthorizeRenewFunc": func() (*renewTest, error) {
|
||||||
|
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
aa.x509CAService = a.x509CAService
|
||||||
|
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
|
||||||
|
return &renewTest{
|
||||||
|
auth: aa,
|
||||||
|
cert: cert,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, genTestCase := range tests {
|
for name, genTestCase := range tests {
|
||||||
|
@ -836,8 +862,8 @@ func TestAuthority_Renew(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
||||||
assert.Nil(t, certChain)
|
assert.Nil(t, certChain)
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
|
||||||
|
@ -856,7 +882,7 @@ func TestAuthority_Renew(t *testing.T) {
|
||||||
|
|
||||||
expiry := now.Add(time.Minute * 7)
|
expiry := now.Add(time.Minute * 7)
|
||||||
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
|
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
|
||||||
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
|
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
|
||||||
|
|
||||||
tmplt := a.config.AuthorityConfig.Template
|
tmplt := a.config.AuthorityConfig.Template
|
||||||
assert.Equals(t, leaf.Subject.String(),
|
assert.Equals(t, leaf.Subject.String(),
|
||||||
|
@ -956,7 +982,7 @@ func TestAuthority_Rekey(t *testing.T) {
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
nb1 := now.Add(-time.Minute * 7)
|
nb1 := now.Add(-time.Minute * 7)
|
||||||
na1 := now
|
na1 := now.Add(time.Hour)
|
||||||
so := &provisioner.SignOptions{
|
so := &provisioner.SignOptions{
|
||||||
NotBefore: provisioner.NewTimeDuration(nb1),
|
NotBefore: provisioner.NewTimeDuration(nb1),
|
||||||
NotAfter: provisioner.NewTimeDuration(na1),
|
NotAfter: provisioner.NewTimeDuration(na1),
|
||||||
|
@ -998,7 +1024,7 @@ func TestAuthority_Rekey(t *testing.T) {
|
||||||
"fail/unauthorized": func() (*renewTest, error) {
|
"fail/unauthorized": func() (*renewTest, error) {
|
||||||
return &renewTest{
|
return &renewTest{
|
||||||
cert: certNoRenew,
|
cert: certNoRenew,
|
||||||
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"),
|
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||||
code: http.StatusUnauthorized,
|
code: http.StatusUnauthorized,
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
|
@ -1043,8 +1069,8 @@ func TestAuthority_Rekey(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
||||||
assert.Nil(t, certChain)
|
assert.Nil(t, certChain)
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
|
||||||
|
@ -1063,7 +1089,7 @@ func TestAuthority_Rekey(t *testing.T) {
|
||||||
|
|
||||||
expiry := now.Add(time.Minute * 7)
|
expiry := now.Add(time.Minute * 7)
|
||||||
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
|
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
|
||||||
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
|
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
|
||||||
|
|
||||||
tmplt := a.config.AuthorityConfig.Template
|
tmplt := a.config.AuthorityConfig.Template
|
||||||
assert.Equals(t, leaf.Subject.String(),
|
assert.Equals(t, leaf.Subject.String(),
|
||||||
|
@ -1432,8 +1458,8 @@ func TestAuthority_Revoke(t *testing.T) {
|
||||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
|
||||||
if err := tc.auth.Revoke(ctx, tc.opts); err != nil {
|
if err := tc.auth.Revoke(ctx, tc.opts); err != nil {
|
||||||
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
|
|
||||||
|
|
|
@ -12,12 +12,13 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
"go.step.sm/crypto/pemutil"
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
acmeAPI "github.com/smallstep/certificates/acme/api"
|
acmeAPI "github.com/smallstep/certificates/acme/api"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api/render"
|
||||||
"go.step.sm/crypto/jose"
|
|
||||||
"go.step.sm/crypto/pemutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewACMEClient(t *testing.T) {
|
func TestNewACMEClient(t *testing.T) {
|
||||||
|
@ -112,15 +113,15 @@ func TestNewACMEClient(t *testing.T) {
|
||||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||||
switch {
|
switch {
|
||||||
case i == 0:
|
case i == 0:
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
case i == 1:
|
case i == 1:
|
||||||
w.Header().Set("Replay-Nonce", "abc123")
|
w.Header().Set("Replay-Nonce", "abc123")
|
||||||
api.JSONStatus(w, []byte{}, 200)
|
render.JSONStatus(w, []byte{}, 200)
|
||||||
i++
|
i++
|
||||||
default:
|
default:
|
||||||
w.Header().Set("Location", accLocation)
|
w.Header().Set("Location", accLocation)
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -206,7 +207,7 @@ func TestACMEClient_GetNonce(t *testing.T) {
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
})
|
})
|
||||||
|
|
||||||
if nonce, err := ac.GetNonce(); err != nil {
|
if nonce, err := ac.GetNonce(); err != nil {
|
||||||
|
@ -315,7 +316,7 @@ func TestACMEClient_post(t *testing.T) {
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -338,7 +339,7 @@ func TestACMEClient_post(t *testing.T) {
|
||||||
assert.Equals(t, hdr.KeyID, ac.kid)
|
assert.Equals(t, hdr.KeyID, ac.kid)
|
||||||
}
|
}
|
||||||
|
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
})
|
})
|
||||||
|
|
||||||
if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil {
|
if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil {
|
||||||
|
@ -455,7 +456,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -477,7 +478,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, payload, norb)
|
assert.Equals(t, payload, norb)
|
||||||
|
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
})
|
})
|
||||||
|
|
||||||
if res, err := ac.NewOrder(norb); err != nil {
|
if res, err := ac.NewOrder(norb); err != nil {
|
||||||
|
@ -577,7 +578,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -599,7 +600,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, len(payload), 0)
|
assert.Equals(t, len(payload), 0)
|
||||||
|
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
})
|
})
|
||||||
|
|
||||||
if res, err := ac.GetOrder(url); err != nil {
|
if res, err := ac.GetOrder(url); err != nil {
|
||||||
|
@ -699,7 +700,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -721,7 +722,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, len(payload), 0)
|
assert.Equals(t, len(payload), 0)
|
||||||
|
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
})
|
})
|
||||||
|
|
||||||
if res, err := ac.GetAuthz(url); err != nil {
|
if res, err := ac.GetAuthz(url); err != nil {
|
||||||
|
@ -821,7 +822,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -844,7 +845,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
|
||||||
|
|
||||||
assert.Equals(t, len(payload), 0)
|
assert.Equals(t, len(payload), 0)
|
||||||
|
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
})
|
})
|
||||||
|
|
||||||
if res, err := ac.GetChallenge(url); err != nil {
|
if res, err := ac.GetChallenge(url); err != nil {
|
||||||
|
@ -944,7 +945,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -967,7 +968,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
|
||||||
|
|
||||||
assert.Equals(t, payload, []byte("{}"))
|
assert.Equals(t, payload, []byte("{}"))
|
||||||
|
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := ac.ValidateChallenge(url); err != nil {
|
if err := ac.ValidateChallenge(url); err != nil {
|
||||||
|
@ -1071,7 +1072,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1093,7 +1094,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, payload, frb)
|
assert.Equals(t, payload, frb)
|
||||||
|
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := ac.FinalizeOrder(url, csr); err != nil {
|
if err := ac.FinalizeOrder(url, csr); err != nil {
|
||||||
|
@ -1200,7 +1201,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1222,7 +1223,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, len(payload), 0)
|
assert.Equals(t, len(payload), 0)
|
||||||
|
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
})
|
})
|
||||||
|
|
||||||
if res, err := tc.client.GetAccountOrders(); err != nil {
|
if res, err := tc.client.GetAccountOrders(); err != nil {
|
||||||
|
@ -1331,7 +1332,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
api.JSONStatus(w, tc.r1, tc.rc1)
|
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||||
i++
|
i++
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1356,7 +1357,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
|
||||||
if tc.certBytes != nil {
|
if tc.certBytes != nil {
|
||||||
w.Write(tc.certBytes)
|
w.Write(tc.certBytes)
|
||||||
} else {
|
} else {
|
||||||
api.JSONStatus(w, tc.r2, tc.rc2)
|
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -14,11 +14,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/api"
|
|
||||||
"github.com/smallstep/certificates/authority"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/randutil"
|
"go.step.sm/crypto/randutil"
|
||||||
|
|
||||||
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
|
"github.com/smallstep/certificates/authority"
|
||||||
|
"github.com/smallstep/certificates/errs"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newLocalListener() net.Listener {
|
func newLocalListener() net.Listener {
|
||||||
|
@ -79,7 +82,7 @@ func startCAServer(configFile string) (*CA, string, error) {
|
||||||
func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler {
|
func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/version" {
|
if r.URL.Path == "/version" {
|
||||||
api.JSON(w, api.VersionResponse{
|
render.JSON(w, api.VersionResponse{
|
||||||
Version: "test",
|
Version: "test",
|
||||||
RequireClientAuthentication: true,
|
RequireClientAuthentication: true,
|
||||||
})
|
})
|
||||||
|
@ -93,7 +96,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han
|
||||||
}
|
}
|
||||||
isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0
|
isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0
|
||||||
if !isMTLS {
|
if !isMTLS {
|
||||||
api.WriteError(w, errs.Unauthorized("missing peer certificate"))
|
render.Error(w, errs.Unauthorized("missing peer certificate"))
|
||||||
} else {
|
} else {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
@ -408,6 +411,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
|
||||||
server.ServeTLS(listener, "", "")
|
server.ServeTLS(listener, "", "")
|
||||||
}()
|
}()
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
// Create bootstrap client
|
// Create bootstrap client
|
||||||
token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
|
token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
|
||||||
|
@ -419,7 +423,6 @@ func TestBootstrapClientServerRotation(t *testing.T) {
|
||||||
|
|
||||||
// doTest does a request that requires mTLS
|
// doTest does a request that requires mTLS
|
||||||
doTest := func(client *http.Client) error {
|
doTest := func(client *http.Client) error {
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
// test with ca
|
// test with ca
|
||||||
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
|
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
44
ca/ca.go
44
ca/ca.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
|
@ -26,11 +27,14 @@ import (
|
||||||
scepAPI "github.com/smallstep/certificates/scep/api"
|
scepAPI "github.com/smallstep/certificates/scep/api"
|
||||||
"github.com/smallstep/certificates/server"
|
"github.com/smallstep/certificates/server"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
|
"go.step.sm/cli-utils/step"
|
||||||
|
"go.step.sm/crypto/x509util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
configFile string
|
configFile string
|
||||||
linkedCAToken string
|
linkedCAToken string
|
||||||
|
quiet bool
|
||||||
password []byte
|
password []byte
|
||||||
issuerPassword []byte
|
issuerPassword []byte
|
||||||
sshHostPassword []byte
|
sshHostPassword []byte
|
||||||
|
@ -101,6 +105,13 @@ func WithLinkedCAToken(token string) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithQuiet sets the quiet flag.
|
||||||
|
func WithQuiet(quiet bool) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.quiet = quiet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CA is the type used to build the complete certificate authority. It builds
|
// CA is the type used to build the complete certificate authority. It builds
|
||||||
// the HTTP server, set ups the middlewares and the HTTP handlers.
|
// the HTTP server, set ups the middlewares and the HTTP handlers.
|
||||||
type CA struct {
|
type CA struct {
|
||||||
|
@ -288,6 +299,35 @@ func (ca *CA) Run() error {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
errs := make(chan error, 1)
|
errs := make(chan error, 1)
|
||||||
|
|
||||||
|
if !ca.opts.quiet {
|
||||||
|
authorityInfo := ca.auth.GetInfo()
|
||||||
|
log.Printf("Starting %s", step.Version())
|
||||||
|
log.Printf("Documentation: https://u.step.sm/docs/ca")
|
||||||
|
log.Printf("Community Discord: https://u.step.sm/discord")
|
||||||
|
if step.Contexts().GetCurrent() != nil {
|
||||||
|
log.Printf("Current context: %s", step.Contexts().GetCurrent().Name)
|
||||||
|
}
|
||||||
|
log.Printf("Config file: %s", ca.opts.configFile)
|
||||||
|
baseURL := fmt.Sprintf("https://%s%s",
|
||||||
|
authorityInfo.DNSNames[0],
|
||||||
|
ca.config.Address[strings.LastIndex(ca.config.Address, ":"):])
|
||||||
|
log.Printf("The primary server URL is %s", baseURL)
|
||||||
|
log.Printf("Root certificates are available at %s/roots.pem", baseURL)
|
||||||
|
if len(authorityInfo.DNSNames) > 1 {
|
||||||
|
log.Printf("Additional configured hostnames: %s",
|
||||||
|
strings.Join(authorityInfo.DNSNames[1:], ", "))
|
||||||
|
}
|
||||||
|
for _, crt := range authorityInfo.RootX509Certs {
|
||||||
|
log.Printf("X.509 Root Fingerprint: %s", x509util.Fingerprint(crt))
|
||||||
|
}
|
||||||
|
if authorityInfo.SSHCAHostPublicKey != nil {
|
||||||
|
log.Printf("SSH Host CA Key is %s\n", authorityInfo.SSHCAHostPublicKey)
|
||||||
|
}
|
||||||
|
if authorityInfo.SSHCAUserPublicKey != nil {
|
||||||
|
log.Printf("SSH User CA Key: %s\n", authorityInfo.SSHCAUserPublicKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if ca.insecureSrv != nil {
|
if ca.insecureSrv != nil {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -355,6 +395,7 @@ func (ca *CA) Reload() error {
|
||||||
WithSSHUserPassword(ca.opts.sshUserPassword),
|
WithSSHUserPassword(ca.opts.sshUserPassword),
|
||||||
WithIssuerPassword(ca.opts.issuerPassword),
|
WithIssuerPassword(ca.opts.issuerPassword),
|
||||||
WithLinkedCAToken(ca.opts.linkedCAToken),
|
WithLinkedCAToken(ca.opts.linkedCAToken),
|
||||||
|
WithQuiet(ca.opts.quiet),
|
||||||
WithConfigFile(ca.opts.configFile),
|
WithConfigFile(ca.opts.configFile),
|
||||||
WithDatabase(ca.auth.GetDatabase()),
|
WithDatabase(ca.auth.GetDatabase()),
|
||||||
)
|
)
|
||||||
|
@ -450,9 +491,6 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
|
||||||
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
||||||
tlsConfig.ClientCAs = certPool
|
tlsConfig.ClientCAs = certPool
|
||||||
|
|
||||||
// Use server's most preferred ciphersuite
|
|
||||||
tlsConfig.PreferServerCipherSuites = true
|
|
||||||
|
|
||||||
return tlsConfig, nil
|
return tlsConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
35
ca/client.go
35
ca/client.go
|
@ -563,6 +563,11 @@ func (c *Client) retryOnError(r *http.Response) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCaURL returns the configured CA url.
|
||||||
|
func (c *Client) GetCaURL() string {
|
||||||
|
return c.endpoint.String()
|
||||||
|
}
|
||||||
|
|
||||||
// GetRootCAs returns the RootCAs certificate pool from the configured
|
// GetRootCAs returns the RootCAs certificate pool from the configured
|
||||||
// transport.
|
// transport.
|
||||||
func (c *Client) GetRootCAs() *x509.CertPool {
|
func (c *Client) GetRootCAs() *x509.CertPool {
|
||||||
|
@ -723,6 +728,36 @@ retry:
|
||||||
return &sign, nil
|
return &sign, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RenewWithToken performs the renew request to the CA with the given
|
||||||
|
// authorization token and returns the api.SignResponse struct. This method is
|
||||||
|
// generally used to renew an expired certificate.
|
||||||
|
func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) {
|
||||||
|
var retried bool
|
||||||
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
|
||||||
|
req, err := http.NewRequest("POST", u.String(), http.NoBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request")
|
||||||
|
}
|
||||||
|
req.Header.Add("Authorization", "Bearer "+token)
|
||||||
|
retry:
|
||||||
|
resp, err := c.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
return nil, readError(resp.Body)
|
||||||
|
}
|
||||||
|
var sign api.SignResponse
|
||||||
|
if err := readJSON(resp.Body, &sign); err != nil {
|
||||||
|
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u)
|
||||||
|
}
|
||||||
|
return &sign, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Rekey performs the rekey request to the CA and returns the api.SignResponse
|
// Rekey performs the rekey request to the CA and returns the api.SignResponse
|
||||||
// struct.
|
// struct.
|
||||||
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {
|
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -16,14 +17,16 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"go.step.sm/crypto/x509util"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/api"
|
"github.com/smallstep/certificates/api"
|
||||||
|
"github.com/smallstep/certificates/api/read"
|
||||||
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"go.step.sm/crypto/x509util"
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -179,7 +182,7 @@ func TestClient_Version(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Version()
|
got, err := c.Version()
|
||||||
|
@ -229,7 +232,7 @@ func TestClient_Health(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Health()
|
got, err := c.Health()
|
||||||
|
@ -287,7 +290,7 @@ func TestClient_Root(t *testing.T) {
|
||||||
if req.RequestURI != expected {
|
if req.RequestURI != expected {
|
||||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
||||||
}
|
}
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Root(tt.shasum)
|
got, err := c.Root(tt.shasum)
|
||||||
|
@ -354,10 +357,10 @@ func TestClient_Sign(t *testing.T) {
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
body := new(api.SignRequest)
|
body := new(api.SignRequest)
|
||||||
if err := api.ReadJSON(req.Body, body); err != nil {
|
if err := read.JSON(req.Body, body); err != nil {
|
||||||
e, ok := tt.response.(error)
|
e, ok := tt.response.(error)
|
||||||
assert.Fatal(t, ok, "response expected to be error type")
|
assert.Fatal(t, ok, "response expected to be error type")
|
||||||
api.WriteError(w, e)
|
render.Error(w, e)
|
||||||
return
|
return
|
||||||
} else if !equalJSON(t, body, tt.request) {
|
} else if !equalJSON(t, body, tt.request) {
|
||||||
if tt.request == nil {
|
if tt.request == nil {
|
||||||
|
@ -368,7 +371,7 @@ func TestClient_Sign(t *testing.T) {
|
||||||
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
|
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Sign(tt.request)
|
got, err := c.Sign(tt.request)
|
||||||
|
@ -426,10 +429,10 @@ func TestClient_Revoke(t *testing.T) {
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
body := new(api.RevokeRequest)
|
body := new(api.RevokeRequest)
|
||||||
if err := api.ReadJSON(req.Body, body); err != nil {
|
if err := read.JSON(req.Body, body); err != nil {
|
||||||
e, ok := tt.response.(error)
|
e, ok := tt.response.(error)
|
||||||
assert.Fatal(t, ok, "response expected to be error type")
|
assert.Fatal(t, ok, "response expected to be error type")
|
||||||
api.WriteError(w, e)
|
render.Error(w, e)
|
||||||
return
|
return
|
||||||
} else if !equalJSON(t, body, tt.request) {
|
} else if !equalJSON(t, body, tt.request) {
|
||||||
if tt.request == nil {
|
if tt.request == nil {
|
||||||
|
@ -440,7 +443,7 @@ func TestClient_Revoke(t *testing.T) {
|
||||||
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
|
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Revoke(tt.request, nil)
|
got, err := c.Revoke(tt.request, nil)
|
||||||
|
@ -500,7 +503,7 @@ func TestClient_Renew(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Renew(nil)
|
got, err := c.Renew(nil)
|
||||||
|
@ -516,8 +519,8 @@ func TestClient_Renew(t *testing.T) {
|
||||||
t.Errorf("Client.Renew() = %v, want nil", got)
|
t.Errorf("Client.Renew() = %v, want nil", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
default:
|
default:
|
||||||
|
@ -529,6 +532,74 @@ func TestClient_Renew(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_RenewWithToken(t *testing.T) {
|
||||||
|
ok := &api.SignResponse{
|
||||||
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
||||||
|
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
||||||
|
CertChainPEM: []api.Certificate{
|
||||||
|
{Certificate: parseCertificate(certPEM)},
|
||||||
|
{Certificate: parseCertificate(rootPEM)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
response interface{}
|
||||||
|
responseCode int
|
||||||
|
wantErr bool
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"ok", ok, 200, false, nil},
|
||||||
|
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||||
|
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
|
||||||
|
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(nil)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewClient() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Header.Get("Authorization") != "Bearer token" {
|
||||||
|
render.JSONStatus(w, errs.InternalServer("force"), 500)
|
||||||
|
} else {
|
||||||
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := c.RenewWithToken("token")
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
fmt.Printf("%+v", err)
|
||||||
|
t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
if got != nil {
|
||||||
|
t.Errorf("Client.RenewWithToken() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
sc, ok := err.(render.StatusCodedError)
|
||||||
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||||
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
|
default:
|
||||||
|
if !reflect.DeepEqual(got, tt.response) {
|
||||||
|
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClient_Rekey(t *testing.T) {
|
func TestClient_Rekey(t *testing.T) {
|
||||||
ok := &api.SignResponse{
|
ok := &api.SignResponse{
|
||||||
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
||||||
|
@ -569,7 +640,7 @@ func TestClient_Rekey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Rekey(tt.request, nil)
|
got, err := c.Rekey(tt.request, nil)
|
||||||
|
@ -585,8 +656,8 @@ func TestClient_Rekey(t *testing.T) {
|
||||||
t.Errorf("Client.Renew() = %v, want nil", got)
|
t.Errorf("Client.Renew() = %v, want nil", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
default:
|
default:
|
||||||
|
@ -634,7 +705,7 @@ func TestClient_Provisioners(t *testing.T) {
|
||||||
if req.RequestURI != tt.expectedURI {
|
if req.RequestURI != tt.expectedURI {
|
||||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
|
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
|
||||||
}
|
}
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Provisioners(tt.args...)
|
got, err := c.Provisioners(tt.args...)
|
||||||
|
@ -691,7 +762,7 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
||||||
if req.RequestURI != expected {
|
if req.RequestURI != expected {
|
||||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
||||||
}
|
}
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.ProvisionerKey(tt.kid)
|
got, err := c.ProvisionerKey(tt.kid)
|
||||||
|
@ -706,8 +777,8 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
||||||
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
|
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||||
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||||
default:
|
default:
|
||||||
|
@ -750,7 +821,7 @@ func TestClient_Roots(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Roots()
|
got, err := c.Roots()
|
||||||
|
@ -765,8 +836,8 @@ func TestClient_Roots(t *testing.T) {
|
||||||
if got != nil {
|
if got != nil {
|
||||||
t.Errorf("Client.Roots() = %v, want nil", got)
|
t.Errorf("Client.Roots() = %v, want nil", got)
|
||||||
}
|
}
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
default:
|
default:
|
||||||
|
@ -808,7 +879,7 @@ func TestClient_Federation(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.Federation()
|
got, err := c.Federation()
|
||||||
|
@ -823,8 +894,8 @@ func TestClient_Federation(t *testing.T) {
|
||||||
if got != nil {
|
if got != nil {
|
||||||
t.Errorf("Client.Federation() = %v, want nil", got)
|
t.Errorf("Client.Federation() = %v, want nil", got)
|
||||||
}
|
}
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||||
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||||
default:
|
default:
|
||||||
|
@ -870,7 +941,7 @@ func TestClient_SSHRoots(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.SSHRoots()
|
got, err := c.SSHRoots()
|
||||||
|
@ -885,8 +956,8 @@ func TestClient_SSHRoots(t *testing.T) {
|
||||||
if got != nil {
|
if got != nil {
|
||||||
t.Errorf("Client.SSHKeys() = %v, want nil", got)
|
t.Errorf("Client.SSHKeys() = %v, want nil", got)
|
||||||
}
|
}
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||||
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||||
default:
|
default:
|
||||||
|
@ -970,7 +1041,7 @@ func TestClient_RootFingerprint(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.RootFingerprint()
|
got, err := c.RootFingerprint()
|
||||||
|
@ -1031,7 +1102,7 @@ func TestClient_SSHBastion(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
got, err := c.SSHBastion(tt.request)
|
got, err := c.SSHBastion(tt.request)
|
||||||
|
@ -1047,8 +1118,8 @@ func TestClient_SSHBastion(t *testing.T) {
|
||||||
t.Errorf("Client.SSHBastion() = %v, want nil", got)
|
t.Errorf("Client.SSHBastion() = %v, want nil", got)
|
||||||
}
|
}
|
||||||
if tt.responseCode != 200 {
|
if tt.responseCode != 200 {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(render.StatusCodedError)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -1060,3 +1131,28 @@ func TestClient_SSHBastion(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_GetCaURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
caURL string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"ok", "https://ca.com", "https://ca.com"},
|
||||||
|
{"ok no schema", "ca.com", "https://ca.com"},
|
||||||
|
{"ok with port", "https://ca.com:9000", "https://ca.com:9000"},
|
||||||
|
{"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewClient() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := c.GetCaURL(); got != tt.want {
|
||||||
|
t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -196,7 +197,7 @@ func TestLoadClient(t *testing.T) {
|
||||||
switch {
|
switch {
|
||||||
case gotTransport.TLSClientConfig.GetClientCertificate == nil:
|
case gotTransport.TLSClientConfig.GetClientCertificate == nil:
|
||||||
t.Error("LoadClient() transport does not define GetClientCertificate")
|
t.Error("LoadClient() transport does not define GetClientCertificate")
|
||||||
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()):
|
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs):
|
||||||
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
|
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
|
||||||
default:
|
default:
|
||||||
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
|
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
|
||||||
|
@ -238,3 +239,23 @@ func Test_defaultsConfig_Validate(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:staticcheck,gocritic
|
||||||
|
func equalPools(a, b *x509.CertPool) bool {
|
||||||
|
if reflect.DeepEqual(a, b) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
subjects := a.Subjects()
|
||||||
|
sA := make([]string, len(subjects))
|
||||||
|
for i := range subjects {
|
||||||
|
sA[i] = string(subjects[i])
|
||||||
|
}
|
||||||
|
subjects = b.Subjects()
|
||||||
|
sB := make([]string, len(subjects))
|
||||||
|
for i := range subjects {
|
||||||
|
sB[i] = string(subjects[i])
|
||||||
|
}
|
||||||
|
sort.Strings(sA)
|
||||||
|
sort.Strings(sB)
|
||||||
|
return reflect.DeepEqual(sA, sB)
|
||||||
|
}
|
||||||
|
|
|
@ -346,6 +346,8 @@ func TestIdentity_GetCertPool(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if got != nil {
|
if got != nil {
|
||||||
|
// nolint:staticcheck // we don't have a different way to check
|
||||||
|
// the certificates in the pool.
|
||||||
subjects := got.Subjects()
|
subjects := got.Subjects()
|
||||||
if !reflect.DeepEqual(subjects, tt.wantSubjects) {
|
if !reflect.DeepEqual(subjects, tt.wantSubjects) {
|
||||||
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)
|
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)
|
||||||
|
|
|
@ -60,7 +60,10 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore)
|
// Use the current time to calculate the initial period. Using a notBefore
|
||||||
|
// in the past might set a renewBefore too large, causing continuous
|
||||||
|
// renewals due to the negative values in nextRenewDuration.
|
||||||
|
period := cert.Leaf.NotAfter.Sub(time.Now().Truncate(time.Second))
|
||||||
if period < minCertDuration {
|
if period < minCertDuration {
|
||||||
return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period)
|
return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period)
|
||||||
}
|
}
|
||||||
|
@ -181,7 +184,7 @@ func (r *TLSRenewer) renewCertificate() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {
|
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {
|
||||||
d := time.Until(notAfter) - r.renewBefore
|
d := time.Until(notAfter).Truncate(time.Second) - r.renewBefore
|
||||||
n := rand.Int63n(int64(r.renewJitter))
|
n := rand.Int63n(int64(r.renewJitter))
|
||||||
d -= time.Duration(n)
|
d -= time.Duration(n)
|
||||||
if d < 0 {
|
if d < 0 {
|
||||||
|
|
2
ca/testdata/ca.json
vendored
2
ca/testdata/ca.json
vendored
|
@ -6,7 +6,7 @@
|
||||||
"password": "password",
|
"password": "password",
|
||||||
"address": "127.0.0.1:0",
|
"address": "127.0.0.1:0",
|
||||||
"dnsNames": ["127.0.0.1"],
|
"dnsNames": ["127.0.0.1"],
|
||||||
"logger": {"format": "text"},
|
"_logger": {"format": "text"},
|
||||||
"tls": {
|
"tls": {
|
||||||
"minVersion": 1.2,
|
"minVersion": 1.2,
|
||||||
"maxVersion": 1.3,
|
"maxVersion": 1.3,
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue