Compare commits

..

1 commit

Author SHA1 Message Date
Herman Slatman
d25b8afe68
Add CNAME chasing for TXT records 2022-10-20 01:18:34 +02:00
199 changed files with 7301 additions and 11040 deletions

View file

@ -20,8 +20,7 @@ jobs:
ci: ci:
uses: smallstep/workflows/.github/workflows/goCI.yml@main uses: smallstep/workflows/.github/workflows/goCI.yml@main
with: with:
only-latest-golang: false os-dependencies: "libpcsclite-dev"
os-dependencies: 'libpcsclite-dev' run-gitleaks: true
run-codeql: true run-codeql: true
test-command: 'V=1 make test'
secrets: inherit secrets: inherit

View file

@ -1,22 +0,0 @@
name: Dependabot auto-merge
on: pull_request
permissions:
contents: write
pull-requests: write
jobs:
dependabot:
runs-on: ubuntu-latest
if: ${{ github.actor == 'dependabot[bot]' }}
steps:
- name: Dependabot metadata
id: metadata
uses: dependabot/fetch-metadata@v1.1.1
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
- name: Enable auto-merge for Dependabot PRs
run: gh pr merge --auto --merge "$PR_URL"
env:
PR_URL: ${{github.event.pull_request.html_url}}
GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}

View file

@ -13,15 +13,10 @@ jobs:
create_release: create_release:
name: Create Release name: Create Release
needs: ci #needs: ci
runs-on: ubuntu-latest runs-on: ubuntu-20.04
env:
DOCKER_IMAGE: smallstep/step-ca
outputs: outputs:
version: ${{ steps.extract-tag.outputs.VERSION }}
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
docker_tags: ${{ env.DOCKER_TAGS }}
docker_tags_hsm: ${{ env.DOCKER_TAGS_HSM }}
steps: steps:
- name: Is Pre-release - name: Is Pre-release
id: is_prerelease id: is_prerelease
@ -30,19 +25,7 @@ jobs:
echo ${{ github.ref }} | grep "\-rc.*" echo ${{ github.ref }} | grep "\-rc.*"
OUT=$? OUT=$?
if [ $OUT -eq 0 ]; then IS_PRERELEASE=true; else IS_PRERELEASE=false; fi if [ $OUT -eq 0 ]; then IS_PRERELEASE=true; else IS_PRERELEASE=false; fi
echo "IS_PRERELEASE=${IS_PRERELEASE}" >> ${GITHUB_OUTPUT} echo "::set-output name=IS_PRERELEASE::${IS_PRERELEASE}"
- name: Extract Tag Names
id: extract-tag
run: |
VERSION=${GITHUB_REF#refs/tags/v}
echo "VERSION=${VERSION}" >> ${GITHUB_OUTPUT}
echo "DOCKER_TAGS=${{ env.DOCKER_IMAGE }}:${VERSION}" >> ${GITHUB_ENV}
echo "DOCKER_TAGS_HSM=${{ env.DOCKER_IMAGE }}:${VERSION}-hsm" >> ${GITHUB_ENV}
- name: Add Latest Tag
if: steps.is_prerelease.outputs.IS_PRERELEASE == 'false'
run: |
echo "DOCKER_TAGS=${{ env.DOCKER_TAGS }},${{ env.DOCKER_IMAGE }}:latest" >> ${GITHUB_ENV}
echo "DOCKER_TAGS_HSM=${{ env.DOCKER_TAGS_HSM }},${{ env.DOCKER_IMAGE }}:hsm" >> ${GITHUB_ENV}
- name: Create Release - name: Create Release
id: create_release id: create_release
uses: actions/create-release@v1 uses: actions/create-release@v1
@ -55,37 +38,64 @@ jobs:
prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
goreleaser: goreleaser:
name: Upload Assets To Github w/ goreleaser
runs-on: ubuntu-20.04
needs: create_release needs: create_release
permissions: steps:
id-token: write - name: Checkout
contents: write uses: actions/checkout@v3
uses: smallstep/workflows/.github/workflows/goreleaser.yml@main - name: Set up Go
secrets: inherit uses: actions/setup-go@v3
with:
go-version: 1.19
check-latest: true
- name: Install cosign
uses: sigstore/cosign-installer@v2.7.0
with:
cosign-release: 'v1.12.1'
- name: Write cosign key to disk
id: write_key
run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
- name: Get Release Date
id: release_date
run: |
RELEASE_DATE=$(date +"%y-%m-%d")
echo "::set-output name=RELEASE_DATE::${RELEASE_DATE}"
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v3
with:
version: 'latest'
args: release --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.GORELEASER_PAT }}
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}
RELEASE_DATE: ${{ steps.release_date.outputs.RELEASE_DATE }}
build_upload_docker: build_upload_docker:
name: Build & Upload Docker Images name: Build & Upload Docker Images
needs: create_release runs-on: ubuntu-20.04
permissions: needs: ci
id-token: write steps:
contents: write - name: Checkout
uses: smallstep/workflows/.github/workflows/docker-buildx-push.yml@main uses: actions/checkout@v3
- name: Setup Go
uses: actions/setup-go@v3
with: with:
platforms: linux/amd64,linux/386,linux/arm,linux/arm64 go-version: '1.19'
tags: ${{ needs.create_release.outputs.docker_tags }} check-latest: true
docker_image: smallstep/step-ca - name: Install cosign
docker_file: docker/Dockerfile uses: sigstore/cosign-installer@v1.1.0
secrets: inherit
build_upload_docker_hsm:
name: Build & Upload HSM Enabled Docker Images
needs: create_release
permissions:
id-token: write
contents: write
uses: smallstep/workflows/.github/workflows/docker-buildx-push.yml@main
with: with:
platforms: linux/amd64,linux/386,linux/arm,linux/arm64 cosign-release: 'v1.1.0'
tags: ${{ needs.create_release.outputs.docker_tags_hsm }} - name: Write cosign key to disk
docker_image: smallstep/step-ca id: write_key
docker_file: docker/Dockerfile.hsm run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
secrets: inherit - name: Build
id: build
run: |
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
make docker-artifacts
env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }}
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}

1
.gitignore vendored
View file

@ -24,4 +24,3 @@ output
vendor vendor
.idea .idea
.envrc .envrc
.vscode

View file

@ -31,12 +31,11 @@ builds:
- -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}} - -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}}
archives: archives:
- &ARCHIVE -
# Can be used to change the archive formats for specific GOOSs. # Can be used to change the archive formats for specific GOOSs.
# Most common use case is to archive as zip on Windows. # Most common use case is to archive as zip on Windows.
# Default is empty. # Default is empty.
name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Version }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}" name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Version }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
rlcp: true
format_overrides: format_overrides:
- goos: windows - goos: windows
format: zip format: zip
@ -45,11 +44,6 @@ archives:
- README.md - README.md
- LICENSE - LICENSE
allow_different_binary_count: true allow_different_binary_count: true
-
<< : *ARCHIVE
id: unversioned
name_template: "{{ .ProjectName }}_{{ .Os }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
nfpms: nfpms:
# Configure nFPM for .deb and .rpm releases # Configure nFPM for .deb and .rpm releases
@ -61,7 +55,7 @@ nfpms:
# List file contents: dpkg -c dist/step_...deb # List file contents: dpkg -c dist/step_...deb
# Package metadata: dpkg --info dist/step_....deb # Package metadata: dpkg --info dist/step_....deb
# #
- &NFPM -
builds: builds:
- step-ca - step-ca
package_name: step-ca package_name: step-ca
@ -81,14 +75,9 @@ nfpms:
contents: contents:
- src: debian/copyright - src: debian/copyright
dst: /usr/share/doc/step-ca/copyright dst: /usr/share/doc/step-ca/copyright
-
<< : *NFPM
id: unversioned
file_name_template: "{{ .PackageName }}_{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }}{{ if .Mips }}_{{ .Mips }}{{ end }}"
source: source:
enabled: true enabled: true
rlcp: true
name_template: '{{ .ProjectName }}_{{ .Version }}' name_template: '{{ .ProjectName }}_{{ .Version }}'
checksum: checksum:
@ -98,9 +87,8 @@ checksum:
signs: signs:
- cmd: cosign - cmd: cosign
signature: "${artifact}.sig" stdin: '{{ .Env.COSIGN_PWD }}'
certificate: "${artifact}.pem" args: ["sign-blob", "-key=/tmp/cosign.key", "-output-signature=${signature}", "${artifact}"]
args: ["sign-blob", "--oidc-issuer=https://token.actions.githubusercontent.com", "--output-certificate=${certificate}", "--output-signature=${signature}", "${artifact}"]
artifacts: all artifacts: all
snapshot: snapshot:
@ -141,17 +129,17 @@ release:
#### Linux #### Linux
- 📦 [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz) - 📦 [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz)
- 📦 [step-ca_{{ .Version }}_amd64.deb](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ .Version }}_amd64.deb) - 📦 [step-ca_{{ .Version }}_amd64.deb](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ .Version }}_amd64.deb)
#### OSX Darwin #### OSX Darwin
- 📦 [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz) - 📦 [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz)
- 📦 [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz) - 📦 [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz)
#### Windows #### Windows
- 📦 [step-ca_windows_{{ .Version }}_amd64.zip](https://dl.smallstep.com/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip) - 📦 [step-ca_windows_{{ .Version }}_arm64.zip](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip)
For more builds across platforms and architectures, see the `Assets` section below. For more builds across platforms and architectures, see the `Assets` section below.
And for packaged versions (Docker, k8s, Homebrew), see our [installation docs](https://smallstep.com/docs/step-ca/installation). And for packaged versions (Docker, k8s, Homebrew), see our [installation docs](https://smallstep.com/docs/step-ca/installation).
@ -166,10 +154,8 @@ release:
``` ```
cosign verify-blob \ cosign verify-blob \
--certificate ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig.pem \ -key https://raw.githubusercontent.com/smallstep/certificates/master/cosign.pub \
--signature ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig \ -signature ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig
--certificate-identity-regexp "https://github\.com/smallstep/certificates/.*" \
--certificate-oidc-issuer https://token.actions.githubusercontent.com \
~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz
``` ```
@ -198,41 +184,3 @@ release:
# - glob: ./path/to/file.txt # - glob: ./path/to/file.txt
# - glob: ./glob/**/to/**/file/**/* # - glob: ./glob/**/to/**/file/**/*
# - glob: ./glob/foo/to/bar/file/foobar/override_from_previous # - glob: ./glob/foo/to/bar/file/foobar/override_from_previous
scoops:
-
ids: [ default ]
# Template for the url which is determined by the given Token (github or gitlab)
# Default for github is "https://github.com/<repo_owner>/<repo_name>/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
# Default for gitlab is "https://gitlab.com/<repo_owner>/<repo_name>/uploads/{{ .ArtifactUploadHash }}/{{ .ArtifactName }}"
# Default for gitea is "https://gitea.com/<repo_owner>/<repo_name>/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
url_template: "http://github.com/smallstep/certificates/releases/download/{{ .Tag }}/{{ .ArtifactName }}"
# Repository to push the app manifest to.
bucket:
owner: smallstep
name: scoop-bucket
# Git author used to commit to the repository.
# Defaults are shown.
commit_author:
name: goreleaserbot
email: goreleaser@smallstep.com
# The project name and current git tag are used in the format string.
commit_msg_template: "Scoop update for {{ .ProjectName }} version {{ .Tag }}"
# Your app's homepage.
# Default is empty.
homepage: "https://smallstep.com/docs/step-ca"
# Skip uploads for prerelease.
skip_upload: auto
# Your app's description.
# Default is empty.
description: "A private certificate authority (X.509 & SSH) & ACME server for secure automated certificate management, so you can use TLS everywhere & SSO for SSH."
# Your app's license
# Default is empty.
license: "Apache-2.0"

View file

@ -1,4 +1,4 @@
#!/usr/bin/env sh #!/usr/bin/env bash
read -r firstline < .VERSION read -r firstline < .VERSION
last_half="${firstline##*tag: }" last_half="${firstline##*tag: }"
if [[ ${last_half::1} == "v" ]]; then if [[ ${last_half::1} == "v" ]]; then

View file

@ -1,200 +1,39 @@
# Changelog # Changelog
All notable changes to this project will be documented in this file. All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## TEMPLATE -- do not alter or remove ### TEMPLATE -- do not alter or remove
--- ---
## [x.y.z] - aaaa-bb-cc ## [x.y.z] - aaaa-bb-cc
### Added ### Added
### Changed ### Changed
### Deprecated ### Deprecated
### Removed ### Removed
### Fixed ### Fixed
### Security ### Security
--- ---
## [Unreleased] ## [Unreleased]
### Fixed
- Improved authentication for ACME requests using kid and provisioner name
(smallstep/certificates#1386).
## [v0.24.2] - 2023-05-11
### Added ### Added
- Added support for ACME device-attest-01 challenge.
- Log SSH certificates (smallstep/certificates#1374)
- CRL endpoints on the HTTP server (smallstep/certificates#1372)
- Dynamic SCEP challenge validation using webhooks (smallstep/certificates#1366)
- For Docker deployments, added DOCKER_STEPCA_INIT_PASSWORD_FILE. Useful for pointing to a Docker Secret in the container (smallstep/certificates#1384)
### Changed
- Depend on [smallstep/go-attestation](https://github.com/smallstep/go-attestation) instead of [google/go-attestation](https://github.com/google/go-attestation)
- Render CRLs into http.ResponseWriter instead of memory (smallstep/certificates#1373)
- Redaction of SCEP static challenge when listing provisioners (smallstep/certificates#1204)
### Fixed
- VaultCAS certificate lifetime (smallstep/certificates#1376)
## [v0.24.1] - 2023-04-14
### Fixed
- Docker image name for HSM support (smallstep/certificates#1348)
## [v0.24.0] - 2023-04-12
### Added
- Add ACME `device-attest-01` support with TPM 2.0
(smallstep/certificates#1063).
- Add support for new Azure SDK, sovereign clouds, and HSM keys on Azure KMS
(smallstep/crypto#192, smallstep/crypto#197, smallstep/crypto#198,
smallstep/certificates#1323, smallstep/certificates#1309).
- Add support for ASN.1 functions on certificate templates
(smallstep/crypto#208, smallstep/certificates#1345)
- Add `DOCKER_STEPCA_INIT_ADDRESS` to configure the address to use in a docker
container (smallstep/certificates#1262).
- Make sure that the CSR used matches the attested key when using AME
`device-attest-01` challenge (smallstep/certificates#1265).
- Add support for compacting the Badger DB (smallstep/certificates#1298).
- Build and release cleanups (smallstep/certificates#1322,
smallstep/certificates#1329, smallstep/certificates#1340).
### Fixed
- Fix support for PKCS #7 RSA-OAEP decryption through
[smallstep/pkcs7#4](https://github.com/smallstep/pkcs7/pull/4), as used in
SCEP.
- Fix RA installation using `scripts/install-step-ra.sh`
(smallstep/certificates#1255).
- Clarify error messages on policy errors (smallstep/certificates#1287,
smallstep/certificates#1278).
- Clarify error message on OIDC email validation (smallstep/certificates#1290).
- Mark the IDP critical in the generated CRL data (smallstep/certificates#1293).
- Disable database if CA is initialized with the `--no-db` flag
(smallstep/certificates#1294).
## [v0.23.2] - 2023-02-02
### Added
- Added [`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) to
docker images, and a new image, `smallstep/step-ca-hsm`, compiled with cgo
(smallstep/certificates#1243).
- Added [`scoop`](https://scoop.sh) packages back to the release
(smallstep/certificates#1250).
- Added optional flag `--pidfile` which allows passing a filename where step-ca
will write its process id (smallstep/certificates#1251).
- Added helpful message on CA startup when config can't be opened
(smallstep/certificates#1252).
- Improved validation and error messages on `device-attest-01` orders
(smallstep/certificates#1235).
### Removed
- The deprecated CLI utils `step-awskms-init`, `step-cloudkms-init`,
`step-pkcs11-init`, `step-yubikey-init` have been removed.
[`step`](https://github.com/smallstep/cli) and
[`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) should be
used instead (smallstep/certificates#1240).
### Fixed
- Fixed remote management flags in docker images (smallstep/certificates#1228).
## [v0.23.1] - 2023-01-10
### Added
- Added configuration property `.crl.idpURL` to be able to set a custom Issuing
Distribution Point in the CRL (smallstep/certificates#1178).
- Added WithContext methods to the CA client (smallstep/certificates#1211).
- Docker: Added environment variables for enabling Remote Management and ACME
provisioner (smallstep/certificates#1201).
- Docker: The entrypoint script now generates and displays an initial JWK
provisioner password by default when the CA is being initialized
(smallstep/certificates#1223).
### Changed
- Ignore SSH principals validation when using an OIDC provisioner. The
provisioner will ignore the principals passed and set the defaults or the ones
including using WebHooks or templates (smallstep/certificates#1206).
## [v0.23.0] - 2022-11-11
### Added
- Added support for ACME device-attest-01 challenge on iOS, iPadOS, tvOS and
YubiKey.
- Ability to disable ACME challenges and attestation formats.
- Added flags to change ACME challenge ports for testing purposes.
- Added name constraints evaluation and enforcement when issuing or renewing - Added name constraints evaluation and enforcement when issuing or renewing
X.509 certificates. X.509 certificates.
- Added provisioner webhooks for augmenting template data and authorizing - Added provisioner webhooks for augmenting template data and authorizing certificate requests before signing.
certificate requests before signing.
- Added automatic migration of provisioners when enabling remote management.
- Added experimental support for CRLs.
- Add certificate renewal support on RA mode. The `step ca renew` command must
use the flag `--mtls=false` to use the token renewal flow.
- Added support for initializing remote management using `step ca init`.
- Added support for renewing X.509 certificates on RAs.
- Added support for using SCEP with keys in a KMS.
- Added client support to set the dialer's local address with the environment variable
`STEP_CLIENT_ADDR`.
### Changed
- Remove the email requirement for issuing SSH certificates with an OIDC
provisioner.
- Root files can contain more than one certificate.
### Fixed ### Fixed
- MySQL DSN parsing issues fixed with upgrade to [smallstep/nosql@v0.5.0](https://github.com/smallstep/nosql/releases/tag/v0.5.0).
- Fixed MySQL DSN parsing issues with an upgrade to
[smallstep/nosql@v0.5.0](https://github.com/smallstep/nosql/releases/tag/v0.5.0).
- Fixed renewal of certificates with missing subject attributes.
- Fixed ACME support with [ejabberd](https://github.com/processone/ejabberd).
### Deprecated
- The CLIs `step-awskms-init`, `step-cloudkms-init`, `step-pkcs11-init`,
`step-yubikey-init` are deprecated. Now you can use
[`step-kms-plugin`](https://github.com/smallstep/step-kms-plugin) in
combination with `step certificates create` to initialize your PKI.
## [0.22.1] - 2022-08-31 ## [0.22.1] - 2022-08-31
### Fixed ### Fixed
- Fixed signature algorithm on EC (root) + RSA (intermediate) PKIs. - Fixed signature algorithm on EC (root) + RSA (intermediate) PKIs.
## [0.22.0] - 2022-08-26 ## [0.22.0] - 2022-08-26
### Added ### Added
- Added automatic configuration of Linked RAs. - Added automatic configuration of Linked RAs.
- Send provisioner configuration on Linked RAs. - Send provisioner configuration on Linked RAs.
### Changed ### Changed
- Certificates signed by an issuer using an RSA key will be signed using the - Certificates signed by an issuer using an RSA key will be signed using the
same algorithm used to sign the issuer certificate. The signature will no same algorithm used to sign the issuer certificate. The signature will no
longer default to PKCS #1. For example, if the issuer certificate was signed longer default to PKCS #1. For example, if the issuer certificate was signed
@ -206,28 +45,20 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Sanitize TLS options. - Sanitize TLS options.
## [0.20.0] - 2022-05-26 ## [0.20.0] - 2022-05-26
### Added ### Added
- Added Kubernetes auth method for Vault RAs. - Added Kubernetes auth method for Vault RAs.
- Added support for reporting provisioners to linkedca. - Added support for reporting provisioners to linkedca.
- Added support for certificate policies on authority level. - Added support for certificate policies on authority level.
- Added a Dockerfile with a step-ca build with HSM support. - Added a Dockerfile with a step-ca build with HSM support.
- A few new WithXX methods for instantiating authorities - A few new WithXX methods for instantiating authorities
### Changed ### Changed
- Context usage in HTTP APIs. - Context usage in HTTP APIs.
- Changed authentication for Vault RAs. - Changed authentication for Vault RAs.
- Error message returned to client when authenticating with expired certificate. - Error message returned to client when authenticating with expired certificate.
- Strip padding from ACME CSRs. - Strip padding from ACME CSRs.
### Deprecated ### Deprecated
- HTTP API handler types. - HTTP API handler types.
### Fixed ### Fixed
- Fixed SSH revocation. - Fixed SSH revocation.
- CA client dial context for js/wasm target. - CA client dial context for js/wasm target.
- Incomplete `extraNames` support in templates. - Incomplete `extraNames` support in templates.
@ -235,9 +66,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Large SCEP request handling. - Large SCEP request handling.
## [0.19.0] - 2022-04-19 ## [0.19.0] - 2022-04-19
### Added ### Added
- Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`. - Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`.
- Added support for `extraNames` in X.509 templates. - Added support for `extraNames` in X.509 templates.
- Added `armv5` builds. - Added `armv5` builds.
@ -252,156 +81,104 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
on startup, the configuration for the current context is used. on startup, the configuration for the current context is used.
- Added startup info logging and option to skip it (`--quiet`). - Added startup info logging and option to skip it (`--quiet`).
- Added support for renaming the CA (Common Name). - Added support for renaming the CA (Common Name).
### Changed ### Changed
- Made SCEP CA URL paths dynamic. - Made SCEP CA URL paths dynamic.
- Support two latest versions of Go (1.17, 1.18). - Support two latest versions of Go (1.17, 1.18).
- Upgrade go.step.sm/crypto to v0.16.1. - Upgrade go.step.sm/crypto to v0.16.1.
- Upgrade go.step.sm/linkedca to v0.15.0. - Upgrade go.step.sm/linkedca to v0.15.0.
### Deprecated ### Deprecated
- Go 1.16 support. - Go 1.16 support.
### Removed ### Removed
### Fixed ### Fixed
- Fixed admin credentials on RAs. - Fixed admin credentials on RAs.
- Fixed ACME HTTP-01 challenges for IPv6 identifiers. - Fixed ACME HTTP-01 challenges for IPv6 identifiers.
- Various improvements under the hood. - Various improvements under the hood.
### Security ### Security
## [0.18.2] - 2022-03-01 ## [0.18.2] - 2022-03-01
### Added ### Added
- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner. - Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner.
- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering - [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering
out database drivers using Go tags. For example, using the Go flag out database drivers using Go tags. For example, using the Go flag
`--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx `--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx
driver for PostgreSQL. driver for PostgreSQL.
### Changed ### Changed
- IPv6 addresses are normalized as IP addresses instead of hostnames. - IPv6 addresses are normalized as IP addresses instead of hostnames.
- More descriptive JWK decryption error message. - More descriptive JWK decryption error message.
- Make the X5C leaf certificate available to the templates using `{{ .AuthorizationCrt }}`. - Make the X5C leaf certificate available to the templates using `{{ .AuthorizationCrt }}`.
### Fixed ### Fixed
- During provisioner add - validate provisioner configuration before storing to DB. - During provisioner add - validate provisioner configuration before storing to DB.
## [0.18.1] - 2022-02-03 ## [0.18.1] - 2022-02-03
### Added ### Added
- Support for ACME revocation. - Support for ACME revocation.
- Replace hash function with an RSA SSH CA to "rsa-sha2-256". - Replace hash function with an RSA SSH CA to "rsa-sha2-256".
- Support Nebula provisioners. - Support Nebula provisioners.
- Example Ansible configurations. - Example Ansible configurations.
- Support PKCS#11 as a decrypter, as used by SCEP. - Support PKCS#11 as a decrypter, as used by SCEP.
### Changed ### Changed
- Automatically create database directory on `step ca init`. - Automatically create database directory on `step ca init`.
- Slightly improve errors reported when a template has invalid content. - Slightly improve errors reported when a template has invalid content.
- Error reporting in logs and to clients. - Error reporting in logs and to clients.
### Fixed ### Fixed
- SCEP renewal using HTTPS on macOS. - SCEP renewal using HTTPS on macOS.
## [0.18.0] - 2021-11-17 ## [0.18.0] - 2021-11-17
### Added ### Added
- Support for multiple certificate authority contexts. - Support for multiple certificate authority contexts.
- Support for generating extractable keys and certificates on a pkcs#11 module. - Support for generating extractable keys and certificates on a pkcs#11 module.
### Changed ### Changed
- Support two latest versions of Go (1.16, 1.17) - Support two latest versions of Go (1.16, 1.17)
### Deprecated ### Deprecated
- go 1.15 support - go 1.15 support
## [0.17.6] - 2021-10-20 ## [0.17.6] - 2021-10-20
### Notes ### Notes
- 0.17.5 failed in CI/CD - 0.17.5 failed in CI/CD
## [0.17.5] - 2021-10-20 ## [0.17.5] - 2021-10-20
### Added ### Added
- Support for Azure Key Vault as a KMS. - Support for Azure Key Vault as a KMS.
- Adapt `pki` package to support key managers. - Adapt `pki` package to support key managers.
- gocritic linter - gocritic linter
### Fixed ### Fixed
- gocritic warnings - gocritic warnings
## [0.17.4] - 2021-09-28 ## [0.17.4] - 2021-09-28
### Fixed ### Fixed
- Support host-only or user-only SSH CA. - Support host-only or user-only SSH CA.
## [0.17.3] - 2021-09-24 ## [0.17.3] - 2021-09-24
### Added ### Added
- go 1.17 to github action test matrix - go 1.17 to github action test matrix
- Support for CloudKMS RSA-PSS signers without using templates. - Support for CloudKMS RSA-PSS signers without using templates.
- Add flags to support individual passwords for the intermediate and SSH keys. - Add flags to support individual passwords for the intermediate and SSH keys.
- Global support for group admins in the OIDC provisioner. - Global support for group admins in the OIDC provisioner.
### Changed ### Changed
- Using go 1.17 for binaries - Using go 1.17 for binaries
### Fixed ### Fixed
- Upgrade go-jose.v2 to fix a bug in the JWK fingerprint of Ed25519 keys. - Upgrade go-jose.v2 to fix a bug in the JWK fingerprint of Ed25519 keys.
### Security ### Security
- Use cosign to sign and upload signatures for multi-arch Docker container. - Use cosign to sign and upload signatures for multi-arch Docker container.
- Add debian checksum - Add debian checksum
## [0.17.2] - 2021-08-30 ## [0.17.2] - 2021-08-30
### Added ### Added
- Additional way to distinguish Azure IID and Azure OIDC tokens. - Additional way to distinguish Azure IID and Azure OIDC tokens.
### Security ### Security
- Sign over all goreleaser github artifacts using cosign - Sign over all goreleaser github artifacts using cosign
## [0.17.1] - 2021-08-26 ## [0.17.1] - 2021-08-26
## [0.17.0] - 2021-08-25 ## [0.17.0] - 2021-08-25
### Added ### Added
- Add support for Linked CAs using protocol buffers and gRPC - Add support for Linked CAs using protocol buffers and gRPC
- `step-ca init` adds support for - `step-ca init` adds support for
- configuring a StepCAS RA - configuring a StepCAS RA
- configuring a Linked CA - configuring a Linked CA
- congifuring a `step-ca` using Helm - congifuring a `step-ca` using Helm
### Changed ### Changed
- Update badger driver to use v2 by default - Update badger driver to use v2 by default
- Update TLS cipher suites to include 1.3 - Update TLS cipher suites to include 1.3
### Security ### Security
- Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512. - Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512.

131
Makefile
View file

@ -1,11 +1,21 @@
PKG?=github.com/smallstep/certificates/cmd/step-ca PKG?=github.com/smallstep/certificates/cmd/step-ca
BINNAME?=step-ca BINNAME?=step-ca
CLOUDKMS_BINNAME?=step-cloudkms-init
CLOUDKMS_PKG?=github.com/smallstep/certificates/cmd/step-cloudkms-init
AWSKMS_BINNAME?=step-awskms-init
AWSKMS_PKG?=github.com/smallstep/certificates/cmd/step-awskms-init
YUBIKEY_BINNAME?=step-yubikey-init
YUBIKEY_PKG?=github.com/smallstep/certificates/cmd/step-yubikey-init
PKCS11_BINNAME?=step-pkcs11-init
PKCS11_PKG?=github.com/smallstep/certificates/cmd/step-pkcs11-init
# Set V to 1 for verbose output from the Makefile # Set V to 1 for verbose output from the Makefile
Q=$(if $V,,@) Q=$(if $V,,@)
PREFIX?= PREFIX?=
SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*") SRC=$(shell find . -type f -name '*.go' -not -path "./vendor/*")
GOOS_OVERRIDE ?= GOOS_OVERRIDE ?=
OUTPUT_ROOT=output/
RELEASE=./.releases
all: lint test build all: lint test build
@ -21,8 +31,6 @@ bootstra%:
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $$(go env GOPATH)/bin latest $Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $$(go env GOPATH)/bin latest
$Q go install golang.org/x/vuln/cmd/govulncheck@latest $Q go install golang.org/x/vuln/cmd/govulncheck@latest
$Q go install gotest.tools/gotestsum@latest $Q go install gotest.tools/gotestsum@latest
$Q go install github.com/goreleaser/goreleaser@latest
$Q go install github.com/sigstore/cosign/v2/cmd/cosign@latest
.PHONY: bootstra% .PHONY: bootstra%
@ -30,8 +38,17 @@ bootstra%:
# Determine the type of `push` and `version` # Determine the type of `push` and `version`
################################################# #################################################
# If TRAVIS_TAG is set then we know this ref has been tagged.
ifdef TRAVIS_TAG
VERSION ?= $(TRAVIS_TAG)
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
ifeq ($(NOT_RC),)
PUSHTYPE := release-candidate
else
PUSHTYPE := release
endif
# GITHUB Actions # GITHUB Actions
ifdef GITHUB_REF else ifdef GITHUB_REF
VERSION ?= $(shell echo $(GITHUB_REF) | sed 's/^refs\/tags\///') VERSION ?= $(shell echo $(GITHUB_REF) | sed 's/^refs\/tags\///')
NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc) NOT_RC := $(shell echo $(VERSION) | grep -v -e -rc)
ifeq ($(NOT_RC),) ifeq ($(NOT_RC),)
@ -44,50 +61,59 @@ VERSION ?= $(shell [ -d .git ] && git describe --tags --always --dirty="-dev")
# If we are not in an active git dir then try reading the version from .VERSION. # If we are not in an active git dir then try reading the version from .VERSION.
# .VERSION contains a slug populated by `git archive`. # .VERSION contains a slug populated by `git archive`.
VERSION := $(or $(VERSION),$(shell ./.version.sh .VERSION)) VERSION := $(or $(VERSION),$(shell ./.version.sh .VERSION))
ifeq ($(TRAVIS_BRANCH),master)
PUSHTYPE := master
else
PUSHTYPE := branch PUSHTYPE := branch
endif
endif endif
VERSION := $(shell echo $(VERSION) | sed 's/^v//') VERSION := $(shell echo $(VERSION) | sed 's/^v//')
DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/./g')
ifdef V ifdef V
$(info TRAVIS_TAG is $(TRAVIS_TAG))
$(info GITHUB_REF is $(GITHUB_REF)) $(info GITHUB_REF is $(GITHUB_REF))
$(info VERSION is $(VERSION)) $(info VERSION is $(VERSION))
$(info DEB_VERSION is $(DEB_VERSION))
$(info PUSHTYPE is $(PUSHTYPE)) $(info PUSHTYPE is $(PUSHTYPE))
endif endif
include make/docker.mk
######################################### #########################################
# Build # Build
######################################### #########################################
DATE := $(shell date -u '+%Y-%m-%d %H:%M UTC') DATE := $(shell date -u '+%Y-%m-%d %H:%M UTC')
LDFLAGS := -ldflags='-w -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"' LDFLAGS := -ldflags='-w -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"'
GOFLAGS := CGO_ENABLED=0
# Always explicitly enable or disable cgo,
# so that go doesn't silently fall back on
# non-cgo when gcc is not found.
ifeq (,$(findstring CGO_ENABLED,$(GO_ENVS)))
ifneq ($(origin GOFLAGS),undefined)
# This section is for backward compatibility with
#
# $ make build GOFLAGS=""
#
# which is how we recommended building step-ca with cgo support
# until June 2023.
GO_ENVS := $(GO_ENVS) CGO_ENABLED=1
else
GO_ENVS := $(GO_ENVS) CGO_ENABLED=0
endif
endif
download: download:
$Q go mod download $Q go mod download
build: $(PREFIX)bin/$(BINNAME) build: $(PREFIX)bin/$(BINNAME) $(PREFIX)bin/$(CLOUDKMS_BINNAME) $(PREFIX)bin/$(AWSKMS_BINNAME) $(PREFIX)bin/$(YUBIKEY_BINNAME) $(PREFIX)bin/$(PKCS11_BINNAME)
@echo "Build Complete!" @echo "Build Complete!"
$(PREFIX)bin/$(BINNAME): download $(call rwildcard,*.go) $(PREFIX)bin/$(BINNAME): download $(call rwildcard,*.go)
$Q mkdir -p $(@D) $Q mkdir -p $(@D)
$Q $(GOOS_OVERRIDE) GOFLAGS="$(GOFLAGS)" $(GO_ENVS) go build -v -o $(PREFIX)bin/$(BINNAME) $(LDFLAGS) $(PKG) $Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(BINNAME) $(LDFLAGS) $(PKG)
$(PREFIX)bin/$(CLOUDKMS_BINNAME): download $(call rwildcard,*.go)
$Q mkdir -p $(@D)
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(CLOUDKMS_BINNAME) $(LDFLAGS) $(CLOUDKMS_PKG)
$(PREFIX)bin/$(AWSKMS_BINNAME): download $(call rwildcard,*.go)
$Q mkdir -p $(@D)
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(AWSKMS_BINNAME) $(LDFLAGS) $(AWSKMS_PKG)
$(PREFIX)bin/$(YUBIKEY_BINNAME): download $(call rwildcard,*.go)
$Q mkdir -p $(@D)
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(YUBIKEY_BINNAME) $(LDFLAGS) $(YUBIKEY_PKG)
$(PREFIX)bin/$(PKCS11_BINNAME): download $(call rwildcard,*.go)
$Q mkdir -p $(@D)
$Q $(GOOS_OVERRIDE) $(GOFLAGS) go build -v -o $(PREFIX)bin/$(PKCS11_BINNAME) $(LDFLAGS) $(PKCS11_PKG)
# Target to force a build of step-ca without running tests # Target to force a build of step-ca without running tests
simple: build simple: build
@ -106,26 +132,19 @@ generate:
######################################### #########################################
# Test # Test
######################################### #########################################
test: testdefault testtpmsimulator combinecoverage test:
$Q $(GOFLAGS) gotestsum -- -coverprofile=coverage.out -short -covermode=atomic ./...
testdefault:
$Q $(GO_ENVS) gotestsum -- -coverprofile=defaultcoverage.out -short -covermode=atomic ./...
testtpmsimulator:
$Q CGO_ENABLED=1 gotestsum -- -coverprofile=tpmsimulatorcoverage.out -short -covermode=atomic -tags tpmsimulator ./acme
testcgo: testcgo:
$Q gotestsum -- -coverprofile=coverage.out -short -covermode=atomic ./... $Q gotestsum -- -coverprofile=coverage.out -short -covermode=atomic ./...
combinecoverage: .PHONY: test testcgo
cat defaultcoverage.out tpmsimulatorcoverage.out > coverage.out
.PHONY: test testdefault testtpmsimulator testcgo combinecoverage
integrate: integration integrate: integration
integration: bin/$(BINNAME) integration: bin/$(BINNAME)
$Q $(GO_ENVS) gotestsum -- -tags=integration ./integration/... $Q $(GOFLAGS) gotestsum -- -tags=integration ./integration/...
.PHONY: integrate integration .PHONY: integrate integration
@ -149,11 +168,15 @@ lint:
INSTALL_PREFIX?=/usr/ INSTALL_PREFIX?=/usr/
install: $(PREFIX)bin/$(BINNAME) install: $(PREFIX)bin/$(BINNAME) $(PREFIX)bin/$(CLOUDKMS_BINNAME) $(PREFIX)bin/$(AWSKMS_BINNAME)
$Q install -D $(PREFIX)bin/$(BINNAME) $(DESTDIR)$(INSTALL_PREFIX)bin/$(BINNAME) $Q install -D $(PREFIX)bin/$(BINNAME) $(DESTDIR)$(INSTALL_PREFIX)bin/$(BINNAME)
$Q install -D $(PREFIX)bin/$(CLOUDKMS_BINNAME) $(DESTDIR)$(INSTALL_PREFIX)bin/$(CLOUDKMS_BINNAME)
$Q install -D $(PREFIX)bin/$(AWSKMS_BINNAME) $(DESTDIR)$(INSTALL_PREFIX)bin/$(AWSKMS_BINNAME)
uninstall: uninstall:
$Q rm -f $(DESTDIR)$(INSTALL_PREFIX)/bin/$(BINNAME) $Q rm -f $(DESTDIR)$(INSTALL_PREFIX)/bin/$(BINNAME)
$Q rm -f $(DESTDIR)$(INSTALL_PREFIX)/bin/$(CLOUDKMS_BINNAME)
$Q rm -f $(DESTDIR)$(INSTALL_PREFIX)/bin/$(AWSKMS_BINNAME)
.PHONY: install uninstall .PHONY: install uninstall
@ -165,6 +188,18 @@ clean:
ifneq ($(BINNAME),"") ifneq ($(BINNAME),"")
$Q rm -f bin/$(BINNAME) $Q rm -f bin/$(BINNAME)
endif endif
ifneq ($(CLOUDKMS_BINNAME),"")
$Q rm -f bin/$(CLOUDKMS_BINNAME)
endif
ifneq ($(AWSKMS_BINNAME),"")
$Q rm -f bin/$(AWSKMS_BINNAME)
endif
ifneq ($(YUBIKEY_BINNAME),"")
$Q rm -f bin/$(YUBIKEY_BINNAME)
endif
ifneq ($(PKCS11_BINNAME),"")
$Q rm -f bin/$(PKCS11_BINNAME)
endif
.PHONY: clean .PHONY: clean
@ -177,3 +212,31 @@ run:
.PHONY: run .PHONY: run
#########################################
# Debian
#########################################
changelog:
$Q echo "step-ca ($(DEB_VERSION)) unstable; urgency=medium" > debian/changelog
$Q echo >> debian/changelog
$Q echo " * See https://github.com/smallstep/certificates/releases" >> debian/changelog
$Q echo >> debian/changelog
$Q echo " -- Smallstep Labs, Inc. <techadmin@smallstep.com> $(shell date -uR)" >> debian/changelog
debian: changelog
$Q mkdir -p $(RELEASE); \
OUTPUT=../step-ca*.deb; \
rm $$OUTPUT; \
dpkg-buildpackage -b -rfakeroot -us -uc && cp $$OUTPUT $(RELEASE)/
distclean: clean
.PHONY: changelog debian distclean
#################################################
# Targets for creating step artifacts
#################################################
docker-artifacts: docker-$(PUSHTYPE)
.PHONY: docker-artifacts

View file

@ -119,12 +119,18 @@ See our installation docs [here](https://smallstep.com/docs/step-ca/installation
## Documentation ## Documentation
* [Official documentation](https://smallstep.com/docs/step-ca) is on smallstep.com Documentation can be found in a handful of different places:
* The `step` command reference is available via `step help`,
[on smallstep.com](https://smallstep.com/docs/step-cli/reference/), 1. On the web at https://smallstep.com/docs/step-ca.
or by running `step help --http=:8080` from the command line
2. On the command line with `step help ca xxx` where `xxx` is the subcommand
you are interested in. Ex: `step help ca provisioner list`.
3. In your browser, by running `step help --http=:8080 ca` from the command line
and visiting http://localhost:8080. and visiting http://localhost:8080.
4. The [docs](./docs/README.md) folder is being deprecated, but it still has some documentation and tutorials.
## Feedback? ## Feedback?
* Tell us what you like and don't like about managing your PKI - we're eager to help solve problems in this space. * Tell us what you like and don't like about managing your PKI - we're eager to help solve problems in this space.

View file

@ -20,16 +20,6 @@ type Account struct {
Status Status `json:"status"` Status Status `json:"status"`
OrdersURL string `json:"orders"` OrdersURL string `json:"orders"`
ExternalAccountBinding interface{} `json:"externalAccountBinding,omitempty"` ExternalAccountBinding interface{} `json:"externalAccountBinding,omitempty"`
LocationPrefix string `json:"-"`
ProvisionerName string `json:"-"`
}
// GetLocation returns the URL location of the given account.
func (a *Account) GetLocation() string {
if a.LocationPrefix == "" {
return ""
}
return a.LocationPrefix + a.ID
} }
// ToLog enables response logging. // ToLog enables response logging.
@ -82,7 +72,6 @@ func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions {
IPRanges: p.X509.Allowed.IPRanges, IPRanges: p.X509.Allowed.IPRanges,
} }
} }
func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions { func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions {
if p == nil { if p == nil {
return nil return nil

View file

@ -66,23 +66,6 @@ func TestKeyToID(t *testing.T) {
} }
} }
func TestAccount_GetLocation(t *testing.T) {
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
type test struct {
acc *Account
exp string
}
tests := map[string]test{
"empty": {acc: &Account{LocationPrefix: ""}, exp: ""},
"not-empty": {acc: &Account{ID: "bar", LocationPrefix: locationPrefix}, exp: locationPrefix + "bar"},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
assert.Equals(t, tc.acc.GetLocation(), tc.exp)
})
}
}
func TestAccount_IsValid(t *testing.T) { func TestAccount_IsValid(t *testing.T) {
type test struct { type test struct {
acc *Account acc *Account
@ -152,6 +135,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
if assert.True(t, errors.As(err, &ae)) { if assert.True(t, errors.As(err, &ae)) {
assert.Equals(t, ae.Type, tt.err.Type) assert.Equals(t, ae.Type, tt.err.Type)
assert.Equals(t, ae.Detail, tt.err.Detail) assert.Equals(t, ae.Detail, tt.err.Detail)
assert.Equals(t, ae.Identifier, tt.err.Identifier)
assert.Equals(t, ae.Subproblems, tt.err.Subproblems) assert.Equals(t, ae.Subproblems, tt.err.Subproblems)
} }
} else { } else {

View file

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"net/http" "net/http"
@ -68,12 +67,6 @@ func (u *UpdateAccountRequest) Validate() error {
} }
} }
// getAccountLocationPath returns the current account URL location.
// Returned location will be of the form: https://<ca-url>/acme/<provisioner>/account/<accID>
func getAccountLocationPath(ctx context.Context, linker acme.Linker, accID string) string {
return linker.GetLink(ctx, acme.AccountLinkType, accID)
}
// NewAccount is the handler resource for creating new ACME accounts. // NewAccount is the handler resource for creating new ACME accounts.
func NewAccount(w http.ResponseWriter, r *http.Request) { func NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
@ -135,8 +128,6 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
Key: jwk, Key: jwk,
Contact: nar.Contact, Contact: nar.Contact,
Status: acme.StatusValid, Status: acme.StatusValid,
LocationPrefix: getAccountLocationPath(ctx, linker, ""),
ProvisionerName: prov.GetName(),
} }
if err := db.CreateAccount(ctx, acc); err != nil { if err := db.CreateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating account")) render.Error(w, acme.WrapErrorISE(err, "error creating account"))
@ -161,7 +152,7 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
linker.LinkAccount(ctx, acc) linker.LinkAccount(ctx, acc)
w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID)) w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
render.JSONStatus(w, acc, httpStatus) render.JSONStatus(w, acc, httpStatus)
} }

View file

@ -34,20 +34,27 @@ var (
type fakeProvisioner struct{} type fakeProvisioner struct{}
func (*fakeProvisioner) AuthorizeOrderIdentifier(context.Context, provisioner.ACMEIdentifier) error { func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
return nil return nil
} }
func (*fakeProvisioner) AuthorizeSign(context.Context, string) ([]provisioner.SignOption, error) {
func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
return nil, nil return nil, nil
} }
func (*fakeProvisioner) IsChallengeEnabled(context.Context, provisioner.ACMEChallenge) bool {
func (*fakeProvisioner) IsChallengeEnabled(ctx context.Context, challenge provisioner.ACMEChallenge) bool {
return true return true
} }
func (*fakeProvisioner) IsAttestationFormatEnabled(context.Context, provisioner.ACMEAttestationFormat) bool {
func (*fakeProvisioner) IsAttestationFormatEnabled(ctx context.Context, format provisioner.ACMEAttestationFormat) bool {
return true return true
} }
func (*fakeProvisioner) GetAttestationRoots() (*x509.CertPool, bool) { return nil, false }
func (*fakeProvisioner) AuthorizeRevoke(context.Context, string) error { return nil } func (*fakeProvisioner) GetAttestationRoots() (*x509.CertPool, bool) {
return nil, false
}
func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil }
func (*fakeProvisioner) GetID() string { return "" } func (*fakeProvisioner) GetID() string { return "" }
func (*fakeProvisioner) GetName() string { return "" } func (*fakeProvisioner) GetName() string { return "" }
func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 } func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 }
@ -362,7 +369,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -381,6 +388,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -801,7 +809,7 @@ func TestHandler_NewAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -820,6 +828,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -1004,7 +1013,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -1023,6 +1032,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {

View file

@ -866,6 +866,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
assert.Equals(t, ae.Status, tc.err.Status) assert.Equals(t, ae.Status, tc.err.Status)
assert.HasPrefix(t, ae.Err.Error(), tc.err.Err.Error()) assert.HasPrefix(t, ae.Err.Error(), tc.err.Err.Error())
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
} }
} else { } else {
@ -1144,6 +1145,7 @@ func Test_validateEABJWS(t *testing.T) {
assert.Equals(t, tc.err.Status, err.Status) assert.Equals(t, tc.err.Status, err.Status)
assert.HasPrefix(t, err.Err.Error(), tc.err.Err.Error()) assert.HasPrefix(t, err.Err.Error(), tc.err.Err.Error())
assert.Equals(t, tc.err.Detail, err.Detail) assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, tc.err.Identifier, err.Identifier)
assert.Equals(t, tc.err.Subproblems, err.Subproblems) assert.Equals(t, tc.err.Subproblems, err.Subproblems)
} else { } else {
assert.Nil(t, err) assert.Nil(t, err)

View file

@ -95,7 +95,7 @@ func (h *handler) Route(r api.Router) {
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil { if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
ctx = authority.NewContext(ctx, ca) ctx = authority.NewContext(ctx, ca)
} }
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker, "") ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
}) })
@ -205,7 +205,7 @@ type Directory struct {
NewOrder string `json:"newOrder"` NewOrder string `json:"newOrder"`
RevokeCert string `json:"revokeCert"` RevokeCert string `json:"revokeCert"`
KeyChange string `json:"keyChange"` KeyChange string `json:"keyChange"`
Meta *Meta `json:"meta,omitempty"` Meta Meta `json:"meta"`
} }
// ToLog enables response logging for the Directory type. // ToLog enables response logging for the Directory type.
@ -228,52 +228,21 @@ func GetDirectory(w http.ResponseWriter, r *http.Request) {
} }
linker := acme.MustLinkerFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
render.JSON(w, &Directory{ render.JSON(w, &Directory{
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType), NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType), NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType), NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType), RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType), KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
Meta: createMetaObject(acmeProv), Meta: Meta{
ExternalAccountRequired: acmeProv.RequireEAB,
},
}) })
} }
// createMetaObject creates a Meta object if the ACME provisioner
// has one or more properties that are written in the ACME directory output.
// It returns nil if none of the properties are set.
func createMetaObject(p *provisioner.ACME) *Meta {
if shouldAddMetaObject(p) {
return &Meta{
TermsOfService: p.TermsOfService,
Website: p.Website,
CaaIdentities: p.CaaIdentities,
ExternalAccountRequired: p.RequireEAB,
}
}
return nil
}
// shouldAddMetaObject returns whether or not the ACME provisioner
// has properties configured that must be added to the ACME directory object.
func shouldAddMetaObject(p *provisioner.ACME) bool {
switch {
case p.TermsOfService != "":
return true
case p.Website != "":
return true
case len(p.CaaIdentities) > 0:
return true
case p.RequireEAB:
return true
default:
return false
}
}
// NotImplemented returns a 501 and is generally a placeholder for functionality which // NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such. // MAY be added at some point in the future but is not in any way a guarantee of such.
func NotImplemented(w http.ResponseWriter, _ *http.Request) { func NotImplemented(w http.ResponseWriter, r *http.Request) {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
} }
@ -394,6 +363,6 @@ func GetCertificate(w http.ResponseWriter, r *http.Request) {
} }
api.LogCertificate(w, cert.Leaf) api.LogCertificate(w, cert.Leaf)
w.Header().Set("Content-Type", "application/pem-certificate-chain") w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
w.Write(certBytes) w.Write(certBytes)
} }

View file

@ -18,13 +18,10 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
) )
type mockClient struct { type mockClient struct {
@ -132,35 +129,7 @@ func TestHandler_GetDirectory(t *testing.T) {
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName),
KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName),
Meta: &Meta{ Meta: Meta{
ExternalAccountRequired: true,
},
}
return test{
ctx: ctx,
dir: expDir,
statusCode: 200,
}
},
"ok/full-meta": func(t *testing.T) test {
prov := newACMEProv(t)
prov.TermsOfService = "https://terms.ca.local/"
prov.Website = "https://ca.local/"
prov.CaaIdentities = []string{"ca.local"}
prov.RequireEAB = true
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := acme.NewProvisionerContext(context.Background(), prov)
expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName),
KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName),
Meta: &Meta{
TermsOfService: "https://terms.ca.local/",
Website: "https://ca.local/",
CaaIdentities: []string{"ca.local"},
ExternalAccountRequired: true, ExternalAccountRequired: true,
}, },
} }
@ -193,6 +162,7 @@ func TestHandler_GetDirectory(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -346,7 +316,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -365,6 +335,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -507,11 +478,12 @@ func TestHandler_GetCertificate(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.HasPrefix(t, ae.Detail, tc.err.Detail) assert.HasPrefix(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes)) assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes))
assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/pem-certificate-chain; charset=utf-8"})
} }
}) })
} }
@ -746,7 +718,7 @@ func TestHandler_GetChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil, "") ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -765,6 +737,7 @@ func TestHandler_GetChallenge(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -778,89 +751,3 @@ func TestHandler_GetChallenge(t *testing.T) {
}) })
} }
} }
func Test_createMetaObject(t *testing.T) {
tests := []struct {
name string
p *provisioner.ACME
want *Meta
}{
{
name: "no-meta",
p: &provisioner.ACME{
Type: "ACME",
Name: "acme",
},
want: nil,
},
{
name: "terms-of-service",
p: &provisioner.ACME{
Type: "ACME",
Name: "acme",
TermsOfService: "https://terms.ca.local",
},
want: &Meta{
TermsOfService: "https://terms.ca.local",
},
},
{
name: "website",
p: &provisioner.ACME{
Type: "ACME",
Name: "acme",
Website: "https://ca.local",
},
want: &Meta{
Website: "https://ca.local",
},
},
{
name: "caa",
p: &provisioner.ACME{
Type: "ACME",
Name: "acme",
CaaIdentities: []string{"ca.local", "ca.remote"},
},
want: &Meta{
CaaIdentities: []string{"ca.local", "ca.remote"},
},
},
{
name: "require-eab",
p: &provisioner.ACME{
Type: "ACME",
Name: "acme",
RequireEAB: true,
},
want: &Meta{
ExternalAccountRequired: true,
},
},
{
name: "full-meta",
p: &provisioner.ACME{
Type: "ACME",
Name: "acme",
TermsOfService: "https://terms.ca.local",
Website: "https://ca.local",
CaaIdentities: []string{"ca.local", "ca.remote"},
RequireEAB: true,
},
want: &Meta{
TermsOfService: "https://terms.ca.local",
Website: "https://ca.local",
CaaIdentities: []string{"ca.local", "ca.remote"},
ExternalAccountRequired: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := createMetaObject(tt.p)
if !cmp.Equal(tt.want, got) {
t.Errorf("createMetaObject() diff =\n%s", cmp.Diff(tt.want, got))
}
})
}
}

View file

@ -7,7 +7,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"path"
"strings" "strings"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
@ -17,6 +16,7 @@ import (
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/nosql"
) )
type nextHTTP = func(http.ResponseWriter, *http.Request) type nextHTTP = func(http.ResponseWriter, *http.Request)
@ -293,6 +293,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
@ -300,16 +301,19 @@ func lookupJWK(next nextHTTP) nextHTTP {
return return
} }
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if kid == "" { if !strings.HasPrefix(kid, kidPrefix) {
render.Error(w, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'")) render.Error(w, acme.NewError(acme.ErrorMalformedType,
"kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid))
return return
} }
accID := path.Base(kid) accID := strings.TrimPrefix(kid, kidPrefix)
acc, err := db.GetAccount(ctx, accID) acc, err := db.GetAccount(ctx, accID)
switch { switch {
case acme.IsErrNotFound(err): case nosql.IsErrNotFound(err):
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return return
case err != nil: case err != nil:
@ -320,45 +324,6 @@ func lookupJWK(next nextHTTP) nextHTTP {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active")) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
if storedLocation := acc.GetLocation(); storedLocation != "" {
if kid != storedLocation {
// ACME accounts should have a stored location equivalent to the
// kid in the ACME request.
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"kid does not match stored account location; expected %s, but got %s",
storedLocation, kid))
return
}
// Verify that the provisioner with which the account was created
// matches the provisioner in the request URL.
reqProv := acme.MustProvisionerFromContext(ctx)
reqProvName := reqProv.GetName()
accProvName := acc.ProvisionerName
if reqProvName != accProvName {
// Provisioner in the URL must match the provisioner with
// which the account was created.
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
"account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s",
accProvName, reqProvName))
return
}
} else {
// This code will only execute for old ACME accounts that do
// not have a cached location. The following validation was
// the original implementation of the `kid` check which has
// since been deprecated. However, the code will remain to
// ensure consistent behavior for old ACME accounts.
linker := acme.MustLinkerFromContext(ctx)
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
if !strings.HasPrefix(kid, kidPrefix) {
render.Error(w, acme.NewError(acme.ErrorMalformedType,
"kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid))
return
}
}
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, jwkContextKey, acc.Key) ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))

View file

@ -17,13 +17,14 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/nosql/database"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
) )
var testBody = []byte("foo") var testBody = []byte("foo")
func testNext(w http.ResponseWriter, _ *http.Request) { func testNext(w http.ResponseWriter, r *http.Request) {
w.Write(testBody) w.Write(testBody)
} }
@ -92,6 +93,7 @@ func TestHandler_addNonce(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -145,6 +147,7 @@ func TestHandler_addDirLink(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -249,6 +252,7 @@ func TestHandler_verifyContentType(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -316,6 +320,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -327,7 +332,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
type errReader int type errReader int
func (errReader) Read([]byte) (int, error) { func (errReader) Read(p []byte) (n int, err error) {
return 0, errors.New("force") return 0, errors.New("force")
} }
func (errReader) Close() error { func (errReader) Close() error {
@ -405,6 +410,7 @@ func TestHandler_parseJWS(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -600,6 +606,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -677,7 +684,31 @@ func TestHandler_lookupJWK(t *testing.T) {
linker: acme.NewLinker("test.ca.smallstep.com", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"), err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
}
},
"fail/bad-kid-prefix": func(t *testing.T) test {
_so := new(jose.SignerOptions)
_so.WithHeader("kid", "foo")
_signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, _so)
assert.FatalError(t, err)
_jws, err := _signer.Sign([]byte("baz"))
assert.FatalError(t, err)
_raw, err := _jws.CompactSerialize()
assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _parsed)
return test{
db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx,
statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
} }
}, },
"fail/account-not-found": func(t *testing.T) test { "fail/account-not-found": func(t *testing.T) test {
@ -688,7 +719,7 @@ func TestHandler_lookupJWK(t *testing.T) {
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, accID) assert.Equals(t, accID, accID)
return nil, acme.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
ctx: ctx, ctx: ctx,
@ -729,77 +760,7 @@ func TestHandler_lookupJWK(t *testing.T) {
err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"), err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"),
} }
}, },
"fail/account-with-location-prefix/bad-kid": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{LocationPrefix: "foobar", Status: "valid"}
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID)
return acc, nil
},
},
ctx: ctx,
statusCode: http.StatusUnauthorized,
err: acme.NewError(acme.ErrorUnauthorizedType, "kid does not match stored account location; expected foobar, but %q", prefix+accID),
}
},
"fail/account-with-location-prefix/bad-provisioner": func(t *testing.T) test {
acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: "other"}
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID)
return acc, nil
},
},
ctx: ctx,
next: func(w http.ResponseWriter, r *http.Request) {
_acc, err := accountFromContext(r.Context())
assert.FatalError(t, err)
assert.Equals(t, _acc, acc)
_jwk, err := jwkFromContext(r.Context())
assert.FatalError(t, err)
assert.Equals(t, _jwk, jwk)
w.Write(testBody)
},
statusCode: http.StatusUnauthorized,
err: acme.NewError(acme.ErrorUnauthorizedType,
"account provisioner does not match requested provisioner; account provisioner = %s, reqested provisioner = %s",
prov.GetName(), "other"),
}
},
"ok/account-with-location-prefix": func(t *testing.T) test {
acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: prov.GetName()}
ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID)
return acc, nil
},
},
ctx: ctx,
next: func(w http.ResponseWriter, r *http.Request) {
_acc, err := accountFromContext(r.Context())
assert.FatalError(t, err)
assert.Equals(t, _acc, acc)
_jwk, err := jwkFromContext(r.Context())
assert.FatalError(t, err)
assert.Equals(t, _jwk, jwk)
w.Write(testBody)
},
statusCode: http.StatusOK,
}
},
"ok/account-without-location-prefix": func(t *testing.T) test {
acc := &acme.Account{Status: "valid", Key: jwk} acc := &acme.Account{Status: "valid", Key: jwk}
ctx := acme.NewProvisionerContext(context.Background(), prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
@ -847,6 +808,7 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -1046,6 +1008,7 @@ func TestHandler_extractJWK(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -1421,6 +1384,7 @@ func TestHandler_validateJWS(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -1603,6 +1567,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -1687,6 +1652,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {

View file

@ -392,7 +392,7 @@ func challengeTypes(az *acme.Authorization) []acme.ChallengeType {
case acme.IP: case acme.IP:
chTypes = []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01} chTypes = []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}
case acme.DNS: case acme.DNS:
chTypes = []acme.ChallengeType{acme.DNS01, acme.NNS01} chTypes = []acme.ChallengeType{acme.DNS01}
// HTTP and TLS challenges can only be used for identifiers without wildcards. // HTTP and TLS challenges can only be used for identifiers without wildcards.
if !az.Wildcard { if !az.Wildcard {
chTypes = append(chTypes, []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}...) chTypes = append(chTypes, []acme.ChallengeType{acme.HTTP01, acme.TLSALPN01}...)

View file

@ -486,6 +486,7 @@ func TestHandler_GetOrder(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -1845,6 +1846,7 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -2142,6 +2144,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {

View file

@ -151,7 +151,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations // the identifiers in the certificate are extracted and compared against the (valid) Authorizations
// that are stored for the ACME Account. If these sets match, the Account is considered authorized // that are stored for the ACME Account. If these sets match, the Account is considered authorized
// to revoke the certificate. If this check fails, the client will receive an unauthorized error. // to revoke the certificate. If this check fails, the client will receive an unauthorized error.
func isAccountAuthorized(_ context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
if !account.IsValid() { if !account.IsValid() {
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
} }

View file

@ -258,7 +258,7 @@ func jwkEncode(pub crypto.PublicKey) (string, error) {
// jwsFinal constructs the final JWS object. // jwsFinal constructs the final JWS object.
// Implementation taken from github.com/mholt/acmez, which seems to be based on // Implementation taken from github.com/mholt/acmez, which seems to be based on
// https://github.com/golang/crypto/blob/master/acme/jws.go. // https://github.com/golang/crypto/blob/master/acme/jws.go.
func jwsFinal(_ crypto.Hash, sig []byte, phead, payload string) ([]byte, error) { func jwsFinal(sha crypto.Hash, sig []byte, phead, payload string) ([]byte, error) {
enc := struct { enc := struct {
Protected string `json:"protected"` Protected string `json:"protected"`
Payload string `json:"payload"` Payload string `json:"payload"`
@ -281,7 +281,7 @@ type mockCA struct {
MockAreSANsallowed func(ctx context.Context, sans []string) error MockAreSANsallowed func(ctx context.Context, sans []string) error
} }
func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil return nil, nil
} }
@ -1090,6 +1090,7 @@ func TestHandler_RevokeCert(t *testing.T) {
assert.Equals(t, ae.Type, tc.err.Type) assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else { } else {
@ -1229,6 +1230,7 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
assert.Equals(t, acmeErr.Type, tc.err.Type) assert.Equals(t, acmeErr.Type, tc.err.Type)
assert.Equals(t, acmeErr.Status, tc.err.Status) assert.Equals(t, acmeErr.Status, tc.err.Status)
assert.Equals(t, acmeErr.Detail, tc.err.Detail) assert.Equals(t, acmeErr.Detail, tc.err.Detail)
assert.Equals(t, acmeErr.Identifier, tc.err.Identifier)
assert.Equals(t, acmeErr.Subproblems, tc.err.Subproblems) assert.Equals(t, acmeErr.Subproblems, tc.err.Subproblems)
}) })
@ -1321,6 +1323,7 @@ func Test_wrapUnauthorizedError(t *testing.T) {
assert.Equals(t, acmeErr.Type, tc.want.Type) assert.Equals(t, acmeErr.Type, tc.want.Type)
assert.Equals(t, acmeErr.Status, tc.want.Status) assert.Equals(t, acmeErr.Status, tc.want.Status)
assert.Equals(t, acmeErr.Detail, tc.want.Detail) assert.Equals(t, acmeErr.Detail, tc.want.Detail)
assert.Equals(t, acmeErr.Identifier, tc.want.Identifier)
assert.Equals(t, acmeErr.Subproblems, tc.want.Subproblems) assert.Equals(t, acmeErr.Subproblems, tc.want.Subproblems)
}) })
} }

View file

@ -11,7 +11,6 @@ type Authorization struct {
ID string `json:"-"` ID string `json:"-"`
AccountID string `json:"-"` AccountID string `json:"-"`
Token string `json:"-"` Token string `json:"-"`
Fingerprint string `json:"-"`
Identifier Identifier `json:"identifier"` Identifier Identifier `json:"identifier"`
Status Status `json:"status"` Status Status `json:"status"`
Challenges []*Challenge `json:"challenges"` Challenges []*Challenge `json:"challenges"`

View file

@ -26,16 +26,9 @@ import (
"time" "time"
"github.com/fxamacker/cbor/v2" "github.com/fxamacker/cbor/v2"
"github.com/google/go-tpm/tpm2"
"golang.org/x/exp/slices"
"github.com/smallstep/go-attestation/attest"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil"
) )
type ChallengeType string type ChallengeType string
@ -49,20 +42,6 @@ const (
TLSALPN01 ChallengeType = "tls-alpn-01" TLSALPN01 ChallengeType = "tls-alpn-01"
// DEVICEATTEST01 is the device-attest-01 ACME challenge type // DEVICEATTEST01 is the device-attest-01 ACME challenge type
DEVICEATTEST01 ChallengeType = "device-attest-01" DEVICEATTEST01 ChallengeType = "device-attest-01"
// NNS01 is the nns-01 ACME challenge type
NNS01 ChallengeType = "nns-01"
)
var (
// InsecurePortHTTP01 is the port used to verify http-01 challenges. If not set it
// defaults to 80.
InsecurePortHTTP01 int
// InsecurePortTLSALPN01 is the port used to verify tls-alpn-01 challenges. If not
// set it defaults to 443.
//
// This variable can be used for testing purposes.
InsecurePortTLSALPN01 int
) )
// Challenge represents an ACME response Challenge type. // Challenge represents an ACME response Challenge type.
@ -88,9 +67,10 @@ func (ch *Challenge) ToLog() (interface{}, error) {
return string(b), nil return string(b), nil
} }
// Validate attempts to validate the Challenge. Stores changes to the Challenge // Validate attempts to validate the challenge. Stores changes to the Challenge
// type using the DB interface. If the Challenge is validated, the 'status' and // type using the DB interface.
// 'validated' attributes are updated. // satisfactorily validated, the 'status' and 'validated' attributes are
// updated.
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, payload []byte) error { func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, payload []byte) error {
// If already valid or invalid then return without performing validation. // If already valid or invalid then return without performing validation.
if ch.Status != StatusPending { if ch.Status != StatusPending {
@ -105,8 +85,6 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
return tlsalpn01Validate(ctx, ch, db, jwk) return tlsalpn01Validate(ctx, ch, db, jwk)
case DEVICEATTEST01: case DEVICEATTEST01:
return deviceAttest01Validate(ctx, ch, db, jwk, payload) return deviceAttest01Validate(ctx, ch, db, jwk, payload)
case NNS01:
return nns01Validate(ctx, ch, db, jwk)
default: default:
return NewErrorISE("unexpected challenge type '%s'", ch.Type) return NewErrorISE("unexpected challenge type '%s'", ch.Type)
} }
@ -115,12 +93,6 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error { func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
// Append insecure port if set.
// Only used for testing purposes.
if InsecurePortHTTP01 != 0 {
u.Host += ":" + strconv.Itoa(InsecurePortHTTP01)
}
vc := MustClientFromContext(ctx) vc := MustClientFromContext(ctx)
resp, err := vc.Get(u.String()) resp, err := vc.Get(u.String())
if err != nil { if err != nil {
@ -193,14 +165,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
InsecureSkipVerify: true, //nolint:gosec // we expect a self-signed challenge certificate InsecureSkipVerify: true, //nolint:gosec // we expect a self-signed challenge certificate
} }
var hostPort string hostPort := net.JoinHostPort(ch.Value, "443")
// Allow to change TLS port for testing purposes.
if port := InsecurePortTLSALPN01; port == 0 {
hostPort = net.JoinHostPort(ch.Value, "443")
} else {
hostPort = net.JoinHostPort(ch.Value, strconv.Itoa(port))
}
vc := MustClientFromContext(ctx) vc := MustClientFromContext(ctx)
conn, err := vc.TLSDial("tcp", hostPort, config) conn, err := vc.TLSDial("tcp", hostPort, config)
@ -345,26 +310,20 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK
return nil return nil
} }
type payloadType struct { type Payload struct {
AttObj string `json:"attObj"` AttObj string `json:"attObj"`
Error string `json:"error"` Error string `json:"error"`
} }
type attestationObject struct { type AttestationObject struct {
Format string `json:"fmt"` Format string `json:"fmt"`
AttStatement map[string]interface{} `json:"attStmt,omitempty"` AttStatement map[string]interface{} `json:"attStmt,omitempty"`
} }
// TODO(bweeks): move attestation verification to a shared package. // TODO(bweeks): move attestation verification to a shared package.
// TODO(bweeks): define new error type for failed attestation validation.
func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error { func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error {
// Load authorization to store the key fingerprint. var p Payload
az, err := db.GetAuthorization(ctx, ch.AuthorizationID)
if err != nil {
return WrapErrorISE(err, "error loading authorization")
}
// Parse payload.
var p payloadType
if err := json.Unmarshal(payload, &p); err != nil { if err := json.Unmarshal(payload, &p); err != nil {
return WrapErrorISE(err, "error unmarshalling JSON") return WrapErrorISE(err, "error unmarshalling JSON")
} }
@ -378,7 +337,7 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
return WrapErrorISE(err, "error base64 decoding attObj") return WrapErrorISE(err, "error base64 decoding attObj")
} }
att := attestationObject{} att := AttestationObject{}
if err := cbor.Unmarshal(attObj, &att); err != nil { if err := cbor.Unmarshal(attObj, &att); err != nil {
return WrapErrorISE(err, "error unmarshalling CBOR") return WrapErrorISE(err, "error unmarshalling CBOR")
} }
@ -402,6 +361,7 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
} }
return WrapErrorISE(err, "error validating attestation") return WrapErrorISE(err, "error validating attestation")
} }
// Validate nonce with SHA-256 of the token. // Validate nonce with SHA-256 of the token.
if len(data.Nonce) != 0 { if len(data.Nonce) != 0 {
sum := sha256.Sum256([]byte(ch.Token)) sum := sha256.Sum256([]byte(ch.Token))
@ -417,9 +377,6 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
if data.UDID != ch.Value && data.SerialNumber != ch.Value { if data.UDID != ch.Value && data.SerialNumber != ch.Value {
return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "permanent identifier does not match")) return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "permanent identifier does not match"))
} }
// Update attestation key fingerprint to compare against the CSR
az.Fingerprint = data.Fingerprint
case "step": case "step":
data, err := doStepAttestationFormat(ctx, prov, ch, jwk, &att) data, err := doStepAttestationFormat(ctx, prov, ch, jwk, &att)
if err != nil { if err != nil {
@ -433,53 +390,13 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
return WrapErrorISE(err, "error validating attestation") return WrapErrorISE(err, "error validating attestation")
} }
// Validate the YubiKey serial number from the attestation // Validate Apple's ClientIdentifier (Identifier.Value) with device
// certificate with the challenged Order value. // identifiers.
// //
// Note: We might want to use an external service for this. // Note: We might want to use an external service for this.
if data.SerialNumber != ch.Value { if data.SerialNumber != ch.Value {
subproblem := NewSubproblemWithIdentifier( return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "permanent identifier does not match"))
ErrorMalformedType,
Identifier{Type: "permanent-identifier", Value: ch.Value},
"challenge identifier %q doesn't match the attested hardware identifier %q", ch.Value, data.SerialNumber,
)
return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "permanent identifier does not match").AddSubproblems(subproblem))
} }
// Update attestation key fingerprint to compare against the CSR
az.Fingerprint = data.Fingerprint
case "tpm":
data, err := doTPMAttestationFormat(ctx, prov, ch, jwk, &att)
if err != nil {
// TODO(hs): we should provide more details in the error reported to the client;
// "Attestation statement cannot be verified" is VERY generic. Also holds true for the other formats.
var acmeError *Error
if errors.As(err, &acmeError) {
if acmeError.Status == 500 {
return acmeError
}
return storeError(ctx, db, ch, true, acmeError)
}
return WrapErrorISE(err, "error validating attestation")
}
// TODO(hs): currently this will allow a request for which no PermanentIdentifiers have been
// extracted from the AK certificate. This is currently the case for AK certs from the CLI, as we
// haven't implemented a way for AK certs requested by the CLI to always contain the requested
// PermanentIdentifier. Omitting the check below doesn't allow just any request, as the Order can
// still fail if the challenge value isn't equal to the CSR subject.
if len(data.PermanentIdentifiers) > 0 && !slices.Contains(data.PermanentIdentifiers, ch.Value) { // TODO(hs): add support for HardwareModuleName
subproblem := NewSubproblemWithIdentifier(
ErrorMalformedType,
Identifier{Type: "permanent-identifier", Value: ch.Value},
"challenge identifier %q doesn't match any of the attested hardware identifiers %q", ch.Value, data.PermanentIdentifiers,
)
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType, "permanent identifier does not match").AddSubproblems(subproblem))
}
// Update attestation key fingerprint to compare against the CSR
az.Fingerprint = data.Fingerprint
default: default:
return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "unexpected attestation object format")) return storeError(ctx, db, ch, true, NewError(ErrorBadAttestationStatementType, "unexpected attestation object format"))
} }
@ -489,362 +406,12 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
ch.Error = nil ch.Error = nil
ch.ValidatedAt = clock.Now().Format(time.RFC3339) ch.ValidatedAt = clock.Now().Format(time.RFC3339)
// Store the fingerprint in the authorization.
//
// TODO: add method to update authorization and challenge atomically.
if az.Fingerprint != "" {
if err := db.UpdateAuthorization(ctx, az); err != nil {
return WrapErrorISE(err, "error updating authorization")
}
}
if err := db.UpdateChallenge(ctx, ch); err != nil { if err := db.UpdateChallenge(ctx, ch); err != nil {
return WrapErrorISE(err, "error updating challenge") return WrapErrorISE(err, "error updating challenge")
} }
return nil return nil
} }
func nns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
domain := strings.TrimPrefix(ch.Value, "*.")
nnsCtx, ok := GetNNSContext(ctx)
if !ok {
return errors.New("error retrieving NNS context")
}
nns := NNS{}
err := nns.Dial(nnsCtx.nnsServer)
if err != nil {
return err
}
defer nns.Close()
txtRecords, err := nns.GetTXTRecords("acme-challenge." + domain)
if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorNNSType, err,
"error looking up TXT records for domain %s", domain))
}
expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk)
if err != nil {
return err
}
h := sha256.Sum256([]byte(expectedKeyAuth))
expected := base64.RawURLEncoding.EncodeToString(h[:])
var found bool
for _, r := range txtRecords {
if r == expected {
found = true
break
}
}
if !found {
return storeError(ctx, db, ch, false, NewError(ErrorRejectedIdentifierType,
"keyAuthorization does not match; expected %s, but got %s", expectedKeyAuth, txtRecords))
}
// Update and store the challenge.
ch.Status = StatusValid
ch.Error = nil
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil {
return WrapErrorISE(err, "error updating challenge")
}
return nil
}
var (
oidSubjectAlternativeName = asn1.ObjectIdentifier{2, 5, 29, 17}
)
type tpmAttestationData struct {
Certificate *x509.Certificate
VerifiedChains [][]*x509.Certificate
PermanentIdentifiers []string
Fingerprint string
}
// coseAlgorithmIdentifier models a COSEAlgorithmIdentifier.
// Also see https://www.w3.org/TR/webauthn-2/#sctn-alg-identifier.
type coseAlgorithmIdentifier int32
const (
coseAlgES256 coseAlgorithmIdentifier = -7
coseAlgRS256 coseAlgorithmIdentifier = -257
)
func doTPMAttestationFormat(_ context.Context, prov Provisioner, ch *Challenge, jwk *jose.JSONWebKey, att *attestationObject) (*tpmAttestationData, error) {
ver, ok := att.AttStatement["ver"].(string)
if !ok {
return nil, NewError(ErrorBadAttestationStatementType, "ver not present")
}
if ver != "2.0" {
return nil, NewError(ErrorBadAttestationStatementType, "version %q is not supported", ver)
}
x5c, ok := att.AttStatement["x5c"].([]interface{})
if !ok {
return nil, NewError(ErrorBadAttestationStatementType, "x5c not present")
}
if len(x5c) == 0 {
return nil, NewError(ErrorBadAttestationStatementType, "x5c is empty")
}
akCertBytes, ok := x5c[0].([]byte)
if !ok {
return nil, NewError(ErrorBadAttestationStatementType, "x5c is malformed")
}
akCert, err := x509.ParseCertificate(akCertBytes)
if err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is malformed")
}
intermediates := x509.NewCertPool()
for _, v := range x5c[1:] {
intCertBytes, vok := v.([]byte)
if !vok {
return nil, NewError(ErrorBadAttestationStatementType, "x5c is malformed")
}
intCert, err := x509.ParseCertificate(intCertBytes)
if err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is malformed")
}
intermediates.AddCert(intCert)
}
// TODO(hs): this can be removed when permanent-identifier/hardware-module-name are handled correctly in
// the stdlib in https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/x509/parser.go;drc=b5b2cf519fe332891c165077f3723ee74932a647;l=362,
// but I doubt that will happen.
if len(akCert.UnhandledCriticalExtensions) > 0 {
unhandledCriticalExtensions := akCert.UnhandledCriticalExtensions[:0]
for _, extOID := range akCert.UnhandledCriticalExtensions {
if !extOID.Equal(oidSubjectAlternativeName) {
// critical extensions other than the Subject Alternative Name remain unhandled
unhandledCriticalExtensions = append(unhandledCriticalExtensions, extOID)
}
}
akCert.UnhandledCriticalExtensions = unhandledCriticalExtensions
}
roots, ok := prov.GetAttestationRoots()
if !ok {
return nil, NewErrorISE("no root CA bundle available to verify the attestation certificate")
}
// verify that the AK certificate was signed by a trusted root,
// chained to by the intermediates provided by the client. As part
// of building the verified certificate chain, the signature over the
// AK certificate is checked to be a valid signature of one of the
// provided intermediates. Signatures over the intermediates are in
// turn also verified to be valid signatures from one of the trusted
// roots.
verifiedChains, err := akCert.Verify(x509.VerifyOptions{
Roots: roots,
Intermediates: intermediates,
CurrentTime: time.Now().Truncate(time.Second),
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
})
if err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "x5c is not valid")
}
// validate additional AK certificate requirements
if err := validateAKCertificate(akCert); err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "AK certificate is not valid")
}
// TODO(hs): implement revocation check; Verify() doesn't perform CRL check nor OCSP lookup.
sans, err := x509util.ParseSubjectAlternativeNames(akCert)
if err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed parsing AK certificate Subject Alternative Names")
}
permanentIdentifiers := make([]string, len(sans.PermanentIdentifiers))
for i, pi := range sans.PermanentIdentifiers {
permanentIdentifiers[i] = pi.Identifier
}
// extract and validate pubArea, sig, certInfo and alg properties from the request body
pubArea, ok := att.AttStatement["pubArea"].([]byte)
if !ok {
return nil, NewError(ErrorBadAttestationStatementType, "invalid pubArea in attestation statement")
}
if len(pubArea) == 0 {
return nil, NewError(ErrorBadAttestationStatementType, "pubArea is empty")
}
sig, ok := att.AttStatement["sig"].([]byte)
if !ok {
return nil, NewError(ErrorBadAttestationStatementType, "invalid sig in attestation statement")
}
if len(sig) == 0 {
return nil, NewError(ErrorBadAttestationStatementType, "sig is empty")
}
certInfo, ok := att.AttStatement["certInfo"].([]byte)
if !ok {
return nil, NewError(ErrorBadAttestationStatementType, "invalid certInfo in attestation statement")
}
if len(certInfo) == 0 {
return nil, NewError(ErrorBadAttestationStatementType, "certInfo is empty")
}
alg, ok := att.AttStatement["alg"].(int64)
if !ok {
return nil, NewError(ErrorBadAttestationStatementType, "invalid alg in attestation statement")
}
// only RS256 and ES256 are allowed
coseAlg := coseAlgorithmIdentifier(alg)
if coseAlg != coseAlgRS256 && coseAlg != coseAlgES256 {
return nil, NewError(ErrorBadAttestationStatementType, "invalid alg %d in attestation statement", alg)
}
// set the hash algorithm to use to SHA256
hash := crypto.SHA256
// recreate the generated key certification parameter values and verify
// the attested key using the public key of the AK.
certificationParameters := &attest.CertificationParameters{
Public: pubArea, // the public key that was attested
CreateAttestation: certInfo, // the attested properties of the key
CreateSignature: sig, // signature over the attested properties
}
verifyOpts := attest.VerifyOpts{
Public: akCert.PublicKey, // public key of the AK that attested the key
Hash: hash,
}
if err = certificationParameters.Verify(verifyOpts); err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "invalid certification parameters")
}
// decode the "certInfo" data. This won't fail, as it's also done as part of Verify().
tpmCertInfo, err := tpm2.DecodeAttestationData(certInfo)
if err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed decoding attestation data")
}
keyAuth, err := KeyAuthorization(ch.Token, jwk)
if err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed creating key auth digest")
}
hashedKeyAuth := sha256.Sum256([]byte(keyAuth))
// verify the WebAuthn object contains the expect key authorization digest, which is carried
// within the encoded `certInfo` property of the attestation statement.
if subtle.ConstantTimeCompare(hashedKeyAuth[:], []byte(tpmCertInfo.ExtraData)) == 0 {
return nil, NewError(ErrorBadAttestationStatementType, "key authorization does not match")
}
// decode the (attested) public key and determine its fingerprint. This won't fail, as it's also done as part of Verify().
pub, err := tpm2.DecodePublic(pubArea)
if err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed decoding pubArea")
}
publicKey, err := pub.Key()
if err != nil {
return nil, WrapError(ErrorBadAttestationStatementType, err, "failed getting public key")
}
data := &tpmAttestationData{
Certificate: akCert,
VerifiedChains: verifiedChains,
PermanentIdentifiers: permanentIdentifiers,
}
if data.Fingerprint, err = keyutil.Fingerprint(publicKey); err != nil {
return nil, WrapErrorISE(err, "error calculating key fingerprint")
}
// TODO(hs): pass more attestation data, so that that can be used/recorded too?
return data, nil
}
var (
oidExtensionExtendedKeyUsage = asn1.ObjectIdentifier{2, 5, 29, 37}
oidTCGKpAIKCertificate = asn1.ObjectIdentifier{2, 23, 133, 8, 3}
)
// validateAKCertifiate validates the X.509 AK certificate to be
// in accordance with the required properties. The requirements come from:
// https://www.w3.org/TR/webauthn-2/#sctn-tpm-cert-requirements.
//
// - Version MUST be set to 3.
// - Subject field MUST be set to empty.
// - The Subject Alternative Name extension MUST be set as defined
// in [TPMv2-EK-Profile] section 3.2.9.
// - The Extended Key Usage extension MUST contain the OID 2.23.133.8.3
// ("joint-iso-itu-t(2) internationalorganizations(23) 133 tcg-kp(8) tcg-kp-AIKCertificate(3)").
// - The Basic Constraints extension MUST have the CA component set to false.
// - An Authority Information Access (AIA) extension with entry id-ad-ocsp
// and a CRL Distribution Point extension [RFC5280] are both OPTIONAL as
// the status of many attestation certificates is available through metadata
// services. See, for example, the FIDO Metadata Service.
func validateAKCertificate(c *x509.Certificate) error {
if c.Version != 3 {
return fmt.Errorf("AK certificate has invalid version %d; only version 3 is allowed", c.Version)
}
if c.Subject.String() != "" {
return fmt.Errorf("AK certificate subject must be empty; got %q", c.Subject)
}
if c.IsCA {
return errors.New("AK certificate must not be a CA")
}
if err := validateAKCertificateExtendedKeyUsage(c); err != nil {
return err
}
return validateAKCertificateSubjectAlternativeNames(c)
}
// validateAKCertificateSubjectAlternativeNames checks if the AK certificate
// has TPM hardware details set.
func validateAKCertificateSubjectAlternativeNames(c *x509.Certificate) error {
sans, err := x509util.ParseSubjectAlternativeNames(c)
if err != nil {
return fmt.Errorf("failed parsing AK certificate Subject Alternative Names: %w", err)
}
details := sans.TPMHardwareDetails
manufacturer, model, version := details.Manufacturer, details.Model, details.Version
switch {
case manufacturer == "":
return errors.New("missing TPM manufacturer")
case model == "":
return errors.New("missing TPM model")
case version == "":
return errors.New("missing TPM version")
}
return nil
}
// validateAKCertificateExtendedKeyUsage checks if the AK certificate
// has the "tcg-kp-AIKCertificate" Extended Key Usage set.
func validateAKCertificateExtendedKeyUsage(c *x509.Certificate) error {
var (
valid = false
ekus []asn1.ObjectIdentifier
)
for _, ext := range c.Extensions {
if ext.Id.Equal(oidExtensionExtendedKeyUsage) {
if _, err := asn1.Unmarshal(ext.Value, &ekus); err != nil || !ekus[0].Equal(oidTCGKpAIKCertificate) {
return errors.New("AK certificate is missing Extended Key Usage value tcg-kp-AIKCertificate (2.23.133.8.3)")
}
valid = true
}
}
if !valid {
return errors.New("AK certificate is missing Extended Key Usage extension")
}
return nil
}
// Apple Enterprise Attestation Root CA from // Apple Enterprise Attestation Root CA from
// https://www.apple.com/certificateauthority/private/ // https://www.apple.com/certificateauthority/private/
const appleEnterpriseAttestationRootCA = `-----BEGIN CERTIFICATE----- const appleEnterpriseAttestationRootCA = `-----BEGIN CERTIFICATE-----
@ -875,10 +442,9 @@ type appleAttestationData struct {
UDID string UDID string
SEPVersion string SEPVersion string
Certificate *x509.Certificate Certificate *x509.Certificate
Fingerprint string
} }
func doAppleAttestationFormat(_ context.Context, prov Provisioner, _ *Challenge, att *attestationObject) (*appleAttestationData, error) { func doAppleAttestationFormat(ctx context.Context, prov Provisioner, ch *Challenge, att *AttestationObject) (*appleAttestationData, error) {
// Use configured or default attestation roots if none is configured. // Use configured or default attestation roots if none is configured.
roots, ok := prov.GetAttestationRoots() roots, ok := prov.GetAttestationRoots()
if !ok { if !ok {
@ -932,9 +498,6 @@ func doAppleAttestationFormat(_ context.Context, prov Provisioner, _ *Challenge,
data := &appleAttestationData{ data := &appleAttestationData{
Certificate: leaf, Certificate: leaf,
} }
if data.Fingerprint, err = keyutil.Fingerprint(leaf.PublicKey); err != nil {
return nil, WrapErrorISE(err, "error calculating key fingerprint")
}
for _, ext := range leaf.Extensions { for _, ext := range leaf.Extensions {
switch { switch {
case ext.Id.Equal(oidAppleSerialNumber): case ext.Id.Equal(oidAppleSerialNumber):
@ -980,10 +543,9 @@ var oidYubicoSerialNumber = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 41482, 3, 7}
type stepAttestationData struct { type stepAttestationData struct {
Certificate *x509.Certificate Certificate *x509.Certificate
SerialNumber string SerialNumber string
Fingerprint string
} }
func doStepAttestationFormat(_ context.Context, prov Provisioner, ch *Challenge, jwk *jose.JSONWebKey, att *attestationObject) (*stepAttestationData, error) { func doStepAttestationFormat(ctx context.Context, prov Provisioner, ch *Challenge, jwk *jose.JSONWebKey, att *AttestationObject) (*stepAttestationData, error) {
// Use configured or default attestation roots if none is configured. // Use configured or default attestation roots if none is configured.
roots, ok := prov.GetAttestationRoots() roots, ok := prov.GetAttestationRoots()
if !ok { if !ok {
@ -1076,9 +638,6 @@ func doStepAttestationFormat(_ context.Context, prov Provisioner, ch *Challenge,
data := &stepAttestationData{ data := &stepAttestationData{
Certificate: leaf, Certificate: leaf,
} }
if data.Fingerprint, err = keyutil.Fingerprint(leaf.PublicKey); err != nil {
return nil, WrapErrorISE(err, "error calculating key fingerprint")
}
for _, ext := range leaf.Extensions { for _, ext := range leaf.Extensions {
if !ext.Id.Equal(oidYubicoSerialNumber) { if !ext.Id.Equal(oidYubicoSerialNumber) {
continue continue
@ -1142,10 +701,10 @@ func uitoa(val uint) string {
var buf [20]byte // big enough for 64bit value base 10 var buf [20]byte // big enough for 64bit value base 10
i := len(buf) - 1 i := len(buf) - 1
for val >= 10 { for val >= 10 {
v := val / 10 q := val / 10
buf[i] = byte('0' + val - v*10) buf[i] = byte('0' + val - q*10)
i-- i--
val = v val = q
} }
// val < 10 // val < 10
buf[i] = byte('0' + val) buf[i] = byte('0' + val)

File diff suppressed because it is too large Load diff

View file

@ -1,860 +0,0 @@
//go:build tpmsimulator
// +build tpmsimulator
package acme
import (
"context"
"crypto"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/url"
"testing"
"github.com/fxamacker/cbor/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/smallstep/go-attestation/attest"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/minica"
"go.step.sm/crypto/tpm"
"go.step.sm/crypto/tpm/simulator"
tpmstorage "go.step.sm/crypto/tpm/storage"
"go.step.sm/crypto/x509util"
)
func newSimulatedTPM(t *testing.T) *tpm.TPM {
t.Helper()
tmpDir := t.TempDir()
tpm, err := tpm.New(withSimulator(t), tpm.WithStore(tpmstorage.NewDirstore(tmpDir))) // TODO: provide in-memory storage implementation instead
require.NoError(t, err)
return tpm
}
func withSimulator(t *testing.T) tpm.NewTPMOption {
t.Helper()
var sim simulator.Simulator
t.Cleanup(func() {
if sim == nil {
return
}
err := sim.Close()
require.NoError(t, err)
})
sim, err := simulator.New()
require.NoError(t, err)
err = sim.Open()
require.NoError(t, err)
return tpm.WithSimulator(sim)
}
func generateKeyID(t *testing.T, pub crypto.PublicKey) []byte {
t.Helper()
b, err := x509.MarshalPKIXPublicKey(pub)
require.NoError(t, err)
hash := sha256.Sum256(b)
return hash[:]
}
func mustAttestTPM(t *testing.T, keyAuthorization string, permanentIdentifiers []string) ([]byte, crypto.Signer, *x509.Certificate) {
t.Helper()
aca, err := minica.New(
minica.WithName("TPM Testing"),
minica.WithGetSignerFunc(
func() (crypto.Signer, error) {
return keyutil.GenerateSigner("RSA", "", 2048)
},
),
)
require.NoError(t, err)
// prepare simulated TPM and create an AK
stpm := newSimulatedTPM(t)
eks, err := stpm.GetEKs(context.Background())
require.NoError(t, err)
ak, err := stpm.CreateAK(context.Background(), "first-ak")
require.NoError(t, err)
require.NotNil(t, ak)
// extract the AK public key // TODO(hs): replace this when there's a simpler method to get the AK public key (e.g. ak.Public())
ap, err := ak.AttestationParameters(context.Background())
require.NoError(t, err)
akp, err := attest.ParseAKPublic(attest.TPMVersion20, ap.Public)
require.NoError(t, err)
// create template and sign certificate for the AK public key
keyID := generateKeyID(t, eks[0].Public())
template := &x509.Certificate{
PublicKey: akp.Public,
IsCA: false,
UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate},
}
sans := []x509util.SubjectAlternativeName{}
uris := []*url.URL{{Scheme: "urn", Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID)}}
for _, pi := range permanentIdentifiers {
sans = append(sans, x509util.SubjectAlternativeName{
Type: x509util.PermanentIdentifierType,
Value: pi,
})
}
asn1Value := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55"))
sans = append(sans, x509util.SubjectAlternativeName{
Type: x509util.DirectoryNameType,
ASN1Value: asn1Value,
})
ext, err := createSubjectAltNameExtension(nil, nil, nil, uris, sans, true)
require.NoError(t, err)
ext.Set(template)
akCert, err := aca.Sign(template)
require.NoError(t, err)
require.NotNil(t, akCert)
// create a new key attested by the AK, while including
// the key authorization bytes as qualifying data.
keyAuthSum := sha256.Sum256([]byte(keyAuthorization))
config := tpm.AttestKeyConfig{
Algorithm: "RSA",
Size: 2048,
QualifyingData: keyAuthSum[:],
}
key, err := stpm.AttestKey(context.Background(), "first-ak", "first-key", config)
require.NoError(t, err)
require.NotNil(t, key)
require.Equal(t, "first-key", key.Name())
require.NotEqual(t, 0, len(key.Data()))
require.Equal(t, "first-ak", key.AttestedBy())
require.True(t, key.WasAttested())
require.True(t, key.WasAttestedBy(ak))
signer, err := key.Signer(context.Background())
require.NoError(t, err)
// prepare the attestation object with the AK certificate chain,
// the attested key, its metadata and the signature signed by the
// AK.
params, err := key.CertificationParameters(context.Background())
require.NoError(t, err)
attObj, err := cbor.Marshal(struct {
Format string `json:"fmt"`
AttStatement map[string]interface{} `json:"attStmt,omitempty"`
}{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
})
require.NoError(t, err)
// marshal the ACME payload
payload, err := json.Marshal(struct {
AttObj string `json:"attObj"`
}{
AttObj: base64.RawURLEncoding.EncodeToString(attObj),
})
require.NoError(t, err)
return payload, signer, aca.Root
}
func Test_deviceAttest01ValidateWithTPMSimulator(t *testing.T) {
type args struct {
ctx context.Context
ch *Challenge
db DB
jwk *jose.JSONWebKey
payload []byte
}
type test struct {
args args
wantErr *Error
}
tests := map[string]func(t *testing.T) test{
"ok/doTPMAttestationFormat-storeError": func(t *testing.T) test {
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
payload, _, root := mustAttestTPM(t, keyAuth, nil) // TODO: value(s) for AK cert?
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
// parse payload, set invalid "ver", remarshal
var p payloadType
err := json.Unmarshal(payload, &p)
require.NoError(t, err)
attObj, err := base64.RawURLEncoding.DecodeString(p.AttObj)
require.NoError(t, err)
att := attestationObject{}
err = cbor.Unmarshal(attObj, &att)
require.NoError(t, err)
att.AttStatement["ver"] = "bogus"
attObj, err = cbor.Marshal(struct {
Format string `json:"fmt"`
AttStatement map[string]interface{} `json:"attStmt,omitempty"`
}{
Format: "tpm",
AttStatement: att.AttStatement,
})
require.NoError(t, err)
payload, err = json.Marshal(struct {
AttObj string `json:"attObj"`
}{
AttObj: base64.RawURLEncoding.EncodeToString(attObj),
})
require.NoError(t, err)
return test{
args: args{
ctx: ctx,
jwk: jwk,
ch: &Challenge{
ID: "chID",
AuthorizationID: "azID",
Token: "token",
Type: "device-attest-01",
Status: StatusPending,
Value: "device.id.12345678",
},
payload: payload,
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
assert.Equal(t, "azID", id)
return &Authorization{ID: "azID"}, nil
},
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "device.id.12345678", updch.Value)
err := NewError(ErrorBadAttestationStatementType, `version "bogus" is not supported`)
assert.EqualError(t, updch.Error.Err, err.Err.Error())
assert.Equal(t, err.Type, updch.Error.Type)
assert.Equal(t, err.Detail, updch.Error.Detail)
assert.Equal(t, err.Status, updch.Error.Status)
assert.Equal(t, err.Subproblems, updch.Error.Subproblems)
return nil
},
},
},
wantErr: nil,
}
},
"ok with invalid PermanentIdentifier SAN": func(t *testing.T) test {
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
payload, _, root := mustAttestTPM(t, keyAuth, []string{"device.id.12345678"}) // TODO: value(s) for AK cert?
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
return test{
args: args{
ctx: ctx,
jwk: jwk,
ch: &Challenge{
ID: "chID",
AuthorizationID: "azID",
Token: "token",
Type: "device-attest-01",
Status: StatusPending,
Value: "device.id.99999999",
},
payload: payload,
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
assert.Equal(t, "azID", id)
return &Authorization{ID: "azID"}, nil
},
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "device.id.99999999", updch.Value)
err := NewError(ErrorRejectedIdentifierType, `permanent identifier does not match`).
AddSubproblems(NewSubproblemWithIdentifier(
ErrorMalformedType,
Identifier{Type: "permanent-identifier", Value: "device.id.99999999"},
`challenge identifier "device.id.99999999" doesn't match any of the attested hardware identifiers ["device.id.12345678"]`,
))
assert.EqualError(t, updch.Error.Err, err.Err.Error())
assert.Equal(t, err.Type, updch.Error.Type)
assert.Equal(t, err.Detail, updch.Error.Detail)
assert.Equal(t, err.Status, updch.Error.Status)
assert.Equal(t, err.Subproblems, updch.Error.Subproblems)
return nil
},
},
},
wantErr: nil,
}
},
"ok": func(t *testing.T) test {
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
payload, signer, root := mustAttestTPM(t, keyAuth, nil) // TODO: value(s) for AK cert?
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
return test{
args: args{
ctx: ctx,
jwk: jwk,
ch: &Challenge{
ID: "chID",
AuthorizationID: "azID",
Token: "token",
Type: "device-attest-01",
Status: StatusPending,
Value: "device.id.12345678",
},
payload: payload,
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
assert.Equal(t, "azID", id)
return &Authorization{ID: "azID"}, nil
},
MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error {
fingerprint, err := keyutil.Fingerprint(signer.Public())
assert.NoError(t, err)
assert.Equal(t, "azID", az.ID)
assert.Equal(t, fingerprint, az.Fingerprint)
return nil
},
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "device.id.12345678", updch.Value)
return nil
},
},
},
wantErr: nil,
}
},
"ok with PermanentIdentifier SAN": func(t *testing.T) test {
jwk, keyAuth := mustAccountAndKeyAuthorization(t, "token")
payload, signer, root := mustAttestTPM(t, keyAuth, []string{"device.id.12345678"}) // TODO: value(s) for AK cert?
caRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: root.Raw})
ctx := NewProvisionerContext(context.Background(), mustAttestationProvisioner(t, caRoot))
return test{
args: args{
ctx: ctx,
jwk: jwk,
ch: &Challenge{
ID: "chID",
AuthorizationID: "azID",
Token: "token",
Type: "device-attest-01",
Status: StatusPending,
Value: "device.id.12345678",
},
payload: payload,
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
assert.Equal(t, "azID", id)
return &Authorization{ID: "azID"}, nil
},
MockUpdateAuthorization: func(ctx context.Context, az *Authorization) error {
fingerprint, err := keyutil.Fingerprint(signer.Public())
assert.NoError(t, err)
assert.Equal(t, "azID", az.ID)
assert.Equal(t, fingerprint, az.Fingerprint)
return nil
},
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
assert.Equal(t, "chID", updch.ID)
assert.Equal(t, "token", updch.Token)
assert.Equal(t, StatusValid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "device.id.12345678", updch.Value)
return nil
},
},
},
wantErr: nil,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := deviceAttest01Validate(tc.args.ctx, tc.args.ch, tc.args.db, tc.args.jwk, tc.args.payload); err != nil {
assert.Error(t, tc.wantErr)
assert.EqualError(t, err, tc.wantErr.Error())
return
}
assert.Nil(t, tc.wantErr)
})
}
}
func newBadAttestationStatementError(msg string) *Error {
return &Error{
Type: "urn:ietf:params:acme:error:badAttestationStatement",
Status: 400,
Err: errors.New(msg),
}
}
func newInternalServerError(msg string) *Error {
return &Error{
Type: "urn:ietf:params:acme:error:serverInternal",
Status: 500,
Err: errors.New(msg),
}
}
var (
oidPermanentIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3}
oidHardwareModuleNameIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 4}
)
func Test_doTPMAttestationFormat(t *testing.T) {
ctx := context.Background()
aca, err := minica.New(
minica.WithName("TPM Testing"),
minica.WithGetSignerFunc(
func() (crypto.Signer, error) {
return keyutil.GenerateSigner("RSA", "", 2048)
},
),
)
require.NoError(t, err)
acaRoot := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: aca.Root.Raw})
// prepare simulated TPM and create an AK
stpm := newSimulatedTPM(t)
eks, err := stpm.GetEKs(context.Background())
require.NoError(t, err)
ak, err := stpm.CreateAK(context.Background(), "first-ak")
require.NoError(t, err)
require.NotNil(t, ak)
// extract the AK public key // TODO(hs): replace this when there's a simpler method to get the AK public key (e.g. ak.Public())
ap, err := ak.AttestationParameters(context.Background())
require.NoError(t, err)
akp, err := attest.ParseAKPublic(attest.TPMVersion20, ap.Public)
require.NoError(t, err)
// create template and sign certificate for the AK public key
keyID := generateKeyID(t, eks[0].Public())
template := &x509.Certificate{
PublicKey: akp.Public,
IsCA: false,
UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate},
}
sans := []x509util.SubjectAlternativeName{}
uris := []*url.URL{{Scheme: "urn", Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID)}}
asn1Value := []byte(fmt.Sprintf(`{"extraNames":[{"type": %q, "value": %q},{"type": %q, "value": %q},{"type": %q, "value": %q}]}`, oidTPMManufacturer, "1414747215", oidTPMModel, "SLB 9670 TPM2.0", oidTPMVersion, "7.55"))
sans = append(sans, x509util.SubjectAlternativeName{
Type: x509util.DirectoryNameType,
ASN1Value: asn1Value,
})
ext, err := createSubjectAltNameExtension(nil, nil, nil, uris, sans, true)
require.NoError(t, err)
ext.Set(template)
akCert, err := aca.Sign(template)
require.NoError(t, err)
require.NotNil(t, akCert)
invalidTemplate := &x509.Certificate{
PublicKey: akp.Public,
IsCA: false,
UnknownExtKeyUsage: []asn1.ObjectIdentifier{oidTCGKpAIKCertificate},
}
invalidAKCert, err := aca.Sign(invalidTemplate)
require.NoError(t, err)
require.NotNil(t, invalidAKCert)
// generate a JWK and the key authorization value
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
require.NoError(t, err)
keyAuthorization, err := KeyAuthorization("token", jwk)
require.NoError(t, err)
// create a new key attested by the AK, while including
// the key authorization bytes as qualifying data.
keyAuthSum := sha256.Sum256([]byte(keyAuthorization))
config := tpm.AttestKeyConfig{
Algorithm: "RSA",
Size: 2048,
QualifyingData: keyAuthSum[:],
}
key, err := stpm.AttestKey(context.Background(), "first-ak", "first-key", config)
require.NoError(t, err)
require.NotNil(t, key)
params, err := key.CertificationParameters(context.Background())
require.NoError(t, err)
signer, err := key.Signer(context.Background())
require.NoError(t, err)
fingerprint, err := keyutil.Fingerprint(signer.Public())
require.NoError(t, err)
// attest another key and get its certification parameters
anotherKey, err := stpm.AttestKey(context.Background(), "first-ak", "another-key", config)
require.NoError(t, err)
require.NotNil(t, key)
anotherKeyParams, err := anotherKey.CertificationParameters(context.Background())
require.NoError(t, err)
type args struct {
ctx context.Context
prov Provisioner
ch *Challenge
jwk *jose.JSONWebKey
att *attestationObject
}
tests := []struct {
name string
args args
want *tpmAttestationData
expErr *Error
}{
{"ok", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, nil},
{"fail ver not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("ver not present")},
{"fail ver type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": []interface{}{},
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("ver not present")},
{"fail bogus ver", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "bogus",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError(`version "bogus" is not supported`)},
{"fail x5c not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("x5c not present")},
{"fail x5c type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": [][]byte{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("x5c not present")},
{"fail x5c empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("x5c is empty")},
{"fail leaf type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "step",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{"leaf", aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("x5c is malformed")},
{"fail leaf parse", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "step",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw[:100], aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("x5c is malformed: x509: malformed certificate")},
{"fail intermediate type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "step",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, "intermediate"},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("x5c is malformed")},
{"fail intermediate parse", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "step",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw[:100]},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("x5c is malformed: x509: malformed certificate")},
{"fail roots", args{ctx, mustAttestationProvisioner(t, nil), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newInternalServerError("no root CA bundle available to verify the attestation certificate")},
{"fail verify", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "step",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("x5c is not valid: x509: certificate signed by unknown authority")},
{"fail validateAKCertificate", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{invalidAKCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("AK certificate is not valid: missing TPM manufacturer")},
{"fail pubArea not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
},
}}, nil, newBadAttestationStatementError("invalid pubArea in attestation statement")},
{"fail pubArea type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": []interface{}{},
},
}}, nil, newBadAttestationStatementError("invalid pubArea in attestation statement")},
{"fail pubArea empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": []byte{},
},
}}, nil, newBadAttestationStatementError("pubArea is empty")},
{"fail sig not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("invalid sig in attestation statement")},
{"fail sig type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": []interface{}{},
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("invalid sig in attestation statement")},
{"fail sig empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": []byte{},
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("sig is empty")},
{"fail certInfo not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("invalid certInfo in attestation statement")},
{"fail certInfo type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": []interface{}{},
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("invalid certInfo in attestation statement")},
{"fail certInfo empty", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": []byte{},
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("certInfo is empty")},
{"fail alg not present", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("invalid alg in attestation statement")},
{"fail alg type", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(0), // invalid alg
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("invalid alg 0 in attestation statement")},
{"fail attestation verification", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": anotherKeyParams.Public,
},
}}, nil, newBadAttestationStatementError("invalid certification parameters: certification refers to a different key")},
{"fail keyAuthorization", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "token"}, &jose.JSONWebKey{Key: []byte("not an asymmetric key")}, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), // RS256
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newInternalServerError("failed creating key auth digest: error generating JWK thumbprint: square/go-jose: unknown key type '[]uint8'")},
{"fail different keyAuthorization", args{ctx, mustAttestationProvisioner(t, acaRoot), &Challenge{Token: "aDifferentToken"}, jwk, &attestationObject{
Format: "tpm",
AttStatement: map[string]interface{}{
"ver": "2.0",
"x5c": []interface{}{akCert.Raw, aca.Intermediate.Raw},
"alg": int64(-257), //
"sig": params.CreateSignature,
"certInfo": params.CreateAttestation,
"pubArea": params.Public,
},
}}, nil, newBadAttestationStatementError("key authorization does not match")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := doTPMAttestationFormat(tt.args.ctx, tt.args.prov, tt.args.ch, tt.args.jwk, tt.args.att)
if tt.expErr != nil {
var ae *Error
if assert.True(t, errors.As(err, &ae)) {
assert.EqualError(t, err, tt.expErr.Error())
assert.Equal(t, ae.StatusCode(), tt.expErr.StatusCode())
assert.Equal(t, ae.Type, tt.expErr.Type)
}
assert.Nil(t, got)
return
}
assert.NoError(t, err)
if assert.NotNil(t, got) {
assert.Equal(t, akCert, got.Certificate)
assert.Equal(t, [][]*x509.Certificate{
{
akCert, aca.Intermediate, aca.Root,
},
}, got.VerifiedChains)
assert.Equal(t, fingerprint, got.Fingerprint)
assert.Empty(t, got.PermanentIdentifiers) // currently expected to be always empty
}
})
}
}

View file

@ -72,7 +72,12 @@ func (c *client) Get(url string) (*http.Response, error) {
} }
func (c *client) LookupTxt(name string) ([]string, error) { func (c *client) LookupTxt(name string) ([]string, error) {
return net.LookupTXT(name) // chase CNAME records, if any
cname, err := net.LookupCNAME(name)
if err != nil {
return nil, err
}
return net.LookupTXT(cname)
} }
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {

View file

@ -29,12 +29,10 @@ type CertificateAuthority interface {
} }
// NewContext adds the given acme components to the context. // NewContext adds the given acme components to the context.
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker, func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
nnsServer string) context.Context {
ctx = NewDatabaseContext(ctx, db) ctx = NewDatabaseContext(ctx, db)
ctx = NewClientContext(ctx, client) ctx = NewClientContext(ctx, client)
ctx = NewLinkerContext(ctx, linker) ctx = NewLinkerContext(ctx, linker)
ctx = NewNNSContext(ctx, nnsServer)
// Prerequisite checker is optional. // Prerequisite checker is optional.
if fn != nil { if fn != nil {
ctx = NewPrerequisitesCheckerContext(ctx, fn) ctx = NewPrerequisitesCheckerContext(ctx, fn)
@ -48,7 +46,7 @@ type PrerequisitesChecker func(ctx context.Context) (bool, error)
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns // DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
// always true. // always true.
func DefaultPrerequisitesChecker(context.Context) (bool, error) { func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
return true, nil return true, nil
} }

View file

@ -12,12 +12,6 @@ import (
// account. // account.
var ErrNotFound = errors.New("not found") var ErrNotFound = errors.New("not found")
// IsErrNotFound returns true if the error is a "not found" error. Returns false
// otherwise.
func IsErrNotFound(err error) bool {
return errors.Is(err, ErrNotFound)
}
// DB is the DB interface expected by the step-ca ACME API. // DB is the DB interface expected by the step-ca ACME API.
type DB interface { type DB interface {
CreateAccount(ctx context.Context, acc *Account) error CreateAccount(ctx context.Context, acc *Account) error

View file

@ -17,8 +17,6 @@ type dbAccount struct {
Key *jose.JSONWebKey `json:"key"` Key *jose.JSONWebKey `json:"key"`
Contact []string `json:"contact,omitempty"` Contact []string `json:"contact,omitempty"`
Status acme.Status `json:"status"` Status acme.Status `json:"status"`
LocationPrefix string `json:"locationPrefix"`
ProvisionerName string `json:"provisionerName"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
DeactivatedAt time.Time `json:"deactivatedAt"` DeactivatedAt time.Time `json:"deactivatedAt"`
} }
@ -28,7 +26,7 @@ func (dba *dbAccount) clone() *dbAccount {
return &nu return &nu
} }
func (db *DB) getAccountIDByKeyID(_ context.Context, kid string) (string, error) { func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) {
id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
if err != nil { if err != nil {
if nosqlDB.IsErrNotFound(err) { if nosqlDB.IsErrNotFound(err) {
@ -40,7 +38,7 @@ func (db *DB) getAccountIDByKeyID(_ context.Context, kid string) (string, error)
} }
// getDBAccount retrieves and unmarshals dbAccount. // getDBAccount retrieves and unmarshals dbAccount.
func (db *DB) getDBAccount(_ context.Context, id string) (*dbAccount, error) { func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) {
data, err := db.db.Get(accountTable, []byte(id)) data, err := db.db.Get(accountTable, []byte(id))
if err != nil { if err != nil {
if nosqlDB.IsErrNotFound(err) { if nosqlDB.IsErrNotFound(err) {
@ -68,8 +66,6 @@ func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error)
Contact: dbacc.Contact, Contact: dbacc.Contact,
Key: dbacc.Key, Key: dbacc.Key,
ID: dbacc.ID, ID: dbacc.ID,
LocationPrefix: dbacc.LocationPrefix,
ProvisionerName: dbacc.ProvisionerName,
}, nil }, nil
} }
@ -96,8 +92,6 @@ func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
Contact: acc.Contact, Contact: acc.Contact,
Status: acc.Status, Status: acc.Status,
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
LocationPrefix: acc.LocationPrefix,
ProvisionerName: acc.ProvisionerName,
} }
kid, err := acme.KeyToID(dba.Key) kid, err := acme.KeyToID(dba.Key)

View file

@ -197,8 +197,6 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
func TestDB_GetAccount(t *testing.T) { func TestDB_GetAccount(t *testing.T) {
accID := "accID" accID := "accID"
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
provisionerName := "foo"
type test struct { type test struct {
db nosql.DB db nosql.DB
err error err error
@ -230,8 +228,6 @@ func TestDB_GetAccount(t *testing.T) {
DeactivatedAt: now, DeactivatedAt: now,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
Key: jwk, Key: jwk,
LocationPrefix: locationPrefix,
ProvisionerName: provisionerName,
} }
b, err := json.Marshal(dbacc) b, err := json.Marshal(dbacc)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -270,8 +266,6 @@ func TestDB_GetAccount(t *testing.T) {
assert.Equals(t, acc.ID, tc.dbacc.ID) assert.Equals(t, acc.ID, tc.dbacc.ID)
assert.Equals(t, acc.Status, tc.dbacc.Status) assert.Equals(t, acc.Status, tc.dbacc.Status)
assert.Equals(t, acc.Contact, tc.dbacc.Contact) assert.Equals(t, acc.Contact, tc.dbacc.Contact)
assert.Equals(t, acc.LocationPrefix, tc.dbacc.LocationPrefix)
assert.Equals(t, acc.ProvisionerName, tc.dbacc.ProvisionerName)
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID) assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
} }
}) })
@ -385,7 +379,6 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
} }
func TestDB_CreateAccount(t *testing.T) { func TestDB_CreateAccount(t *testing.T) {
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
type test struct { type test struct {
db nosql.DB db nosql.DB
acc *acme.Account acc *acme.Account
@ -400,7 +393,6 @@ func TestDB_CreateAccount(t *testing.T) {
Status: acme.StatusValid, Status: acme.StatusValid,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
Key: jwk, Key: jwk,
LocationPrefix: locationPrefix,
} }
return test{ return test{
db: &db.MockNoSQLDB{ db: &db.MockNoSQLDB{
@ -424,7 +416,6 @@ func TestDB_CreateAccount(t *testing.T) {
Status: acme.StatusValid, Status: acme.StatusValid,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
Key: jwk, Key: jwk,
LocationPrefix: locationPrefix,
} }
return test{ return test{
db: &db.MockNoSQLDB{ db: &db.MockNoSQLDB{
@ -448,7 +439,6 @@ func TestDB_CreateAccount(t *testing.T) {
Status: acme.StatusValid, Status: acme.StatusValid,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
Key: jwk, Key: jwk,
LocationPrefix: locationPrefix,
} }
return test{ return test{
db: &db.MockNoSQLDB{ db: &db.MockNoSQLDB{
@ -466,8 +456,6 @@ func TestDB_CreateAccount(t *testing.T) {
assert.FatalError(t, json.Unmarshal(nu, dbacc)) assert.FatalError(t, json.Unmarshal(nu, dbacc))
assert.Equals(t, dbacc.ID, string(key)) assert.Equals(t, dbacc.ID, string(key))
assert.Equals(t, dbacc.Contact, acc.Contact) assert.Equals(t, dbacc.Contact, acc.Contact)
assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix)
assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName)
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
@ -494,7 +482,6 @@ func TestDB_CreateAccount(t *testing.T) {
Status: acme.StatusValid, Status: acme.StatusValid,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
Key: jwk, Key: jwk,
LocationPrefix: locationPrefix,
} }
return test{ return test{
db: &db.MockNoSQLDB{ db: &db.MockNoSQLDB{
@ -513,8 +500,6 @@ func TestDB_CreateAccount(t *testing.T) {
assert.FatalError(t, json.Unmarshal(nu, dbacc)) assert.FatalError(t, json.Unmarshal(nu, dbacc))
assert.Equals(t, dbacc.ID, string(key)) assert.Equals(t, dbacc.ID, string(key))
assert.Equals(t, dbacc.Contact, acc.Contact) assert.Equals(t, dbacc.Contact, acc.Contact)
assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix)
assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName)
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID) assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt)) assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
@ -559,8 +544,6 @@ func TestDB_UpdateAccount(t *testing.T) {
CreatedAt: now, CreatedAt: now,
DeactivatedAt: now, DeactivatedAt: now,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
LocationPrefix: "foo",
ProvisionerName: "alpha",
Key: jwk, Key: jwk,
} }
b, err := json.Marshal(dbacc) b, err := json.Marshal(dbacc)
@ -663,9 +646,7 @@ func TestDB_UpdateAccount(t *testing.T) {
acc := &acme.Account{ acc := &acme.Account{
ID: accID, ID: accID,
Status: acme.StatusDeactivated, Status: acme.StatusDeactivated,
Contact: []string{"baz", "zap"}, Contact: []string{"foo", "bar"},
LocationPrefix: "bar",
ProvisionerName: "beta",
Key: jwk, Key: jwk,
} }
return test{ return test{
@ -685,10 +666,7 @@ func TestDB_UpdateAccount(t *testing.T) {
assert.FatalError(t, json.Unmarshal(nu, dbNew)) assert.FatalError(t, json.Unmarshal(nu, dbNew))
assert.Equals(t, dbNew.ID, dbacc.ID) assert.Equals(t, dbNew.ID, dbacc.ID)
assert.Equals(t, dbNew.Status, acc.Status) assert.Equals(t, dbNew.Status, acc.Status)
assert.Equals(t, dbNew.Contact, acc.Contact) assert.Equals(t, dbNew.Contact, dbacc.Contact)
// LocationPrefix should not change.
assert.Equals(t, dbNew.LocationPrefix, dbacc.LocationPrefix)
assert.Equals(t, dbNew.ProvisionerName, dbacc.ProvisionerName)
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID) assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt) assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now)) assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
@ -708,7 +686,12 @@ func TestDB_UpdateAccount(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
assert.Nil(t, tc.err) if assert.Nil(t, tc.err) {
assert.Equals(t, tc.acc.ID, dbacc.ID)
assert.Equals(t, tc.acc.Status, dbacc.Status)
assert.Equals(t, tc.acc.Contact, dbacc.Contact)
assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID)
}
} }
}) })
} }

View file

@ -17,7 +17,6 @@ type dbAuthz struct {
Identifier acme.Identifier `json:"identifier"` Identifier acme.Identifier `json:"identifier"`
Status acme.Status `json:"status"` Status acme.Status `json:"status"`
Token string `json:"token"` Token string `json:"token"`
Fingerprint string `json:"fingerprint,omitempty"`
ChallengeIDs []string `json:"challengeIDs"` ChallengeIDs []string `json:"challengeIDs"`
Wildcard bool `json:"wildcard"` Wildcard bool `json:"wildcard"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
@ -32,7 +31,7 @@ func (ba *dbAuthz) clone() *dbAuthz {
// getDBAuthz retrieves and unmarshals a database representation of the // getDBAuthz retrieves and unmarshals a database representation of the
// ACME Authorization type. // ACME Authorization type.
func (db *DB) getDBAuthz(_ context.Context, id string) (*dbAuthz, error) { func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) {
data, err := db.db.Get(authzTable, []byte(id)) data, err := db.db.Get(authzTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id) return nil, acme.NewError(acme.ErrorMalformedType, "authz %s not found", id)
@ -70,7 +69,6 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat
Wildcard: dbaz.Wildcard, Wildcard: dbaz.Wildcard,
ExpiresAt: dbaz.ExpiresAt, ExpiresAt: dbaz.ExpiresAt,
Token: dbaz.Token, Token: dbaz.Token,
Fingerprint: dbaz.Fingerprint,
Error: dbaz.Error, Error: dbaz.Error,
}, nil }, nil
} }
@ -99,7 +97,6 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e
Identifier: az.Identifier, Identifier: az.Identifier,
ChallengeIDs: chIDs, ChallengeIDs: chIDs,
Token: az.Token, Token: az.Token,
Fingerprint: az.Fingerprint,
Wildcard: az.Wildcard, Wildcard: az.Wildcard,
} }
@ -114,14 +111,14 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) e
} }
nu := old.clone() nu := old.clone()
nu.Status = az.Status nu.Status = az.Status
nu.Fingerprint = az.Fingerprint
nu.Error = az.Error nu.Error = az.Error
return db.save(ctx, old.ID, nu, old, "authz", authzTable) return db.save(ctx, old.ID, nu, old, "authz", authzTable)
} }
// GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database. // GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database.
func (db *DB) GetAuthorizationsByAccountID(_ context.Context, accountID string) ([]*acme.Authorization, error) { func (db *DB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*acme.Authorization, error) {
entries, err := db.db.List(authzTable) entries, err := db.db.List(authzTable)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error listing authz") return nil, errors.Wrapf(err, "error listing authz")
@ -147,7 +144,6 @@ func (db *DB) GetAuthorizationsByAccountID(_ context.Context, accountID string)
Wildcard: dbaz.Wildcard, Wildcard: dbaz.Wildcard,
ExpiresAt: dbaz.ExpiresAt, ExpiresAt: dbaz.ExpiresAt,
Token: dbaz.Token, Token: dbaz.Token,
Fingerprint: dbaz.Fingerprint,
Error: dbaz.Error, Error: dbaz.Error,
}) })
} }

View file

@ -473,7 +473,6 @@ func TestDB_UpdateAuthorization(t *testing.T) {
ExpiresAt: now.Add(5 * time.Minute), ExpiresAt: now.Add(5 * time.Minute),
ChallengeIDs: []string{"foo", "bar"}, ChallengeIDs: []string{"foo", "bar"},
Wildcard: true, Wildcard: true,
Fingerprint: "fingerprint",
} }
b, err := json.Marshal(dbaz) b, err := json.Marshal(dbaz)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -553,7 +552,6 @@ func TestDB_UpdateAuthorization(t *testing.T) {
Token: dbaz.Token, Token: dbaz.Token,
Wildcard: dbaz.Wildcard, Wildcard: dbaz.Wildcard,
ExpiresAt: dbaz.ExpiresAt, ExpiresAt: dbaz.ExpiresAt,
Fingerprint: "fingerprint",
Error: acme.NewError(acme.ErrorMalformedType, "malformed"), Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
} }
return test{ return test{
@ -584,7 +582,6 @@ func TestDB_UpdateAuthorization(t *testing.T) {
assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard)
assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt)
assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt)
assert.Equals(t, dbNew.Fingerprint, dbaz.Fingerprint)
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error()) assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "The request message was malformed").Error())
return nu, true, nil return nu, true, nil
}, },

View file

@ -69,7 +69,7 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err
// GetCertificate retrieves and unmarshals an ACME certificate type from the // GetCertificate retrieves and unmarshals an ACME certificate type from the
// datastore. // datastore.
func (db *DB) GetCertificate(_ context.Context, id string) (*acme.Certificate, error) { func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) {
b, err := db.db.Get(certTable, []byte(id)) b, err := db.db.Get(certTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id) return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id)

View file

@ -6,10 +6,8 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/nosql"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/nosql"
) )
type dbChallenge struct { type dbChallenge struct {
@ -21,7 +19,7 @@ type dbChallenge struct {
Value string `json:"value"` Value string `json:"value"`
ValidatedAt string `json:"validatedAt"` ValidatedAt string `json:"validatedAt"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
Error *acme.Error `json:"error"` // TODO(hs): a bit dangerous; should become db-specific type Error *acme.Error `json:"error"`
} }
func (dbc *dbChallenge) clone() *dbChallenge { func (dbc *dbChallenge) clone() *dbChallenge {
@ -29,7 +27,7 @@ func (dbc *dbChallenge) clone() *dbChallenge {
return &u return &u
} }
func (db *DB) getDBChallenge(_ context.Context, id string) (*dbChallenge, error) { func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) {
data, err := db.db.Get(challengeTable, []byte(id)) data, err := db.db.Get(challengeTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id) return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id)
@ -69,7 +67,6 @@ func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error {
// GetChallenge retrieves and unmarshals an ACME challenge type from the database. // GetChallenge retrieves and unmarshals an ACME challenge type from the database.
// Implements the acme.DB GetChallenge interface. // Implements the acme.DB GetChallenge interface.
func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) { func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) {
_ = authzID // unused input
dbch, err := db.getDBChallenge(ctx, id) dbch, err := db.getDBChallenge(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -35,7 +35,7 @@ type dbExternalAccountKeyReference struct {
} }
// getDBExternalAccountKey retrieves and unmarshals dbExternalAccountKey. // getDBExternalAccountKey retrieves and unmarshals dbExternalAccountKey.
func (db *DB) getDBExternalAccountKey(_ context.Context, id string) (*dbExternalAccountKey, error) { func (db *DB) getDBExternalAccountKey(ctx context.Context, id string) (*dbExternalAccountKey, error) {
data, err := db.db.Get(externalAccountKeyTable, []byte(id)) data, err := db.db.Get(externalAccountKeyTable, []byte(id))
if err != nil { if err != nil {
if nosqlDB.IsErrNotFound(err) { if nosqlDB.IsErrNotFound(err) {
@ -160,8 +160,6 @@ func (db *DB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID
// GetExternalAccountKeys retrieves all External Account Binding keys for a provisioner // GetExternalAccountKeys retrieves all External Account Binding keys for a provisioner
func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*acme.ExternalAccountKey, string, error) { func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*acme.ExternalAccountKey, string, error) {
_, _ = cursor, limit // unused input
externalAccountKeyMutex.RLock() externalAccountKeyMutex.RLock()
defer externalAccountKeyMutex.RUnlock() defer externalAccountKeyMutex.RUnlock()
@ -229,7 +227,7 @@ func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerI
return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID) return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID)
} }
func (db *DB) GetExternalAccountKeyByAccountID(context.Context, string, string) (*acme.ExternalAccountKey, error) { func (db *DB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
//nolint:nilnil // legacy //nolint:nilnil // legacy
return nil, nil return nil, nil
} }

View file

@ -39,7 +39,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
// DeleteNonce verifies that the nonce is valid (by checking if it exists), // DeleteNonce verifies that the nonce is valid (by checking if it exists),
// and if so, consumes the nonce resource by deleting it from the database. // and if so, consumes the nonce resource by deleting it from the database.
func (db *DB) DeleteNonce(_ context.Context, nonce acme.Nonce) error { func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error {
err := db.db.Update(&database.Tx{ err := db.db.Update(&database.Tx{
Operations: []*database.TxEntry{ Operations: []*database.TxEntry{
{ {

View file

@ -48,7 +48,7 @@ func New(db nosqlDB.DB) (*DB, error) {
// save writes the new data to the database, overwriting the old data if it // save writes the new data to the database, overwriting the old data if it
// existed. // existed.
func (db *DB) save(_ context.Context, id string, nu, old interface{}, typ string, table []byte) error { func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
var ( var (
err error err error
newB []byte newB []byte

View file

@ -35,7 +35,7 @@ func (a *dbOrder) clone() *dbOrder {
} }
// getDBOrder retrieves and unmarshals an ACME Order type from the database. // getDBOrder retrieves and unmarshals an ACME Order type from the database.
func (db *DB) getDBOrder(_ context.Context, id string) (*dbOrder, error) { func (db *DB) getDBOrder(ctx context.Context, id string) (*dbOrder, error) {
b, err := db.db.Get(orderTable, []byte(id)) b, err := db.db.Get(orderTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id) return nil, acme.NewError(acme.ErrorMalformedType, "order %s not found", id)

View file

@ -65,8 +65,6 @@ const (
ErrorUserActionRequiredType ErrorUserActionRequiredType
// ErrorNotImplementedType operation is not implemented // ErrorNotImplementedType operation is not implemented
ErrorNotImplementedType ErrorNotImplementedType
// ErrorNNSType was a problem with a NNS query during identifier validation
ErrorNNSType
) )
// String returns the string representation of the acme problem type, // String returns the string representation of the acme problem type,
@ -123,8 +121,6 @@ func (ap ProblemType) String() string {
return "userActionRequired" return "userActionRequired"
case ErrorNotImplementedType: case ErrorNotImplementedType:
return "notImplemented" return "notImplemented"
case ErrorNNSType:
return "nns"
default: default:
return fmt.Sprintf("unsupported type ACME error type '%d'", int(ap)) return fmt.Sprintf("unsupported type ACME error type '%d'", int(ap))
} }
@ -274,61 +270,21 @@ var (
} }
) )
// Error represents an ACME Error // Error represents an ACME
type Error struct { type Error struct {
Type string `json:"type"` Type string `json:"type"`
Detail string `json:"detail"` Detail string `json:"detail"`
Subproblems []Subproblem `json:"subproblems,omitempty"` Subproblems []interface{} `json:"subproblems,omitempty"`
Identifier interface{} `json:"identifier,omitempty"`
Err error `json:"-"` Err error `json:"-"`
Status int `json:"-"` Status int `json:"-"`
} }
// Subproblem represents an ACME subproblem. It's fairly
// similar to an ACME error, but differs in that it can't
// include subproblems itself, the error is reflected
// in the Detail property and doesn't have a Status.
type Subproblem struct {
Type string `json:"type"`
Detail string `json:"detail"`
// The "identifier" field MUST NOT be present at the top level in ACME
// problem documents. It can only be present in subproblems.
// Subproblems need not all have the same type, and they do not need to
// match the top level type.
Identifier *Identifier `json:"identifier,omitempty"`
}
// AddSubproblems adds the Subproblems to Error. It
// returns the Error, allowing for fluent addition.
func (e *Error) AddSubproblems(subproblems ...Subproblem) *Error {
e.Subproblems = append(e.Subproblems, subproblems...)
return e
}
// NewError creates a new Error type. // NewError creates a new Error type.
func NewError(pt ProblemType, msg string, args ...interface{}) *Error { func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
return newError(pt, errors.Errorf(msg, args...)) return newError(pt, errors.Errorf(msg, args...))
} }
// NewSubproblem creates a new Subproblem. The msg and args
// are used to create a new error, which is set as the Detail, allowing
// for more detailed error messages to be returned to the ACME client.
func NewSubproblem(pt ProblemType, msg string, args ...interface{}) Subproblem {
e := newError(pt, fmt.Errorf(msg, args...))
s := Subproblem{
Type: e.Type,
Detail: e.Err.Error(),
}
return s
}
// NewSubproblemWithIdentifier creates a new Subproblem with a specific ACME
// Identifier. It calls NewSubproblem and sets the Identifier.
func NewSubproblemWithIdentifier(pt ProblemType, identifier Identifier, msg string, args ...interface{}) Subproblem {
s := NewSubproblem(pt, msg, args...)
s.Identifier = &identifier
return s
}
func newError(pt ProblemType, err error) *Error { func newError(pt ProblemType, err error) *Error {
meta, ok := errorMap[pt] meta, ok := errorMap[pt]
if !ok { if !ok {

View file

@ -1,122 +0,0 @@
package acme
import (
"context"
"errors"
"fmt"
"net/url"
"git.frostfs.info/TrueCloudLab/frostfs-contract/nns"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/rpcclient"
"github.com/nspcc-dev/neo-go/pkg/rpcclient/invoker"
"github.com/nspcc-dev/neo-go/pkg/rpcclient/unwrap"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
)
// multiSchemeClient unites invoker.RPCInvoke and common interface of
// rpcclient.Client and rpcclient.WSClient.
type multiSchemeClient interface {
invoker.RPCInvoke
// Init turns client to "ready-to-work" state.
Init() error
// Close closes connections.
Close()
// GetContractStateByID returns state of the NNS contract on 1 input.
GetContractStateByID(int32) (*state.Contract, error)
}
// NNS is used to interact with NNS contract.
// Before work, the connection to the NNS server must be established using Dial method.
type NNS struct {
nnsContract util.Uint160
client multiSchemeClient
}
// NNSContext is used to store info about NNS server.
type NNSContext struct {
nnsServer string
}
type nnsKey struct{}
// NewNNSContext adds new NNSContext with given params to the context.
func NewNNSContext(ctx context.Context, nnsServer string) context.Context {
return context.WithValue(ctx, nnsKey{}, NNSContext{nnsServer: nnsServer})
}
// GetNNSContext returns NNSContext from the given context.
func GetNNSContext(ctx context.Context) (NNSContext, bool) {
c, ok := ctx.Value(nnsKey{}).(NNSContext)
return c, ok
}
// Dial connects to the address of the NNS server.
// If URL address scheme is 'ws' or 'wss', then WebSocket protocol is used, otherwise HTTP.
func (n *NNS) Dial(address string) error {
var err error
uri, err := url.Parse(address)
if err == nil && (uri.Scheme == "ws" || uri.Scheme == "wss") {
n.client, err = rpcclient.NewWS(context.Background(), address, rpcclient.WSOptions{})
if err != nil {
return fmt.Errorf("create Neo WebSocket client: %w", err)
}
} else {
n.client, err = rpcclient.New(context.Background(), address, rpcclient.Options{})
if err != nil {
return fmt.Errorf("create Neo HTTP client: %w", err)
}
}
if err = n.client.Init(); err != nil {
return fmt.Errorf("initialize Neo client: %w", err)
}
nnsContract, err := n.client.GetContractStateByID(1)
if err != nil {
return fmt.Errorf("get NNS contract state: %w", err)
}
n.nnsContract = nnsContract.Hash
return nil
}
// Close closes connections of multiSchemeClient.
func (n *NNS) Close() {
n.client.Close()
}
// GetTXTRecords returns TXT records of the provided domain by calling `getRecords` method of NNS contract.
func (n *NNS) GetTXTRecords(name string) ([]string, error) {
params, err := smartcontract.NewParametersFromValues(name, int64(nns.TXT))
if err != nil {
return make([]string, 0), fmt.Errorf("create slice of params: %w", err)
}
item, err := unwrap.Item(n.client.InvokeFunction(n.nnsContract, "getRecords", params, nil))
if err != nil {
return make([]string, 0), fmt.Errorf("contract invocation: %w", err)
}
if _, ok := item.(stackitem.Null); !ok {
arr, ok := item.Value().([]stackitem.Item)
if !ok {
return make([]string, 0), errors.New("invalid cast to stack item slice")
}
var result = make([]string, 0, len(arr))
for i := range arr {
recordValue, err := arr[i].TryBytes()
if err != nil {
return make([]string, 0), fmt.Errorf("convert array item to byte slice: %w", err)
}
result = append(result, string(recordValue))
}
return result, nil
}
return make([]string, 0), errors.New("records not found")
}

View file

@ -3,7 +3,6 @@ package acme
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/subtle"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"net" "net"
@ -12,7 +11,6 @@ import (
"time" "time"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
) )
@ -127,27 +125,6 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
return nil return nil
} }
// getKeyFingerprint returns a fingerprint from the list of authorizations. This
// fingerprint is used on the device-attest-01 flow to verify the attestation
// certificate public key with the CSR public key.
//
// There's no point on reading all the authorizations as there will be only one
// for a permanent identifier.
func (o *Order) getAuthorizationFingerprint(ctx context.Context, db DB) (string, error) {
for _, azID := range o.AuthorizationIDs {
az, err := db.GetAuthorization(ctx, azID)
if err != nil {
return "", WrapErrorISE(err, "error getting authorization %q", azID)
}
// There's no point on reading all the authorizations as there will
// be only one for a permanent identifier.
if az.Fingerprint != "" {
return az.Fingerprint, nil
}
}
return "", nil
}
// Finalize signs a certificate if the necessary conditions for Order completion // Finalize signs a certificate if the necessary conditions for Order completion
// have been met. // have been met.
// //
@ -173,24 +150,6 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID) return NewErrorISE("unexpected status %s for order %s", o.Status, o.ID)
} }
// Get key fingerprint if any. And then compare it with the CSR fingerprint.
//
// In device-attest-01 challenges we should check that the keys in the CSR
// and the attestation certificate are the same.
fingerprint, err := o.getAuthorizationFingerprint(ctx, db)
if err != nil {
return err
}
if fingerprint != "" {
fp, err := keyutil.Fingerprint(csr.PublicKey)
if err != nil {
return WrapErrorISE(err, "error calculating key fingerprint")
}
if subtle.ConstantTimeCompare([]byte(fingerprint), []byte(fp)) == 0 {
return NewError(ErrorUnauthorizedType, "order %s csr does not match the attested key", o.ID)
}
}
// canonicalize the CSR to allow for comparison // canonicalize the CSR to allow for comparison
csr = canonicalize(csr) csr = canonicalize(csr)
@ -206,15 +165,6 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
for i := range o.Identifiers { for i := range o.Identifiers {
if o.Identifiers[i].Type == PermanentIdentifier { if o.Identifiers[i].Type == PermanentIdentifier {
permanentIdentifier = o.Identifiers[i].Value permanentIdentifier = o.Identifiers[i].Value
// the first (and only) Permanent Identifier that gets added to the certificate
// should be equal to the Subject Common Name if it's set. If not equal, the CSR
// is rejected, because the Common Name hasn't been challenged in that case. This
// could result in unauthorized access if a relying system relies on the Common
// Name in its authorization logic.
if csr.Subject.CommonName != "" && csr.Subject.CommonName != permanentIdentifier {
return NewError(ErrorBadCSRType, "CSR Subject Common Name does not match identifiers exactly: "+
"CSR Subject Common Name = %s, Order Permanent Identifier = %s", csr.Subject.CommonName, permanentIdentifier)
}
break break
} }
} }

View file

@ -2,12 +2,9 @@ package acme
import ( import (
"context" "context"
"crypto"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"net/url" "net/url"
"reflect" "reflect"
@ -19,7 +16,6 @@ import (
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
) )
@ -301,7 +297,7 @@ func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface
return m.ret1.(provisioner.Interface), m.err return m.ret1.(provisioner.Interface), m.err
} }
func (m *mockSignAuth) IsRevoked(string) (bool, error) { func (m *mockSignAuth) IsRevoked(sn string) (bool, error) {
return false, nil return false, nil
} }
@ -310,14 +306,6 @@ func (m *mockSignAuth) Revoke(context.Context, *authority.RevokeOptions) error {
} }
func TestOrder_Finalize(t *testing.T) { func TestOrder_Finalize(t *testing.T) {
mustSigner := func(kty, crv string, size int) crypto.Signer {
s, err := keyutil.GenerateSigner(kty, crv, size)
if err != nil {
t.Fatal(err)
}
return s
}
type test struct { type test struct {
o *Order o *Order
err *Error err *Error
@ -398,72 +386,6 @@ func TestOrder_Finalize(t *testing.T) {
err: NewErrorISE("unrecognized order status: %s", o.Status), err: NewErrorISE("unrecognized order status: %s", o.Status),
} }
}, },
"fail/non-matching-permanent-identifier-common-name": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusReady,
ExpiresAt: now.Add(5 * time.Minute),
AuthorizationIDs: []string{"a", "b"},
Identifiers: []Identifier{
{Type: "permanent-identifier", Value: "a-permanent-identifier"},
},
}
signer := mustSigner("EC", "P-256", 0)
fingerprint, err := keyutil.Fingerprint(signer.Public())
if err != nil {
t.Fatal(err)
}
csr := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: "a-different-identifier",
},
PublicKey: signer.Public(),
ExtraExtensions: []pkix.Extension{
{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
Value: []byte("a-permanent-identifier"),
},
},
}
return test{
o: o,
csr: csr,
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
switch id {
case "a":
return &Authorization{
ID: id,
Status: StatusValid,
}, nil
case "b":
return &Authorization{
ID: id,
Fingerprint: fingerprint,
Status: StatusValid,
}, nil
default:
assert.FatalError(t, errors.Errorf("unexpected authorization %s", id))
return nil, errors.New("force")
}
},
MockUpdateOrder: func(ctx context.Context, o *Order) error {
return nil
},
},
err: &Error{
Type: "urn:ietf:params:acme:error:badCSR",
Detail: "The CSR is unacceptable",
Status: 400,
Err: fmt.Errorf("CSR Subject Common Name does not match identifiers exactly: "+
"CSR Subject Common Name = %s, Order Permanent Identifier = %s", csr.Subject.CommonName, "a-permanent-identifier"),
},
}
},
"fail/error-provisioner-auth": func(t *testing.T) test { "fail/error-provisioner-auth": func(t *testing.T) test {
now := clock.Now() now := clock.Now()
o := &Order{ o := &Order{
@ -493,11 +415,6 @@ func TestOrder_Finalize(t *testing.T) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{ID: id, Status: StatusValid}, nil
},
},
err: NewErrorISE("error retrieving authorization options from ACME provisioner: force"), err: NewErrorISE("error retrieving authorization options from ACME provisioner: force"),
} }
}, },
@ -537,11 +454,6 @@ func TestOrder_Finalize(t *testing.T) {
} }
}, },
}, },
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{ID: id, Status: StatusValid}, nil
},
},
err: NewErrorISE("error creating template options from ACME provisioner: error unmarshaling template data: invalid character 'o' in literal false (expecting 'a')"), err: NewErrorISE("error creating template options from ACME provisioner: error unmarshaling template data: invalid character 'o' in literal false (expecting 'a')"),
} }
}, },
@ -583,11 +495,6 @@ func TestOrder_Finalize(t *testing.T) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{ID: id, Status: StatusValid}, nil
},
},
err: NewErrorISE("error signing certificate for order oID: force"), err: NewErrorISE("error signing certificate for order oID: force"),
} }
}, },
@ -634,9 +541,6 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
db: &MockDB{ db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{ID: id, Status: StatusValid}, nil
},
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.AccountID, o.AccountID)
assert.Equals(t, cert.OrderID, o.ID) assert.Equals(t, cert.OrderID, o.ID)
@ -691,9 +595,6 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
db: &MockDB{ db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{ID: id, Status: StatusValid}, nil
},
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
cert.ID = "certID" cert.ID = "certID"
assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.AccountID, o.AccountID)
@ -716,297 +617,6 @@ func TestOrder_Finalize(t *testing.T) {
err: NewErrorISE("error updating order oID: force"), err: NewErrorISE("error updating order oID: force"),
} }
}, },
"fail/csr-fingerprint": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusReady,
ExpiresAt: now.Add(5 * time.Minute),
AuthorizationIDs: []string{"a", "b"},
Identifiers: []Identifier{
{Type: "permanent-identifier", Value: "a-permanent-identifier"},
},
}
signer := mustSigner("EC", "P-256", 0)
csr := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: "a-permanent-identifier",
},
PublicKey: signer.Public(),
ExtraExtensions: []pkix.Extension{
{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
Value: []byte("a-permanent-identifier"),
},
},
}
leaf := &x509.Certificate{
Subject: pkix.Name{CommonName: "a-permanent-identifier"},
PublicKey: signer.Public(),
ExtraExtensions: []pkix.Extension{
{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
Value: []byte("a-permanent-identifier"),
},
},
}
inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}}
root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}}
return test{
o: o,
csr: csr,
prov: &MockProvisioner{
MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
assert.Equals(t, token, "")
return nil, nil
},
MgetOptions: func() *provisioner.Options {
return nil
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil
},
},
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{
ID: id,
Fingerprint: "other-fingerprint",
Status: StatusValid,
}, nil
},
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
cert.ID = "certID"
assert.Equals(t, cert.AccountID, o.AccountID)
assert.Equals(t, cert.OrderID, o.ID)
assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
return nil
},
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
assert.Equals(t, updo.CertificateID, "certID")
assert.Equals(t, updo.Status, StatusValid)
assert.Equals(t, updo.ID, o.ID)
assert.Equals(t, updo.AccountID, o.AccountID)
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
assert.Equals(t, updo.Identifiers, o.Identifiers)
return nil
},
},
err: NewError(ErrorUnauthorizedType, "order oID csr does not match the attested key"),
}
},
"ok/permanent-identifier": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusReady,
ExpiresAt: now.Add(5 * time.Minute),
AuthorizationIDs: []string{"a", "b"},
Identifiers: []Identifier{
{Type: "permanent-identifier", Value: "a-permanent-identifier"},
},
}
signer := mustSigner("EC", "P-256", 0)
fingerprint, err := keyutil.Fingerprint(signer.Public())
if err != nil {
t.Fatal(err)
}
csr := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: "a-permanent-identifier",
},
PublicKey: signer.Public(),
ExtraExtensions: []pkix.Extension{
{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
Value: []byte("a-permanent-identifier"),
},
},
}
leaf := &x509.Certificate{
Subject: pkix.Name{CommonName: "a-permanent-identifier"},
PublicKey: signer.Public(),
ExtraExtensions: []pkix.Extension{
{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
Value: []byte("a-permanent-identifier"),
},
},
}
inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}}
root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}}
return test{
o: o,
csr: csr,
prov: &MockProvisioner{
MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
assert.Equals(t, token, "")
return nil, nil
},
MgetOptions: func() *provisioner.Options {
return nil
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil
},
},
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
switch id {
case "a":
return &Authorization{
ID: id,
Status: StatusValid,
}, nil
case "b":
return &Authorization{
ID: id,
Fingerprint: fingerprint,
Status: StatusValid,
}, nil
default:
assert.FatalError(t, errors.Errorf("unexpected authorization %s", id))
return nil, errors.New("force")
}
},
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
cert.ID = "certID"
assert.Equals(t, cert.AccountID, o.AccountID)
assert.Equals(t, cert.OrderID, o.ID)
assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
return nil
},
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
assert.Equals(t, updo.CertificateID, "certID")
assert.Equals(t, updo.Status, StatusValid)
assert.Equals(t, updo.ID, o.ID)
assert.Equals(t, updo.AccountID, o.AccountID)
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
assert.Equals(t, updo.Identifiers, o.Identifiers)
return nil
},
},
}
},
"ok/permanent-identifier-only": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusReady,
ExpiresAt: now.Add(5 * time.Minute),
AuthorizationIDs: []string{"a", "b"},
Identifiers: []Identifier{
{Type: "dns", Value: "foo.internal"},
{Type: "permanent-identifier", Value: "a-permanent-identifier"},
},
}
signer := mustSigner("EC", "P-256", 0)
fingerprint, err := keyutil.Fingerprint(signer.Public())
if err != nil {
t.Fatal(err)
}
csr := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: "a-permanent-identifier",
},
DNSNames: []string{"foo.internal"},
PublicKey: signer.Public(),
ExtraExtensions: []pkix.Extension{
{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
Value: []byte("a-permanent-identifier"),
},
},
}
leaf := &x509.Certificate{
Subject: pkix.Name{CommonName: "a-permanent-identifier"},
PublicKey: signer.Public(),
ExtraExtensions: []pkix.Extension{
{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 8, 3},
Value: []byte("a-permanent-identifier"),
},
},
}
inter := &x509.Certificate{Subject: pkix.Name{CommonName: "inter"}}
root := &x509.Certificate{Subject: pkix.Name{CommonName: "root"}}
return test{
o: o,
csr: csr,
prov: &MockProvisioner{
MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) {
assert.Equals(t, token, "")
return nil, nil
},
MgetOptions: func() *provisioner.Options {
return nil
},
},
// TODO(hs): we should work on making the mocks more realistic. Ideally, we should get rid of
// the mock entirely, relying on an instances of provisioner, authority and DB (possibly hardest), so
// that behavior of the tests is what an actual CA would do. We could gradually phase them out by
// using the mocking functions as a wrapper for actual test helpers generated per test case or per
// function that's tested.
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil
},
},
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{
ID: id,
Fingerprint: fingerprint,
Status: StatusValid,
}, nil
},
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
cert.ID = "certID"
assert.Equals(t, cert.AccountID, o.AccountID)
assert.Equals(t, cert.OrderID, o.ID)
assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
return nil
},
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
assert.Equals(t, updo.CertificateID, "certID")
assert.Equals(t, updo.Status, StatusValid)
assert.Equals(t, updo.ID, o.ID)
assert.Equals(t, updo.AccountID, o.AccountID)
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
assert.Equals(t, updo.AuthorizationIDs, o.AuthorizationIDs)
assert.Equals(t, updo.Identifiers, o.Identifiers)
return nil
},
},
}
},
"ok/new-cert-dns": func(t *testing.T) test { "ok/new-cert-dns": func(t *testing.T) test {
now := clock.Now() now := clock.Now()
o := &Order{ o := &Order{
@ -1050,9 +660,6 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
db: &MockDB{ db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{ID: id, Status: StatusValid}, nil
},
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
cert.ID = "certID" cert.ID = "certID"
assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.AccountID, o.AccountID)
@ -1114,9 +721,6 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
db: &MockDB{ db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{ID: id, Status: StatusValid}, nil
},
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
cert.ID = "certID" cert.ID = "certID"
assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.AccountID, o.AccountID)
@ -1181,9 +785,6 @@ func TestOrder_Finalize(t *testing.T) {
}, },
}, },
db: &MockDB{ db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{ID: id, Status: StatusValid}, nil
},
MockCreateCertificate: func(ctx context.Context, cert *Certificate) error { MockCreateCertificate: func(ctx context.Context, cert *Certificate) error {
cert.ID = "certID" cert.ID = "certID"
assert.Equals(t, cert.AccountID, o.AccountID) assert.Equals(t, cert.AccountID, o.AccountID)
@ -1891,55 +1492,3 @@ func TestOrder_sans(t *testing.T) {
}) })
} }
} }
func TestOrder_getAuthorizationFingerprint(t *testing.T) {
ctx := context.Background()
type fields struct {
AuthorizationIDs []string
}
type args struct {
ctx context.Context
db DB
}
tests := []struct {
name string
fields fields
args args
want string
wantErr bool
}{
{"ok", fields{[]string{"az1", "az2"}}, args{ctx, &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return &Authorization{ID: id, Status: StatusValid}, nil
},
}}, "", false},
{"ok fingerprint", fields{[]string{"az1", "az2"}}, args{ctx, &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
if id == "az1" {
return &Authorization{ID: id, Status: StatusValid}, nil
}
return &Authorization{ID: id, Fingerprint: "fingerprint", Status: StatusValid}, nil
},
}}, "fingerprint", false},
{"fail", fields{[]string{"az1", "az2"}}, args{ctx, &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
return nil, errors.New("force")
},
}}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := &Order{
AuthorizationIDs: tt.fields.AuthorizationIDs,
}
got, err := o.getAuthorizationFingerprint(tt.args.ctx, tt.args.db)
if (err != nil) != tt.wantErr {
t.Errorf("Order.getAuthorizationFingerprint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Order.getAuthorizationFingerprint() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,7 +1,6 @@
package api package api
import ( import (
"bytes"
"context" "context"
"crypto" "crypto"
"crypto/dsa" //nolint:staticcheck // support legacy algorithms "crypto/dsa" //nolint:staticcheck // support legacy algorithms
@ -21,8 +20,6 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/sshutil"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/log" "github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
@ -43,7 +40,6 @@ type Authority interface {
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error) LoadProvisionerByCertificate(*x509.Certificate) (provisioner.Interface, error)
LoadProvisionerByName(string) (provisioner.Interface, error) LoadProvisionerByName(string) (provisioner.Interface, error)
@ -53,7 +49,6 @@ type Authority interface {
GetRoots() ([]*x509.Certificate, error) GetRoots() ([]*x509.Certificate, error)
GetFederation() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error)
Version() authority.Version Version() authority.Version
GetCertificateRevocationList() ([]byte, error)
} }
// mustAuthority will be replaced on unit tests. // mustAuthority will be replaced on unit tests.
@ -227,39 +222,8 @@ type RootResponse struct {
// ProvisionersResponse is the response object that returns the list of // ProvisionersResponse is the response object that returns the list of
// provisioners. // provisioners.
type ProvisionersResponse struct { type ProvisionersResponse struct {
Provisioners provisioner.List Provisioners provisioner.List `json:"provisioners"`
NextCursor string
}
// MarshalJSON implements json.Marshaler. It marshals the ProvisionersResponse
// into a byte slice.
//
// Special treatment is given to the SCEP provisioner, as it contains a
// challenge secret that MUST NOT be leaked in (public) HTTP responses. The
// challenge value is thus redacted in HTTP responses.
func (p ProvisionersResponse) MarshalJSON() ([]byte, error) {
for _, item := range p.Provisioners {
scepProv, ok := item.(*provisioner.SCEP)
if !ok {
continue
}
old := scepProv.ChallengePassword
scepProv.ChallengePassword = "*** REDACTED ***"
defer func(p string) { //nolint:gocritic // defer in loop required to restore initial state of provisioners
scepProv.ChallengePassword = p
}(old)
}
var list = struct {
Provisioners []provisioner.Interface `json:"provisioners"`
NextCursor string `json:"nextCursor"` NextCursor string `json:"nextCursor"`
}{
Provisioners: []provisioner.Interface(p.Provisioners),
NextCursor: p.NextCursor,
}
return json.Marshal(list)
} }
// ProvisionerKeyResponse is the response object that returns the encrypted key // ProvisionerKeyResponse is the response object that returns the encrypted key
@ -291,7 +255,7 @@ func (h *caHandler) Route(r Router) {
// New creates a new RouterHandler with the CA endpoints. // New creates a new RouterHandler with the CA endpoints.
// //
// Deprecated: Use api.Route(r Router) // Deprecated: Use api.Route(r Router)
func New(Authority) RouterHandler { func New(auth Authority) RouterHandler {
return &caHandler{} return &caHandler{}
} }
@ -303,7 +267,6 @@ func Route(r Router) {
r.MethodFunc("POST", "/renew", Renew) r.MethodFunc("POST", "/renew", Renew)
r.MethodFunc("POST", "/rekey", Rekey) r.MethodFunc("POST", "/rekey", Rekey)
r.MethodFunc("POST", "/revoke", Revoke) r.MethodFunc("POST", "/revoke", Revoke)
r.MethodFunc("GET", "/crl", CRL)
r.MethodFunc("GET", "/provisioners", Provisioners) r.MethodFunc("GET", "/provisioners", Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey) r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
r.MethodFunc("GET", "/roots", Roots) r.MethodFunc("GET", "/roots", Roots)
@ -338,7 +301,7 @@ func Version(w http.ResponseWriter, r *http.Request) {
} }
// Health is an HTTP handler that returns the status of the server. // Health is an HTTP handler that returns the status of the server.
func Health(w http.ResponseWriter, _ *http.Request) { func Health(w http.ResponseWriter, r *http.Request) {
render.JSON(w, HealthResponse{Status: "ok"}) render.JSON(w, HealthResponse{Status: "ok"})
} }
@ -472,7 +435,7 @@ func logOtt(w http.ResponseWriter, token string) {
} }
} }
// LogCertificate adds certificate fields to the log message. // LogCertificate add certificate fields to the log message.
func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) { func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
if rl, ok := w.(logging.ResponseLogger); ok { if rl, ok := w.(logging.ResponseLogger); ok {
m := map[string]interface{}{ m := map[string]interface{}{
@ -504,41 +467,6 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
} }
} }
// LogSSHCertificate adds SSH certificate fields to the log message.
func LogSSHCertificate(w http.ResponseWriter, cert *ssh.Certificate) {
if rl, ok := w.(logging.ResponseLogger); ok {
mak := bytes.TrimSpace(ssh.MarshalAuthorizedKey(cert))
var certificate string
parts := strings.Split(string(mak), " ")
if len(parts) > 1 {
certificate = parts[1]
}
var userOrHost string
if cert.CertType == ssh.HostCert {
userOrHost = "host"
} else {
userOrHost = "user"
}
certificateType := fmt.Sprintf("%s %s certificate", parts[0], userOrHost) // e.g. ecdsa-sha2-nistp256-cert-v01@openssh.com user certificate
m := map[string]interface{}{
"serial": cert.Serial,
"principals": cert.ValidPrincipals,
"valid-from": time.Unix(int64(cert.ValidAfter), 0).Format(time.RFC3339),
"valid-to": time.Unix(int64(cert.ValidBefore), 0).Format(time.RFC3339),
"certificate": certificate,
"certificate-type": certificateType,
}
fingerprint, err := sshutil.FormatFingerprint(mak, sshutil.DefaultFingerprint)
if err == nil {
fpParts := strings.Split(fingerprint, " ")
if len(fpParts) > 3 {
m["public-key"] = fmt.Sprintf("%s %s", fpParts[1], fpParts[len(fpParts)-1])
}
}
rl.WithFields(m)
}
}
// ParseCursor parses the cursor and limit from the request query params. // ParseCursor parses the cursor and limit from the request query params.
func ParseCursor(r *http.Request) (cursor string, limit int, err error) { func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
q := r.URL.Query() q := r.URL.Query()

View file

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"crypto" "crypto"
"crypto/dsa" //nolint:staticcheck // support legacy algorithms "crypto/dsa" //nolint
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/elliptic" "crypto/elliptic"
@ -28,15 +28,12 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
sassert "github.com/stretchr/testify/assert" "golang.org/x/crypto/ssh"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
squarejose "gopkg.in/square/go-jose.v2"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -195,7 +192,6 @@ type mockAuthority struct {
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
loadProvisionerByName func(name string) (provisioner.Interface, error) loadProvisionerByName func(name string) (provisioner.Interface, error)
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
@ -203,7 +199,6 @@ type mockAuthority struct {
getEncryptedKey func(kid string) (string, error) getEncryptedKey func(kid string) (string, error)
getRoots func() ([]*x509.Certificate, error) getRoots func() ([]*x509.Certificate, error)
getFederation func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error)
getCRL func() ([]byte, error)
signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
@ -217,14 +212,6 @@ type mockAuthority struct {
version func() authority.Version version func() authority.Version
} }
func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) {
if m.getCRL != nil {
return m.getCRL()
}
return m.ret1.([]byte), m.err
}
// TODO: remove once Authorize is deprecated. // TODO: remove once Authorize is deprecated.
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
if m.authorize != nil { if m.authorize != nil {
@ -268,13 +255,6 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
} }
func (m *mockAuthority) RenewContext(ctx context.Context, oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
if m.renewContext != nil {
return m.renewContext(ctx, oldcert, pk)
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
if m.rekey != nil { if m.rekey != nil {
return m.rekey(oldcert, pk) return m.rekey(oldcert, pk)
@ -792,45 +772,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) (
return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err
} }
func Test_CRLGeneration(t *testing.T) {
tests := []struct {
name string
err error
statusCode int
expected []byte
}{
{"empty", nil, http.StatusOK, nil},
}
chiCtx := chi.NewRouteContext()
req := httptest.NewRequest("GET", "http://example.com/crl", nil)
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockMustAuthority(t, &mockAuthority{ret1: tt.expected, err: tt.err})
w := httptest.NewRecorder()
CRL(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.CRL StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err)
}
if tt.statusCode == 200 {
if !bytes.Equal(bytes.TrimSpace(body), tt.expected) {
t.Errorf("caHandler.Root CRL = %s, wants %s", body, tt.expected)
}
}
})
}
}
func Test_caHandler_Route(t *testing.T) { func Test_caHandler_Route(t *testing.T) {
type fields struct { type fields struct {
Authority Authority Authority Authority
@ -1567,122 +1508,3 @@ func mustCertificate(t *testing.T, pub, priv interface{}) *x509.Certificate {
} }
return cert return cert
} }
func TestProvisionersResponse_MarshalJSON(t *testing.T) {
k := map[string]any{
"use": "sig",
"kty": "EC",
"kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc",
"crv": "P-256",
"alg": "ES256",
"x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8",
"y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y",
}
key := squarejose.JSONWebKey{}
b, err := json.Marshal(k)
assert.FatalError(t, err)
err = json.Unmarshal(b, &key)
assert.FatalError(t, err)
r := ProvisionersResponse{
Provisioners: provisioner.List{
&provisioner.SCEP{
Name: "scep",
Type: "scep",
ChallengePassword: "not-so-secret",
MinimumPublicKeyLength: 2048,
EncryptionAlgorithmIdentifier: 2,
},
&provisioner.JWK{
EncryptedKey: "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg",
Key: &key,
Name: "step-cli",
Type: "JWK",
},
},
NextCursor: "next",
}
expected := map[string]any{
"provisioners": []map[string]any{
{
"type": "scep",
"name": "scep",
"challenge": "*** REDACTED ***",
"minimumPublicKeyLength": 2048,
"encryptionAlgorithmIdentifier": 2,
},
{
"type": "JWK",
"name": "step-cli",
"key": map[string]any{
"use": "sig",
"kty": "EC",
"kid": "4UELJx8e0aS9m0CH3fZ0EB7D5aUPICb759zALHFejvc",
"crv": "P-256",
"alg": "ES256",
"x": "7ZdAAMZCFU4XwgblI5RfZouBi8lYmF6DlZusNNnsbm8",
"y": "sQr2JdzwD2fgyrymBEXWsxDxFNjjqN64qLLSbLdLZ9Y",
},
"encryptedKey": "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg",
},
},
"nextCursor": "next",
}
expBytes, err := json.Marshal(expected)
sassert.NoError(t, err)
br, err := r.MarshalJSON()
sassert.NoError(t, err)
sassert.JSONEq(t, string(expBytes), string(br))
keyCopy := key
expList := provisioner.List{
&provisioner.SCEP{
Name: "scep",
Type: "scep",
ChallengePassword: "not-so-secret",
MinimumPublicKeyLength: 2048,
EncryptionAlgorithmIdentifier: 2,
},
&provisioner.JWK{
EncryptedKey: "eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJlbmMiOiJBMTI4R0NNIiwicDJjIjoxMDAwMDAsInAycyI6IlhOdmYxQjgxSUlLMFA2NUkwcmtGTGcifQ.XaN9zcPQeWt49zchUDm34FECUTHfQTn_.tmNHPQDqR3ebsWfd.9WZr3YVdeOyJh36vvx0VlRtluhvYp4K7jJ1KGDr1qypwZ3ziBVSNbYYQ71du7fTtrnfG1wgGTVR39tWSzBU-zwQ5hdV3rpMAaEbod5zeW6SHd95H3Bvcb43YiiqJFNL5sGZzFb7FqzVmpsZ1efiv6sZaGDHtnCAL6r12UG5EZuqGfM0jGCZitUz2m9TUKXJL5DJ7MOYbFfkCEsUBPDm_TInliSVn2kMJhFa0VOe5wZk5YOuYM3lNYW64HGtbf-llN2Xk-4O9TfeSPizBx9ZqGpeu8pz13efUDT2WL9tWo6-0UE-CrG0bScm8lFTncTkHcu49_a5NaUBkYlBjEiw.thPcx3t1AUcWuEygXIY3Fg",
Key: &keyCopy,
Name: "step-cli",
Type: "JWK",
},
}
// MarshalJSON must not affect the struct properties itself
sassert.Equal(t, expList, r.Provisioners)
}
const (
fixtureECDSACertificate = `ecdsa-sha2-nistp256-cert-v01@openssh.com AAAAKGVjZHNhLXNoYTItbmlzdHAyNTYtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLnkvSk4odlo3b1R+RDw+LmorL3RkN354IilCIVFVen4AAAAIbmlzdHAyNTYAAABBBHjKHss8WM2ffMYlavisoLXR0I6UEIU+cidV1ogEH1U6+/SYaFPrlzQo0tGLM5CNkMbhInbyasQsrHzn8F1Rt7nHg5/tcSf9qwAAAAEAAAAGaGVybWFuAAAACgAAAAZoZXJtYW4AAAAAY8kvJwAAAABjyhBjAAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAGgAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAAhuaXN0cDI1NgAAAEEE/ayqpPrZZF5uA1UlDt4FreTf15agztQIzpxnWq/XoxAHzagRSkFGkdgFpjgsfiRpP8URHH3BZScqc0ZDCTxhoQAAAGQAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAEkAAAAhAJuP1wCVwoyrKrEtHGfFXrVbRHySDjvXtS1tVTdHyqymAAAAIBa/CSSzfZb4D2NLP+eEmOOMJwSjYOiNM8fiOoAaqglI herman`
)
func TestLogSSHCertificate(t *testing.T) {
out, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fixtureECDSACertificate))
require.NoError(t, err)
cert, ok := out.(*ssh.Certificate)
require.True(t, ok)
w := httptest.NewRecorder()
rl := logging.NewResponseLogger(w)
LogSSHCertificate(rl, cert)
sassert.Equal(t, 200, w.Result().StatusCode)
fields := rl.Fields()
sassert.Equal(t, uint64(14376510277651266987), fields["serial"])
sassert.Equal(t, []string{"herman"}, fields["principals"])
sassert.Equal(t, "ecdsa-sha2-nistp256-cert-v01@openssh.com user certificate", fields["certificate-type"])
sassert.Equal(t, time.Unix(1674129191, 0).Format(time.RFC3339), fields["valid-from"])
sassert.Equal(t, time.Unix(1674186851, 0).Format(time.RFC3339), fields["valid-to"])
sassert.Equal(t, "AAAAKGVjZHNhLXNoYTItbmlzdHAyNTYtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLnkvSk4odlo3b1R+RDw+LmorL3RkN354IilCIVFVen4AAAAIbmlzdHAyNTYAAABBBHjKHss8WM2ffMYlavisoLXR0I6UEIU+cidV1ogEH1U6+/SYaFPrlzQo0tGLM5CNkMbhInbyasQsrHzn8F1Rt7nHg5/tcSf9qwAAAAEAAAAGaGVybWFuAAAACgAAAAZoZXJtYW4AAAAAY8kvJwAAAABjyhBjAAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAGgAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAAhuaXN0cDI1NgAAAEEE/ayqpPrZZF5uA1UlDt4FreTf15agztQIzpxnWq/XoxAHzagRSkFGkdgFpjgsfiRpP8URHH3BZScqc0ZDCTxhoQAAAGQAAAATZWNkc2Etc2hhMi1uaXN0cDI1NgAAAEkAAAAhAJuP1wCVwoyrKrEtHGfFXrVbRHySDjvXtS1tVTdHyqymAAAAIBa/CSSzfZb4D2NLP+eEmOOMJwSjYOiNM8fiOoAaqglI", fields["certificate"])
sassert.Equal(t, "SHA256:RvkDPGwl/G9d7LUFm1kmWhvOD9I/moPq4yxcb0STwr0 (ECDSA-CERT)", fields["public-key"])
}

View file

@ -1,32 +0,0 @@
package api
import (
"encoding/pem"
"net/http"
"github.com/smallstep/certificates/api/render"
)
// CRL is an HTTP handler that returns the current CRL in DER or PEM format
func CRL(w http.ResponseWriter, r *http.Request) {
crlBytes, err := mustAuthority(r.Context()).GetCertificateRevocationList()
if err != nil {
render.Error(w, err)
return
}
_, formatAsPEM := r.URL.Query()["pem"]
if formatAsPEM {
w.Header().Add("Content-Type", "application/x-pem-file")
w.Header().Add("Content-Disposition", "attachment; filename=\"crl.pem\"")
_ = pem.Encode(w, &pem.Block{
Type: "X509 CRL",
Bytes: crlBytes,
})
} else {
w.Header().Add("Content-Type", "application/pkix-crl")
w.Header().Add("Content-Disposition", "attachment; filename=\"crl.der\"")
w.Write(crlBytes)
}
}

View file

@ -7,6 +7,8 @@ import (
"os" "os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/logging"
) )
// StackTracedError is the set of errors implementing the StackTrace function. // StackTracedError is the set of errors implementing the StackTrace function.
@ -19,21 +21,16 @@ type StackTracedError interface {
StackTrace() errors.StackTrace StackTrace() errors.StackTrace
} }
type fieldCarrier interface {
WithFields(map[string]any)
Fields() map[string]any
}
// Error adds to the response writer the given error if it implements // Error adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error // logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package. // using the log package.
func Error(rw http.ResponseWriter, err error) { func Error(rw http.ResponseWriter, err error) {
fc, ok := rw.(fieldCarrier) rl, ok := rw.(logging.ResponseLogger)
if !ok { if !ok {
return return
} }
fc.WithFields(map[string]any{ rl.WithFields(map[string]interface{}{
"error": err, "error": err,
}) })
@ -42,8 +39,8 @@ func Error(rw http.ResponseWriter, err error) {
} }
var st StackTracedError var st StackTracedError
if errors.As(err, &st) { if !errors.As(err, &st) {
fc.WithFields(map[string]any{ rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", st.StackTrace()), "stack-trace": fmt.Sprintf("%+v", st.StackTrace()),
}) })
} }
@ -51,9 +48,9 @@ func Error(rw http.ResponseWriter, err error) {
// EnabledResponse log the response object if it implements the EnableLogger // EnabledResponse log the response object if it implements the EnableLogger
// interface. // interface.
func EnabledResponse(rw http.ResponseWriter, v any) { func EnabledResponse(rw http.ResponseWriter, v interface{}) {
type enableLogger interface { type enableLogger interface {
ToLog() (any, error) ToLog() (interface{}, error)
} }
if el, ok := v.(enableLogger); ok { if el, ok := v.(enableLogger); ok {
@ -64,8 +61,8 @@ func EnabledResponse(rw http.ResponseWriter, v any) {
return return
} }
if rl, ok := rw.(fieldCarrier); ok { if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]any{ rl.WithFields(map[string]interface{}{
"response": out, "response": out,
}) })
} }

View file

@ -1,78 +1,43 @@
package log package log
import ( import (
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"testing" "testing"
"unsafe"
pkgerrors "github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
type stackTracedError struct{}
func (stackTracedError) Error() string {
return "a stacktraced error"
}
func (stackTracedError) StackTrace() pkgerrors.StackTrace {
f := struct{}{}
return pkgerrors.StackTrace{ // fake stacktrace
pkgerrors.Frame(unsafe.Pointer(&f)),
pkgerrors.Frame(unsafe.Pointer(&f)),
}
}
func TestError(t *testing.T) { func TestError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct { tests := []struct {
name string name string
error args args
rw http.ResponseWriter withFields bool
isFieldCarrier bool
stepDebug bool
expectStackTrace bool
}{ }{
{"noLogger", nil, nil, false, false, false}, {"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false}, {"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
{"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
{"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false},
{"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
{"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
{"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if tt.stepDebug { Error(tt.args.rw, tt.args.err)
t.Setenv("STEPDEBUG", "1") if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else { } else {
t.Setenv("STEPDEBUG", "0") t.Error("ResponseWriter does not implement logging.ResponseLogger")
} }
Error(tt.rw, tt.error)
// return early if test case doesn't use logger
if !tt.isFieldCarrier {
return
}
fields := tt.rw.(logging.ResponseLogger).Fields()
// expect the error field to be (not) set and to be the same error that was fed to Error
if tt.error == nil {
assert.Nil(t, fields["error"])
} else {
assert.Same(t, tt.error, fields["error"])
}
// check if stack-trace is set when expected
if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace {
t.Error(`ResponseLogger["stack-trace"] not set`)
} else if !tt.expectStackTrace && hasStackTrace {
t.Error(`ResponseLogger["stack-trace"] was set`)
} }
}) })
} }

View file

@ -2,6 +2,7 @@
package render package render
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"net/http" "net/http"
@ -23,25 +24,14 @@ func JSON(w http.ResponseWriter, v interface{}) {
// JSONStatus sets the Content-Type of w to application/json unless one is // JSONStatus sets the Content-Type of w to application/json unless one is
// specified. // specified.
func JSONStatus(w http.ResponseWriter, v interface{}, status int) { func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(v); err != nil {
panic(err)
}
setContentTypeUnlessPresent(w, "application/json") setContentTypeUnlessPresent(w, "application/json")
w.WriteHeader(status) w.WriteHeader(status)
_, _ = b.WriteTo(w)
if err := json.NewEncoder(w).Encode(v); err != nil {
var errUnsupportedType *json.UnsupportedTypeError
if errors.As(err, &errUnsupportedType) {
panic(err)
}
var errUnsupportedValue *json.UnsupportedValueError
if errors.As(err, &errUnsupportedValue) {
panic(err)
}
var errMarshalError *json.MarshalerError
if errors.As(err, &errMarshalError) {
panic(err)
}
}
log.EnabledResponse(w, v) log.EnabledResponse(w, v)
} }

View file

@ -1,10 +1,8 @@
package render package render
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"math"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv" "strconv"
@ -28,43 +26,10 @@ func TestJSON(t *testing.T) {
assert.Empty(t, rw.Fields()) assert.Empty(t, rw.Fields())
} }
func TestJSONPanicsOnUnsupportedType(t *testing.T) { func TestJSONPanics(t *testing.T) {
jsonPanicTest[json.UnsupportedTypeError](t, make(chan struct{})) assert.Panics(t, func() {
} JSON(httptest.NewRecorder(), make(chan struct{}))
})
func TestJSONPanicsOnUnsupportedValue(t *testing.T) {
jsonPanicTest[json.UnsupportedValueError](t, math.NaN())
}
func TestJSONPanicsOnMarshalerError(t *testing.T) {
var v erroneousJSONMarshaler
jsonPanicTest[json.MarshalerError](t, v)
}
type erroneousJSONMarshaler struct{}
func (erroneousJSONMarshaler) MarshalJSON() ([]byte, error) {
return nil, assert.AnError
}
func jsonPanicTest[T json.UnsupportedTypeError | json.UnsupportedValueError | json.MarshalerError](t *testing.T, v any) {
t.Helper()
defer func() {
var err error
if r := recover(); r == nil {
t.Fatal("expected panic")
} else if e, ok := r.(error); !ok {
t.Fatalf("did not panic with an error (%T)", r)
} else {
err = e
}
var e *T
assert.ErrorAs(t, err, &e)
}()
JSON(httptest.NewRecorder(), v)
} }
type renderableError struct { type renderableError struct {

View file

@ -6,7 +6,6 @@ import (
"strings" "strings"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -18,22 +17,14 @@ const (
// Renew uses the information of certificate in the TLS connection to create a // Renew uses the information of certificate in the TLS connection to create a
// new one. // new one.
func Renew(w http.ResponseWriter, r *http.Request) { func Renew(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() cert, err := getPeerCertificate(r)
// Get the leaf certificate from the peer or the token.
cert, token, err := getPeerCertificate(r)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
// The token can be used by RAs to renew a certificate. a := mustAuthority(r.Context())
if token != "" { certChain, err := a.Renew(cert)
ctx = authority.NewTokenContext(ctx, token)
}
a := mustAuthority(ctx)
certChain, err := a.RenewContext(ctx, cert, nil)
if err != nil { if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return return
@ -53,16 +44,15 @@ func Renew(w http.ResponseWriter, r *http.Request) {
}, http.StatusCreated) }, http.StatusCreated)
} }
func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) { func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], "", nil return r.TLS.PeerCertificates[0], nil
} }
if s := r.Header.Get(authorizationHeader); s != "" { if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
ctx := r.Context() ctx := r.Context()
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
return peer, parts[1], err
} }
} }
return nil, "", errs.BadRequest("missing client certificate") return nil, errs.BadRequest("missing client certificate")
} }

View file

@ -88,7 +88,6 @@ func Sign(w http.ResponseWriter, r *http.Request) {
if len(certChainPEM) > 1 { if len(certChainPEM) > 1 {
caPEM = certChainPEM[1] caPEM = certChainPEM[1]
} }
LogCertificate(w, certChain[0]) LogCertificate(w, certChain[0])
render.JSONStatus(w, &SignResponse{ render.JSONStatus(w, &SignResponse{
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],

View file

@ -338,7 +338,6 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
identityCertificate = certChainToPEM(certChain) identityCertificate = certChainToPEM(certChain)
} }
LogSSHCertificate(w, cert)
render.JSONStatus(w, &SSHSignResponse{ render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{cert}, Certificate: SSHCertificate{cert},
AddUserCertificate: addUserCertificate, AddUserCertificate: addUserCertificate,

View file

@ -89,7 +89,6 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
return return
} }
LogSSHCertificate(w, newCert)
render.JSONStatus(w, &SSHRekeyResponse{ render.JSONStatus(w, &SSHRekeyResponse{
Certificate: SSHCertificate{newCert}, Certificate: SSHCertificate{newCert},
IdentityCertificate: identity, IdentityCertificate: identity,

View file

@ -81,7 +81,6 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
return return
} }
LogSSHCertificate(w, newCert)
render.JSONStatus(w, &SSHSignResponse{ render.JSONStatus(w, &SSHSignResponse{
Certificate: SSHCertificate{newCert}, Certificate: SSHCertificate{newCert},
IdentityCertificate: identity, IdentityCertificate: identity,

View file

@ -69,17 +69,17 @@ func NewACMEAdminResponder() ACMEAdminResponder {
} }
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint // GetExternalAccountKeys writes the response for the EAB keys GET endpoint
func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, _ *http.Request) { func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// CreateExternalAccountKey writes the response for the EAB key POST endpoint // CreateExternalAccountKey writes the response for the EAB key POST endpoint
func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, _ *http.Request) { func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, _ *http.Request) { func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }

View file

@ -57,9 +57,9 @@ func validateWebhook(webhook *linkedca.Webhook) error {
// kind // kind
switch webhook.Kind { switch webhook.Kind {
case linkedca.Webhook_ENRICHING, linkedca.Webhook_AUTHORIZING, linkedca.Webhook_SCEPCHALLENGE: case linkedca.Webhook_ENRICHING, linkedca.Webhook_AUTHORIZING:
default: default:
return admin.NewError(admin.ErrorBadRequestType, "webhook kind %q is invalid", webhook.Kind) return admin.NewError(admin.ErrorBadRequestType, "webhook kind is invalid")
} }
return nil return nil

View file

@ -180,26 +180,6 @@ func TestWebhookAdminResponder_CreateProvisionerWebhook(t *testing.T) {
statusCode: 400, statusCode: 400,
} }
}, },
"fail/unsupported-webhook-kind": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Name: "provName",
}
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
adminErr := admin.NewError(admin.ErrorBadRequestType, `(line 5:13): invalid value for enum type: "UNSUPPORTED"`)
adminErr.Message = `(line 5:13): invalid value for enum type: "UNSUPPORTED"`
body := []byte(`
{
"name": "metadata",
"url": "https://example.com",
"kind": "UNSUPPORTED",
}`)
return test{
ctx: ctx,
body: body,
err: adminErr,
statusCode: 400,
}
},
"fail/auth.UpdateProvisioner-error": func(t *testing.T) test { "fail/auth.UpdateProvisioner-error": func(t *testing.T) test {
adm := &linkedca.Admin{ adm := &linkedca.Admin{
Subject: "step", Subject: "step",

View file

@ -40,7 +40,7 @@ func (dba *dbAdmin) clone() *dbAdmin {
return &u return &u
} }
func (db *DB) getDBAdminBytes(_ context.Context, id string) ([]byte, error) { func (db *DB) getDBAdminBytes(ctx context.Context, id string) ([]byte, error) {
data, err := db.db.Get(adminsTable, []byte(id)) data, err := db.db.Get(adminsTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id) return nil, admin.NewError(admin.ErrorNotFoundType, "admin %s not found", id)
@ -102,7 +102,7 @@ func (db *DB) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error)
// GetAdmins retrieves and unmarshals all active (not deleted) admins // GetAdmins retrieves and unmarshals all active (not deleted) admins
// from the database. // from the database.
// TODO should we be paginating? // TODO should we be paginating?
func (db *DB) GetAdmins(context.Context) ([]*linkedca.Admin, error) { func (db *DB) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
dbEntries, err := db.db.List(adminsTable) dbEntries, err := db.db.List(adminsTable)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error loading admins") return nil, errors.Wrap(err, "error loading admins")
@ -115,11 +115,13 @@ func (db *DB) GetAdmins(context.Context) ([]*linkedca.Admin, error) {
if errors.As(err, &ae) { if errors.As(err, &ae) {
if ae.IsType(admin.ErrorDeletedType) || ae.IsType(admin.ErrorAuthorityMismatchType) { if ae.IsType(admin.ErrorDeletedType) || ae.IsType(admin.ErrorAuthorityMismatchType) {
continue continue
} } else {
return nil, err return nil, err
} }
} else {
return nil, err return nil, err
} }
}
if adm.AuthorityId != db.authorityID { if adm.AuthorityId != db.authorityID {
continue continue
} }

View file

@ -36,7 +36,7 @@ func New(db nosqlDB.DB, authorityID string) (*DB, error) {
// save writes the new data to the database, overwriting the old data if it // save writes the new data to the database, overwriting the old data if it
// existed. // existed.
func (db *DB) save(_ context.Context, id string, nu, old interface{}, typ string, table []byte) error { func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
var ( var (
err error err error
newB []byte newB []byte

View file

@ -71,7 +71,7 @@ func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy {
return dbToLinked(dbap.Policy) return dbToLinked(dbap.Policy)
} }
func (db *DB) getDBAuthorityPolicyBytes(_ context.Context, authorityID string) ([]byte, error) { func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) {
data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID)) data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found") return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found")

View file

@ -70,7 +70,7 @@ func (dbp *dbProvisioner) convert2linkedca() (*linkedca.Provisioner, error) {
}, nil }, nil
} }
func (db *DB) getDBProvisionerBytes(_ context.Context, id string) ([]byte, error) { func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) {
data, err := db.db.Get(provisionersTable, []byte(id)) data, err := db.db.Get(provisionersTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id) return nil, admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found", id)
@ -132,7 +132,7 @@ func (db *DB) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisio
// GetProvisioners retrieves and unmarshals all active (not deleted) provisioners // GetProvisioners retrieves and unmarshals all active (not deleted) provisioners
// from the database. // from the database.
func (db *DB) GetProvisioners(_ context.Context) ([]*linkedca.Provisioner, error) { func (db *DB) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
dbEntries, err := db.db.List(provisionersTable) dbEntries, err := db.db.List(provisionersTable)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error loading provisioners") return nil, errors.Wrap(err, "error loading provisioners")
@ -145,11 +145,13 @@ func (db *DB) GetProvisioners(_ context.Context) ([]*linkedca.Provisioner, error
if errors.As(err, &ae) { if errors.As(err, &ae) {
if ae.IsType(admin.ErrorDeletedType) || ae.IsType(admin.ErrorAuthorityMismatchType) { if ae.IsType(admin.ErrorDeletedType) || ae.IsType(admin.ErrorAuthorityMismatchType) {
continue continue
} } else {
return nil, err return nil, err
} }
} else {
return nil, err return nil, err
} }
}
if prov.AuthorityId != db.authorityID { if prov.AuthorityId != db.authorityID {
continue continue
} }

View file

@ -73,12 +73,7 @@ type Authority struct {
sshCAUserFederatedCerts []ssh.PublicKey sshCAUserFederatedCerts []ssh.PublicKey
sshCAHostFederatedCerts []ssh.PublicKey sshCAHostFederatedCerts []ssh.PublicKey
// CRL vars // Do not re-initialize
crlTicker *time.Ticker
crlStopper chan struct{}
crlMutex sync.Mutex
// If true, do not re-initialize
initOnce bool initOnce bool
startTime time.Time startTime time.Time
@ -96,11 +91,8 @@ type Authority struct {
adminMutex sync.RWMutex adminMutex sync.RWMutex
// If true, do not initialize the authority // Do Not initialize the authority
skipInit bool skipInit bool
// If true, do not output initialization logs
quietInit bool
} }
// Info contains information about the authority. // Info contains information about the authority.
@ -413,13 +405,13 @@ func (a *Authority) init() error {
// Read root certificates and store them in the certificates map. // Read root certificates and store them in the certificates map.
if len(a.rootX509Certs) == 0 { if len(a.rootX509Certs) == 0 {
a.rootX509Certs = make([]*x509.Certificate, 0, len(a.config.Root)) a.rootX509Certs = make([]*x509.Certificate, len(a.config.Root))
for _, path := range a.config.Root { for i, path := range a.config.Root {
crts, err := pemutil.ReadCertificateBundle(path) crt, err := pemutil.ReadCertificate(path)
if err != nil { if err != nil {
return err return err
} }
a.rootX509Certs = append(a.rootX509Certs, crts...) a.rootX509Certs[i] = crt
} }
} }
for _, crt := range a.rootX509Certs { for _, crt := range a.rootX509Certs {
@ -434,13 +426,13 @@ func (a *Authority) init() error {
// Read federated certificates and store them in the certificates map. // Read federated certificates and store them in the certificates map.
if len(a.federatedX509Certs) == 0 { if len(a.federatedX509Certs) == 0 {
a.federatedX509Certs = make([]*x509.Certificate, 0, len(a.config.FederatedRoots)) a.federatedX509Certs = make([]*x509.Certificate, len(a.config.FederatedRoots))
for _, path := range a.config.FederatedRoots { for i, path := range a.config.FederatedRoots {
crts, err := pemutil.ReadCertificateBundle(path) crt, err := pemutil.ReadCertificate(path)
if err != nil { if err != nil {
return err return err
} }
a.federatedX509Certs = append(a.federatedX509Certs, crts...) a.federatedX509Certs[i] = crt
} }
} }
for _, crt := range a.federatedX509Certs { for _, crt := range a.federatedX509Certs {
@ -545,101 +537,6 @@ func (a *Authority) init() error {
tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts...) tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts...)
} }
if a.config.AuthorityConfig.EnableAdmin {
// Initialize step-ca Admin Database if it's not already initialized using
// WithAdminDB.
if a.adminDB == nil {
if linkedcaClient != nil {
a.adminDB = linkedcaClient
} else {
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
if err != nil {
return err
}
}
}
provs, err := a.adminDB.GetProvisioners(ctx)
if err != nil {
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
}
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
// Migration will currently only be kicked off once, because either one or more provisioners
// are migrated or a default JWK provisioner will be created in the DB. It won't run for
// linked or hosted deployments. Not for linked, because that case is explicitly checked
// for above. Not for hosted, because there'll be at least an existing OIDC provisioner.
var firstJWKProvisioner *linkedca.Provisioner
if len(a.config.AuthorityConfig.Provisioners) > 0 {
// Existing provisioners detected; try migrating them to DB storage.
a.initLogf("Starting migration of provisioners")
for _, p := range a.config.AuthorityConfig.Provisioners {
lp, err := ProvisionerToLinkedca(p)
if err != nil {
return admin.WrapErrorISE(err, "error transforming provisioner %q while migrating", p.GetName())
}
// Store the provisioner to be migrated
if err := a.adminDB.CreateProvisioner(ctx, lp); err != nil {
return admin.WrapErrorISE(err, "error creating provisioner %q while migrating", p.GetName())
}
// Mark the first JWK provisioner, so that it can be used for administration purposes
if firstJWKProvisioner == nil && lp.Type == linkedca.Provisioner_JWK {
firstJWKProvisioner = lp
a.initLogf("Migrated JWK provisioner %q with admin permissions", p.GetName())
} else {
a.initLogf("Migrated %s provisioner %q", p.GetType(), p.GetName())
}
}
c := a.config
if c.WasLoadedFromFile() {
// The provisioners in the configuration file can be deleted from
// the file by editing it. Automatic rewriting of the file was considered
// to be too surprising for users and not the right solution for all
// use cases, so we leave it up to users to this themselves.
a.initLogf("Provisioners that were migrated can now be removed from `ca.json` by editing it")
}
a.initLogf("Finished migrating provisioners")
}
// Create first JWK provisioner for remote administration purposes if none exists yet
if firstJWKProvisioner == nil {
firstJWKProvisioner, err = CreateFirstProvisioner(ctx, a.adminDB, string(a.password))
if err != nil {
return admin.WrapErrorISE(err, "error creating first provisioner")
}
a.initLogf("Created JWK provisioner %q with admin permissions", firstJWKProvisioner.GetName())
}
// Create first super admin, belonging to the first JWK provisioner
// TODO(hs): pass a user-provided first super admin subject to here. With `ca init` it's
// added to the DB immediately if using remote management. But when migrating from
// ca.json to the DB, this option doesn't exist. Adding a flag just to do it during
// migration isn't nice. We could opt for a user to change it afterwards. There exist
// cases in which creation of `step` could lock out a user from API access. This is the
// case if `step` isn't allowed to be signed by Name Constraints or the X.509 policy.
// We have protection for that when creating and updating a policy, but if a policy or
// Name Constraints are in use at the time of migration, that could lock the user out.
superAdminSubject := "step"
if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{
ProvisionerId: firstJWKProvisioner.Id,
Subject: superAdminSubject,
Type: linkedca.Admin_SUPER_ADMIN,
}); err != nil {
return admin.WrapErrorISE(err, "error creating first admin")
}
a.initLogf("Created super admin %q for JWK provisioner %q", superAdminSubject, firstJWKProvisioner.GetName())
}
}
// Load Provisioners and Admins
if err := a.ReloadAdminResources(ctx); err != nil {
return err
}
// Check if a KMS with decryption capability is required and available // Check if a KMS with decryption capability is required and available
if a.requiresDecrypter() { if a.requiresDecrypter() {
if _, ok := a.keyManager.(kmsapi.Decrypter); !ok { if _, ok := a.keyManager.(kmsapi.Decrypter); !ok {
@ -684,6 +581,47 @@ func (a *Authority) init() error {
// TODO: mimick the x509CAService GetCertificateAuthority here too? // TODO: mimick the x509CAService GetCertificateAuthority here too?
} }
if a.config.AuthorityConfig.EnableAdmin {
// Initialize step-ca Admin Database if it's not already initialized using
// WithAdminDB.
if a.adminDB == nil {
if linkedcaClient != nil {
a.adminDB = linkedcaClient
} else {
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
if err != nil {
return err
}
}
}
provs, err := a.adminDB.GetProvisioners(ctx)
if err != nil {
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
}
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
// Create First Provisioner
prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password))
if err != nil {
return admin.WrapErrorISE(err, "error creating first provisioner")
}
// Create first admin
if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{
ProvisionerId: prov.Id,
Subject: "step",
Type: linkedca.Admin_SUPER_ADMIN,
}); err != nil {
return admin.WrapErrorISE(err, "error creating first admin")
}
}
}
// Load Provisioners and Admins
if err := a.ReloadAdminResources(ctx); err != nil {
return err
}
// Load X509 constraints engine. // Load X509 constraints engine.
// //
// This is currently only available in CA mode. // This is currently only available in CA mode.
@ -716,18 +654,6 @@ func (a *Authority) init() error {
a.templates.Data["Step"] = tmplVars a.templates.Data["Step"] = tmplVars
} }
// Start the CRL generator, we can assume the configuration is validated.
if a.config.CRL.IsEnabled() {
// Default cache duration to the default one
if v := a.config.CRL.CacheDuration; v == nil || v.Duration <= 0 {
a.config.CRL.CacheDuration = config.DefaultCRLCacheDuration
}
// Start CRL generator
if err := a.startCRLGenerator(); err != nil {
return err
}
}
// JWT numeric dates are seconds. // JWT numeric dates are seconds.
a.startTime = time.Now().Truncate(time.Second) a.startTime = time.Now().Truncate(time.Second)
// Set flag indicating that initialization has been completed, and should // Set flag indicating that initialization has been completed, and should
@ -737,14 +663,6 @@ func (a *Authority) init() error {
return nil return nil
} }
// initLogf is used to log initialization information. The output
// can be disabled by starting the CA with the `--quiet` flag.
func (a *Authority) initLogf(format string, v ...any) {
if !a.quietInit {
log.Printf(format, v...)
}
}
// GetID returns the define authority id or a zero uuid. // GetID returns the define authority id or a zero uuid.
func (a *Authority) GetID() string { func (a *Authority) GetID() string {
const zeroUUID = "00000000-0000-0000-0000-000000000000" const zeroUUID = "00000000-0000-0000-0000-000000000000"
@ -794,11 +712,6 @@ func (a *Authority) IsAdminAPIEnabled() bool {
// Shutdown safely shuts down any clients, databases, etc. held by the Authority. // Shutdown safely shuts down any clients, databases, etc. held by the Authority.
func (a *Authority) Shutdown() error { func (a *Authority) Shutdown() error {
if a.crlTicker != nil {
a.crlTicker.Stop()
close(a.crlStopper)
}
if err := a.keyManager.Close(); err != nil { if err := a.keyManager.Close(); err != nil {
log.Printf("error closing the key manager: %v", err) log.Printf("error closing the key manager: %v", err)
} }
@ -807,11 +720,6 @@ func (a *Authority) Shutdown() error {
// CloseForReload closes internal services, to allow a safe reload. // CloseForReload closes internal services, to allow a safe reload.
func (a *Authority) CloseForReload() { func (a *Authority) CloseForReload() {
if a.crlTicker != nil {
a.crlTicker.Stop()
close(a.crlStopper)
}
if err := a.keyManager.Close(); err != nil { if err := a.keyManager.Close(); err != nil {
log.Printf("error closing the key manager: %v", err) log.Printf("error closing the key manager: %v", err)
} }
@ -852,49 +760,11 @@ func (a *Authority) requiresSCEPService() bool {
return false return false
} }
// GetSCEPService returns the configured SCEP Service. // GetSCEPService returns the configured SCEP Service
// // TODO: this function is intended to exist temporarily
// TODO: this function is intended to exist temporarily in order to make SCEP // in order to make SCEP work more easily. It can be
// work more easily. It can be made more correct by using the right // made more correct by using the right interfaces/abstractions
// interfaces/abstractions after it works as expected. // after it works as expected.
func (a *Authority) GetSCEPService() *scep.Service { func (a *Authority) GetSCEPService() *scep.Service {
return a.scepService return a.scepService
} }
func (a *Authority) startCRLGenerator() error {
if !a.config.CRL.IsEnabled() {
return nil
}
// Check that there is a valid CRL in the DB right now. If it doesn't exist
// or is expired, generate one now
_, ok := a.db.(db.CertificateRevocationListDB)
if !ok {
return errors.Errorf("CRL Generation requested, but database does not support CRL generation")
}
// Always create a new CRL on startup in case the CA has been down and the
// time to next expected CRL update is less than the cache duration.
if err := a.GenerateCertificateRevocationList(); err != nil {
return errors.Wrap(err, "could not generate a CRL")
}
a.crlStopper = make(chan struct{}, 1)
a.crlTicker = time.NewTicker(a.config.CRL.TickerDuration())
go func() {
for {
select {
case <-a.crlTicker.C:
log.Println("Regenerating CRL")
if err := a.GenerateCertificateRevocationList(); err != nil {
log.Printf("error regenerating the CRL: %v", err)
}
case <-a.crlStopper:
return
}
}
}()
return nil
}

View file

@ -6,10 +6,8 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"encoding/pem"
"net" "net"
"os" "os"
"path/filepath"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -20,7 +18,6 @@ import (
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/minica"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
@ -175,130 +172,6 @@ func TestAuthorityNew(t *testing.T) {
} }
} }
func TestAuthorityNew_bundles(t *testing.T) {
ca0, err := minica.New()
if err != nil {
t.Fatal(err)
}
ca1, err := minica.New()
if err != nil {
t.Fatal(err)
}
ca2, err := minica.New()
if err != nil {
t.Fatal(err)
}
rootPath := t.TempDir()
writeCert := func(fn string, certs ...*x509.Certificate) error {
var b []byte
for _, crt := range certs {
b = append(b, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: crt.Raw,
})...)
}
return os.WriteFile(filepath.Join(rootPath, fn), b, 0600)
}
writeKey := func(fn string, signer crypto.Signer) error {
_, err := pemutil.Serialize(signer, pemutil.ToFile(filepath.Join(rootPath, fn), 0600))
return err
}
if err := writeCert("root0.crt", ca0.Root); err != nil {
t.Fatal(err)
}
if err := writeCert("int0.crt", ca0.Intermediate); err != nil {
t.Fatal(err)
}
if err := writeKey("int0.key", ca0.Signer); err != nil {
t.Fatal(err)
}
if err := writeCert("root1.crt", ca1.Root); err != nil {
t.Fatal(err)
}
if err := writeCert("int1.crt", ca1.Intermediate); err != nil {
t.Fatal(err)
}
if err := writeKey("int1.key", ca1.Signer); err != nil {
t.Fatal(err)
}
if err := writeCert("bundle0.crt", ca0.Root, ca1.Root); err != nil {
t.Fatal(err)
}
if err := writeCert("bundle1.crt", ca1.Root, ca2.Root); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
config *config.Config
wantErr bool
}{
{"ok ca0", &config.Config{
Address: "127.0.0.1:443",
Root: []string{filepath.Join(rootPath, "root0.crt")},
IntermediateCert: filepath.Join(rootPath, "int0.crt"),
IntermediateKey: filepath.Join(rootPath, "int0.key"),
DNSNames: []string{"127.0.0.1"},
AuthorityConfig: &AuthConfig{},
}, false},
{"ok bundle", &config.Config{
Address: "127.0.0.1:443",
Root: []string{filepath.Join(rootPath, "bundle0.crt")},
IntermediateCert: filepath.Join(rootPath, "int0.crt"),
IntermediateKey: filepath.Join(rootPath, "int0.key"),
DNSNames: []string{"127.0.0.1"},
AuthorityConfig: &AuthConfig{},
}, false},
{"ok federated ca1", &config.Config{
Address: "127.0.0.1:443",
Root: []string{filepath.Join(rootPath, "root0.crt")},
FederatedRoots: []string{filepath.Join(rootPath, "root1.crt")},
IntermediateCert: filepath.Join(rootPath, "int0.crt"),
IntermediateKey: filepath.Join(rootPath, "int0.key"),
DNSNames: []string{"127.0.0.1"},
AuthorityConfig: &AuthConfig{},
}, false},
{"ok federated bundle", &config.Config{
Address: "127.0.0.1:443",
Root: []string{filepath.Join(rootPath, "root0.crt")},
FederatedRoots: []string{filepath.Join(rootPath, "bundle1.crt")},
IntermediateCert: filepath.Join(rootPath, "int0.crt"),
IntermediateKey: filepath.Join(rootPath, "int0.key"),
DNSNames: []string{"127.0.0.1"},
AuthorityConfig: &AuthConfig{},
}, false},
{"fail root", &config.Config{
Address: "127.0.0.1:443",
Root: []string{filepath.Join(rootPath, "missing.crt")},
IntermediateCert: filepath.Join(rootPath, "int0.crt"),
IntermediateKey: filepath.Join(rootPath, "int0.key"),
DNSNames: []string{"127.0.0.1"},
AuthorityConfig: &AuthConfig{},
}, true},
{"fail federated", &config.Config{
Address: "127.0.0.1:443",
Root: []string{filepath.Join(rootPath, "root0.crt")},
FederatedRoots: []string{filepath.Join(rootPath, "missing.crt")},
IntermediateCert: filepath.Join(rootPath, "int0.crt"),
IntermediateKey: filepath.Join(rootPath, "int0.key"),
DNSNames: []string{"127.0.0.1"},
AuthorityConfig: &AuthConfig{},
}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := New(tt.config)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestAuthority_GetDatabase(t *testing.T) { func TestAuthority_GetDatabase(t *testing.T) {
auth := testAuthority(t) auth := testAuthority(t)
authWithDatabase, err := New(auth.config, WithDatabase(auth.db)) authWithDatabase, err := New(auth.config, WithDatabase(auth.db))

View file

@ -286,7 +286,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
// extra extension cannot be found, authorize the renewal by default. // extra extension cannot be found, authorize the renewal by default.
// //
// TODO(mariano): should we authorize by default? // TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
serial := cert.SerialNumber.String() serial := cert.SerialNumber.String()
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
@ -308,14 +308,14 @@ func (a *Authority) authorizeRenew(ctx context.Context, cert *x509.Certificate)
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
} }
} }
if err := p.AuthorizeRenew(ctx, cert); err != nil { if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
} }
return nil return nil
} }
// authorizeSSHCertificate returns an error if the given certificate is revoked. // authorizeSSHCertificate returns an error if the given certificate is revoked.
func (a *Authority) authorizeSSHCertificate(_ context.Context, cert *ssh.Certificate) error { func (a *Authority) authorizeSSHCertificate(ctx context.Context, cert *ssh.Certificate) error {
var err error var err error
var isRevoked bool var isRevoked bool
@ -394,7 +394,7 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error
// AuthorizeRenewToken validates the renew token and returns the leaf // AuthorizeRenewToken validates the renew token and returns the leaf
// certificate in the x5cInsecure header. // certificate in the x5cInsecure header.
func (a *Authority) AuthorizeRenewToken(_ context.Context, ott string) (*x509.Certificate, error) { func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
var claims jose.Claims var claims jose.Claims
jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs) jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs)
if err != nil { if err != nil {
@ -434,7 +434,7 @@ func (a *Authority) AuthorizeRenewToken(_ context.Context, ott string) (*x509.Ce
} }
audiences := a.config.GetAudiences().Renew audiences := a.config.GetAudiences().Renew
if !matchesAudience(claims.Audience, audiences) && !isRAProvisioner(p) { if !matchesAudience(claims.Audience, audiences) {
return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)")) return nil, errs.InternalServerErr(jose.ErrInvalidAudience, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
} }

View file

@ -876,7 +876,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := genTestCase(t) tc := genTestCase(t)
err := tc.auth.authorizeRenew(context.Background(), tc.cert) err := tc.auth.authorizeRenew(tc.cert)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
var sc render.StatusCodedError var sc render.StatusCodedError
@ -1459,37 +1459,6 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
}) })
return nil return nil
})) }))
a4 := testAuthority(t)
a4.db = &db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
return &db.CertificateData{
Provisioner: &db.ProvisionerData{ID: "Max:IMi94WBNI6gP5cNHXlZYNUzvMjGdHyBRmFoo-lCEaqk", Name: "Max"},
RaInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
}, nil
},
}
t4, c4 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://ra.example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"}, Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com", Subject: "test.example.com",
@ -1658,7 +1627,6 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
{"ok", a1, args{ctx, t1}, c1, false}, {"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false}, {"ok expired cert", a1, args{ctx, t2}, c2, false},
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false}, {"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
{"ok ra provisioner", a4, args{ctx, t4}, c4, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true}, {"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true}, {"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true}, {"fail token signature", a1, args{ctx, badSigner}, nil, true},

View file

@ -35,13 +35,8 @@ var (
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally // DefaultEnableSSHCA enable SSH CA features per provisioner or globally
// for all provisioners. // for all provisioners.
DefaultEnableSSHCA = false DefaultEnableSSHCA = false
// DefaultCRLCacheDuration is the default cache duration for the CRL. // GlobalProvisionerClaims default claims for the Authority. Can be overridden
DefaultCRLCacheDuration = &provisioner.Duration{Duration: 24 * time.Hour} // by provisioner specific claims.
// DefaultCRLExpiredDuration is the default duration in which expired
// certificates will remain in the CRL after expiration.
DefaultCRLExpiredDuration = time.Hour
// GlobalProvisionerClaims is the default duration that expired certificates
// remain in the CRL after expiration.
GlobalProvisionerClaims = provisioner.Claims{ GlobalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
@ -77,62 +72,7 @@ type Config struct {
Password string `json:"password,omitempty"` Password string `json:"password,omitempty"`
Templates *templates.Templates `json:"templates,omitempty"` Templates *templates.Templates `json:"templates,omitempty"`
CommonName string `json:"commonName,omitempty"` CommonName string `json:"commonName,omitempty"`
CRL *CRLConfig `json:"crl,omitempty"`
SkipValidation bool `json:"-"` SkipValidation bool `json:"-"`
NNSServer string `json:"nnsServer,omitempty"`
// Keeps record of the filename the Config is read from
loadedFromFilepath string
}
// CRLConfig represents config options for CRL generation
type CRLConfig struct {
Enabled bool `json:"enabled"`
GenerateOnRevoke bool `json:"generateOnRevoke,omitempty"`
CacheDuration *provisioner.Duration `json:"cacheDuration,omitempty"`
RenewPeriod *provisioner.Duration `json:"renewPeriod,omitempty"`
IDPurl string `json:"idpURL,omitempty"`
}
// IsEnabled returns if the CRL is enabled.
func (c *CRLConfig) IsEnabled() bool {
return c != nil && c.Enabled
}
// Validate validates the CRL configuration.
func (c *CRLConfig) Validate() error {
if c == nil {
return nil
}
if c.CacheDuration != nil && c.CacheDuration.Duration < 0 {
return errors.New("crl.cacheDuration must be greater than or equal to 0")
}
if c.RenewPeriod != nil && c.RenewPeriod.Duration < 0 {
return errors.New("crl.renewPeriod must be greater than or equal to 0")
}
if c.RenewPeriod != nil && c.CacheDuration != nil &&
c.RenewPeriod.Duration > c.CacheDuration.Duration {
return errors.New("crl.cacheDuration must be greater than or equal to crl.renewPeriod")
}
return nil
}
// TickerDuration the renewal ticker duration. This is set by renewPeriod, of it
// is not set is ~2/3 of cacheDuration.
func (c *CRLConfig) TickerDuration() time.Duration {
if !c.IsEnabled() {
return 0
}
if c.RenewPeriod != nil && c.RenewPeriod.Duration > 0 {
return c.RenewPeriod.Duration
}
return (c.CacheDuration.Duration / 3) * 2
} }
// ASN1DN contains ASN1.DN attributes that are used in Subject and Issuer // ASN1DN contains ASN1.DN attributes that are used in Subject and Issuer
@ -183,7 +123,7 @@ func (c *AuthConfig) init() {
} }
// Validate validates the authority configuration. // Validate validates the authority configuration.
func (c *AuthConfig) Validate(provisioner.Audiences) error { func (c *AuthConfig) Validate(audiences provisioner.Audiences) error {
if c == nil { if c == nil {
return errors.New("authority cannot be undefined") return errors.New("authority cannot be undefined")
} }
@ -223,10 +163,6 @@ func LoadConfiguration(filename string) (*Config, error) {
return nil, errors.Wrapf(err, "error parsing %s", filename) return nil, errors.Wrapf(err, "error parsing %s", filename)
} }
// store filename that was read to populate Config
c.loadedFromFilepath = filename
// initialize the Config
c.Init() c.Init()
return &c, nil return &c, nil
@ -247,9 +183,6 @@ func (c *Config) Init() {
if c.CommonName == "" { if c.CommonName == "" {
c.CommonName = "Step Online CA" c.CommonName = "Step Online CA"
} }
if c.CRL != nil && c.CRL.Enabled && c.CRL.CacheDuration == nil {
c.CRL.CacheDuration = DefaultCRLCacheDuration
}
c.AuthorityConfig.init() c.AuthorityConfig.init()
} }
@ -266,30 +199,6 @@ func (c *Config) Save(filename string) error {
return errors.Wrapf(enc.Encode(c), "error writing %s", filename) return errors.Wrapf(enc.Encode(c), "error writing %s", filename)
} }
// Commit saves the current configuration to the same
// file it was initially loaded from.
//
// TODO(hs): rename Save() to WriteTo() and replace this
// with Save()? Or is Commit clear enough.
func (c *Config) Commit() error {
if !c.WasLoadedFromFile() {
return errors.New("cannot commit configuration if not loaded from file")
}
return c.Save(c.loadedFromFilepath)
}
// WasLoadedFromFile returns whether or not the Config was
// loaded from a file.
func (c *Config) WasLoadedFromFile() bool {
return c.loadedFromFilepath != ""
}
// Filepath returns the path to the file the Config was
// loaded from.
func (c *Config) Filepath() string {
return c.loadedFromFilepath
}
// Validate validates the configuration. // Validate validates the configuration.
func (c *Config) Validate() error { func (c *Config) Validate() error {
switch { switch {
@ -360,11 +269,6 @@ func (c *Config) Validate() error {
return err return err
} }
// Validate crl config: nil is ok
if err := c.CRL.Validate(); err != nil {
return err
}
return c.AuthorityConfig.Validate(c.GetAudiences()) return c.AuthorityConfig.Validate(c.GetAudiences())
} }

View file

@ -265,20 +265,8 @@ func (c *linkedCaClient) GetCertificateData(serial string) (*db.CertificateData,
ID: p.Id, Name: p.Name, Type: p.Type.String(), ID: p.Id, Name: p.Name, Type: p.Type.String(),
} }
} }
var raInfo *provisioner.RAInfo
if p := resp.RaProvisioner; p != nil && p.Provisioner != nil {
raInfo = &provisioner.RAInfo{
AuthorityID: p.AuthorityId,
ProvisionerID: p.Provisioner.Id,
ProvisionerType: p.Provisioner.Type.String(),
ProvisionerName: p.Provisioner.Name,
}
}
return &db.CertificateData{ return &db.CertificateData{
Provisioner: pd, Provisioner: pd,
RaInfo: raInfo,
}, nil }, nil
} }
@ -381,19 +369,19 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
} }
func (c *linkedCaClient) CreateAuthorityPolicy(_ context.Context, _ *linkedca.Policy) error { func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet") return errors.New("not implemented yet")
} }
func (c *linkedCaClient) GetAuthorityPolicy(context.Context) (*linkedca.Policy, error) { func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
return nil, errors.New("not implemented yet") return nil, errors.New("not implemented yet")
} }
func (c *linkedCaClient) UpdateAuthorityPolicy(_ context.Context, _ *linkedca.Policy) error { func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet") return errors.New("not implemented yet")
} }
func (c *linkedCaClient) DeleteAuthorityPolicy(context.Context) error { func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error {
return errors.New("not implemented yet") return errors.New("not implemented yet")
} }

View file

@ -86,14 +86,6 @@ func WithDatabase(d db.AuthDB) Option {
} }
} }
// WithQuietInit disables log output when the authority is initialized.
func WithQuietInit() Option {
return func(a *Authority) error {
a.quietInit = true
return nil
}
}
// WithWebhookClient sets the http.Client to be used for outbound requests. // WithWebhookClient sets the http.Client to be used for outbound requests.
func WithWebhookClient(c *http.Client) Option { func WithWebhookClient(c *http.Client) Option {
return func(a *Authority) error { return func(a *Authority) error {

View file

@ -154,7 +154,7 @@ func (a *Authority) checkProvisionerPolicy(ctx context.Context, provName string,
// checkPolicy checks if a new or updated policy configuration results in the user // checkPolicy checks if a new or updated policy configuration results in the user
// locking themselves or other admins out of the CA. // locking themselves or other admins out of the CA.
func (a *Authority) checkPolicy(_ context.Context, currentAdmin *linkedca.Admin, otherAdmins []*linkedca.Admin, p *linkedca.Policy) error { func (a *Authority) checkPolicy(ctx context.Context, currentAdmin *linkedca.Admin, otherAdmins []*linkedca.Admin, p *linkedca.Policy) error {
// convert the policy; return early if nil // convert the policy; return early if nil
policyOptions := authPolicy.LinkedToCertificates(p) policyOptions := authPolicy.LinkedToCertificates(p)
if policyOptions == nil { if policyOptions == nil {
@ -248,7 +248,7 @@ func isAllowed(engine authPolicy.X509Policy, sans []string) error {
if isNamePolicyError && policyErr.Reason == policy.NotAllowed { if isNamePolicyError && policyErr.Reason == policy.NotAllowed {
return &PolicyError{ return &PolicyError{
Typ: AdminLockOut, Typ: AdminLockOut,
Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please create an x509 policy to include %s as an allowed DNS name", sans, sans), Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans),
} }
} }
return &PolicyError{ return &PolicyError{

View file

@ -80,7 +80,7 @@ func TestAuthority_checkPolicy(t *testing.T) {
}, },
err: &PolicyError{ err: &PolicyError{
Typ: AdminLockOut, Typ: AdminLockOut,
Err: errors.New("the provided policy would lock out [step] from the CA. Please create an x509 policy to include [step] as an allowed DNS name"), Err: errors.New("the provided policy would lock out [step] from the CA. Please update your policy to include [step] as an allowed name"),
}, },
} }
}, },
@ -127,7 +127,7 @@ func TestAuthority_checkPolicy(t *testing.T) {
}, },
err: &PolicyError{ err: &PolicyError{
Typ: AdminLockOut, Typ: AdminLockOut,
Err: errors.New("the provided policy would lock out [otherAdmin] from the CA. Please create an x509 policy to include [otherAdmin] as an allowed DNS name"), Err: errors.New("the provided policy would lock out [otherAdmin] from the CA. Please update your policy to include [otherAdmin] as an allowed name"),
}, },
} }
}, },

View file

@ -26,8 +26,6 @@ const (
TLS_ALPN_01 ACMEChallenge = "tls-alpn-01" TLS_ALPN_01 ACMEChallenge = "tls-alpn-01"
// DEVICE_ATTEST_01 is the device-attest-01 ACME challenge. // DEVICE_ATTEST_01 is the device-attest-01 ACME challenge.
DEVICE_ATTEST_01 ACMEChallenge = "device-attest-01" DEVICE_ATTEST_01 ACMEChallenge = "device-attest-01"
// NNS_01 is the nns-01 ACME challenge.
NNS_01 ACMEChallenge = "nns-01"
) )
// String returns a normalized version of the challenge. // String returns a normalized version of the challenge.
@ -38,7 +36,7 @@ func (c ACMEChallenge) String() string {
// Validate returns an error if the acme challenge is not a valid one. // Validate returns an error if the acme challenge is not a valid one.
func (c ACMEChallenge) Validate() error { func (c ACMEChallenge) Validate() error {
switch ACMEChallenge(c.String()) { switch ACMEChallenge(c.String()) {
case HTTP_01, DNS_01, TLS_ALPN_01, DEVICE_ATTEST_01, NNS_01: case HTTP_01, DNS_01, TLS_ALPN_01, DEVICE_ATTEST_01:
return nil return nil
default: default:
return fmt.Errorf("acme challenge %q is not supported", c) return fmt.Errorf("acme challenge %q is not supported", c)
@ -50,7 +48,7 @@ func (c ACMEChallenge) Validate() error {
type ACMEAttestationFormat string type ACMEAttestationFormat string
const ( const (
// APPLE is the format used to enable device-attest-01 on Apple devices. // APPLE is the format used to enable device-attest-01 on apple devices.
APPLE ACMEAttestationFormat = "apple" APPLE ACMEAttestationFormat = "apple"
// STEP is the format used to enable device-attest-01 on devices that // STEP is the format used to enable device-attest-01 on devices that
@ -59,7 +57,7 @@ const (
// TODO(mariano): should we rename this to something else. // TODO(mariano): should we rename this to something else.
STEP ACMEAttestationFormat = "step" STEP ACMEAttestationFormat = "step"
// TPM is the format used to enable device-attest-01 with TPMs. // TPM is the format used to enable device-attest-01 on TPMs.
TPM ACMEAttestationFormat = "tpm" TPM ACMEAttestationFormat = "tpm"
) )
@ -86,17 +84,6 @@ type ACME struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
ForceCN bool `json:"forceCN,omitempty"` ForceCN bool `json:"forceCN,omitempty"`
// TermsOfService contains a URL pointing to the ACME server's
// terms of service. Defaults to empty.
TermsOfService string `json:"termsOfService,omitempty"`
// Website contains an URL pointing to more information about
// the ACME server. Defaults to empty.
Website string `json:"website,omitempty"`
// CaaIdentities is an array of hostnames that the ACME server
// identifies itself with. These hostnames can be used by ACME
// clients to determine the correct issuer domain name to use
// when configuring CAA records. Defaults to empty array.
CaaIdentities []string `json:"caaIdentities,omitempty"`
// RequireEAB makes the provisioner require ACME EAB to be provided // RequireEAB makes the provisioner require ACME EAB to be provided
// by clients when creating a new Account. If set to true, the provided // by clients when creating a new Account. If set to true, the provided
// EAB will be verified. If set to false and an EAB is provided, it is // EAB will be verified. If set to false and an EAB is provided, it is
@ -135,7 +122,7 @@ func (p *ACME) GetIDForToken() string {
} }
// GetTokenID returns the identifier of the token. // GetTokenID returns the identifier of the token.
func (p *ACME) GetTokenID(string) (string, error) { func (p *ACME) GetTokenID(ott string) (string, error) {
return "", errors.New("acme provisioner does not implement GetTokenID") return "", errors.New("acme provisioner does not implement GetTokenID")
} }
@ -186,7 +173,7 @@ func (p *ACME) Init(config Config) (err error) {
} }
// Parse attestation roots. // Parse attestation roots.
// The pool will be nil if there are no roots. // The pool will be nil if the there are not roots.
if rest := p.AttestationRoots; len(rest) > 0 { if rest := p.AttestationRoots; len(rest) > 0 {
var block *pem.Block var block *pem.Block
var hasCert bool var hasCert bool
@ -230,7 +217,7 @@ type ACMEIdentifier struct {
// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a // AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a
// certificate for an ACME Order Identifier. // certificate for an ACME Order Identifier.
func (p *ACME) AuthorizeOrderIdentifier(_ context.Context, identifier ACMEIdentifier) error { func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error {
x509Policy := p.ctl.getPolicy().getX509() x509Policy := p.ctl.getPolicy().getX509()
// identifier is allowed if no policy is configured // identifier is allowed if no policy is configured
@ -255,7 +242,7 @@ func (p *ACME) AuthorizeOrderIdentifier(_ context.Context, identifier ACMEIdenti
// AuthorizeSign does not do any validation, because all validation is handled // AuthorizeSign does not do any validation, because all validation is handled
// in the ACME protocol. This method returns a list of modifiers / constraints // in the ACME protocol. This method returns a list of modifiers / constraints
// on the resulting certificate. // on the resulting certificate.
func (p *ACME) AuthorizeSign(context.Context, string) ([]SignOption, error) { func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
opts := []SignOption{ opts := []SignOption{
p, p,
// modifiers / withOptions // modifiers / withOptions
@ -276,7 +263,7 @@ func (p *ACME) AuthorizeSign(context.Context, string) ([]SignOption, error) {
// the CA. It can be used to authorize revocation of a certificate. With the // the CA. It can be used to authorize revocation of a certificate. With the
// ACME protocol, revocation authorization is specified and performed as part // ACME protocol, revocation authorization is specified and performed as part
// of the client/server interaction, so this is a no-op. // of the client/server interaction, so this is a no-op.
func (p *ACME) AuthorizeRevoke(context.Context, string) error { func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
return nil return nil
} }
@ -291,9 +278,9 @@ func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// IsChallengeEnabled checks if the given challenge is enabled. By default // IsChallengeEnabled checks if the given challenge is enabled. By default
// http-01, dns-01 and tls-alpn-01 are enabled, to disable any of them the // http-01, dns-01 and tls-alpn-01 are enabled, to disable any of them the
// Challenge provisioner property should have at least one element. // Challenge provisioner property should have at least one element.
func (p *ACME) IsChallengeEnabled(_ context.Context, challenge ACMEChallenge) bool { func (p *ACME) IsChallengeEnabled(ctx context.Context, challenge ACMEChallenge) bool {
enabledChallenges := []ACMEChallenge{ enabledChallenges := []ACMEChallenge{
HTTP_01, DNS_01, TLS_ALPN_01, NNS_01, HTTP_01, DNS_01, TLS_ALPN_01,
} }
if len(p.Challenges) > 0 { if len(p.Challenges) > 0 {
enabledChallenges = p.Challenges enabledChallenges = p.Challenges
@ -309,7 +296,7 @@ func (p *ACME) IsChallengeEnabled(_ context.Context, challenge ACMEChallenge) bo
// IsAttestationFormatEnabled checks if the given attestation format is enabled. // IsAttestationFormatEnabled checks if the given attestation format is enabled.
// By default apple, step and tpm are enabled, to disable any of them the // By default apple, step and tpm are enabled, to disable any of them the
// AttestationFormat provisioner property should have at least one element. // AttestationFormat provisioner property should have at least one element.
func (p *ACME) IsAttestationFormatEnabled(_ context.Context, format ACMEAttestationFormat) bool { func (p *ACME) IsAttestationFormatEnabled(ctx context.Context, format ACMEAttestationFormat) bool {
enabledFormats := []ACMEAttestationFormat{ enabledFormats := []ACMEAttestationFormat{
APPLE, STEP, TPM, APPLE, STEP, TPM,
} }

View file

@ -24,7 +24,6 @@ import (
"go.step.sm/linkedca" "go.step.sm/linkedca"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/webhook"
) )
// awsIssuer is the string used as issuer in the generated tokens. // awsIssuer is the string used as issuer in the generated tokens.
@ -74,14 +73,6 @@ const awsMetadataTokenTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds" //nolin
// The fifth certificate is used in: // The fifth certificate is used in:
// //
// me-south-1 // me-south-1
//
// The sixth certificate is used in:
//
// me-central-1
//
// The seventh certificate is used in:
//
// ap-southeast-3
const awsCertificate = `-----BEGIN CERTIFICATE----- const awsCertificate = `-----BEGIN CERTIFICATE-----
MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV
BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw
@ -163,34 +154,6 @@ DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQBhkNTBIFgWFd+ZhC/LhRUY
4OjEiykmbEp6hlzQ79T0Tfbn5A4NYDI2icBP0+hmf6qSnIhwJF6typyd1yPK5Fqt 4OjEiykmbEp6hlzQ79T0Tfbn5A4NYDI2icBP0+hmf6qSnIhwJF6typyd1yPK5Fqt
NTpxxcXmUKquX+pHmIkK1LKDO8rNE84jqxrxRsfDi6by82fjVYf2pgjJW8R1FAw+ NTpxxcXmUKquX+pHmIkK1LKDO8rNE84jqxrxRsfDi6by82fjVYf2pgjJW8R1FAw+
mL5WQRFexbfB5aXhcMo0AA== mL5WQRFexbfB5aXhcMo0AA==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIICMzCCAZygAwIBAgIGAXjRrnDjMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT
AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl
MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMTA0MTQxODM5
MzNaGA8yMjAwMDQxNDE4MzkzM1owXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh
c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv
biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDc
aTgW/KyA6zyruJQrYy00a6wqLA7eeUzk3bMiTkLsTeDQfrkaZMfBAjGaaOymRo1C
3qzE4rIenmahvUplu9ZmLwL1idWXMRX2RlSvIt+d2SeoKOKQWoc2UOFZMHYxDue7
zkyk1CIRaBukTeY13/RIrlc6X61zJ5BBtZXlHwayjQIDAQABMA0GCSqGSIb3DQEB
BQUAA4GBABTqTy3R6RXKPW45FA+cgo7YZEj/Cnz5YaoUivRRdX2A83BHuBTvJE2+
WX00FTEj4hRVjameE1nENoO8Z7fUVloAFDlDo69fhkJeSvn51D1WRrPnoWGgEfr1
+OfK1bAcKTtfkkkP9r4RdwSjKzO5Zu/B+Wqm3kVEz/QNcz6npmA6
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIICMzCCAZygAwIBAgIGAXbVDG2yMA0GCSqGSIb3DQEBBQUAMFwxCzAJBgNVBAYT
AlVTMRkwFwYDVQQIDBBXYXNoaW5ndG9uIFN0YXRlMRAwDgYDVQQHDAdTZWF0dGxl
MSAwHgYDVQQKDBdBbWF6b24gV2ViIFNlcnZpY2VzIExMQzAgFw0yMTAxMDYwMDE1
MzBaGA8yMjAwMDEwNjAwMTUzMFowXDELMAkGA1UEBhMCVVMxGTAXBgNVBAgMEFdh
c2hpbmd0b24gU3RhdGUxEDAOBgNVBAcMB1NlYXR0bGUxIDAeBgNVBAoMF0FtYXpv
biBXZWIgU2VydmljZXMgTExDMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn
CS/Vbt0gQ1ebWcur2hSO7PnJifE4OPxQ7RgSAlc4/spJp1sDP+ZrS0LO1ZJfKhXf
1R9S3AUwLnsc7b+IuVXdY5LK9RKqu64nyXP5dx170zoL8loEyCSuRR2fs+04i2Qs
WBVP+KFNAn7P5L1EHRjkgTO8kjNKviwRV+OkP9ab5wIDAQABMA0GCSqGSIb3DQEB
BQUAA4GBAI4WUy6+DKh0JDSzQEZNyBgNlSoSuC2owtMxCwGB6nBfzzfcekWvs6eo
fLTSGovrReX7MtVgrcJBZjmPIentw5dWUs+87w/g9lNwUnUt0ZHYyh2tuBG6hVJu
UEwDJ/z3wDd6wQviLOTF3MITawt9P8siR1hXqLJNxpjRQFZrgHqi
-----END CERTIFICATE-----` -----END CERTIFICATE-----`
// awsSignatureAlgorithm is the signature algorithm used to verify the identity // awsSignatureAlgorithm is the signature algorithm used to verify the identity
@ -472,7 +435,7 @@ func (p *AWS) Init(config Config) (err error) {
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation. // will be used on certificate creation.
func (p *AWS) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
payload, err := p.authorizeToken(token) payload, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign")
@ -522,11 +485,7 @@ func (p *AWS) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro
commonNameValidator(payload.Claims.Subject), commonNameValidator(payload.Claims.Subject),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
p.ctl.newWebhookController( p.ctl.newWebhookController(data, linkedca.Webhook_X509),
data,
linkedca.Webhook_X509,
webhook.WithAuthorizationPrincipal(doc.InstanceID),
),
), nil ), nil
} }
@ -749,7 +708,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.ctl.Claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
} }
@ -809,10 +768,6 @@ func (p *AWS) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
// Ensure that all principal names are allowed // Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
// Call webhooks // Call webhooks
p.ctl.newWebhookController( p.ctl.newWebhookController(data, linkedca.Webhook_SSH),
data,
linkedca.Webhook_SSH,
webhook.WithAuthorizationPrincipal(doc.InstanceID),
),
), nil ), nil
} }

View file

@ -20,19 +20,13 @@ import (
"go.step.sm/linkedca" "go.step.sm/linkedca"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/webhook"
) )
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens. // azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
const azureOIDCBaseURL = "https://login.microsoftonline.com" const azureOIDCBaseURL = "https://login.microsoftonline.com"
//nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance. //nolint:gosec // azureIdentityTokenURL is the URL to get the identity token for an instance.
const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token" const azureIdentityTokenURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F"
const azureIdentityTokenAPIVersion = "2018-02-01"
// azureInstanceComputeURL is the URL to get the instance compute metadata.
const azureInstanceComputeURL = "http://169.254.169.254/metadata/instance/compute/azEnvironment"
// azureDefaultAudience is the default audience used. // azureDefaultAudience is the default audience used.
const azureDefaultAudience = "https://management.azure.com/" const azureDefaultAudience = "https://management.azure.com/"
@ -41,27 +35,15 @@ const azureDefaultAudience = "https://management.azure.com/"
// Using case insensitive as resourceGroups appears as resourcegroups. // Using case insensitive as resourceGroups appears as resourcegroups.
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`) var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
// azureEnvironments is the list of all Azure environments.
var azureEnvironments = map[string]string{
"AzurePublicCloud": "https://management.azure.com/",
"AzureCloud": "https://management.azure.com/",
"AzureUSGovernmentCloud": "https://management.usgovcloudapi.net/",
"AzureUSGovernment": "https://management.usgovcloudapi.net/",
"AzureChinaCloud": "https://management.chinacloudapi.cn/",
"AzureGermanCloud": "https://management.microsoftazure.de/",
}
type azureConfig struct { type azureConfig struct {
oidcDiscoveryURL string oidcDiscoveryURL string
identityTokenURL string identityTokenURL string
instanceComputeURL string
} }
func newAzureConfig(tenantID string) *azureConfig { func newAzureConfig(tenantID string) *azureConfig {
return &azureConfig{ return &azureConfig{
oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration", oidcDiscoveryURL: azureOIDCBaseURL + "/" + tenantID + "/.well-known/openid-configuration",
identityTokenURL: azureIdentityTokenURL, identityTokenURL: azureIdentityTokenURL,
instanceComputeURL: azureInstanceComputeURL,
} }
} }
@ -121,7 +103,6 @@ type Azure struct {
oidcConfig openIDConfiguration oidcConfig openIDConfiguration
keyStore *keyStore keyStore *keyStore
ctl *Controller ctl *Controller
environment string
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -183,35 +164,14 @@ func (p *Azure) GetEncryptedKey() (kid, key string, ok bool) {
// GetIdentityToken retrieves from the metadata service the identity token and // GetIdentityToken retrieves from the metadata service the identity token and
// returns it. // returns it.
func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) { func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
_, _ = subject, caURL // unused input
// Initialize the config if this method is used from the cli. // Initialize the config if this method is used from the cli.
p.assertConfig() p.assertConfig()
// default to AzurePublicCloud to keep existing behavior
identityTokenResource := azureEnvironments["AzurePublicCloud"]
var err error
p.environment, err = p.getAzureEnvironment()
if err != nil {
return "", errors.Wrap(err, "error getting azure environment")
}
if resource, ok := azureEnvironments[p.environment]; ok {
identityTokenResource = resource
}
req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody) req, err := http.NewRequest("GET", p.config.identityTokenURL, http.NoBody)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error creating request") return "", errors.Wrap(err, "error creating request")
} }
req.Header.Set("Metadata", "true") req.Header.Set("Metadata", "true")
query := req.URL.Query()
query.Add("resource", identityTokenResource)
query.Add("api-version", azureIdentityTokenAPIVersion)
req.URL.RawQuery = query.Encode()
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?") return "", errors.Wrap(err, "error getting identity token, are you in a Azure VM?")
@ -316,7 +276,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation. // will be used on certificate creation.
func (p *Azure) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
_, name, group, subscription, identityObjectID, err := p.authorizeToken(token) _, name, group, subscription, identityObjectID, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign")
@ -404,11 +364,7 @@ func (p *Azure) AuthorizeSign(_ context.Context, token string) ([]SignOption, er
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
p.ctl.newWebhookController( p.ctl.newWebhookController(data, linkedca.Webhook_X509),
data,
linkedca.Webhook_X509,
webhook.WithAuthorizationPrincipal(identityObjectID),
),
), nil ), nil
} }
@ -421,12 +377,12 @@ func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.ctl.Claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName())
} }
_, name, _, _, identityObjectID, err := p.authorizeToken(token) _, name, _, _, _, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSSHSign")
} }
@ -478,11 +434,7 @@ func (p *Azure) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption,
// Ensure that all principal names are allowed // Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
// Call webhooks // Call webhooks
p.ctl.newWebhookController( p.ctl.newWebhookController(data, linkedca.Webhook_SSH),
data,
linkedca.Webhook_SSH,
webhook.WithAuthorizationPrincipal(identityObjectID),
),
), nil ), nil
} }
@ -492,37 +444,3 @@ func (p *Azure) assertConfig() {
p.config = newAzureConfig(p.TenantID) p.config = newAzureConfig(p.TenantID)
} }
} }
// getAzureEnvironment returns the Azure environment for the current instance
func (p *Azure) getAzureEnvironment() (string, error) {
if p.environment != "" {
return p.environment, nil
}
req, err := http.NewRequest("GET", p.config.instanceComputeURL, http.NoBody)
if err != nil {
return "", errors.Wrap(err, "error creating request")
}
req.Header.Add("Metadata", "True")
query := req.URL.Query()
query.Add("format", "text")
query.Add("api-version", "2021-02-01")
req.URL.RawQuery = query.Encode()
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", errors.Wrap(err, "error getting azure instance environment, are you in a Azure VM?")
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return "", errors.Wrap(err, "error reading azure environment response")
}
if resp.StatusCode >= 400 {
return "", errors.Errorf("error getting azure environment: status=%d, response=%s", resp.StatusCode, b)
}
return string(b), nil
}

View file

@ -100,14 +100,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
srvIdentity := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wantResource := r.URL.Query().Get("want_resource")
resource := r.URL.Query().Get("resource")
if wantResource == "" || resource != wantResource {
http.Error(w, fmt.Sprintf("Azure query param resource = %s, wantResource %s", resource, wantResource), http.StatusBadRequest)
return
}
switch r.URL.Path { switch r.URL.Path {
case "/bad-request": case "/bad-request":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
@ -118,27 +111,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
fmt.Fprintf(w, `{"access_token":"%s"}`, t1) fmt.Fprintf(w, `{"access_token":"%s"}`, t1)
} }
})) }))
defer srvIdentity.Close() defer srv.Close()
srvInstance := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/bad-request":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case "/AzureChinaCloud":
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzureChinaCloud"))
case "/AzureGermanCloud":
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzureGermanCloud"))
case "/AzureUSGovernmentCloud":
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzureUSGovernmentCloud"))
default:
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzurePublicCloud"))
}
}))
defer srvInstance.Close()
type args struct { type args struct {
subject string subject string
@ -149,27 +122,18 @@ func TestAzure_GetIdentityToken(t *testing.T) {
azure *Azure azure *Azure
args args args args
identityTokenURL string identityTokenURL string
instanceComputeURL string
wantEnvironment string
want string want string
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false}, {"ok", p1, args{"subject", "caURL"}, srv.URL, t1, false},
{"ok azure china", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzurePublicCloud", t1, false}, {"fail request", p1, args{"subject", "caURL"}, srv.URL + "/bad-request", "", true},
{"ok azure germany", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureGermanCloud", t1, false}, {"fail unmarshal", p1, args{"subject", "caURL"}, srv.URL + "/bad-json", "", true},
{"ok azure us gov", p1, args{"subject", "caURL"}, srvIdentity.URL, srvInstance.URL, "AzureUSGovernmentCloud", t1, false}, {"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", "", true},
{"fail instance request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL + "/bad-request", "AzurePublicCloud", "", true}, {"fail connect", p1, args{"subject", "caURL"}, "foobarzar", "", true},
{"fail request", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-request", srvInstance.URL, "AzurePublicCloud", "", true},
{"fail unmarshal", p1, args{"subject", "caURL"}, srvIdentity.URL + "/bad-json", srvInstance.URL, "AzurePublicCloud", "", true},
{"fail url", p1, args{"subject", "caURL"}, "://ca.smallstep.com", srvInstance.URL, "AzurePublicCloud", "", true},
{"fail connect", p1, args{"subject", "caURL"}, "foobarzar", srvInstance.URL, "AzurePublicCloud", "", true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// reset environment between tests to avoid caching issues tt.azure.config.identityTokenURL = tt.identityTokenURL
p1.environment = ""
tt.azure.config.identityTokenURL = tt.identityTokenURL + "?want_resource=" + azureEnvironments[tt.wantEnvironment]
tt.azure.config.instanceComputeURL = tt.instanceComputeURL + "/" + tt.wantEnvironment
got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL) got, err := tt.azure.GetIdentityToken(tt.args.subject, tt.args.caURL)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.GetIdentityToken() error = %v, wantErr %v", err, tt.wantErr)

View file

@ -10,7 +10,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -78,7 +77,7 @@ func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificat
return DefaultAuthorizeSSHRenew(ctx, c, cert) return DefaultAuthorizeSSHRenew(ctx, c, cert)
} }
func (c *Controller) newWebhookController(templateData WebhookSetter, certType linkedca.Webhook_CertType, opts ...webhook.RequestBodyOption) *WebhookController { func (c *Controller) newWebhookController(templateData WebhookSetter, certType linkedca.Webhook_CertType) *WebhookController {
client := c.webhookClient client := c.webhookClient
if client == nil { if client == nil {
client = http.DefaultClient client = http.DefaultClient
@ -88,7 +87,6 @@ func (c *Controller) newWebhookController(templateData WebhookSetter, certType l
client: client, client: client,
webhooks: c.webhooks, webhooks: c.webhooks,
certType: certType, certType: certType,
options: opts,
} }
} }
@ -113,7 +111,7 @@ type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Ce
// DefaultIdentityFunc return a default identity depending on the provisioner // DefaultIdentityFunc return a default identity depending on the provisioner
// type. For OIDC email is always present and the usernames might // type. For OIDC email is always present and the usernames might
// contain empty strings. // contain empty strings.
func DefaultIdentityFunc(_ context.Context, p Interface, email string) (*Identity, error) { func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) { switch k := p.(type) {
case *OIDC: case *OIDC:
// OIDC principals would be: // OIDC principals would be:
@ -142,7 +140,7 @@ func DefaultIdentityFunc(_ context.Context, p Interface, email string) (*Identit
// will return an error if the provisioner has the renewal disabled, if the // will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after // certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled. // expiry is disabled.
func DefaultAuthorizeRenew(_ context.Context, p *Controller, cert *x509.Certificate) error { func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error {
if p.Claimer.IsDisableRenewal() { if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
} }
@ -164,7 +162,7 @@ func DefaultAuthorizeRenew(_ context.Context, p *Controller, cert *x509.Certific
// will return an error if the provisioner has the renewal disabled, if the // will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after // certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled. // expiry is disabled.
func DefaultAuthorizeSSHRenew(_ context.Context, p *Controller, cert *ssh.Certificate) error { func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
if p.Claimer.IsDisableRenewal() { if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName()) return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
} }

View file

@ -4,18 +4,15 @@ import (
"context" "context"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"net/http"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/webhook"
) )
var trueValue = true var trueValue = true
@ -452,39 +449,16 @@ func TestDefaultAuthorizeSSHRenew(t *testing.T) {
} }
func Test_newWebhookController(t *testing.T) { func Test_newWebhookController(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock())
if err != nil {
t.Fatal(err)
}
opts := []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}
type args struct {
templateData WebhookSetter
certType linkedca.Webhook_CertType
opts []webhook.RequestBodyOption
}
tests := []struct {
name string
args args
want *WebhookController
}{
{"ok", args{x509util.TemplateData{"foo": "bar"}, linkedca.Webhook_X509, nil}, &WebhookController{
TemplateData: x509util.TemplateData{"foo": "bar"},
certType: linkedca.Webhook_X509,
client: http.DefaultClient,
}},
{"ok with options", args{x509util.TemplateData{"foo": "bar"}, linkedca.Webhook_SSH, opts}, &WebhookController{
TemplateData: x509util.TemplateData{"foo": "bar"},
certType: linkedca.Webhook_SSH,
client: http.DefaultClient,
options: opts,
}},
}
for _, tt := range tests {
c := &Controller{} c := &Controller{}
got := c.newWebhookController(tt.args.templateData, tt.args.certType, tt.args.opts...) data := x509util.TemplateData{"foo": "bar"}
if !reflect.DeepEqual(got, tt.want) { ctl := c.newWebhookController(data, linkedca.Webhook_X509)
t.Errorf("newWebhookController() = %v, want %v", got, tt.want) if !reflect.DeepEqual(ctl.TemplateData, data) {
t.Error("Failed to set templateData")
} }
if ctl.certType != linkedca.Webhook_X509 {
t.Error("Failed to set certType")
}
if ctl.client == nil {
t.Error("Failed to set client")
} }
} }

View file

@ -21,7 +21,6 @@ import (
"go.step.sm/linkedca" "go.step.sm/linkedca"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/webhook"
) )
// gcpCertsURL is the url that serves Google OAuth2 public keys. // gcpCertsURL is the url that serves Google OAuth2 public keys.
@ -170,8 +169,6 @@ func (p *GCP) GetIdentityURL(audience string) string {
// GetIdentityToken does an HTTP request to the identity url. // GetIdentityToken does an HTTP request to the identity url.
func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) { func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
_ = subject // unused input
audience, err := generateSignAudience(caURL, p.GetIDForToken()) audience, err := generateSignAudience(caURL, p.GetIDForToken())
if err != nil { if err != nil {
return "", err return "", err
@ -223,7 +220,7 @@ func (p *GCP) Init(config Config) (err error) {
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
// will be used on certificate creation. // will be used on certificate creation.
func (p *GCP) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign")
@ -276,11 +273,7 @@ func (p *GCP) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
p.ctl.newWebhookController( p.ctl.newWebhookController(data, linkedca.Webhook_X509),
data,
linkedca.Webhook_X509,
webhook.WithAuthorizationPrincipal(ce.InstanceID),
),
), nil ), nil
} }
@ -387,7 +380,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.ctl.Claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName())
} }
@ -447,10 +440,6 @@ func (p *GCP) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
// Ensure that all principal names are allowed // Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil), newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
// Call webhooks // Call webhooks
p.ctl.newWebhookController( p.ctl.newWebhookController(data, linkedca.Webhook_SSH),
data,
linkedca.Webhook_SSH,
webhook.WithAuthorizationPrincipal(ce.InstanceID),
),
), nil ), nil
} }

View file

@ -143,14 +143,14 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(_ context.Context, token string) error { func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke) // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
@ -209,7 +209,7 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.ctl.Claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName())
} }
@ -286,7 +286,7 @@ func (p *JWK) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
} }
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(_ context.Context, token string) error { func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
// TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke) // TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")

View file

@ -72,7 +72,7 @@ func (p *K8sSA) GetIDForToken() string {
} }
// GetTokenID returns an unimplemented error and does not use the input ott. // GetTokenID returns an unimplemented error and does not use the input ott.
func (p *K8sSA) GetTokenID(string) (string, error) { func (p *K8sSA) GetTokenID(ott string) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
@ -148,7 +148,6 @@ func (p *K8sSA) Init(config Config) (err error) {
// claims for case specific downstream parsing. // claims for case specific downstream parsing.
// e.g. a Sign request will auth/validate different fields than a Revoke request. // e.g. a Sign request will auth/validate different fields than a Revoke request.
func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, error) { func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload, error) {
_ = audiences // unused input
jwt, err := jose.ParseSigned(token) jwt, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, return nil, errs.Wrap(http.StatusUnauthorized, err,
@ -208,13 +207,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *K8sSA) AuthorizeRevoke(_ context.Context, token string) error { func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *K8sSA) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
@ -254,7 +253,7 @@ func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) erro
} }
// AuthorizeSSHSign validates an request for an SSH certificate. // AuthorizeSSHSign validates an request for an SSH certificate.
func (p *K8sSA) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.ctl.Claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName())
} }

View file

@ -116,7 +116,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) {
} }
// AuthorizeSign returns the list of SignOption for a Sign request. // AuthorizeSign returns the list of SignOption for a Sign request.
func (p *Nebula) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, err return nil, err
@ -171,7 +171,7 @@ func (p *Nebula) AuthorizeSign(_ context.Context, token string) ([]SignOption, e
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
// Currently the Nebula provisioner only grants host SSH certificates. // Currently the Nebula provisioner only grants host SSH certificates.
func (p *Nebula) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.ctl.Claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
} }
@ -275,12 +275,12 @@ func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) erro
} }
// AuthorizeRevoke returns an error if the token is not valid. // AuthorizeRevoke returns an error if the token is not valid.
func (p *Nebula) AuthorizeRevoke(_ context.Context, token string) error { func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error {
return p.validateToken(token, p.ctl.Audiences.Revoke) return p.validateToken(token, p.ctl.Audiences.Revoke)
} }
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid. // AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
func (p *Nebula) AuthorizeSSHRevoke(_ context.Context, token string) error { func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error {
if !p.ctl.Claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
} }
@ -291,12 +291,12 @@ func (p *Nebula) AuthorizeSSHRevoke(_ context.Context, token string) error {
} }
// AuthorizeSSHRenew returns an unauthorized error. // AuthorizeSSHRenew returns an unauthorized error.
func (p *Nebula) AuthorizeSSHRenew(context.Context, string) (*ssh.Certificate, error) { func (p *Nebula) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
return nil, errs.Unauthorized("nebula provisioner does not support SSH renew") return nil, errs.Unauthorized("nebula provisioner does not support SSH renew")
} }
// AuthorizeSSHRekey returns an unauthorized error. // AuthorizeSSHRekey returns an unauthorized error.
func (p *Nebula) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, []SignOption, error) { func (p *Nebula) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
return nil, nil, errs.Unauthorized("nebula provisioner does not support SSH rekey") return nil, nil, errs.Unauthorized("nebula provisioner does not support SSH rekey")
} }

View file

@ -18,7 +18,7 @@ func (p *noop) GetIDForToken() string {
return "noop" return "noop"
} }
func (p *noop) GetTokenID(string) (string, error) { func (p *noop) GetTokenID(token string) (string, error) {
return "", nil return "", nil
} }
@ -33,35 +33,35 @@ func (p *noop) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }
func (p *noop) Init(Config) error { func (p *noop) Init(config Config) error {
return nil return nil
} }
func (p *noop) AuthorizeSign(context.Context, string) ([]SignOption, error) { func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{p}, nil return []SignOption{p}, nil
} }
func (p *noop) AuthorizeRenew(context.Context, *x509.Certificate) error { func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
return nil return nil
} }
func (p *noop) AuthorizeRevoke(context.Context, string) error { func (p *noop) AuthorizeRevoke(ctx context.Context, token string) error {
return nil return nil
} }
func (p *noop) AuthorizeSSHSign(context.Context, string) ([]SignOption, error) { func (p *noop) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{p}, nil return []SignOption{p}, nil
} }
func (p *noop) AuthorizeSSHRenew(context.Context, string) (*ssh.Certificate, error) { func (p *noop) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
//nolint:nilnil // fine for noop //nolint:nilnil // fine for noop
return nil, nil return nil, nil
} }
func (p *noop) AuthorizeSSHRevoke(context.Context, string) error { func (p *noop) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return nil return nil
} }
func (p *noop) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, []SignOption, error) { func (p *noop) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
return nil, []SignOption{}, nil return nil, []SignOption{}, nil
} }

View file

@ -230,7 +230,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
} }
} }
if !found { if !found {
return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email %q is not allowed", p.Email) return errs.Unauthorized("validatePayload: failed to validate oidc token payload: email is not allowed")
} }
} }
@ -292,7 +292,7 @@ func (o *OIDC) authorizeToken(token string) (*openIDPayload, error) {
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
// Only tokens generated by an admin have the right to revoke a certificate. // Only tokens generated by an admin have the right to revoke a certificate.
func (o *OIDC) AuthorizeRevoke(_ context.Context, token string) error { func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeRevoke")
@ -307,7 +307,7 @@ func (o *OIDC) AuthorizeRevoke(_ context.Context, token string) error {
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (o *OIDC) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSign")
@ -385,13 +385,16 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
} }
var data sshutil.TemplateData var data sshutil.TemplateData
var principals []string
if claims.Email == "" { if claims.Email == "" {
// If email is empty, use the Subject claim instead to create minimal // If email is empty, use the Subject claim instead to create minimal data for the template to use
// data for the template to use.
data = sshutil.CreateTemplateData(sshutil.UserCert, claims.Subject, nil) data = sshutil.CreateTemplateData(sshutil.UserCert, claims.Subject, nil)
if v, err := unsafeParseSigned(token); err == nil { if v, err := unsafeParseSigned(token); err == nil {
data.SetToken(v) data.SetToken(v)
} }
principals = nil
} else { } else {
// Get the identity using either the default identityFunc or one injected // Get the identity using either the default identityFunc or one injected
// externally. Note that the PreferredUsername might be empty. // externally. Note that the PreferredUsername might be empty.
@ -414,6 +417,8 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
for k, v := range iden.Permissions.CriticalOptions { for k, v := range iden.Permissions.CriticalOptions {
data.AddCriticalOption(k, v) data.AddCriticalOption(k, v)
} }
principals = iden.Usernames
} }
// Use the default template unless no-templates are configured and email is // Use the default template unless no-templates are configured and email is
@ -442,6 +447,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
} else { } else {
signOptions = append(signOptions, sshCertOptionsValidator(SignSSHOptions{ signOptions = append(signOptions, sshCertOptionsValidator(SignSSHOptions{
CertType: SSHUserCert, CertType: SSHUserCert,
Principals: principals,
})) }))
} }
@ -463,7 +469,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
} }
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (o *OIDC) AuthorizeSSHRevoke(_ context.Context, token string) error { func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHRevoke")

View file

@ -13,7 +13,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert" "github.com/smallstep/assert"
@ -222,37 +221,39 @@ func TestOIDC_authorizeToken(t *testing.T) {
args args args args
code int code int
wantIssuer string wantIssuer string
expErr error wantErr bool
}{ }{
{"ok1", p1, args{t1}, http.StatusOK, issuer, nil}, {"ok1", p1, args{t1}, http.StatusOK, issuer, false},
{"ok tenantid", p2, args{t2}, http.StatusOK, tenantIssuer, nil}, {"ok tenantid", p2, args{t2}, http.StatusOK, tenantIssuer, false},
{"ok admin", p3, args{t3}, http.StatusOK, issuer, nil}, {"ok admin", p3, args{t3}, http.StatusOK, issuer, false},
{"ok domain", p3, args{t4}, http.StatusOK, issuer, nil}, {"ok domain", p3, args{t4}, http.StatusOK, issuer, false},
{"ok no email", p3, args{t5}, http.StatusOK, issuer, nil}, {"ok no email", p3, args{t5}, http.StatusOK, issuer, false},
{"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: email "name@example.com" is not allowed`)}, {"fail-domain", p3, args{failDomain}, http.StatusUnauthorized, "", true},
{"fail-key", p1, args{failKey}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; cannot validate oidc token`)}, {"fail-key", p1, args{failKey}, http.StatusUnauthorized, "", true},
{"fail-token", p1, args{failTok}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; error parsing oidc token: invalid character '~' looking for beginning of value`)}, {"fail-token", p1, args{failTok}, http.StatusUnauthorized, "", true},
{"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; error parsing oidc token claims: invalid character '~' looking for beginning of value`)}, {"fail-claims", p1, args{failClaims}, http.StatusUnauthorized, "", true},
{"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: square/go-jose/jwt: validation failed, invalid issuer claim (iss)`)}, {"fail-issuer", p1, args{failIss}, http.StatusUnauthorized, "", true},
{"fail-audience", p1, args{failAud}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: square/go-jose/jwt: validation failed, invalid audience claim (aud)`)}, {"fail-audience", p1, args{failAud}, http.StatusUnauthorized, "", true},
{"fail-signature", p1, args{failSig}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken; cannot validate oidc token`)}, {"fail-signature", p1, args{failSig}, http.StatusUnauthorized, "", true},
{"fail-expired", p1, args{failExp}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: square/go-jose/jwt: validation failed, token is expired (exp)`)}, {"fail-expired", p1, args{failExp}, http.StatusUnauthorized, "", true},
{"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, "", errors.New(`oidc.AuthorizeToken: validatePayload: failed to validate oidc token payload: square/go-jose/jwt: validation failed, token not valid yet (nbf)`)}, {"fail-not-before", p1, args{failNbf}, http.StatusUnauthorized, "", true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.authorizeToken(tt.args.token) got, err := tt.prov.authorizeToken(tt.args.token)
if tt.expErr != nil { if (err != nil) != tt.wantErr {
require.Error(t, err) fmt.Println(tt)
require.EqualError(t, err, tt.expErr.Error()) t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
var sc render.StatusCodedError var sc render.StatusCodedError
require.ErrorAs(t, err, &sc, "error does not implement StatusCodedError interface") assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
require.Equal(t, tt.code, sc.StatusCode()) assert.Equals(t, sc.StatusCode(), tt.code)
require.Nil(t, got) assert.Nil(t, got)
} else { } else {
require.NotNil(t, got) assert.NotNil(t, got)
require.Equal(t, tt.wantIssuer, got.Issuer) assert.Equals(t, got.Issuer, tt.wantIssuer)
} }
}) })
} }
@ -338,6 +339,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case emailOnlyIdentity:
assert.Equals(t, string(v), "name@smallstep.com")
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine) assert.Equals(t, nil, v.policyEngine)
case *WebhookController: case *WebhookController:
@ -579,9 +582,6 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
{"ok-principals", p1, args{t1, SignSSHOptions{Principals: []string{"name"}}, pub}, {"ok-principals", p1, args{t1, SignSSHOptions{Principals: []string{"name"}}, pub},
&SignSSHOptions{CertType: "user", Principals: []string{"name", "name@smallstep.com"}, &SignSSHOptions{CertType: "user", Principals: []string{"name", "name@smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"ok-principals-ignore-passed", p1, args{t1, SignSSHOptions{Principals: []string{"root"}}, pub},
&SignSSHOptions{CertType: "user", Principals: []string{"name", "name@smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"ok-principals-getIdentity", p4, args{okGetIdentityToken, SignSSHOptions{Principals: []string{"mariano"}}, pub}, {"ok-principals-getIdentity", p4, args{okGetIdentityToken, SignSSHOptions{Principals: []string{"mariano"}}, pub},
&SignSSHOptions{CertType: "user", Principals: []string{"max", "mariano"}, &SignSSHOptions{CertType: "user", Principals: []string{"max", "mariano"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
@ -600,6 +600,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, http.StatusOK, false, false},
{"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true}, {"fail-rsa1024", p1, args{t1, SignSSHOptions{}, rsa1024.Public()}, expectedUserOptions, http.StatusOK, false, true},
{"fail-user-host", p1, args{t1, SignSSHOptions{CertType: "host"}, pub}, nil, http.StatusOK, false, true}, {"fail-user-host", p1, args{t1, SignSSHOptions{CertType: "host"}, pub}, nil, http.StatusOK, false, true},
{"fail-user-principals", p1, args{t1, SignSSHOptions{Principals: []string{"root"}}, pub}, nil, http.StatusOK, false, true},
{"fail-getIdentity", p5, args{failGetIdentityToken, SignSSHOptions{}, pub}, nil, http.StatusInternalServerError, true, false}, {"fail-getIdentity", p5, args{failGetIdentityToken, SignSSHOptions{}, pub}, nil, http.StatusInternalServerError, true, false},
{"fail-sshCA-disabled", p6, args{"foo", SignSSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false}, {"fail-sshCA-disabled", p6, args{"foo", SignSSHOptions{}, pub}, nil, http.StatusUnauthorized, true, false},
// Missing parametrs // Missing parametrs

View file

@ -10,9 +10,8 @@ import (
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
) )
// Interface is the interface that all provisioner types must implement. // Interface is the interface that all provisioner types must implement.
@ -298,43 +297,43 @@ type base struct{}
// AuthorizeSign returns an unimplemented error. Provisioners should overwrite // AuthorizeSign returns an unimplemented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing x509 Certificates. // this method if they will support authorizing tokens for signing x509 Certificates.
func (b *base) AuthorizeSign(context.Context, string) ([]SignOption, error) { func (b *base) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented") return nil, errs.Unauthorized("provisioner.AuthorizeSign not implemented")
} }
// AuthorizeRevoke returns an unimplemented error. Provisioners should overwrite // AuthorizeRevoke returns an unimplemented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking x509 Certificates. // this method if they will support authorizing tokens for revoking x509 Certificates.
func (b *base) AuthorizeRevoke(context.Context, string) error { func (b *base) AuthorizeRevoke(ctx context.Context, token string) error {
return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented") return errs.Unauthorized("provisioner.AuthorizeRevoke not implemented")
} }
// AuthorizeRenew returns an unimplemented error. Provisioners should overwrite // AuthorizeRenew returns an unimplemented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing x509 Certificates. // this method if they will support authorizing tokens for renewing x509 Certificates.
func (b *base) AuthorizeRenew(context.Context, *x509.Certificate) error { func (b *base) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
return errs.Unauthorized("provisioner.AuthorizeRenew not implemented") return errs.Unauthorized("provisioner.AuthorizeRenew not implemented")
} }
// AuthorizeSSHSign returns an unimplemented error. Provisioners should overwrite // AuthorizeSSHSign returns an unimplemented error. Provisioners should overwrite
// this method if they will support authorizing tokens for signing SSH Certificates. // this method if they will support authorizing tokens for signing SSH Certificates.
func (b *base) AuthorizeSSHSign(context.Context, string) ([]SignOption, error) { func (b *base) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented") return nil, errs.Unauthorized("provisioner.AuthorizeSSHSign not implemented")
} }
// AuthorizeRevoke returns an unimplemented error. Provisioners should overwrite // AuthorizeRevoke returns an unimplemented error. Provisioners should overwrite
// this method if they will support authorizing tokens for revoking SSH Certificates. // this method if they will support authorizing tokens for revoking SSH Certificates.
func (b *base) AuthorizeSSHRevoke(context.Context, string) error { func (b *base) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented") return errs.Unauthorized("provisioner.AuthorizeSSHRevoke not implemented")
} }
// AuthorizeSSHRenew returns an unimplemented error. Provisioners should overwrite // AuthorizeSSHRenew returns an unimplemented error. Provisioners should overwrite
// this method if they will support authorizing tokens for renewing SSH Certificates. // this method if they will support authorizing tokens for renewing SSH Certificates.
func (b *base) AuthorizeSSHRenew(context.Context, string) (*ssh.Certificate, error) { func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented") return nil, errs.Unauthorized("provisioner.AuthorizeSSHRenew not implemented")
} }
// AuthorizeSSHRekey returns an unimplemented error. Provisioners should overwrite // AuthorizeSSHRekey returns an unimplemented error. Provisioners should overwrite
// this method if they will support authorizing tokens for rekeying SSH Certificates. // this method if they will support authorizing tokens for rekeying SSH Certificates.
func (b *base) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, []SignOption, error) { func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented") return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
} }

View file

@ -2,16 +2,10 @@ package provisioner
import ( import (
"context" "context"
"crypto/subtle"
"fmt"
"net/http"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"github.com/smallstep/certificates/webhook"
) )
// SCEP is the SCEP provisioner type, an entity that can authorize the // SCEP is the SCEP provisioner type, an entity that can authorize the
@ -39,8 +33,8 @@ type SCEP struct {
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
ctl *Controller ctl *Controller
secretChallengePassword string
encryptionAlgorithm int encryptionAlgorithm int
challengeValidationController *challengeValidationController
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -73,7 +67,7 @@ func (s *SCEP) GetEncryptedKey() (string, string, bool) {
} }
// GetTokenID returns the identifier of the token. // GetTokenID returns the identifier of the token.
func (s *SCEP) GetTokenID(string) (string, error) { func (s *SCEP) GetTokenID(ott string) (string, error) {
return "", errors.New("scep provisioner does not implement GetTokenID") return "", errors.New("scep provisioner does not implement GetTokenID")
} }
@ -88,67 +82,6 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration {
return s.ctl.Claimer.DefaultTLSCertDuration() return s.ctl.Claimer.DefaultTLSCertDuration()
} }
type challengeValidationController struct {
client *http.Client
webhooks []*Webhook
}
// newChallengeValidationController creates a new challengeValidationController
// that performs challenge validation through webhooks.
func newChallengeValidationController(client *http.Client, webhooks []*Webhook) *challengeValidationController {
scepHooks := []*Webhook{}
for _, wh := range webhooks {
if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() {
continue
}
if !isCertTypeOK(wh) {
continue
}
scepHooks = append(scepHooks, wh)
}
return &challengeValidationController{
client: client,
webhooks: scepHooks,
}
}
var (
ErrSCEPChallengeInvalid = errors.New("webhook server did not allow request")
)
// Validate executes zero or more configured webhooks to
// validate the SCEP challenge. If at least one of them indicates
// the challenge value is accepted, validation succeeds. In
// that case, the other webhooks will be skipped. If none of
// the webhooks indicates the value of the challenge was accepted,
// an error is returned.
func (c *challengeValidationController) Validate(ctx context.Context, challenge, transactionID string) error {
for _, wh := range c.webhooks {
req := &webhook.RequestBody{
SCEPChallenge: challenge,
SCEPTransactionID: transactionID,
}
resp, err := wh.DoWithContext(ctx, c.client, req, nil) // TODO(hs): support templated URL? Requires some refactoring
if err != nil {
return fmt.Errorf("failed executing webhook request: %w", err)
}
if resp.Allow {
return nil // return early when response is positive
}
}
return ErrSCEPChallengeInvalid
}
// isCertTypeOK returns whether or not the webhook can be used
// with the SCEP challenge validation webhook controller.
func isCertTypeOK(wh *Webhook) bool {
if wh.CertType == linkedca.Webhook_ALL.String() || wh.CertType == "" {
return true
}
return linkedca.Webhook_X509.String() == wh.CertType
}
// Init initializes and validates the fields of a SCEP type. // Init initializes and validates the fields of a SCEP type.
func (s *SCEP) Init(config Config) (err error) { func (s *SCEP) Init(config Config) (err error) {
switch { switch {
@ -158,6 +91,10 @@ func (s *SCEP) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty") return errors.New("provisioner name cannot be empty")
} }
// Mask the actual challenge value, so it won't be marshaled
s.secretChallengePassword = s.ChallengePassword
s.ChallengePassword = "*** redacted ***"
// Default to 2048 bits minimum public key length (for CSRs) if not set // Default to 2048 bits minimum public key length (for CSRs) if not set
if s.MinimumPublicKeyLength == 0 { if s.MinimumPublicKeyLength == 0 {
s.MinimumPublicKeyLength = 2048 s.MinimumPublicKeyLength = 2048
@ -172,11 +109,6 @@ func (s *SCEP) Init(config Config) (err error) {
return errors.New("only encryption algorithm identifiers from 0 to 4 are valid") return errors.New("only encryption algorithm identifiers from 0 to 4 are valid")
} }
s.challengeValidationController = newChallengeValidationController(
config.WebhookClient,
s.GetOptions().GetWebhooks(),
)
// TODO: add other, SCEP specific, options? // TODO: add other, SCEP specific, options?
s.ctl, err = NewController(s, s.Claims, config, s.Options) s.ctl, err = NewController(s, s.Claims, config, s.Options)
@ -186,7 +118,7 @@ func (s *SCEP) Init(config Config) (err error) {
// AuthorizeSign does not do any verification, because all verification is handled // AuthorizeSign does not do any verification, because all verification is handled
// in the SCEP protocol. This method returns a list of modifiers / constraints // in the SCEP protocol. This method returns a list of modifiers / constraints
// on the resulting certificate. // on the resulting certificate.
func (s *SCEP) AuthorizeSign(context.Context, string) ([]SignOption, error) { func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{ return []SignOption{
s, s,
// modifiers / withOptions // modifiers / withOptions
@ -201,6 +133,11 @@ func (s *SCEP) AuthorizeSign(context.Context, string) ([]SignOption, error) {
}, nil }, nil
} }
// GetChallengePassword returns the challenge password
func (s *SCEP) GetChallengePassword() string {
return s.secretChallengePassword
}
// GetCapabilities returns the CA capabilities // GetCapabilities returns the CA capabilities
func (s *SCEP) GetCapabilities() []string { func (s *SCEP) GetCapabilities() []string {
return s.Capabilities return s.Capabilities
@ -219,43 +156,3 @@ func (s *SCEP) ShouldIncludeRootInChain() bool {
func (s *SCEP) GetContentEncryptionAlgorithm() int { func (s *SCEP) GetContentEncryptionAlgorithm() int {
return s.encryptionAlgorithm return s.encryptionAlgorithm
} }
// ValidateChallenge validates the provided challenge. It starts by
// selecting the validation method to use, then performs validation
// according to that method.
func (s *SCEP) ValidateChallenge(ctx context.Context, challenge, transactionID string) error {
if s.challengeValidationController == nil {
return fmt.Errorf("provisioner %q wasn't initialized", s.Name)
}
switch s.selectValidationMethod() {
case validationMethodWebhook:
return s.challengeValidationController.Validate(ctx, challenge, transactionID)
default:
if subtle.ConstantTimeCompare([]byte(s.ChallengePassword), []byte(challenge)) == 0 {
return errors.New("invalid challenge password provided")
}
return nil
}
}
type validationMethod string
const (
validationMethodNone validationMethod = "none"
validationMethodStatic validationMethod = "static"
validationMethodWebhook validationMethod = "webhook"
)
// selectValidationMethod returns the method to validate SCEP
// challenges. If a webhook is configured with kind `SCEPCHALLENGE`,
// the webhook method will be used. If a challenge password is set,
// the static method is used. It will default to the `none` method.
func (s *SCEP) selectValidationMethod() validationMethod {
if len(s.challengeValidationController.webhooks) > 0 {
return validationMethodWebhook
}
if s.ChallengePassword != "" {
return validationMethodStatic
}
return validationMethodNone
}

View file

@ -1,342 +0,0 @@
package provisioner
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/linkedca"
)
func Test_challengeValidationController_Validate(t *testing.T) {
type request struct {
Challenge string `json:"scepChallenge"`
TransactionID string `json:"scepTransactionID"`
}
type response struct {
Allow bool `json:"allow"`
}
nokServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req := &request{}
err := json.NewDecoder(r.Body).Decode(req)
require.NoError(t, err)
assert.Equal(t, "not-allowed", req.Challenge)
assert.Equal(t, "transaction-1", req.TransactionID)
b, err := json.Marshal(response{Allow: false})
require.NoError(t, err)
w.WriteHeader(200)
w.Write(b)
}))
okServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req := &request{}
err := json.NewDecoder(r.Body).Decode(req)
require.NoError(t, err)
assert.Equal(t, "challenge", req.Challenge)
assert.Equal(t, "transaction-1", req.TransactionID)
b, err := json.Marshal(response{Allow: true})
require.NoError(t, err)
w.WriteHeader(200)
w.Write(b)
}))
type fields struct {
client *http.Client
webhooks []*Webhook
}
type args struct {
challenge string
transactionID string
}
tests := []struct {
name string
fields fields
args args
server *httptest.Server
expErr error
}{
{
name: "fail/no-webhook",
fields: fields{http.DefaultClient, nil},
args: args{"no-webhook", "transaction-1"},
expErr: errors.New("webhook server did not allow request"),
},
{
name: "fail/wrong-cert-type",
fields: fields{http.DefaultClient, []*Webhook{
{
Kind: linkedca.Webhook_SCEPCHALLENGE.String(),
CertType: linkedca.Webhook_SSH.String(),
},
}},
args: args{"wrong-cert-type", "transaction-1"},
expErr: errors.New("webhook server did not allow request"),
},
{
name: "fail/wrong-secret-value",
fields: fields{http.DefaultClient, []*Webhook{
{
ID: "webhook-id-1",
Name: "webhook-name-1",
Secret: "{{}}",
Kind: linkedca.Webhook_SCEPCHALLENGE.String(),
CertType: linkedca.Webhook_X509.String(),
URL: okServer.URL,
},
}},
args: args{
challenge: "wrong-secret-value",
transactionID: "transaction-1",
},
expErr: errors.New("failed executing webhook request: illegal base64 data at input byte 0"),
},
{
name: "fail/not-allowed",
fields: fields{http.DefaultClient, []*Webhook{
{
ID: "webhook-id-1",
Name: "webhook-name-1",
Secret: "MTIzNAo=",
Kind: linkedca.Webhook_SCEPCHALLENGE.String(),
CertType: linkedca.Webhook_X509.String(),
URL: nokServer.URL,
},
}},
args: args{
challenge: "not-allowed",
transactionID: "transaction-1",
},
server: nokServer,
expErr: errors.New("webhook server did not allow request"),
},
{
name: "ok",
fields: fields{http.DefaultClient, []*Webhook{
{
ID: "webhook-id-1",
Name: "webhook-name-1",
Secret: "MTIzNAo=",
Kind: linkedca.Webhook_SCEPCHALLENGE.String(),
CertType: linkedca.Webhook_X509.String(),
URL: okServer.URL,
},
}},
args: args{
challenge: "challenge",
transactionID: "transaction-1",
},
server: okServer,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := newChallengeValidationController(tt.fields.client, tt.fields.webhooks)
if tt.server != nil {
defer tt.server.Close()
}
ctx := context.Background()
err := c.Validate(ctx, tt.args.challenge, tt.args.transactionID)
if tt.expErr != nil {
assert.EqualError(t, err, tt.expErr.Error())
return
}
assert.NoError(t, err)
})
}
}
func TestController_isCertTypeOK(t *testing.T) {
assert.True(t, isCertTypeOK(&Webhook{CertType: linkedca.Webhook_X509.String()}))
assert.True(t, isCertTypeOK(&Webhook{CertType: linkedca.Webhook_ALL.String()}))
assert.True(t, isCertTypeOK(&Webhook{CertType: ""}))
assert.False(t, isCertTypeOK(&Webhook{CertType: linkedca.Webhook_SSH.String()}))
}
func Test_selectValidationMethod(t *testing.T) {
tests := []struct {
name string
p *SCEP
want validationMethod
}{
{"webhooks", &SCEP{
Name: "SCEP",
Type: "SCEP",
Options: &Options{
Webhooks: []*Webhook{
{
Kind: linkedca.Webhook_SCEPCHALLENGE.String(),
},
},
},
}, "webhook"},
{"challenge", &SCEP{
Name: "SCEP",
Type: "SCEP",
ChallengePassword: "pass",
}, "static"},
{"challenge-with-different-webhook", &SCEP{
Name: "SCEP",
Type: "SCEP",
Options: &Options{
Webhooks: []*Webhook{
{
Kind: linkedca.Webhook_AUTHORIZING.String(),
},
},
},
ChallengePassword: "pass",
}, "static"},
{"none", &SCEP{
Name: "SCEP",
Type: "SCEP",
}, "none"},
{"none-with-different-webhook", &SCEP{
Name: "SCEP",
Type: "SCEP",
Options: &Options{
Webhooks: []*Webhook{
{
Kind: linkedca.Webhook_AUTHORIZING.String(),
},
},
},
}, "none"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.p.Init(Config{Claims: globalProvisionerClaims})
require.NoError(t, err)
got := tt.p.selectValidationMethod()
assert.Equal(t, tt.want, got)
})
}
}
func TestSCEP_ValidateChallenge(t *testing.T) {
type request struct {
Challenge string `json:"scepChallenge"`
TransactionID string `json:"scepTransactionID"`
}
type response struct {
Allow bool `json:"allow"`
}
okServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req := &request{}
err := json.NewDecoder(r.Body).Decode(req)
require.NoError(t, err)
assert.Equal(t, "webhook-challenge", req.Challenge)
assert.Equal(t, "webhook-transaction-1", req.TransactionID)
b, err := json.Marshal(response{Allow: true})
require.NoError(t, err)
w.WriteHeader(200)
w.Write(b)
}))
type args struct {
challenge string
transactionID string
}
tests := []struct {
name string
p *SCEP
server *httptest.Server
args args
expErr error
}{
{"ok/webhooks", &SCEP{
Name: "SCEP",
Type: "SCEP",
Options: &Options{
Webhooks: []*Webhook{
{
ID: "webhook-id-1",
Name: "webhook-name-1",
Secret: "MTIzNAo=",
Kind: linkedca.Webhook_SCEPCHALLENGE.String(),
CertType: linkedca.Webhook_X509.String(),
URL: okServer.URL,
},
},
},
}, okServer, args{"webhook-challenge", "webhook-transaction-1"},
nil,
},
{"fail/webhooks-secret-configuration", &SCEP{
Name: "SCEP",
Type: "SCEP",
Options: &Options{
Webhooks: []*Webhook{
{
ID: "webhook-id-1",
Name: "webhook-name-1",
Secret: "{{}}",
Kind: linkedca.Webhook_SCEPCHALLENGE.String(),
CertType: linkedca.Webhook_X509.String(),
URL: okServer.URL,
},
},
},
}, nil, args{"webhook-challenge", "webhook-transaction-1"},
errors.New("failed executing webhook request: illegal base64 data at input byte 0"),
},
{"ok/static-challenge", &SCEP{
Name: "SCEP",
Type: "SCEP",
Options: &Options{},
ChallengePassword: "secret-static-challenge",
}, nil, args{"secret-static-challenge", "static-transaction-1"},
nil,
},
{"fail/wrong-static-challenge", &SCEP{
Name: "SCEP",
Type: "SCEP",
Options: &Options{},
ChallengePassword: "secret-static-challenge",
}, nil, args{"the-wrong-challenge-secret", "static-transaction-1"},
errors.New("invalid challenge password provided"),
},
{"ok/no-challenge", &SCEP{
Name: "SCEP",
Type: "SCEP",
Options: &Options{},
ChallengePassword: "",
}, nil, args{"", "static-transaction-1"},
nil,
},
{"fail/no-challenge-but-provided", &SCEP{
Name: "SCEP",
Type: "SCEP",
Options: &Options{},
ChallengePassword: "",
}, nil, args{"a-challenge-value", "static-transaction-1"},
errors.New("invalid challenge password provided"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.server != nil {
defer tt.server.Close()
}
err := tt.p.Init(Config{Claims: globalProvisionerClaims, WebhookClient: http.DefaultClient})
require.NoError(t, err)
ctx := context.Background()
err = tt.p.ValidateChallenge(ctx, tt.args.challenge, tt.args.transactionID)
if tt.expErr != nil {
assert.EqualError(t, err, tt.expErr.Error())
return
}
assert.NoError(t, err)
})
}
}

View file

@ -83,6 +83,31 @@ type AttestationData struct {
PermanentIdentifier string PermanentIdentifier string
} }
// emailOnlyIdentity is a CertificateRequestValidator that checks that the only
// SAN provided is the given email address.
type emailOnlyIdentity string
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
switch {
case len(req.DNSNames) > 0:
return errs.Forbidden("certificate request cannot contain DNS names")
case len(req.IPAddresses) > 0:
return errs.Forbidden("certificate request cannot contain IP addresses")
case len(req.URIs) > 0:
return errs.Forbidden("certificate request cannot contain URIs")
case len(req.EmailAddresses) == 0:
return errs.Forbidden("certificate request does not contain any email address")
case len(req.EmailAddresses) > 1:
return errs.Forbidden("certificate request contains too many email addresses")
case req.EmailAddresses[0] == "":
return errs.Forbidden("certificate request cannot contain an empty email address")
case req.EmailAddresses[0] != string(e):
return errs.Forbidden("certificate request does not contain the valid email address - got %s, want %s", req.EmailAddresses[0], e)
default:
return nil
}
}
// defaultPublicKeyValidator validates the public key of a certificate request. // defaultPublicKeyValidator validates the public key of a certificate request.
type defaultPublicKeyValidator struct{} type defaultPublicKeyValidator struct{}

View file

@ -16,6 +16,38 @@ import (
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
func Test_emailOnlyIdentity_Valid(t *testing.T) {
uri, err := url.Parse("https://example.com/1.0/getUser")
if err != nil {
t.Fatal(err)
}
type args struct {
req *x509.CertificateRequest
}
tests := []struct {
name string
e emailOnlyIdentity
args args
wantErr bool
}{
{"ok", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com"}}}, false},
{"DNSNames", "name@smallstep.com", args{&x509.CertificateRequest{DNSNames: []string{"foo.bar.zar"}}}, true},
{"IPAddresses", "name@smallstep.com", args{&x509.CertificateRequest{IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}}}, true},
{"URIs", "name@smallstep.com", args{&x509.CertificateRequest{URIs: []*url.URL{uri}}}, true},
{"no-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{}}}, true},
{"empty-email", "", args{&x509.CertificateRequest{EmailAddresses: []string{""}}}, true},
{"multiple-emails", "name@smallstep.com", args{&x509.CertificateRequest{EmailAddresses: []string{"name@smallstep.com", "foo@smallstep.com"}}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.e.Valid(tt.args.req); (err != nil) != tt.wantErr {
t.Errorf("emailOnlyIdentity.Valid() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_defaultPublicKeyValidator_Valid(t *testing.T) { func Test_defaultPublicKeyValidator_Valid(t *testing.T) {
_shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr") _shortRSA, err := pemutil.Read("./testdata/certs/short-rsa.csr")
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -125,6 +125,35 @@ func (o SignSSHOptions) match(got SignSSHOptions) error {
return nil return nil
} }
// sshCertPrincipalsModifier is an SSHCertModifier that sets the
// principals to the SSH certificate.
type sshCertPrincipalsModifier []string
// Modify the ValidPrincipals value of the cert.
func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.ValidPrincipals = []string(o)
return nil
}
// sshCertKeyIDModifier is an SSHCertModifier that sets the given
// Key ID in the SSH certificate.
type sshCertKeyIDModifier string
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.KeyId = string(m)
return nil
}
// sshCertTypeModifier is an SSHCertModifier that sets the
// certificate type.
type sshCertTypeModifier string
// Modify sets the CertType for the ssh certificate.
func (m sshCertTypeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.CertType = sshCertTypeUInt32(string(m))
return nil
}
// sshCertValidAfterModifier is an SSHCertModifier that sets the // sshCertValidAfterModifier is an SSHCertModifier that sets the
// ValidAfter in the SSH certificate. // ValidAfter in the SSH certificate.
type sshCertValidAfterModifier uint64 type sshCertValidAfterModifier uint64
@ -143,6 +172,51 @@ func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptio
return nil return nil
} }
// sshCertDefaultsModifier implements a SSHCertModifier that
// modifies the certificate with the given options if they are not set.
type sshCertDefaultsModifier SignSSHOptions
// Modify implements the SSHCertModifier interface.
func (m sshCertDefaultsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
if cert.CertType == 0 {
cert.CertType = sshCertTypeUInt32(m.CertType)
}
if len(cert.ValidPrincipals) == 0 {
cert.ValidPrincipals = m.Principals
}
if cert.ValidAfter == 0 && !m.ValidAfter.IsZero() {
cert.ValidAfter = uint64(m.ValidAfter.Unix())
}
if cert.ValidBefore == 0 && !m.ValidBefore.IsZero() {
cert.ValidBefore = uint64(m.ValidBefore.Unix())
}
return nil
}
// sshDefaultExtensionModifier implements an SSHCertModifier that sets
// the default extensions in an SSH certificate.
type sshDefaultExtensionModifier struct{}
func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
switch cert.CertType {
// Default to no extensions for HostCert.
case ssh.HostCert:
return nil
case ssh.UserCert:
if cert.Extensions == nil {
cert.Extensions = make(map[string]string)
}
cert.Extensions["permit-X11-forwarding"] = ""
cert.Extensions["permit-agent-forwarding"] = ""
cert.Extensions["permit-port-forwarding"] = ""
cert.Extensions["permit-pty"] = ""
cert.Extensions["permit-user-rc"] = ""
return nil
default:
return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType)
}
}
// sshDefaultDuration is an SSHCertModifier that sets the certificate // sshDefaultDuration is an SSHCertModifier that sets the certificate
// ValidAfter and ValidBefore if they have not been set. It will fail if a // ValidAfter and ValidBefore if they have not been set. It will fail if a
// CertType has not been set or is not valid. // CertType has not been set or is not valid.
@ -311,7 +385,7 @@ type sshCertDefaultValidator struct{}
// Valid returns an error if the given certificate does not contain the // Valid returns an error if the given certificate does not contain the
// necessary fields. We skip ValidPrincipals and Extensions as with custom // necessary fields. We skip ValidPrincipals and Extensions as with custom
// templates you can set them empty. // templates you can set them empty.
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
switch { switch {
case len(cert.Nonce) == 0: case len(cert.Nonce) == 0:
return errs.Forbidden("ssh certificate nonce cannot be empty") return errs.Forbidden("ssh certificate nonce cannot be empty")
@ -346,7 +420,7 @@ type sshDefaultPublicKeyValidator struct{}
// TODO: this is the only validator that checks the key type. We should execute // TODO: this is the only validator that checks the key type. We should execute
// this before the signing. We should add a new validations interface or extend // this before the signing. We should add a new validations interface or extend
// SSHCertOptionsValidator with the key. // SSHCertOptionsValidator with the key.
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
if cert.Key == nil { if cert.Key == nil {
return errs.BadRequest("ssh certificate key cannot be nil") return errs.BadRequest("ssh certificate key cannot be nil")
} }

View file

@ -202,6 +202,97 @@ func TestSSHOptions_Match(t *testing.T) {
} }
} }
func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertPrincipalsModifier
cert *ssh.Certificate
expected []string
}
tests := map[string]func() test{
"ok": func() test {
a := []string{"foo", "bar"}
return test{
modifier: sshCertPrincipalsModifier(a),
cert: new(ssh.Certificate),
expected: a,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
assert.Equals(t, tc.cert.ValidPrincipals, tc.expected)
}
})
}
}
func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertKeyIDModifier
cert *ssh.Certificate
expected string
}
tests := map[string]func() test{
"ok": func() test {
a := "foo"
return test{
modifier: sshCertKeyIDModifier(a),
cert: new(ssh.Certificate),
expected: a,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
assert.Equals(t, tc.cert.KeyId, tc.expected)
}
})
}
}
func Test_sshCertTypeModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertTypeModifier
cert *ssh.Certificate
expected uint32
}
tests := map[string]func() test{
"ok/user": func() test {
return test{
modifier: sshCertTypeModifier("user"),
cert: new(ssh.Certificate),
expected: ssh.UserCert,
}
},
"ok/host": func() test {
return test{
modifier: sshCertTypeModifier("host"),
cert: new(ssh.Certificate),
expected: ssh.HostCert,
}
},
"ok/default": func() test {
return test{
modifier: sshCertTypeModifier("foo"),
cert: new(ssh.Certificate),
expected: 0,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
assert.Equals(t, tc.cert.CertType, tc.expected)
}
})
}
}
func Test_sshCertValidAfterModifier_Modify(t *testing.T) { func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
type test struct { type test struct {
modifier sshCertValidAfterModifier modifier sshCertValidAfterModifier
@ -227,6 +318,176 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
} }
} }
func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
type test struct {
modifier sshCertDefaultsModifier
cert *ssh.Certificate
valid func(*ssh.Certificate)
}
tests := map[string]func() test{
"ok/changes": func() test {
n := time.Now()
va := NewTimeDuration(n.Add(1 * time.Minute))
vb := NewTimeDuration(n.Add(5 * time.Minute))
so := SignSSHOptions{
Principals: []string{"foo", "bar"},
CertType: "host",
ValidAfter: va,
ValidBefore: vb,
}
return test{
modifier: sshCertDefaultsModifier(so),
cert: new(ssh.Certificate),
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.ValidPrincipals, so.Principals)
assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
},
}
},
"ok/no-changes": func() test {
n := time.Now()
so := SignSSHOptions{
Principals: []string{"foo", "bar"},
CertType: "host",
ValidAfter: NewTimeDuration(n.Add(15 * time.Minute)),
ValidBefore: NewTimeDuration(n.Add(25 * time.Minute)),
}
return test{
modifier: sshCertDefaultsModifier(so),
cert: &ssh.Certificate{
CertType: uint32(ssh.UserCert),
ValidPrincipals: []string{"zap", "zoop"},
ValidAfter: 15,
ValidBefore: 25,
},
valid: func(cert *ssh.Certificate) {
assert.Equals(t, cert.ValidPrincipals, []string{"zap", "zoop"})
assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
assert.Equals(t, cert.ValidAfter, uint64(15))
assert.Equals(t, cert.ValidBefore, uint64(25))
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
tc.valid(tc.cert)
}
})
}
}
func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
type test struct {
modifier sshDefaultExtensionModifier
cert *ssh.Certificate
valid func(*ssh.Certificate)
err error
}
tests := map[string]func() test{
"fail/unexpected-cert-type": func() test {
cert := &ssh.Certificate{CertType: 3}
return test{
modifier: sshDefaultExtensionModifier{},
cert: cert,
err: errors.New("ssh certificate has an unknown type '3'"),
}
},
"ok/host": func() test {
cert := &ssh.Certificate{CertType: ssh.HostCert}
return test{
modifier: sshDefaultExtensionModifier{},
cert: cert,
valid: func(cert *ssh.Certificate) {
assert.Len(t, 0, cert.Extensions)
},
}
},
"ok/user/extensions-exists": func() test {
cert := &ssh.Certificate{CertType: ssh.UserCert, Permissions: ssh.Permissions{Extensions: map[string]string{
"foo": "bar",
}}}
return test{
modifier: sshDefaultExtensionModifier{},
cert: cert,
valid: func(cert *ssh.Certificate) {
val, ok := cert.Extensions["foo"]
assert.True(t, ok)
assert.Equals(t, val, "bar")
val, ok = cert.Extensions["permit-X11-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-agent-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-port-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-pty"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-user-rc"]
assert.True(t, ok)
assert.Equals(t, val, "")
},
}
},
"ok/user/no-extensions": func() test {
return test{
modifier: sshDefaultExtensionModifier{},
cert: &ssh.Certificate{CertType: ssh.UserCert},
valid: func(cert *ssh.Certificate) {
_, ok := cert.Extensions["foo"]
assert.False(t, ok)
val, ok := cert.Extensions["permit-X11-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-agent-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-port-forwarding"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-pty"]
assert.True(t, ok)
assert.Equals(t, val, "")
val, ok = cert.Extensions["permit-user-rc"]
assert.True(t, ok)
assert.Equals(t, val, "")
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run()
if err := tc.modifier.Modify(tc.cert, SignSSHOptions{}); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
tc.valid(tc.cert)
}
}
})
}
}
func Test_sshCertDefaultValidator_Valid(t *testing.T) { func Test_sshCertDefaultValidator_Valid(t *testing.T) {
pub, _, err := keyutil.GenerateDefaultKeyPair() pub, _, err := keyutil.GenerateDefaultKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -187,7 +187,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity
// AuthorizeSSHRevoke validates the authorization token and extracts/validates // AuthorizeSSHRevoke validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRevoke(_ context.Context, token string) error { func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
@ -213,7 +213,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
// AuthorizeSSHRekey validates the authorization token and extracts/validates // AuthorizeSSHRekey validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRekey(_ context.Context, token string) (*ssh.Certificate, []SignOption, error) { func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true)
if err != nil { if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")

View file

@ -665,9 +665,6 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
AccessToken: tok, AccessToken: tok,
}) })
} }
case "/metadata/instance/compute/azEnvironment":
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte("AzurePublicCloud"))
default: default:
http.NotFound(w, r) http.NotFound(w, r)
} }
@ -675,7 +672,6 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
srv.Start() srv.Start()
az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration" az.config.oidcDiscoveryURL = srv.URL + "/" + az.TenantID + "/.well-known/openid-configuration"
az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token" az.config.identityTokenURL = srv.URL + "/metadata/identity/oauth2/token"
az.config.instanceComputeURL = srv.URL + "/metadata/instance/compute/azEnvironment"
return az, srv, nil return az, srv, nil
} }

View file

@ -30,7 +30,6 @@ type WebhookController struct {
client *http.Client client *http.Client
webhooks []*Webhook webhooks []*Webhook
certType linkedca.Webhook_CertType certType linkedca.Webhook_CertType
options []webhook.RequestBodyOption
TemplateData WebhookSetter TemplateData WebhookSetter
} }
@ -40,14 +39,6 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
if wc == nil { if wc == nil {
return nil return nil
} }
// Apply extra options in the webhook controller
for _, fn := range wc.options {
if err := fn(req); err != nil {
return err
}
}
for _, wh := range wc.webhooks { for _, wh := range wc.webhooks {
if wh.Kind != linkedca.Webhook_ENRICHING.String() { if wh.Kind != linkedca.Webhook_ENRICHING.String() {
continue continue
@ -72,14 +63,6 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
if wc == nil { if wc == nil {
return nil return nil
} }
// Apply extra options in the webhook controller
for _, fn := range wc.options {
if err := fn(req); err != nil {
return err
}
}
for _, wh := range wc.webhooks { for _, wh := range wc.webhooks {
if wh.Kind != linkedca.Webhook_AUTHORIZING.String() { if wh.Kind != linkedca.Webhook_AUTHORIZING.String() {
continue continue
@ -124,13 +107,6 @@ type Webhook struct {
} }
func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return w.DoWithContext(ctx, client, reqBody, data)
}
func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL)
if err != nil { if err != nil {
return nil, err return nil, err
@ -153,6 +129,8 @@ func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBod
reqBody.Token = tmpl[sshutil.TokenKey] reqBody.Token = tmpl[sshutil.TokenKey]
} }
*/ */
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
reqBody.Timestamp = time.Now() reqBody.Timestamp = time.Now()

View file

@ -4,7 +4,6 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -17,7 +16,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/webhook" "github.com/smallstep/certificates/webhook"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"go.step.sm/linkedca" "go.step.sm/linkedca"
) )
@ -98,18 +96,12 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
} }
func TestWebhookController_Enrich(t *testing.T) { func TestWebhookController_Enrich(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock())
if err != nil {
t.Fatal(err)
}
type test struct { type test struct {
ctl *WebhookController ctl *WebhookController
req *webhook.RequestBody req *webhook.RequestBody
responses []*webhook.ResponseBody responses []*webhook.ResponseBody
expectErr bool expectErr bool
expectTemplateData any expectTemplateData any
assertRequest func(t *testing.T, req *webhook.RequestBody)
} }
tests := map[string]test{ tests := map[string]test{
"ok/no enriching webhooks": { "ok/no enriching webhooks": {
@ -178,29 +170,6 @@ func TestWebhookController_Enrich(t *testing.T) {
}, },
}, },
}, },
"ok/with options": {
ctl: &WebhookController{
client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
},
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false,
expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}},
assertRequest: func(t *testing.T, req *webhook.RequestBody) {
key, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
assert.FatalError(t, err)
assert.Equals(t, &webhook.X5CCertificate{
Raw: cert.Raw,
PublicKey: key,
PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(),
NotBefore: cert.NotBefore,
NotAfter: cert.NotAfter,
}, req.X5CCertificate)
},
},
"deny": { "deny": {
ctl: &WebhookController{ ctl: &WebhookController{
client: http.DefaultClient, client: http.DefaultClient,
@ -212,20 +181,6 @@ func TestWebhookController_Enrich(t *testing.T) {
expectErr: true, expectErr: true,
expectTemplateData: x509util.TemplateData{}, expectTemplateData: x509util.TemplateData{},
}, },
"fail/with options": {
ctl: &WebhookController{
client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(&x509.Certificate{
PublicKey: []byte("bad"),
})},
},
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
expectTemplateData: x509util.TemplateData{},
},
} }
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
@ -245,25 +200,16 @@ func TestWebhookController_Enrich(t *testing.T) {
t.Fatalf("Got err %v, want %v", err, test.expectErr) t.Fatalf("Got err %v, want %v", err, test.expectErr)
} }
assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData)
if test.assertRequest != nil {
test.assertRequest(t, test.req)
}
}) })
} }
} }
func TestWebhookController_Authorize(t *testing.T) { func TestWebhookController_Authorize(t *testing.T) {
cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock())
if err != nil {
t.Fatal(err)
}
type test struct { type test struct {
ctl *WebhookController ctl *WebhookController
req *webhook.RequestBody req *webhook.RequestBody
responses []*webhook.ResponseBody responses []*webhook.ResponseBody
expectErr bool expectErr bool
assertRequest func(t *testing.T, req *webhook.RequestBody)
} }
tests := map[string]test{ tests := map[string]test{
"ok/no enriching webhooks": { "ok/no enriching webhooks": {
@ -294,27 +240,6 @@ func TestWebhookController_Authorize(t *testing.T) {
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: false, expectErr: false,
}, },
"ok/with options": {
ctl: &WebhookController{
client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
},
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false,
assertRequest: func(t *testing.T, req *webhook.RequestBody) {
key, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
assert.FatalError(t, err)
assert.Equals(t, &webhook.X5CCertificate{
Raw: cert.Raw,
PublicKey: key,
PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(),
NotBefore: cert.NotBefore,
NotAfter: cert.NotAfter,
}, req.X5CCertificate)
},
},
"deny": { "deny": {
ctl: &WebhookController{ ctl: &WebhookController{
client: http.DefaultClient, client: http.DefaultClient,
@ -324,18 +249,6 @@ func TestWebhookController_Authorize(t *testing.T) {
responses: []*webhook.ResponseBody{{Allow: false}}, responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true, expectErr: true,
}, },
"fail/with options": {
ctl: &WebhookController{
client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(&x509.Certificate{
PublicKey: []byte("bad"),
})},
},
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
},
} }
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
@ -354,9 +267,6 @@ func TestWebhookController_Authorize(t *testing.T) {
if (err != nil) != test.expectErr { if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr) t.Fatalf("Got err %v, want %v", err, test.expectErr)
} }
if test.assertRequest != nil {
test.assertRequest(t, test.req)
}
}) })
} }
} }

View file

@ -15,7 +15,6 @@ import (
"go.step.sm/linkedca" "go.step.sm/linkedca"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/webhook"
) )
// x5cPayload extends jwt.Claims with step attributes. // x5cPayload extends jwt.Claims with step attributes.
@ -188,13 +187,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *X5C) AuthorizeRevoke(_ context.Context, token string) error { func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
@ -216,8 +215,7 @@ func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro
// The X509 certificate will be available using the template variable // The X509 certificate will be available using the template variable
// AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be // AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be
// used to get all the domains. // used to get all the domains.
x5cLeaf := claims.chains[0][0] data.SetAuthorizationCertificate(claims.chains[0][0])
data.SetAuthorizationCertificate(x5cLeaf)
templateOptions, err := TemplateOptions(p.Options, data) templateOptions, err := TemplateOptions(p.Options, data)
if err != nil { if err != nil {
@ -240,7 +238,7 @@ func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro
newProvisionerExtensionOption(TypeX5C, p.Name, ""), newProvisionerExtensionOption(TypeX5C, p.Name, ""),
profileLimitDuration{ profileLimitDuration{
p.ctl.Claimer.DefaultTLSCertDuration(), p.ctl.Claimer.DefaultTLSCertDuration(),
x5cLeaf.NotBefore, x5cLeaf.NotAfter, claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter,
}, },
// validators // validators
commonNameValidator(claims.Subject), commonNameValidator(claims.Subject),
@ -248,12 +246,7 @@ func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
p.ctl.newWebhookController( p.ctl.newWebhookController(data, linkedca.Webhook_X509),
data,
linkedca.Webhook_X509,
webhook.WithX5CCertificate(x5cLeaf),
webhook.WithAuthorizationPrincipal(x5cLeaf.Subject.CommonName),
),
}, nil }, nil
} }
@ -263,7 +256,7 @@ func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *X5C) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.ctl.Claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName())
} }
@ -312,8 +305,7 @@ func (p *X5C) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
// The X509 certificate will be available using the template variable // The X509 certificate will be available using the template variable
// AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be // AuthorizationCrt. For example {{ .AuthorizationCrt.DNSNames }} can be
// used to get all the domains. // used to get all the domains.
x5cLeaf := claims.chains[0][0] data.SetAuthorizationCertificate(claims.chains[0][0])
data.SetAuthorizationCertificate(x5cLeaf)
templateOptions, err := TemplateSSHOptions(p.Options, data) templateOptions, err := TemplateSSHOptions(p.Options, data)
if err != nil { if err != nil {
@ -333,7 +325,7 @@ func (p *X5C) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
return append(signOptions, return append(signOptions,
p, p,
// Checks the validity bounds, and set the validity if has not been set. // Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.ctl.Claimer, x5cLeaf.NotAfter}, &sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter},
// Validate public key. // Validate public key.
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
@ -343,11 +335,6 @@ func (p *X5C) AuthorizeSSHSign(_ context.Context, token string) ([]SignOption, e
// Ensure that all principal names are allowed // Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()), newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
// Call webhooks // Call webhooks
p.ctl.newWebhookController( p.ctl.newWebhookController(data, linkedca.Webhook_SSH),
data,
linkedca.Webhook_SSH,
webhook.WithX5CCertificate(x5cLeaf),
webhook.WithAuthorizationPrincipal(x5cLeaf.Subject.CommonName),
),
), nil ), nil
} }

View file

@ -12,7 +12,6 @@ import (
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"go.step.sm/linkedca"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
@ -498,8 +497,6 @@ func TestX5C_AuthorizeSign(t *testing.T) {
assert.Equals(t, nil, v.policyEngine) assert.Equals(t, nil, v.policyEngine)
case *WebhookController: case *WebhookController:
assert.Len(t, 0, v.webhooks) assert.Len(t, 0, v.webhooks)
assert.Equals(t, linkedca.Webhook_X509, v.certType)
assert.Len(t, 2, v.options)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
@ -793,6 +790,8 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix()) assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())
case sshCertValidBeforeModifier: case sshCertValidBeforeModifier:
assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix()) assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidBefore.RelativeTime(nw).Unix())
case sshCertDefaultsModifier:
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
case *sshLimitDuration: case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
@ -804,8 +803,6 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc: case *sshDefaultPublicKeyValidator, *sshCertDefaultValidator, sshCertificateOptionsFunc:
case *WebhookController: case *WebhookController:
assert.Len(t, 0, v.webhooks) assert.Len(t, 0, v.webhooks)
assert.Equals(t, linkedca.Webhook_SSH, v.certType)
assert.Len(t, 2, v.options)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }

View file

@ -48,22 +48,6 @@ func wrapProvisioner(p provisioner.Interface, attData *provisioner.AttestationDa
} }
} }
// wrapRAProvisioner wraps the given provisioner with RA information.
func wrapRAProvisioner(p provisioner.Interface, raInfo *provisioner.RAInfo) *wrappedProvisioner {
return &wrappedProvisioner{
Interface: p,
raInfo: raInfo,
}
}
// isRAProvisioner returns if the given provisioner is an RA provisioner.
func isRAProvisioner(p provisioner.Interface) bool {
if rap, ok := p.(raProvisioner); ok {
return rap.RAInfo() != nil
}
return false
}
// wrappedProvisioner implements raProvisioner and attProvisioner. // wrappedProvisioner implements raProvisioner and attProvisioner.
type wrappedProvisioner struct { type wrappedProvisioner struct {
provisioner.Interface provisioner.Interface
@ -135,9 +119,6 @@ func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (pr
} }
if err == nil && data != nil && data.Provisioner != nil { if err == nil && data != nil && data.Provisioner != nil {
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok { if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
if data.RaInfo != nil {
return wrapRAProvisioner(p, data.RaInfo), nil
}
return p, nil return p, nil
} }
} }
@ -880,9 +861,6 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
Type: p.Type.String(), Type: p.Type.String(),
Name: p.Name, Name: p.Name,
ForceCN: cfg.ForceCn, ForceCN: cfg.ForceCn,
TermsOfService: cfg.TermsOfService,
Website: cfg.Website,
CaaIdentities: cfg.CaaIdentities,
RequireEAB: cfg.RequireEab, RequireEAB: cfg.RequireEab,
Challenges: challengesToCertificates(cfg.Challenges), Challenges: challengesToCertificates(cfg.Challenges),
AttestationFormats: attestationFormatsToCertificates(cfg.AttestationFormats), AttestationFormats: attestationFormatsToCertificates(cfg.AttestationFormats),
@ -1141,10 +1119,6 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro
Data: &linkedca.ProvisionerDetails_ACME{ Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{ ACME: &linkedca.ACMEProvisioner{
ForceCn: p.ForceCN, ForceCn: p.ForceCN,
TermsOfService: p.TermsOfService,
Website: p.Website,
CaaIdentities: p.CaaIdentities,
RequireEab: p.RequireEAB,
Challenges: challengesToLinkedca(p.Challenges), Challenges: challengesToLinkedca(p.Challenges),
AttestationFormats: attestationFormatsToLinkedca(p.AttestationFormats), AttestationFormats: attestationFormatsToLinkedca(p.AttestationFormats),
AttestationRoots: provisionerPEMToLinkedca(p.AttestationRoots), AttestationRoots: provisionerPEMToLinkedca(p.AttestationRoots),
@ -1223,7 +1197,7 @@ func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, erro
Data: &linkedca.ProvisionerDetails_SCEP{ Data: &linkedca.ProvisionerDetails_SCEP{
SCEP: &linkedca.SCEPProvisioner{ SCEP: &linkedca.SCEPProvisioner{
ForceCn: p.ForceCN, ForceCn: p.ForceCN,
Challenge: p.ChallengePassword, Challenge: p.GetChallengePassword(),
Capabilities: p.Capabilities, Capabilities: p.Capabilities,
MinimumPublicKeyLength: int32(p.MinimumPublicKeyLength), MinimumPublicKeyLength: int32(p.MinimumPublicKeyLength),
IncludeRoot: p.IncludeRoot, IncludeRoot: p.IncludeRoot,

View file

@ -9,17 +9,14 @@ import (
"testing" "testing"
"time" "time"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/linkedca"
"github.com/stretchr/testify/require"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/linkedca"
) )
func TestGetEncryptedKey(t *testing.T) { func TestGetEncryptedKey(t *testing.T) {
@ -32,9 +29,9 @@ func TestGetEncryptedKey(t *testing.T) {
tests := map[string]func(t *testing.T) *ek{ tests := map[string]func(t *testing.T) *ek{
"ok": func(t *testing.T) *ek { "ok": func(t *testing.T) *ek {
c, err := LoadConfiguration("../ca/testdata/ca.json") c, err := LoadConfiguration("../ca/testdata/ca.json")
require.NoError(t, err) assert.FatalError(t, err)
a, err := New(c) a, err := New(c)
require.NoError(t, err) assert.FatalError(t, err)
return &ek{ return &ek{
a: a, a: a,
kid: c.AuthorityConfig.Provisioners[1].(*provisioner.JWK).Key.KeyID, kid: c.AuthorityConfig.Provisioners[1].(*provisioner.JWK).Key.KeyID,
@ -42,9 +39,9 @@ func TestGetEncryptedKey(t *testing.T) {
}, },
"fail-not-found": func(t *testing.T) *ek { "fail-not-found": func(t *testing.T) *ek {
c, err := LoadConfiguration("../ca/testdata/ca.json") c, err := LoadConfiguration("../ca/testdata/ca.json")
require.NoError(t, err) assert.FatalError(t, err)
a, err := New(c) a, err := New(c)
require.NoError(t, err) assert.FatalError(t, err)
return &ek{ return &ek{
a: a, a: a,
kid: "foo", kid: "foo",
@ -98,16 +95,9 @@ func TestGetProvisioners(t *testing.T) {
tests := map[string]func(t *testing.T) *gp{ tests := map[string]func(t *testing.T) *gp{
"ok": func(t *testing.T) *gp { "ok": func(t *testing.T) *gp {
c, err := LoadConfiguration("../ca/testdata/ca.json") c, err := LoadConfiguration("../ca/testdata/ca.json")
require.NoError(t, err) assert.FatalError(t, err)
a, err := New(c) a, err := New(c)
require.NoError(t, err) assert.FatalError(t, err)
return &gp{a: a}
},
"ok/rsa": func(t *testing.T) *gp {
c, err := LoadConfiguration("../ca/testdata/rsaca.json")
require.NoError(t, err)
a, err := New(c)
require.NoError(t, err)
return &gp{a: a} return &gp{a: a}
}, },
} }
@ -121,13 +111,13 @@ func TestGetProvisioners(t *testing.T) {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
var sc render.StatusCodedError var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, tc.code, sc.StatusCode()) assert.Equals(t, sc.StatusCode(), tc.code)
} }
assert.HasPrefix(t, tc.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, tc.a.config.AuthorityConfig.Provisioners, ps) assert.Equals(t, ps, tc.a.config.AuthorityConfig.Provisioners)
assert.Equals(t, "", next) assert.Equals(t, "", next)
} }
} }
@ -137,20 +127,20 @@ func TestGetProvisioners(t *testing.T) {
func TestAuthority_LoadProvisionerByCertificate(t *testing.T) { func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
_, priv, err := keyutil.GenerateDefaultKeyPair() _, priv, err := keyutil.GenerateDefaultKeyPair()
require.NoError(t, err) assert.FatalError(t, err)
csr := getCSR(t, priv) csr := getCSR(t, priv)
sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate { sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate {
key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
require.NoError(t, err) assert.FatalError(t, err)
token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key) token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
require.NoError(t, err) assert.FatalError(t, err)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
opts, err := a.Authorize(ctx, token) opts, err := a.Authorize(ctx, token)
require.NoError(t, err) assert.FatalError(t, err)
opts = append(opts, extraOpts...) opts = append(opts, extraOpts...)
certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...) certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
require.NoError(t, err) assert.FatalError(t, err)
return certs[0] return certs[0]
} }
getProvisioner := func(a *Authority, name string) provisioner.Interface { getProvisioner := func(a *Authority, name string) provisioner.Interface {
@ -179,7 +169,9 @@ func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
}, },
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
p, err := a1.LoadProvisionerByName("dev") p, err := a1.LoadProvisionerByName("dev")
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
return &db.CertificateData{ return &db.CertificateData{
Provisioner: &db.ProvisionerData{ Provisioner: &db.ProvisionerData{
ID: p.GetID(), ID: p.GetID(),
@ -194,7 +186,9 @@ func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
a2.adminDB = &mockAdminDB{ a2.adminDB = &mockAdminDB{
MGetCertificateData: (func(s string) (*db.CertificateData, error) { MGetCertificateData: (func(s string) (*db.CertificateData, error) {
p, err := a2.LoadProvisionerByName("dev") p, err := a2.LoadProvisionerByName("dev")
require.NoError(t, err) if err != nil {
t.Fatal(err)
}
return &db.CertificateData{ return &db.CertificateData{
Provisioner: &db.ProvisionerData{ Provisioner: &db.ProvisionerData{
ID: p.GetID(), ID: p.GetID(),
@ -339,54 +333,3 @@ func TestProvisionerWebhookToLinkedca(t *testing.T) {
}) })
} }
} }
func Test_wrapRAProvisioner(t *testing.T) {
type args struct {
p provisioner.Interface
raInfo *provisioner.RAInfo
}
tests := []struct {
name string
args args
want *wrappedProvisioner
}{
{"ok", args{&provisioner.JWK{Name: "jwt"}, &provisioner.RAInfo{ProvisionerName: "ra"}}, &wrappedProvisioner{
Interface: &provisioner.JWK{Name: "jwt"},
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := wrapRAProvisioner(tt.args.p, tt.args.raInfo); !reflect.DeepEqual(got, tt.want) {
t.Errorf("wrapRAProvisioner() = %v, want %v", got, tt.want)
}
})
}
}
func Test_isRAProvisioner(t *testing.T) {
type args struct {
p provisioner.Interface
}
tests := []struct {
name string
args args
want bool
}{
{"true", args{&wrappedProvisioner{
Interface: &provisioner.JWK{Name: "jwt"},
raInfo: &provisioner.RAInfo{ProvisionerName: "ra"},
}}, true},
{"nil ra", args{&wrappedProvisioner{
Interface: &provisioner.JWK{Name: "jwt"},
}}, false},
{"not ra", args{&provisioner.JWK{Name: "jwt"}}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isRAProvisioner(tt.args.p); got != tt.want {
t.Errorf("isRAProvisioner() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -52,7 +52,7 @@ func (a *Authority) GetSSHFederation(context.Context) (*config.SSHKeys, error) {
} }
// GetSSHConfig returns rendered templates for clients (user) or servers (host). // GetSSHConfig returns rendered templates for clients (user) or servers (host).
func (a *Authority) GetSSHConfig(_ context.Context, typ string, data map[string]string) ([]templates.Output, error) { func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil { if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
return nil, errs.NotFound("getSSHConfig: ssh is not configured") return nil, errs.NotFound("getSSHConfig: ssh is not configured")
} }
@ -146,7 +146,7 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*
} }
// SignSSH creates a signed SSH certificate with the given public key and options. // SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var ( var (
certOptions []sshutil.Option certOptions []sshutil.Option
mods []provisioner.SSHCertModifier mods []provisioner.SSHCertModifier
@ -663,7 +663,11 @@ func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.Certifica
if err != nil { if err != nil {
return err return err
} }
return webhookCtl.Enrich(whEnrichReq) if err := webhookCtl.Enrich(whEnrichReq); err != nil {
return err
}
return nil
} }
func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error { func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error {
@ -676,5 +680,9 @@ func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Cert
if err != nil { if err != nil {
return err return err
} }
return webhookCtl.Authorize(whAuthBody) if err := webhookCtl.Authorize(whAuthBody); err != nil {
return err
}
return nil
} }

Some files were not shown because too many files have changed in this diff Show more