forked from TrueCloudLab/certificates
Merge branch 'master' into hs/acme-revocation
This commit is contained in:
commit
3151255a25
153 changed files with 6603 additions and 1745 deletions
89
.github/workflows/release.yml
vendored
89
.github/workflows/release.yml
vendored
|
@ -12,7 +12,7 @@ jobs:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: [ '1.15', '1.16' ]
|
go: [ '1.15', '1.16', '1.17' ]
|
||||||
outputs:
|
outputs:
|
||||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||||
steps:
|
steps:
|
||||||
|
@ -62,8 +62,15 @@ jobs:
|
||||||
needs: test
|
needs: test
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
outputs:
|
outputs:
|
||||||
|
debversion: ${{ steps.extract-tag.outputs.DEB_VERSION }}
|
||||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||||
steps:
|
steps:
|
||||||
|
-
|
||||||
|
name: Extract Tag Names
|
||||||
|
id: extract-tag
|
||||||
|
run: |
|
||||||
|
DEB_VERSION=$(echo ${GITHUB_REF#refs/tags/v} | sed 's/-/./')
|
||||||
|
echo "::set-output name=DEB_VERSION::${DEB_VERSION}"
|
||||||
-
|
-
|
||||||
name: Is Pre-release
|
name: Is Pre-release
|
||||||
id: is_prerelease
|
id: is_prerelease
|
||||||
|
@ -99,62 +106,71 @@ jobs:
|
||||||
name: Set up Go
|
name: Set up Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: 1.16
|
go-version: 1.17
|
||||||
-
|
|
||||||
name: Run GoReleaser
|
|
||||||
uses: goreleaser/goreleaser-action@56f5b77f7fa4a8fe068bf22b732ec036cc9bc13f # v2.4.1
|
|
||||||
with:
|
|
||||||
version: latest
|
|
||||||
args: release --rm-dist
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ secrets.PAT }}
|
|
||||||
|
|
||||||
release_deb:
|
|
||||||
name: Build & Upload Debian Package To Github
|
|
||||||
runs-on: ubuntu-20.04
|
|
||||||
needs: create_release
|
|
||||||
steps:
|
|
||||||
-
|
|
||||||
name: Checkout
|
|
||||||
uses: actions/checkout@v2
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
-
|
|
||||||
name: Set up Go
|
|
||||||
uses: actions/setup-go@v2
|
|
||||||
with:
|
|
||||||
go-version: '1.16'
|
|
||||||
-
|
-
|
||||||
name: APT Install
|
name: APT Install
|
||||||
id: aptInstall
|
id: aptInstall
|
||||||
run: sudo apt-get -y install build-essential debhelper fakeroot
|
run: sudo apt-get -y install build-essential debhelper fakeroot
|
||||||
-
|
-
|
||||||
name: Build Debian package
|
name: Build Debian package
|
||||||
id: build
|
id: make_debian
|
||||||
run: |
|
run: |
|
||||||
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
|
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
|
||||||
make debian
|
make debian
|
||||||
|
# need to restore the git state otherwise goreleaser fails due to dirty state
|
||||||
|
git restore debian/changelog
|
||||||
|
git clean -fd
|
||||||
-
|
-
|
||||||
name: Upload Debian Package
|
name: Install cosign
|
||||||
id: upload_deb
|
uses: sigstore/cosign-installer@v1.1.0
|
||||||
|
with:
|
||||||
|
cosign-release: 'v1.1.0'
|
||||||
|
-
|
||||||
|
name: Write cosign key to disk
|
||||||
|
id: write_key
|
||||||
|
run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
|
||||||
|
-
|
||||||
|
name: Get Release Date
|
||||||
|
id: release_date
|
||||||
run: |
|
run: |
|
||||||
tag_name="${GITHUB_REF##*/}"
|
RELEASE_DATE=$(date +"%y-%m-%d")
|
||||||
hub release edit $(find ./.releases -type f -printf "-a %p ") -m "" "$tag_name"
|
echo "::set-output name=RELEASE_DATE::${RELEASE_DATE}"
|
||||||
|
-
|
||||||
|
name: Run GoReleaser
|
||||||
|
uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0
|
||||||
|
with:
|
||||||
|
version: latest
|
||||||
|
args: release --rm-dist
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.PAT }}
|
||||||
|
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}
|
||||||
|
DEB_VERSION: ${{ needs.create_release.outputs.debversion }}
|
||||||
|
RELEASE_DATE: ${{ steps.release_date.outputs.RELEASE_DATE }}
|
||||||
|
|
||||||
build_upload_docker:
|
build_upload_docker:
|
||||||
name: Build & Upload Docker Images
|
name: Build & Upload Docker Images
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
needs: test
|
needs: test
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
-
|
||||||
|
name: Checkout
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
- name: Setup Go
|
-
|
||||||
|
name: Setup Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: '1.16'
|
go-version: '1.17'
|
||||||
- name: Build
|
-
|
||||||
|
name: Install cosign
|
||||||
|
uses: sigstore/cosign-installer@v1.1.0
|
||||||
|
with:
|
||||||
|
cosign-release: 'v1.1.0'
|
||||||
|
-
|
||||||
|
name: Write cosign key to disk
|
||||||
|
id: write_key
|
||||||
|
run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
|
||||||
|
-
|
||||||
|
name: Build
|
||||||
id: build
|
id: build
|
||||||
run: |
|
run: |
|
||||||
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
|
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
|
||||||
|
@ -162,3 +178,4 @@ jobs:
|
||||||
env:
|
env:
|
||||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||||
DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }}
|
DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }}
|
||||||
|
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}
|
||||||
|
|
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
@ -14,7 +14,7 @@ jobs:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: [ '1.15', '1.16' ]
|
go: [ '1.15', '1.16', '1.17' ]
|
||||||
steps:
|
steps:
|
||||||
-
|
-
|
||||||
name: Checkout
|
name: Checkout
|
||||||
|
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -14,8 +14,8 @@
|
||||||
|
|
||||||
# Others
|
# Others
|
||||||
*.swp
|
*.swp
|
||||||
.travis-releases
|
.releases
|
||||||
coverage.txt
|
coverage.txt
|
||||||
vendor
|
|
||||||
output
|
output
|
||||||
|
vendor
|
||||||
.idea
|
.idea
|
||||||
|
|
|
@ -36,22 +36,30 @@ linters-settings:
|
||||||
- performance
|
- performance
|
||||||
- style
|
- style
|
||||||
- experimental
|
- experimental
|
||||||
|
- diagnostic
|
||||||
disabled-checks:
|
disabled-checks:
|
||||||
- wrapperFunc
|
- commentFormatting
|
||||||
- dupImport # https://github.com/go-critic/go-critic/issues/845
|
- commentedOutCode
|
||||||
|
- evalOrder
|
||||||
|
- hugeParam
|
||||||
|
- octalLiteral
|
||||||
|
- rangeValCopy
|
||||||
|
- tooManyResultsChecker
|
||||||
|
- unnamedResult
|
||||||
|
|
||||||
linters:
|
linters:
|
||||||
disable-all: true
|
disable-all: true
|
||||||
enable:
|
enable:
|
||||||
- gofmt
|
|
||||||
- revive
|
|
||||||
- govet
|
|
||||||
- misspell
|
|
||||||
- ineffassign
|
|
||||||
- deadcode
|
- deadcode
|
||||||
|
- gocritic
|
||||||
|
- gofmt
|
||||||
|
- gosimple
|
||||||
|
- govet
|
||||||
|
- ineffassign
|
||||||
|
- misspell
|
||||||
|
- revive
|
||||||
- staticcheck
|
- staticcheck
|
||||||
- unused
|
- unused
|
||||||
- gosimple
|
|
||||||
|
|
||||||
run:
|
run:
|
||||||
skip-dirs:
|
skip-dirs:
|
||||||
|
|
154
.goreleaser.yml
154
.goreleaser.yml
|
@ -1,34 +1,27 @@
|
||||||
# This is an example .goreleaser.yml file with some sane defaults.
|
# This is an example .goreleaser.yml file with some sane defaults.
|
||||||
# Make sure to check the documentation at http://goreleaser.com
|
# Make sure to check the documentation at http://goreleaser.com
|
||||||
project_name: step-ca
|
project_name: step-ca
|
||||||
|
|
||||||
before:
|
before:
|
||||||
hooks:
|
hooks:
|
||||||
# You may remove this if you don't use go modules.
|
# You may remove this if you don't use go modules.
|
||||||
- go mod download
|
- go mod download
|
||||||
|
|
||||||
builds:
|
builds:
|
||||||
-
|
-
|
||||||
id: step-ca
|
id: step-ca
|
||||||
env:
|
env:
|
||||||
- CGO_ENABLED=0
|
- CGO_ENABLED=0
|
||||||
goos:
|
targets:
|
||||||
- linux
|
- darwin_amd64
|
||||||
- darwin
|
- darwin_arm64
|
||||||
- windows
|
- freebsd_amd64
|
||||||
goarch:
|
- linux_386
|
||||||
- amd64
|
- linux_amd64
|
||||||
- arm
|
- linux_arm64
|
||||||
- arm64
|
- linux_arm_6
|
||||||
- 386
|
- linux_arm_7
|
||||||
goarm:
|
- windows_amd64
|
||||||
- 6
|
|
||||||
- 7
|
|
||||||
ignore:
|
|
||||||
- goos: windows
|
|
||||||
goarch: 386
|
|
||||||
- goos: windows
|
|
||||||
goarm: 6
|
|
||||||
- goos: windows
|
|
||||||
goarm: 7
|
|
||||||
flags:
|
flags:
|
||||||
- -trimpath
|
- -trimpath
|
||||||
main: ./cmd/step-ca/main.go
|
main: ./cmd/step-ca/main.go
|
||||||
|
@ -39,25 +32,16 @@ builds:
|
||||||
id: step-cloudkms-init
|
id: step-cloudkms-init
|
||||||
env:
|
env:
|
||||||
- CGO_ENABLED=0
|
- CGO_ENABLED=0
|
||||||
goos:
|
targets:
|
||||||
- linux
|
- darwin_amd64
|
||||||
- darwin
|
- darwin_arm64
|
||||||
- windows
|
- freebsd_amd64
|
||||||
goarch:
|
- linux_386
|
||||||
- amd64
|
- linux_amd64
|
||||||
- arm
|
- linux_arm64
|
||||||
- arm64
|
- linux_arm_6
|
||||||
- 386
|
- linux_arm_7
|
||||||
goarm:
|
- windows_amd64
|
||||||
- 6
|
|
||||||
- 7
|
|
||||||
ignore:
|
|
||||||
- goos: windows
|
|
||||||
goarch: 386
|
|
||||||
- goos: windows
|
|
||||||
goarm: 6
|
|
||||||
- goos: windows
|
|
||||||
goarm: 7
|
|
||||||
flags:
|
flags:
|
||||||
- -trimpath
|
- -trimpath
|
||||||
main: ./cmd/step-cloudkms-init/main.go
|
main: ./cmd/step-cloudkms-init/main.go
|
||||||
|
@ -68,31 +52,23 @@ builds:
|
||||||
id: step-awskms-init
|
id: step-awskms-init
|
||||||
env:
|
env:
|
||||||
- CGO_ENABLED=0
|
- CGO_ENABLED=0
|
||||||
goos:
|
targets:
|
||||||
- linux
|
- darwin_amd64
|
||||||
- darwin
|
- darwin_arm64
|
||||||
- windows
|
- freebsd_amd64
|
||||||
goarch:
|
- linux_386
|
||||||
- amd64
|
- linux_amd64
|
||||||
- arm
|
- linux_arm64
|
||||||
- arm64
|
- linux_arm_6
|
||||||
- 386
|
- linux_arm_7
|
||||||
goarm:
|
- windows_amd64
|
||||||
- 6
|
|
||||||
- 7
|
|
||||||
ignore:
|
|
||||||
- goos: windows
|
|
||||||
goarch: 386
|
|
||||||
- goos: windows
|
|
||||||
goarm: 6
|
|
||||||
- goos: windows
|
|
||||||
goarm: 7
|
|
||||||
flags:
|
flags:
|
||||||
- -trimpath
|
- -trimpath
|
||||||
main: ./cmd/step-awskms-init/main.go
|
main: ./cmd/step-awskms-init/main.go
|
||||||
binary: bin/step-awskms-init
|
binary: bin/step-awskms-init
|
||||||
ldflags:
|
ldflags:
|
||||||
- -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}}
|
- -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}}
|
||||||
|
|
||||||
archives:
|
archives:
|
||||||
-
|
-
|
||||||
# Can be used to change the archive formats for specific GOOSs.
|
# Can be used to change the archive formats for specific GOOSs.
|
||||||
|
@ -106,13 +82,25 @@ archives:
|
||||||
files:
|
files:
|
||||||
- README.md
|
- README.md
|
||||||
- LICENSE
|
- LICENSE
|
||||||
|
|
||||||
source:
|
source:
|
||||||
enabled: true
|
enabled: true
|
||||||
name_template: '{{ .ProjectName }}_{{ .Version }}'
|
name_template: '{{ .ProjectName }}_{{ .Version }}'
|
||||||
|
|
||||||
checksum:
|
checksum:
|
||||||
name_template: 'checksums.txt'
|
name_template: 'checksums.txt'
|
||||||
|
extra_files:
|
||||||
|
- glob: ./.releases/*
|
||||||
|
|
||||||
|
signs:
|
||||||
|
- cmd: cosign
|
||||||
|
stdin: '{{ .Env.COSIGN_PWD }}'
|
||||||
|
args: ["sign-blob", "-key=/tmp/cosign.key", "-output=${signature}", "${artifact}"]
|
||||||
|
artifacts: all
|
||||||
|
|
||||||
snapshot:
|
snapshot:
|
||||||
name_template: "{{ .Tag }}-next"
|
name_template: "{{ .Tag }}-next"
|
||||||
|
|
||||||
release:
|
release:
|
||||||
# Repo in which the release will be created.
|
# Repo in which the release will be created.
|
||||||
# Default is extracted from the origin remote URL or empty if its private hosted.
|
# Default is extracted from the origin remote URL or empty if its private hosted.
|
||||||
|
@ -139,7 +127,55 @@ release:
|
||||||
|
|
||||||
# You can change the name of the release.
|
# You can change the name of the release.
|
||||||
# Default is `{{.Tag}}`
|
# Default is `{{.Tag}}`
|
||||||
#name_template: "{{.ProjectName}}-v{{.Version}} {{.Env.USER}}"
|
name_template: "Step CA {{ .Tag }} ({{ .Env.RELEASE_DATE }})"
|
||||||
|
|
||||||
|
# Header template for the release body.
|
||||||
|
# Defaults to empty.
|
||||||
|
header: |
|
||||||
|
## Official Release Artifacts
|
||||||
|
|
||||||
|
#### Linux
|
||||||
|
|
||||||
|
- 📦 [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz)
|
||||||
|
- 📦 [step-ca_{{ .Env.DEB_VERSION }}_amd64.deb](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ .Env.DEB_VERSION }}_amd64.deb)
|
||||||
|
|
||||||
|
#### OSX Darwin
|
||||||
|
|
||||||
|
- 📦 [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz)
|
||||||
|
- 📦 [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz)
|
||||||
|
|
||||||
|
#### Windows
|
||||||
|
|
||||||
|
- 📦 [step-ca_windows_{{ .Version }}_arm64.zip](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip)
|
||||||
|
|
||||||
|
For more builds across platforms and architectures, see the `Assets` section below.
|
||||||
|
And for packaged versions (Docker, k8s, Homebrew), see our [installation docs](https://smallstep.com/docs/step-ca/installation).
|
||||||
|
|
||||||
|
Don't see the artifact you need? Open an issue [here](https://github.com/smallstep/certificates/issues/new/choose).
|
||||||
|
|
||||||
|
## Signatures and Checksums
|
||||||
|
|
||||||
|
`step-ca` uses [sigstore/cosign](https://github.com/sigstore/cosign) for signing and verifying release artifacts.
|
||||||
|
|
||||||
|
Below is an example using `cosign` to verify a release artifact:
|
||||||
|
|
||||||
|
```
|
||||||
|
cosign verify-blob \
|
||||||
|
-key https://raw.githubusercontent.com/smallstep/certificates/master/cosign.pub \
|
||||||
|
-signature ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig
|
||||||
|
~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz
|
||||||
|
```
|
||||||
|
|
||||||
|
The `checksums.txt` file (in the `Assets` section below) contains a checksum for every artifact in the release.
|
||||||
|
|
||||||
|
# Footer template for the release body.
|
||||||
|
# Defaults to empty.
|
||||||
|
footer: |
|
||||||
|
## Thanks!
|
||||||
|
|
||||||
|
Those were the changes on {{ .Tag }}!
|
||||||
|
|
||||||
|
Come join us on [Discord](https://discord.gg/X2RKGwEbV9) to ask questions, chat about PKI, or get a sneak peak at the freshest PKI memes.
|
||||||
|
|
||||||
# You can disable this pipe in order to not upload any artifacts.
|
# You can disable this pipe in order to not upload any artifacts.
|
||||||
# Defaults to false.
|
# Defaults to false.
|
||||||
|
@ -149,6 +185,8 @@ release:
|
||||||
# The filename on the release will be the last part of the path (base). If
|
# The filename on the release will be the last part of the path (base). If
|
||||||
# another file with the same name exists, the latest one found will be used.
|
# another file with the same name exists, the latest one found will be used.
|
||||||
# Defaults to empty.
|
# Defaults to empty.
|
||||||
|
extra_files:
|
||||||
|
- glob: ./.releases/*
|
||||||
#extra_files:
|
#extra_files:
|
||||||
# - glob: ./path/to/file.txt
|
# - glob: ./path/to/file.txt
|
||||||
# - glob: ./glob/**/to/**/file/**/*
|
# - glob: ./glob/**/to/**/file/**/*
|
||||||
|
|
54
CHANGELOG.md
54
CHANGELOG.md
|
@ -4,10 +4,62 @@ All notable changes to this project will be documented in this file.
|
||||||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
||||||
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [Unreleased - 0.0.1] - DATE
|
## [Unreleased - 0.17.7] - DATE
|
||||||
### Added
|
### Added
|
||||||
|
- Support for generate extractable keys and certificates on a pkcs#11 module.
|
||||||
### Changed
|
### Changed
|
||||||
### Deprecated
|
### Deprecated
|
||||||
### Removed
|
### Removed
|
||||||
### Fixed
|
### Fixed
|
||||||
### Security
|
### Security
|
||||||
|
|
||||||
|
## [0.17.6] - 2021-10-20
|
||||||
|
### Notes
|
||||||
|
- 0.17.5 failed in CI/CD
|
||||||
|
|
||||||
|
## [0.17.5] - 2021-10-20
|
||||||
|
### Added
|
||||||
|
- Support for Azure Key Vault as a KMS.
|
||||||
|
- Adapt `pki` package to support key managers.
|
||||||
|
- gocritic linter
|
||||||
|
### Fixed
|
||||||
|
- gocritic warnings
|
||||||
|
|
||||||
|
## [0.17.4] - 2021-09-28
|
||||||
|
### Fixed
|
||||||
|
- Support host-only or user-only SSH CA.
|
||||||
|
|
||||||
|
## [0.17.3] - 2021-09-24
|
||||||
|
### Added
|
||||||
|
- go 1.17 to github action test matrix
|
||||||
|
- Support for CloudKMS RSA-PSS signers without using templates.
|
||||||
|
- Add flags to support individual passwords for the intermediate and SSH keys.
|
||||||
|
- Global support for group admins in the OIDC provisioner.
|
||||||
|
### Changed
|
||||||
|
- Using go 1.17 for binaries
|
||||||
|
### Fixed
|
||||||
|
- Upgrade go-jose.v2 to fix a bug in the JWK fingerprint of Ed25519 keys.
|
||||||
|
### Security
|
||||||
|
- Use cosign to sign and upload signatures for multi-arch Docker container.
|
||||||
|
- Add debian checksum
|
||||||
|
|
||||||
|
## [0.17.2] - 2021-08-30
|
||||||
|
### Added
|
||||||
|
- Additional way to distinguish Azure IID and Azure OIDC tokens.
|
||||||
|
### Security
|
||||||
|
- Sign over all goreleaser github artifacts using cosign
|
||||||
|
|
||||||
|
## [0.17.1] - 2021-08-26
|
||||||
|
|
||||||
|
## [0.17.0] - 2021-08-25
|
||||||
|
### Added
|
||||||
|
- Add support for Linked CAs using protocol buffers and gRPC
|
||||||
|
- `step-ca init` adds support for
|
||||||
|
- configuring a StepCAS RA
|
||||||
|
- configuring a Linked CA
|
||||||
|
- congifuring a `step-ca` using Helm
|
||||||
|
### Changed
|
||||||
|
- Update badger driver to use v2 by default
|
||||||
|
- Update TLS cipher suites to include 1.3
|
||||||
|
### Security
|
||||||
|
- Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512.
|
||||||
|
|
6
Makefile
6
Makefile
|
@ -29,7 +29,7 @@ ci: testcgo build
|
||||||
|
|
||||||
bootstra%:
|
bootstra%:
|
||||||
# Using a released version of golangci-lint to take into account custom replacements in their go.mod
|
# Using a released version of golangci-lint to take into account custom replacements in their go.mod
|
||||||
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v1.39.0
|
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v1.42.0
|
||||||
|
|
||||||
.PHONY: bootstra%
|
.PHONY: bootstra%
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ PUSHTYPE := branch
|
||||||
endif
|
endif
|
||||||
|
|
||||||
VERSION := $(shell echo $(VERSION) | sed 's/^v//')
|
VERSION := $(shell echo $(VERSION) | sed 's/^v//')
|
||||||
DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/~/g')
|
DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/./g')
|
||||||
|
|
||||||
ifdef V
|
ifdef V
|
||||||
$(info TRAVIS_TAG is $(TRAVIS_TAG))
|
$(info TRAVIS_TAG is $(TRAVIS_TAG))
|
||||||
|
@ -154,7 +154,7 @@ fmt:
|
||||||
$Q gofmt -l -w $(SRC)
|
$Q gofmt -l -w $(SRC)
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
$Q $(GOFLAGS) LOG_LEVEL=error golangci-lint run --timeout=30m
|
$Q golangci-lint run --timeout=30m
|
||||||
|
|
||||||
lintcgo:
|
lintcgo:
|
||||||
$Q LOG_LEVEL=error golangci-lint run --timeout=30m
|
$Q LOG_LEVEL=error golangci-lint run --timeout=30m
|
||||||
|
|
36
README.md
36
README.md
|
@ -18,7 +18,14 @@ You can use it to:
|
||||||
|
|
||||||
Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [safe, sane defaults](https://smallstep.com/docs/step-ca/certificate-authority-server-production#sane-cryptographic-defaults).
|
Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [safe, sane defaults](https://smallstep.com/docs/step-ca/certificate-authority-server-production#sane-cryptographic-defaults).
|
||||||
|
|
||||||
**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions).**
|
---
|
||||||
|
|
||||||
|
**Don't want to run your own CA?**
|
||||||
|
To get up and running quickly, or as an alternative to running your own `step-ca` server, consider creating a [free hosted smallstep Certificate Manager authority](https://info.smallstep.com/certificate-manager-early-access-mvp/).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions) or [Join our Discord](https://u.step.sm/discord).**
|
||||||
|
|
||||||
[Website](https://smallstep.com/certificates) |
|
[Website](https://smallstep.com/certificates) |
|
||||||
[Documentation](https://smallstep.com/docs) |
|
[Documentation](https://smallstep.com/docs) |
|
||||||
|
@ -27,7 +34,6 @@ Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [
|
||||||
[Contributor's Guide](./docs/CONTRIBUTING.md)
|
[Contributor's Guide](./docs/CONTRIBUTING.md)
|
||||||
|
|
||||||
[![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest)
|
[![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest)
|
||||||
[![CA Image](https://images.microbadger.com/badges/image/smallstep/step-ca.svg)](https://microbadger.com/images/smallstep/step-ca)
|
|
||||||
[![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates)
|
[![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates)
|
||||||
[![Build Status](https://travis-ci.com/smallstep/certificates.svg?branch=master)](https://travis-ci.com/smallstep/certificates)
|
[![Build Status](https://travis-ci.com/smallstep/certificates.svg?branch=master)](https://travis-ci.com/smallstep/certificates)
|
||||||
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
|
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
|
||||||
|
@ -58,10 +64,11 @@ You can issue certificates in exchange for:
|
||||||
- ID tokens from Okta, GSuite, Azure AD, Auth0.
|
- ID tokens from Okta, GSuite, Azure AD, Auth0.
|
||||||
- ID tokens from an OAuth OIDC service that you host, like [Keycloak](https://www.keycloak.org/) or [Dex](https://github.com/dexidp/dex)
|
- ID tokens from an OAuth OIDC service that you host, like [Keycloak](https://www.keycloak.org/) or [Dex](https://github.com/dexidp/dex)
|
||||||
- [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure
|
- [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure
|
||||||
- [Single-use, short-lived JWK tokens]() issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
|
- [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
|
||||||
- A trusted X.509 certificate (X5C provisioner)
|
- A trusted X.509 certificate (X5C provisioner)
|
||||||
- Expiring SSH host certificates needing rotation (the SSHPOP provisioner)
|
- A SCEP challenge (SCEP provisioner)
|
||||||
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/configuration#jwk)
|
- An SSH host certificates needing renewal (the SSHPOP provisioner)
|
||||||
|
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners)
|
||||||
|
|
||||||
### 🏔 Your own private ACME server
|
### 🏔 Your own private ACME server
|
||||||
|
|
||||||
|
@ -74,16 +81,17 @@ ACME is the protocol used by Let's Encrypt to automate the issuance of HTTPS cer
|
||||||
- For `tls-alpn-01`, respond to the challenge at the TLS layer ([as Caddy does](https://caddy.community/t/caddy-supports-the-acme-tls-alpn-challenge/4860)) to prove that you control the web server
|
- For `tls-alpn-01`, respond to the challenge at the TLS layer ([as Caddy does](https://caddy.community/t/caddy-supports-the-acme-tls-alpn-challenge/4860)) to prove that you control the web server
|
||||||
|
|
||||||
- Works with any ACME client. We've written examples for:
|
- Works with any ACME client. We've written examples for:
|
||||||
- [certbot](https://smallstep.com/blog/private-acme-server/#certbotuploadsacme-certbotpng-certbot-example)
|
- [certbot](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#certbot)
|
||||||
- [acme.sh](https://smallstep.com/blog/private-acme-server/#acmeshuploadsacme-acme-shpng-acmesh-example)
|
- [acme.sh](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#acmesh)
|
||||||
- [Caddy](https://smallstep.com/blog/private-acme-server/#caddyuploadsacme-caddypng-caddy-example)
|
- [win-acme](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#win-acme)
|
||||||
- [Traefik](https://smallstep.com/blog/private-acme-server/#traefikuploadsacme-traefikpng-traefik-example)
|
- [Caddy](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#caddy-v2)
|
||||||
- [Apache](https://smallstep.com/blog/private-acme-server/#apacheuploadsacme-apachepng-apache-example)
|
- [Traefik](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#traefik)
|
||||||
- [nginx](https://smallstep.com/blog/private-acme-server/#nginxuploadsacme-nginxpng-nginx-example)
|
- [Apache](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#apache)
|
||||||
|
- [nginx](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#nginx)
|
||||||
- Get certificates programmatically using ACME, using these libraries:
|
- Get certificates programmatically using ACME, using these libraries:
|
||||||
- [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/blog/private-acme-server/#golanguploadsacme-golangpng-go-example))
|
- [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#golang))
|
||||||
- certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/blog/private-acme-server/#pythonuploadsacme-pythonpng-python-example))
|
- certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#python))
|
||||||
- [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/blog/private-acme-server/#nodejsuploadsacme-node-jspng-nodejs-example))
|
- [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#node))
|
||||||
- Our own [`step` CLI tool](https://github.com/smallstep/cli) is also an ACME client!
|
- Our own [`step` CLI tool](https://github.com/smallstep/cli) is also an ACME client!
|
||||||
- See our [ACME tutorial](https://smallstep.com/docs/tutorials/acme-challenge) for more
|
- See our [ACME tutorial](https://smallstep.com/docs/tutorials/acme-challenge) for more
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ type NewAccountRequest struct {
|
||||||
|
|
||||||
func validateContacts(cs []string) error {
|
func validateContacts(cs []string) error {
|
||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
if len(c) == 0 {
|
if c == "" {
|
||||||
return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
|
return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -178,7 +178,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
|
||||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
|
u := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
|
||||||
|
|
||||||
oids := []string{"foo", "bar"}
|
oids := []string{"foo", "bar"}
|
||||||
oidURLs := []string{
|
oidURLs := []string{
|
||||||
|
@ -255,7 +255,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetOrdersByAccountID(w, req)
|
h.GetOrdersByAccountID(w, req)
|
||||||
|
|
|
@ -64,8 +64,14 @@ type HandlerOptions struct {
|
||||||
|
|
||||||
// NewHandler returns a new ACME API handler.
|
// NewHandler returns a new ACME API handler.
|
||||||
func NewHandler(ops HandlerOptions) api.RouterHandler {
|
func NewHandler(ops HandlerOptions) api.RouterHandler {
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
client := http.Client{
|
client := http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
|
Transport: transport,
|
||||||
}
|
}
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
|
|
|
@ -148,7 +148,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("authzID", az.ID)
|
chiCtx.URLParams.Add("authzID", az.ID)
|
||||||
url := fmt.Sprintf("%s/acme/%s/authz/%s",
|
u := fmt.Sprintf("%s/acme/%s/authz/%s",
|
||||||
baseURL.String(), provName, az.ID)
|
baseURL.String(), provName, az.ID)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -280,7 +280,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
||||||
expB, err := json.Marshal(az)
|
expB, err := json.Marshal(az)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -314,7 +314,7 @@ func TestHandler_GetCertificate(t *testing.T) {
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("certID", certID)
|
chiCtx.URLParams.Add("certID", certID)
|
||||||
url := fmt.Sprintf("%s/acme/%s/certificate/%s",
|
u := fmt.Sprintf("%s/acme/%s/certificate/%s",
|
||||||
baseURL.String(), provName, certID)
|
baseURL.String(), provName, certID)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -396,7 +396,7 @@ func TestHandler_GetCertificate(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
h := &Handler{db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetCertificate(w, req)
|
h.GetCertificate(w, req)
|
||||||
|
@ -434,7 +434,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
|
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
|
u := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
|
||||||
baseURL.String(), provName, "authzID", "chID")
|
baseURL.String(), provName, "authzID", "chID")
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -635,7 +635,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
AuthorizationID: "authzID",
|
AuthorizationID: "authzID",
|
||||||
Type: acme.HTTP01,
|
Type: acme.HTTP01,
|
||||||
AccountID: "accID",
|
AccountID: "accID",
|
||||||
URL: url,
|
URL: u,
|
||||||
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
||||||
},
|
},
|
||||||
vco: &acme.ValidateChallengeOptions{
|
vco: &acme.ValidateChallengeOptions{
|
||||||
|
@ -652,7 +652,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetChallenge(w, req)
|
h.GetChallenge(w, req)
|
||||||
|
@ -678,7 +678,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")})
|
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")})
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -223,7 +223,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -399,7 +399,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_addNonce(t *testing.T) {
|
func TestHandler_addNonce(t *testing.T) {
|
||||||
url := "https://ca.smallstep.com/acme/new-nonce"
|
u := "https://ca.smallstep.com/acme/new-nonce"
|
||||||
type test struct {
|
type test struct {
|
||||||
db acme.DB
|
db acme.DB
|
||||||
err *acme.Error
|
err *acme.Error
|
||||||
|
@ -141,7 +141,7 @@ func TestHandler_addNonce(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
h := &Handler{db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.addNonce(testNext)(w, req)
|
h.addNonce(testNext)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
@ -230,7 +230,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
escProvName := url.PathEscape(prov.GetName())
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
||||||
type test struct {
|
type test struct {
|
||||||
h Handler
|
h Handler
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
@ -245,7 +245,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
h: Handler{
|
h: Handler{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: NewLinker("dns", "acme"),
|
||||||
},
|
},
|
||||||
url: url,
|
url: u,
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
contentType: "foo",
|
contentType: "foo",
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
|
@ -257,7 +257,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
h: Handler{
|
h: Handler{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: NewLinker("dns", "acme"),
|
||||||
},
|
},
|
||||||
url: url,
|
url: u,
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
contentType: "foo",
|
contentType: "foo",
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
|
@ -319,11 +319,11 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
_url := url
|
_u := u
|
||||||
if tc.url != "" {
|
if tc.url != "" {
|
||||||
_url = tc.url
|
_u = tc.url
|
||||||
}
|
}
|
||||||
req := httptest.NewRequest("GET", _url, nil)
|
req := httptest.NewRequest("GET", _u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
req.Header.Add("Content-Type", tc.contentType)
|
req.Header.Add("Content-Type", tc.contentType)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -353,7 +353,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_isPostAsGet(t *testing.T) {
|
func TestHandler_isPostAsGet(t *testing.T) {
|
||||||
url := "https://ca.smallstep.com/acme/new-account"
|
u := "https://ca.smallstep.com/acme/new-account"
|
||||||
type test struct {
|
type test struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
err *acme.Error
|
err *acme.Error
|
||||||
|
@ -392,7 +392,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.isPostAsGet(testNext)(w, req)
|
h.isPostAsGet(testNext)(w, req)
|
||||||
|
@ -430,7 +430,7 @@ func (errReader) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_parseJWS(t *testing.T) {
|
func TestHandler_parseJWS(t *testing.T) {
|
||||||
url := "https://ca.smallstep.com/acme/new-account"
|
u := "https://ca.smallstep.com/acme/new-account"
|
||||||
type test struct {
|
type test struct {
|
||||||
next nextHTTP
|
next nextHTTP
|
||||||
body io.Reader
|
body io.Reader
|
||||||
|
@ -483,7 +483,7 @@ func TestHandler_parseJWS(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest("GET", url, tc.body)
|
req := httptest.NewRequest("GET", u, tc.body)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.parseJWS(tc.next)(w, req)
|
h.parseJWS(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
@ -528,7 +528,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
parsedJWS, err := jose.ParseJWS(raw)
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
url := "https://ca.smallstep.com/acme/account/1234"
|
u := "https://ca.smallstep.com/acme/account/1234"
|
||||||
type test struct {
|
type test struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
next func(http.ResponseWriter, *http.Request)
|
next func(http.ResponseWriter, *http.Request)
|
||||||
|
@ -681,7 +681,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||||
|
@ -713,7 +713,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
url := fmt.Sprintf("%s/acme/%s/account/1234",
|
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||||
baseURL, provName)
|
baseURL, provName)
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
@ -883,7 +883,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: tc.linker}
|
h := &Handler{db: tc.db, linker: tc.linker}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.lookupJWK(tc.next)(w, req)
|
h.lookupJWK(tc.next)(w, req)
|
||||||
|
@ -934,7 +934,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
parsedJWS, err := jose.ParseJWS(raw)
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
|
u := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
|
||||||
provName)
|
provName)
|
||||||
type test struct {
|
type test struct {
|
||||||
db acme.DB
|
db acme.DB
|
||||||
|
@ -1079,7 +1079,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
h := &Handler{db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.extractJWK(tc.next)(w, req)
|
h.extractJWK(tc.next)(w, req)
|
||||||
|
@ -1108,7 +1108,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_validateJWS(t *testing.T) {
|
func TestHandler_validateJWS(t *testing.T) {
|
||||||
url := "https://ca.smallstep.com/acme/account/1234"
|
u := "https://ca.smallstep.com/acme/account/1234"
|
||||||
type test struct {
|
type test struct {
|
||||||
db acme.DB
|
db acme.DB
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
@ -1198,7 +1198,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.RS256,
|
Algorithm: jose.RS256,
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1226,7 +1226,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.RS256,
|
Algorithm: jose.RS256,
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1298,7 +1298,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
},
|
},
|
||||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", url),
|
err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", u),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/both-jwk-kid": func(t *testing.T) test {
|
"fail/both-jwk-kid": func(t *testing.T) test {
|
||||||
|
@ -1313,7 +1313,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
KeyID: "bar",
|
KeyID: "bar",
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1337,7 +1337,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Protected: jose.Header{
|
Protected: jose.Header{
|
||||||
Algorithm: jose.ES256,
|
Algorithm: jose.ES256,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1362,7 +1362,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.ES256,
|
Algorithm: jose.ES256,
|
||||||
KeyID: "bar",
|
KeyID: "bar",
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1392,7 +1392,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.ES256,
|
Algorithm: jose.ES256,
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1422,7 +1422,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.RS256,
|
Algorithm: jose.RS256,
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1446,7 +1446,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
h := &Handler{db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.validateJWS(tc.next)(w, req)
|
h.validateJWS(tc.next)(w, req)
|
||||||
|
|
|
@ -264,7 +264,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("ordID", o.ID)
|
chiCtx.URLParams.Add("ordID", o.ID)
|
||||||
url := fmt.Sprintf("%s/acme/%s/order/%s",
|
u := fmt.Sprintf("%s/acme/%s/order/%s",
|
||||||
baseURL.String(), escProvName, o.ID)
|
baseURL.String(), escProvName, o.ID)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -422,7 +422,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetOrder(w, req)
|
h.GetOrder(w, req)
|
||||||
|
@ -448,7 +448,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -663,7 +663,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
escProvName := url.PathEscape(prov.GetName())
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
url := fmt.Sprintf("%s/acme/%s/order/ordID",
|
u := fmt.Sprintf("%s/acme/%s/order/ordID",
|
||||||
baseURL.String(), escProvName)
|
baseURL.String(), escProvName)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -1335,7 +1335,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.NewOrder(w, req)
|
h.NewOrder(w, req)
|
||||||
|
@ -1363,7 +1363,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
tc.vr(t, ro)
|
tc.vr(t, ro)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -1406,7 +1406,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("ordID", o.ID)
|
chiCtx.URLParams.Add("ordID", o.ID)
|
||||||
url := fmt.Sprintf("%s/acme/%s/order/%s",
|
u := fmt.Sprintf("%s/acme/%s/order/%s",
|
||||||
baseURL.String(), escProvName, o.ID)
|
baseURL.String(), escProvName, o.ID)
|
||||||
|
|
||||||
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
||||||
|
@ -1625,7 +1625,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.FinalizeOrder(w, req)
|
h.FinalizeOrder(w, req)
|
||||||
|
@ -1654,7 +1654,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
assert.FatalError(t, json.Unmarshal(body, ro))
|
assert.FatalError(t, json.Unmarshal(body, ro))
|
||||||
|
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -10,11 +10,13 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -74,23 +76,23 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||||
url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||||
|
|
||||||
resp, err := vo.HTTPGet(url.String())
|
resp, err := vo.HTTPGet(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||||
"error doing http GET for url %s", url))
|
"error doing http GET for url %s", u))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
return storeError(ctx, db, ch, false, NewError(ErrorConnectionType,
|
return storeError(ctx, db, ch, false, NewError(ErrorConnectionType,
|
||||||
"error doing http GET for url %s with status code %d", url, resp.StatusCode))
|
"error doing http GET for url %s with status code %d", u, resp.StatusCode))
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WrapErrorISE(err, "error reading "+
|
return WrapErrorISE(err, "error reading "+
|
||||||
"response body for url %s", url)
|
"response body for url %s", u)
|
||||||
}
|
}
|
||||||
keyAuth := strings.TrimSpace(string(body))
|
keyAuth := strings.TrimSpace(string(body))
|
||||||
|
|
||||||
|
@ -114,6 +116,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tlsAlert(err error) uint8 {
|
||||||
|
var opErr *net.OpError
|
||||||
|
if errors.As(err, &opErr) {
|
||||||
|
v := reflect.ValueOf(opErr.Err)
|
||||||
|
if v.Kind() == reflect.Uint8 {
|
||||||
|
return uint8(v.Uint())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||||
config := &tls.Config{
|
config := &tls.Config{
|
||||||
NextProtos: []string{"acme-tls/1"},
|
NextProtos: []string{"acme-tls/1"},
|
||||||
|
@ -129,6 +142,14 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
|
||||||
|
|
||||||
conn, err := vo.TLSDial("tcp", hostPort, config)
|
conn, err := vo.TLSDial("tcp", hostPort, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
|
||||||
|
// client and server protocols. When this happens the connection is
|
||||||
|
// closed with the error no_application_protocol(120) as required by
|
||||||
|
// RFC7301. See https://golang.org/doc/go1.17#ALPN
|
||||||
|
if tlsAlert(err) == 120 {
|
||||||
|
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
|
||||||
|
"cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge"))
|
||||||
|
}
|
||||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||||
"error doing TLS dial for %s", hostPort))
|
"error doing TLS dial for %s", hostPort))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1276,7 +1276,7 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na
|
||||||
oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
|
oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
|
||||||
}
|
}
|
||||||
|
|
||||||
keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:])
|
keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash)
|
||||||
|
|
||||||
certTemplate.ExtraExtensions = []pkix.Extension{
|
certTemplate.ExtraExtensions = []pkix.Extension{
|
||||||
{
|
{
|
||||||
|
@ -1395,7 +1395,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
||||||
assert.Equals(t, updch.Type, ch.Type)
|
assert.Equals(t, updch.Type, ch.Type)
|
||||||
assert.Equals(t, updch.Value, ch.Value)
|
assert.Equals(t, updch.Value, ch.Value)
|
||||||
|
|
||||||
err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.Value)
|
err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443:", ch.Value)
|
||||||
|
|
||||||
assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
|
assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
|
||||||
assert.Equals(t, updch.Error.Type, err.Type)
|
assert.Equals(t, updch.Error.Type, err.Type)
|
||||||
|
|
|
@ -93,8 +93,8 @@ func TestDB_getDBAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil {
|
if dbacc, err := d.getDBAccount(context.Background(), accID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -109,15 +109,13 @@ func TestDB_getDBAccount(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
||||||
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
||||||
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
||||||
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
|
||||||
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
|
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
||||||
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -174,8 +172,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
if retAccID, err := d.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -190,10 +188,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, retAccID, accID)
|
||||||
assert.Equals(t, retAccID, accID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -250,8 +246,8 @@ func TestDB_GetAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if acc, err := db.GetAccount(context.Background(), accID); err != nil {
|
if acc, err := d.GetAccount(context.Background(), accID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -266,13 +262,11 @@ func TestDB_GetAccount(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -358,8 +352,8 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil {
|
if acc, err := d.GetAccountByKeyID(context.Background(), kid); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -374,13 +368,11 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -527,8 +519,8 @@ func TestDB_CreateAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateAccount(context.Background(), tc.acc); err != nil {
|
if err := d.CreateAccount(context.Background(), tc.acc); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -688,8 +680,8 @@ func TestDB_UpdateAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.UpdateAccount(context.Background(), tc.acc); err != nil {
|
if err := d.UpdateAccount(context.Background(), tc.acc); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,8 +97,8 @@ func TestDB_getDBAuthz(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil {
|
if dbaz, err := d.getDBAuthz(context.Background(), azID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -113,18 +113,16 @@ func TestDB_getDBAuthz(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, dbaz.ID, tc.dbaz.ID)
|
||||||
assert.Equals(t, dbaz.ID, tc.dbaz.ID)
|
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
|
||||||
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
|
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
|
||||||
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
|
assert.Equals(t, dbaz.Status, tc.dbaz.Status)
|
||||||
assert.Equals(t, dbaz.Status, tc.dbaz.Status)
|
assert.Equals(t, dbaz.Token, tc.dbaz.Token)
|
||||||
assert.Equals(t, dbaz.Token, tc.dbaz.Token)
|
assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
|
||||||
assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
|
assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||||
assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
|
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
|
||||||
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
|
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
|
||||||
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -293,8 +291,8 @@ func TestDB_GetAuthorization(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if az, err := db.GetAuthorization(context.Background(), azID); err != nil {
|
if az, err := d.GetAuthorization(context.Background(), azID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -309,21 +307,19 @@ func TestDB_GetAuthorization(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, az.ID, tc.dbaz.ID)
|
||||||
assert.Equals(t, az.ID, tc.dbaz.ID)
|
assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
|
||||||
assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
|
assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
|
||||||
assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
|
assert.Equals(t, az.Status, tc.dbaz.Status)
|
||||||
assert.Equals(t, az.Status, tc.dbaz.Status)
|
assert.Equals(t, az.Token, tc.dbaz.Token)
|
||||||
assert.Equals(t, az.Token, tc.dbaz.Token)
|
assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
|
||||||
assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
|
assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
|
||||||
assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
|
assert.Equals(t, az.Challenges, []*acme.Challenge{
|
||||||
assert.Equals(t, az.Challenges, []*acme.Challenge{
|
{ID: "foo"},
|
||||||
{ID: "foo"},
|
{ID: "bar"},
|
||||||
{ID: "bar"},
|
})
|
||||||
})
|
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
|
||||||
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -445,8 +441,8 @@ func TestDB_CreateAuthorization(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateAuthorization(context.Background(), tc.az); err != nil {
|
if err := d.CreateAuthorization(context.Background(), tc.az); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -594,8 +590,8 @@ func TestDB_UpdateAuthorization(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil {
|
if err := d.UpdateAuthorization(context.Background(), tc.az); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,8 +116,8 @@ func TestDB_CreateCertificate(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateCertificate(context.Background(), tc.cert); err != nil {
|
if err := d.CreateCertificate(context.Background(), tc.cert); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -246,8 +246,8 @@ func TestDB_GetCertificate(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
cert, err := db.GetCertificate(context.Background(), certID)
|
cert, err := d.GetCertificate(context.Background(), certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
|
@ -263,14 +263,12 @@ func TestDB_GetCertificate(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, cert.ID, certID)
|
||||||
assert.Equals(t, cert.ID, certID)
|
assert.Equals(t, cert.AccountID, "accountID")
|
||||||
assert.Equals(t, cert.AccountID, "accountID")
|
assert.Equals(t, cert.OrderID, "orderID")
|
||||||
assert.Equals(t, cert.OrderID, "orderID")
|
assert.Equals(t, cert.Leaf, leaf)
|
||||||
assert.Equals(t, cert.Leaf, leaf)
|
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,8 +92,8 @@ func TestDB_getDBChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil {
|
if ch, err := d.getDBChallenge(context.Background(), chID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -108,17 +108,15 @@ func TestDB_getDBChallenge(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -206,8 +204,8 @@ func TestDB_CreateChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil {
|
if err := d.CreateChallenge(context.Background(), tc.ch); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -286,8 +284,8 @@ func TestDB_GetChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil {
|
if ch, err := d.GetChallenge(context.Background(), chID, azID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -302,17 +300,15 @@ func TestDB_GetChallenge(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
assert.Equals(t, ch.Status, tc.dbc.Status)
|
||||||
assert.Equals(t, ch.Status, tc.dbc.Status)
|
assert.Equals(t, ch.Token, tc.dbc.Token)
|
||||||
assert.Equals(t, ch.Token, tc.dbc.Token)
|
assert.Equals(t, ch.Value, tc.dbc.Value)
|
||||||
assert.Equals(t, ch.Value, tc.dbc.Value)
|
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -442,8 +438,8 @@ func TestDB_UpdateChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
if err := d.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
|
||||||
ID: id,
|
ID: id,
|
||||||
CreatedAt: clock.Now(),
|
CreatedAt: clock.Now(),
|
||||||
}
|
}
|
||||||
if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
|
if err := db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return acme.Nonce(id), nil
|
return acme.Nonce(id), nil
|
||||||
|
|
|
@ -67,8 +67,8 @@ func TestDB_CreateNonce(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if n, err := db.CreateNonce(context.Background()); err != nil {
|
if n, err := d.CreateNonce(context.Background()); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -144,8 +144,8 @@ func TestDB_DeleteNonce(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
if err := d.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
|
|
@ -42,7 +42,7 @@ func New(db nosqlDB.DB) (*DB, error) {
|
||||||
|
|
||||||
// save writes the new data to the database, overwriting the old data if it
|
// save writes the new data to the database, overwriting the old data if it
|
||||||
// existed.
|
// existed.
|
||||||
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
|
func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
newB []byte
|
newB []byte
|
||||||
|
|
|
@ -126,8 +126,8 @@ func TestDB_save(t *testing.T) {
|
||||||
}
|
}
|
||||||
for name, tc := range tests {
|
for name, tc := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := &DB{db: tc.db}
|
d := &DB{db: tc.db}
|
||||||
if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
|
if err := d.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -124,10 +124,8 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st
|
||||||
ordersByAccountMux.Lock()
|
ordersByAccountMux.Lock()
|
||||||
defer ordersByAccountMux.Unlock()
|
defer ordersByAccountMux.Unlock()
|
||||||
|
|
||||||
|
var oldOids []string
|
||||||
b, err := db.db.Get(ordersByAccountIDTable, []byte(accID))
|
b, err := db.db.Get(ordersByAccountIDTable, []byte(accID))
|
||||||
var (
|
|
||||||
oldOids []string
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !nosql.IsErrNotFound(err) {
|
if !nosql.IsErrNotFound(err) {
|
||||||
return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID)
|
return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID)
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
nosqldb "github.com/smallstep/nosql/database"
|
"github.com/smallstep/nosql/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDB_getDBOrder(t *testing.T) {
|
func TestDB_getDBOrder(t *testing.T) {
|
||||||
|
@ -31,7 +31,7 @@ func TestDB_getDBOrder(t *testing.T) {
|
||||||
assert.Equals(t, bucket, orderTable)
|
assert.Equals(t, bucket, orderTable)
|
||||||
assert.Equals(t, string(key), orderID)
|
assert.Equals(t, string(key), orderID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
||||||
|
@ -100,8 +100,8 @@ func TestDB_getDBOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if dbo, err := db.getDBOrder(context.Background(), orderID); err != nil {
|
if dbo, err := d.getDBOrder(context.Background(), orderID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -116,20 +116,18 @@ func TestDB_getDBOrder(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, dbo.ID, tc.dbo.ID)
|
||||||
assert.Equals(t, dbo.ID, tc.dbo.ID)
|
assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID)
|
||||||
assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID)
|
assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID)
|
||||||
assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID)
|
assert.Equals(t, dbo.Status, tc.dbo.Status)
|
||||||
assert.Equals(t, dbo.Status, tc.dbo.Status)
|
assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt)
|
||||||
assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt)
|
assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt)
|
||||||
assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt)
|
assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore)
|
||||||
assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore)
|
assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter)
|
||||||
assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter)
|
assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers)
|
||||||
assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers)
|
assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
||||||
assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
|
||||||
assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -164,7 +162,7 @@ func TestDB_GetOrder(t *testing.T) {
|
||||||
assert.Equals(t, bucket, orderTable)
|
assert.Equals(t, bucket, orderTable)
|
||||||
assert.Equals(t, string(key), orderID)
|
assert.Equals(t, string(key), orderID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
||||||
|
@ -206,8 +204,8 @@ func TestDB_GetOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if o, err := db.GetOrder(context.Background(), orderID); err != nil {
|
if o, err := d.GetOrder(context.Background(), orderID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -222,20 +220,18 @@ func TestDB_GetOrder(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, o.ID, tc.dbo.ID)
|
||||||
assert.Equals(t, o.ID, tc.dbo.ID)
|
assert.Equals(t, o.AccountID, tc.dbo.AccountID)
|
||||||
assert.Equals(t, o.AccountID, tc.dbo.AccountID)
|
assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID)
|
||||||
assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID)
|
assert.Equals(t, o.CertificateID, tc.dbo.CertificateID)
|
||||||
assert.Equals(t, o.CertificateID, tc.dbo.CertificateID)
|
assert.Equals(t, o.Status, tc.dbo.Status)
|
||||||
assert.Equals(t, o.Status, tc.dbo.Status)
|
assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt)
|
||||||
assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt)
|
assert.Equals(t, o.NotBefore, tc.dbo.NotBefore)
|
||||||
assert.Equals(t, o.NotBefore, tc.dbo.NotBefore)
|
assert.Equals(t, o.NotAfter, tc.dbo.NotAfter)
|
||||||
assert.Equals(t, o.NotAfter, tc.dbo.NotAfter)
|
assert.Equals(t, o.Identifiers, tc.dbo.Identifiers)
|
||||||
assert.Equals(t, o.Identifiers, tc.dbo.Identifiers)
|
assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
||||||
assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
|
||||||
assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -366,8 +362,8 @@ func TestDB_UpdateOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.UpdateOrder(context.Background(), tc.o); err != nil {
|
if err := d.UpdateOrder(context.Background(), tc.o); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -511,7 +507,7 @@ func TestDB_CreateOrder(t *testing.T) {
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
assert.Equals(t, string(bucket), string(ordersByAccountIDTable))
|
assert.Equals(t, string(bucket), string(ordersByAccountIDTable))
|
||||||
assert.Equals(t, string(key), o.AccountID)
|
assert.Equals(t, string(key), o.AccountID)
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
switch string(bucket) {
|
switch string(bucket) {
|
||||||
|
@ -557,8 +553,8 @@ func TestDB_CreateOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateOrder(context.Background(), tc.o); err != nil {
|
if err := d.CreateOrder(context.Background(), tc.o); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -680,7 +676,7 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||||
assert.Equals(t, key, []byte(accID))
|
assert.Equals(t, key, []byte(accID))
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||||
|
@ -710,6 +706,34 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
||||||
err: errors.Errorf("error saving orderIDs index for account %s", accID),
|
err: errors.Errorf("error saving orderIDs index for account %s", accID),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"ok/no-old": func(t *testing.T) test {
|
||||||
|
return test{
|
||||||
|
db: &db.MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(ordersByAccountIDTable):
|
||||||
|
return nil, database.ErrNotFound
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
|
||||||
|
return nil, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
|
switch string(bucket) {
|
||||||
|
case string(ordersByAccountIDTable):
|
||||||
|
assert.Equals(t, key, []byte(accID))
|
||||||
|
assert.Equals(t, old, nil)
|
||||||
|
assert.Equals(t, nu, nil)
|
||||||
|
return nil, true, nil
|
||||||
|
default:
|
||||||
|
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
|
||||||
|
return nil, false, errors.New("force")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: []string{},
|
||||||
|
}
|
||||||
|
},
|
||||||
"ok/all-old-not-pending": func(t *testing.T) test {
|
"ok/all-old-not-pending": func(t *testing.T) test {
|
||||||
oldOids := []string{"foo", "bar"}
|
oldOids := []string{"foo", "bar"}
|
||||||
bOldOids, err := json.Marshal(oldOids)
|
bOldOids, err := json.Marshal(oldOids)
|
||||||
|
@ -967,15 +991,15 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
var (
|
var (
|
||||||
res []string
|
res []string
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if tc.addOids == nil {
|
if tc.addOids == nil {
|
||||||
res, err = db.updateAddOrderIDs(context.Background(), accID)
|
res, err = d.updateAddOrderIDs(context.Background(), accID)
|
||||||
} else {
|
} else {
|
||||||
res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...)
|
res, err = d.updateAddOrderIDs(context.Background(), accID, tc.addOids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -993,10 +1017,8 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.True(t, reflect.DeepEqual(res, tc.res))
|
||||||
assert.True(t, reflect.DeepEqual(res, tc.res))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -289,6 +289,7 @@ func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.Certificate
|
||||||
// name or in an extensionRequest attribute [RFC2985] requesting a
|
// name or in an extensionRequest attribute [RFC2985] requesting a
|
||||||
// subjectAltName extension, or both.
|
// subjectAltName extension, or both.
|
||||||
if csr.Subject.CommonName != "" {
|
if csr.Subject.CommonName != "" {
|
||||||
|
// nolint:gocritic
|
||||||
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
|
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
|
||||||
}
|
}
|
||||||
canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames)
|
canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames)
|
||||||
|
|
29
api/api.go
29
api/api.go
|
@ -240,9 +240,9 @@ type caHandler struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new RouterHandler with the CA endpoints.
|
// New creates a new RouterHandler with the CA endpoints.
|
||||||
func New(authority Authority) RouterHandler {
|
func New(auth Authority) RouterHandler {
|
||||||
return &caHandler{
|
return &caHandler{
|
||||||
Authority: authority,
|
Authority: auth,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,7 +295,7 @@ func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||||
// certificate for the given SHA256.
|
// certificate for the given SHA256.
|
||||||
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
||||||
sha := chi.URLParam(r, "sha")
|
sha := chi.URLParam(r, "sha")
|
||||||
sum := strings.ToLower(strings.Replace(sha, "-", "", -1))
|
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
|
||||||
// Load root certificate with the
|
// Load root certificate with the
|
||||||
cert, err := h.Authority.Root(sum)
|
cert, err := h.Authority.Root(sum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -409,19 +409,20 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
||||||
"certificate": base64.StdEncoding.EncodeToString(cert.Raw),
|
"certificate": base64.StdEncoding.EncodeToString(cert.Raw),
|
||||||
}
|
}
|
||||||
for _, ext := range cert.Extensions {
|
for _, ext := range cert.Extensions {
|
||||||
if ext.Id.Equal(oidStepProvisioner) {
|
if !ext.Id.Equal(oidStepProvisioner) {
|
||||||
val := &stepProvisioner{}
|
continue
|
||||||
rest, err := asn1.Unmarshal(ext.Value, val)
|
}
|
||||||
if err != nil || len(rest) > 0 {
|
val := &stepProvisioner{}
|
||||||
break
|
rest, err := asn1.Unmarshal(ext.Value, val)
|
||||||
}
|
if err != nil || len(rest) > 0 {
|
||||||
if len(val.CredentialID) > 0 {
|
|
||||||
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
|
|
||||||
} else {
|
|
||||||
m["provisioner"] = string(val.Name)
|
|
||||||
}
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
if len(val.CredentialID) > 0 {
|
||||||
|
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
|
||||||
|
} else {
|
||||||
|
m["provisioner"] = string(val.Name)
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
rl.WithFields(m)
|
rl.WithFields(m)
|
||||||
}
|
}
|
||||||
|
|
|
@ -186,8 +186,8 @@ func TestCertificate_MarshalJSON(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"nil", fields{Certificate: nil}, []byte("null"), false},
|
{"nil", fields{Certificate: nil}, []byte("null"), false},
|
||||||
{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
|
{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
|
||||||
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"`), false},
|
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"`), false},
|
||||||
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`), false},
|
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`), false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -219,11 +219,11 @@ func TestCertificate_UnmarshalJSON(t *testing.T) {
|
||||||
{"invalid string", []byte(`"foobar"`), false, true},
|
{"invalid string", []byte(`"foobar"`), false, true},
|
||||||
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
||||||
{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true},
|
{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true},
|
||||||
{"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false, true},
|
{"invalid type", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), false, true},
|
||||||
{"empty string", []byte(`""`), false, false},
|
{"empty string", []byte(`""`), false, false},
|
||||||
{"json null", []byte(`null`), false, false},
|
{"json null", []byte(`null`), false, false},
|
||||||
{"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true, false},
|
{"valid root", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), true, false},
|
||||||
{"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), true, false},
|
{"valid cert", []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"`), true, false},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -251,7 +251,7 @@ func TestCertificate_UnmarshalJSON_json(t *testing.T) {
|
||||||
{"empty crt (null)", `{"crt":null}`, false, false},
|
{"empty crt (null)", `{"crt":null}`, false, false},
|
||||||
{"empty crt (string)", `{"crt":""}`, false, false},
|
{"empty crt (string)", `{"crt":""}`, false, false},
|
||||||
{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true},
|
{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true},
|
||||||
{"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, true, false},
|
{"valid crt", `{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"}`, true, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
type request struct {
|
type request struct {
|
||||||
|
@ -297,7 +297,7 @@ func TestCertificateRequest_MarshalJSON(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"nil", fields{CertificateRequest: nil}, []byte("null"), false},
|
{"nil", fields{CertificateRequest: nil}, []byte("null"), false},
|
||||||
{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
|
{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
|
||||||
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `\n"`), false},
|
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `\n"`), false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -329,10 +329,10 @@ func TestCertificateRequest_UnmarshalJSON(t *testing.T) {
|
||||||
{"invalid string", []byte(`"foobar"`), false, true},
|
{"invalid string", []byte(`"foobar"`), false, true},
|
||||||
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
||||||
{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true},
|
{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true},
|
||||||
{"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false, true},
|
{"invalid type", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), false, true},
|
||||||
{"empty string", []byte(`""`), false, false},
|
{"empty string", []byte(`""`), false, false},
|
||||||
{"json null", []byte(`null`), false, false},
|
{"json null", []byte(`null`), false, false},
|
||||||
{"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true, false},
|
{"valid csr", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), true, false},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -360,7 +360,7 @@ func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) {
|
||||||
{"empty csr (null)", `{"csr":null}`, false, false},
|
{"empty csr (null)", `{"csr":null}`, false, false},
|
||||||
{"empty csr (string)", `{"csr":""}`, false, false},
|
{"empty csr (string)", `{"csr":""}`, false, false},
|
||||||
{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true},
|
{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true},
|
||||||
{"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, true, false},
|
{"valid csr", `{"csr":"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"}`, true, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
type request struct {
|
type request struct {
|
||||||
|
@ -739,7 +739,7 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin
|
||||||
return m.ret1.(bool), m.err
|
return m.ret1.(bool), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) {
|
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||||
if m.getSSHBastion != nil {
|
if m.getSSHBastion != nil {
|
||||||
return m.getSSHBastion(ctx, user, hostname)
|
return m.getSSHBastion(ctx, user, hostname)
|
||||||
}
|
}
|
||||||
|
@ -816,7 +816,7 @@ func Test_caHandler_Root(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil)
|
req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil)
|
||||||
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
|
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
|
||||||
|
|
||||||
expected := []byte(`{"ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`)
|
expected := []byte(`{"ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -860,8 +860,8 @@ func Test_caHandler_Sign(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expected1 := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
expected2 := []byte(`{"crt":"` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -934,7 +934,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
||||||
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -995,7 +995,7 @@ func Test_caHandler_Rekey(t *testing.T) {
|
||||||
{"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest},
|
{"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -1210,7 +1210,7 @@ func Test_caHandler_Roots(t *testing.T) {
|
||||||
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -1256,7 +1256,7 @@ func Test_caHandler_Federation(t *testing.T) {
|
||||||
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
|
@ -50,12 +50,10 @@ func WriteError(w http.ResponseWriter, err error) {
|
||||||
rl.WithFields(map[string]interface{}{
|
rl.WithFields(map[string]interface{}{
|
||||||
"stack-trace": fmt.Sprintf("%+v", e),
|
"stack-trace": fmt.Sprintf("%+v", e),
|
||||||
})
|
})
|
||||||
} else {
|
} else if e, ok := cause.(errs.StackTracer); ok {
|
||||||
if e, ok := cause.(errs.StackTracer); ok {
|
rl.WithFields(map[string]interface{}{
|
||||||
rl.WithFields(map[string]interface{}{
|
"stack-trace": fmt.Sprintf("%+v", e),
|
||||||
"stack-trace": fmt.Sprintf("%+v", e),
|
})
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
10
api/ssh.go
10
api/ssh.go
|
@ -52,7 +52,7 @@ func (s *SSHSignRequest) Validate() error {
|
||||||
return errors.Errorf("unknown certType %s", s.CertType)
|
return errors.Errorf("unknown certType %s", s.CertType)
|
||||||
case len(s.PublicKey) == 0:
|
case len(s.PublicKey) == 0:
|
||||||
return errors.New("missing or empty publicKey")
|
return errors.New("missing or empty publicKey")
|
||||||
case len(s.OTT) == 0:
|
case s.OTT == "":
|
||||||
return errors.New("missing or empty ott")
|
return errors.New("missing or empty ott")
|
||||||
default:
|
default:
|
||||||
// Validate identity signature if provided
|
// Validate identity signature if provided
|
||||||
|
@ -408,18 +408,18 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var config SSHConfigResponse
|
var cfg SSHConfigResponse
|
||||||
switch body.Type {
|
switch body.Type {
|
||||||
case provisioner.SSHUserCert:
|
case provisioner.SSHUserCert:
|
||||||
config.UserTemplates = ts
|
cfg.UserTemplates = ts
|
||||||
case provisioner.SSHHostCert:
|
case provisioner.SSHHostCert:
|
||||||
config.HostTemplates = ts
|
cfg.HostTemplates = ts
|
||||||
default:
|
default:
|
||||||
WriteError(w, errs.InternalServer("it should hot get here"))
|
WriteError(w, errs.InternalServer("it should hot get here"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
JSON(w, config)
|
JSON(w, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||||
|
|
|
@ -2,6 +2,7 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
@ -18,7 +19,7 @@ type SSHRekeyRequest struct {
|
||||||
// Validate validates the SSHSignRekey.
|
// Validate validates the SSHSignRekey.
|
||||||
func (s *SSHRekeyRequest) Validate() error {
|
func (s *SSHRekeyRequest) Validate() error {
|
||||||
switch {
|
switch {
|
||||||
case len(s.OTT) == 0:
|
case s.OTT == "":
|
||||||
return errors.New("missing or empty ott")
|
return errors.New("missing or empty ott")
|
||||||
case len(s.PublicKey) == 0:
|
case len(s.PublicKey) == 0:
|
||||||
return errors.New("missing or empty public key")
|
return errors.New("missing or empty public key")
|
||||||
|
@ -72,7 +73,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := h.renewIdentityCertificate(r)
|
// Match identity cert with the SSH cert
|
||||||
|
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||||
|
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||||
|
|
||||||
|
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/x509"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
@ -16,7 +18,7 @@ type SSHRenewRequest struct {
|
||||||
// Validate validates the SSHSignRequest.
|
// Validate validates the SSHSignRequest.
|
||||||
func (s *SSHRenewRequest) Validate() error {
|
func (s *SSHRenewRequest) Validate() error {
|
||||||
switch {
|
switch {
|
||||||
case len(s.OTT) == 0:
|
case s.OTT == "":
|
||||||
return errors.New("missing or empty ott")
|
return errors.New("missing or empty ott")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
@ -62,7 +64,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := h.renewIdentityCertificate(r)
|
// Match identity cert with the SSH cert
|
||||||
|
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
|
||||||
|
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
|
||||||
|
|
||||||
|
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err))
|
||||||
return
|
return
|
||||||
|
@ -74,13 +80,28 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
}, http.StatusCreated)
|
}, http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
// renewIdentityCertificate request the client TLS certificate if present.
|
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
|
||||||
func (h *caHandler) renewIdentityCertificate(r *http.Request) ([]Certificate, error) {
|
func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
|
// Clone the certificate as we can modify it.
|
||||||
|
cert, err := x509.ParseCertificate(r.TLS.PeerCertificates[0].Raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error parsing client certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enforce the cert to match another certificate, for example an ssh
|
||||||
|
// certificate.
|
||||||
|
if !notBefore.IsZero() {
|
||||||
|
cert.NotBefore = notBefore
|
||||||
|
}
|
||||||
|
if !notAfter.IsZero() {
|
||||||
|
cert.NotAfter = notAfter
|
||||||
|
}
|
||||||
|
|
||||||
|
certChain, err := h.Authority.Renew(cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
||||||
if !r.Passive {
|
if !r.Passive {
|
||||||
return errs.NotImplemented("non-passive revocation not implemented")
|
return errs.NotImplemented("non-passive revocation not implemented")
|
||||||
}
|
}
|
||||||
if len(r.OTT) == 0 {
|
if r.OTT == "" {
|
||||||
return errs.BadRequest("missing ott")
|
return errs.BadRequest("missing ott")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -284,7 +284,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
identityCerts := []*x509.Certificate{
|
identityCerts := []*x509.Certificate{
|
||||||
parseCertificate(certPEM),
|
parseCertificate(certPEM),
|
||||||
}
|
}
|
||||||
identityCertsPEM := []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`)
|
identityCertsPEM := []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -27,7 +27,7 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
||||||
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
tok := r.Header.Get("Authorization")
|
tok := r.Header.Get("Authorization")
|
||||||
if len(tok) == 0 {
|
if tok == "" {
|
||||||
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
|
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
|
||||||
"missing authorization header token"))
|
"missing authorization header token"))
|
||||||
return
|
return
|
||||||
|
|
|
@ -54,7 +54,7 @@ func UnmarshalProvisionerDetails(typ linkedca.Provisioner_Type, data []byte) (*l
|
||||||
return &linkedca.ProvisionerDetails{Data: v.Data}, nil
|
return &linkedca.ProvisionerDetails{Data: v.Data}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DB is the DB interface expected by the step-ca ACME API.
|
// DB is the DB interface expected by the step-ca Admin API.
|
||||||
type DB interface {
|
type DB interface {
|
||||||
CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error
|
CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error
|
||||||
GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error)
|
GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error)
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
"github.com/smallstep/nosql/database"
|
"github.com/smallstep/nosql/database"
|
||||||
nosqldb "github.com/smallstep/nosql/database"
|
|
||||||
"go.step.sm/linkedca"
|
"go.step.sm/linkedca"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
)
|
)
|
||||||
|
@ -32,7 +31,7 @@ func TestDB_getDBAdminBytes(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -67,8 +66,8 @@ func TestDB_getDBAdminBytes(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if b, err := db.getDBAdminBytes(context.Background(), adminID); err != nil {
|
if b, err := d.getDBAdminBytes(context.Background(), adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -83,10 +82,8 @@ func TestDB_getDBAdminBytes(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Equals(t, string(b), "foo")
|
||||||
assert.Equals(t, string(b), "foo")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -108,7 +105,7 @@ func TestDB_getDBAdmin(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -193,8 +190,8 @@ func TestDB_getDBAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if dba, err := db.getDBAdmin(context.Background(), adminID); err != nil {
|
if dba, err := d.getDBAdmin(context.Background(), adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -209,16 +206,14 @@ func TestDB_getDBAdmin(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
assert.Equals(t, dba.ID, adminID)
|
||||||
assert.Equals(t, dba.ID, adminID)
|
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
|
||||||
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
|
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
|
||||||
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
|
assert.Equals(t, dba.Subject, tc.dba.Subject)
|
||||||
assert.Equals(t, dba.Subject, tc.dba.Subject)
|
assert.Equals(t, dba.Type, tc.dba.Type)
|
||||||
assert.Equals(t, dba.Type, tc.dba.Type)
|
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
|
||||||
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
|
assert.Fatal(t, dba.DeletedAt.IsZero())
|
||||||
assert.Fatal(t, dba.DeletedAt.IsZero())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -283,8 +278,8 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{authorityID: admin.DefaultAuthorityID}
|
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||||
if dba, err := db.unmarshalDBAdmin(tc.in, adminID); err != nil {
|
if dba, err := d.unmarshalDBAdmin(tc.in, adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -299,16 +294,14 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
assert.Equals(t, dba.ID, adminID)
|
||||||
assert.Equals(t, dba.ID, adminID)
|
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
|
||||||
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
|
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
|
||||||
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
|
assert.Equals(t, dba.Subject, tc.dba.Subject)
|
||||||
assert.Equals(t, dba.Subject, tc.dba.Subject)
|
assert.Equals(t, dba.Type, tc.dba.Type)
|
||||||
assert.Equals(t, dba.Type, tc.dba.Type)
|
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
|
||||||
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
|
assert.Fatal(t, dba.DeletedAt.IsZero())
|
||||||
assert.Fatal(t, dba.DeletedAt.IsZero())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -360,8 +353,8 @@ func TestDB_unmarshalAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{authorityID: admin.DefaultAuthorityID}
|
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||||
if adm, err := db.unmarshalAdmin(tc.in, adminID); err != nil {
|
if adm, err := d.unmarshalAdmin(tc.in, adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -376,16 +369,14 @@ func TestDB_unmarshalAdmin(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
assert.Equals(t, adm.Id, adminID)
|
||||||
assert.Equals(t, adm.Id, adminID)
|
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
|
||||||
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
|
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
|
||||||
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
|
assert.Equals(t, adm.Subject, tc.dba.Subject)
|
||||||
assert.Equals(t, adm.Subject, tc.dba.Subject)
|
assert.Equals(t, adm.Type, tc.dba.Type)
|
||||||
assert.Equals(t, adm.Type, tc.dba.Type)
|
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
|
||||||
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
|
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
|
||||||
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -407,7 +398,7 @@ func TestDB_GetAdmin(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -516,8 +507,8 @@ func TestDB_GetAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if adm, err := db.GetAdmin(context.Background(), adminID); err != nil {
|
if adm, err := d.GetAdmin(context.Background(), adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -532,16 +523,14 @@ func TestDB_GetAdmin(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
assert.Equals(t, adm.Id, adminID)
|
||||||
assert.Equals(t, adm.Id, adminID)
|
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
|
||||||
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
|
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
|
||||||
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
|
assert.Equals(t, adm.Subject, tc.dba.Subject)
|
||||||
assert.Equals(t, adm.Subject, tc.dba.Subject)
|
assert.Equals(t, adm.Type, tc.dba.Type)
|
||||||
assert.Equals(t, adm.Type, tc.dba.Type)
|
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
|
||||||
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
|
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
|
||||||
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -562,7 +551,7 @@ func TestDB_DeleteAdmin(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -670,8 +659,8 @@ func TestDB_DeleteAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.DeleteAdmin(context.Background(), adminID); err != nil {
|
if err := d.DeleteAdmin(context.Background(), adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -708,7 +697,7 @@ func TestDB_UpdateAdmin(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -821,8 +810,8 @@ func TestDB_UpdateAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.UpdateAdmin(context.Background(), tc.adm); err != nil {
|
if err := d.UpdateAdmin(context.Background(), tc.adm); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -919,8 +908,8 @@ func TestDB_CreateAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.CreateAdmin(context.Background(), tc.adm); err != nil {
|
if err := d.CreateAdmin(context.Background(), tc.adm); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -1095,8 +1084,8 @@ func TestDB_GetAdmins(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if admins, err := db.GetAdmins(context.Background()); err != nil {
|
if admins, err := d.GetAdmins(context.Background()); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -1111,10 +1100,8 @@ func TestDB_GetAdmins(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
tc.verify(t, admins)
|
||||||
tc.verify(t, admins)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ func New(db nosqlDB.DB, authorityID string) (*DB, error) {
|
||||||
|
|
||||||
// save writes the new data to the database, overwriting the old data if it
|
// save writes the new data to the database, overwriting the old data if it
|
||||||
// existed.
|
// existed.
|
||||||
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
|
func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
newB []byte
|
newB []byte
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
"github.com/smallstep/nosql/database"
|
"github.com/smallstep/nosql/database"
|
||||||
nosqldb "github.com/smallstep/nosql/database"
|
|
||||||
"go.step.sm/linkedca"
|
"go.step.sm/linkedca"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,7 +30,7 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -66,8 +65,8 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if b, err := db.getDBProvisionerBytes(context.Background(), provID); err != nil {
|
if b, err := d.getDBProvisionerBytes(context.Background(), provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -82,10 +81,8 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
assert.Equals(t, string(b), "foo")
|
||||||
assert.Equals(t, string(b), "foo")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -107,7 +104,7 @@ func TestDB_getDBProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -190,8 +187,8 @@ func TestDB_getDBProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if dbp, err := db.getDBProvisioner(context.Background(), provID); err != nil {
|
if dbp, err := d.getDBProvisioner(context.Background(), provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -206,15 +203,13 @@ func TestDB_getDBProvisioner(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
assert.Equals(t, dbp.ID, provID)
|
||||||
assert.Equals(t, dbp.ID, provID)
|
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
|
||||||
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
|
assert.Equals(t, dbp.Type, tc.dbp.Type)
|
||||||
assert.Equals(t, dbp.Type, tc.dbp.Type)
|
assert.Equals(t, dbp.Name, tc.dbp.Name)
|
||||||
assert.Equals(t, dbp.Name, tc.dbp.Name)
|
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
|
||||||
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
|
assert.Fatal(t, dbp.DeletedAt.IsZero())
|
||||||
assert.Fatal(t, dbp.DeletedAt.IsZero())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -278,8 +273,8 @@ func TestDB_unmarshalDBProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{authorityID: admin.DefaultAuthorityID}
|
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||||
if dbp, err := db.unmarshalDBProvisioner(tc.in, provID); err != nil {
|
if dbp, err := d.unmarshalDBProvisioner(tc.in, provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -294,19 +289,17 @@ func TestDB_unmarshalDBProvisioner(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
assert.Equals(t, dbp.ID, provID)
|
||||||
assert.Equals(t, dbp.ID, provID)
|
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
|
||||||
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
|
assert.Equals(t, dbp.Type, tc.dbp.Type)
|
||||||
assert.Equals(t, dbp.Type, tc.dbp.Type)
|
assert.Equals(t, dbp.Name, tc.dbp.Name)
|
||||||
assert.Equals(t, dbp.Name, tc.dbp.Name)
|
assert.Equals(t, dbp.Details, tc.dbp.Details)
|
||||||
assert.Equals(t, dbp.Details, tc.dbp.Details)
|
assert.Equals(t, dbp.Claims, tc.dbp.Claims)
|
||||||
assert.Equals(t, dbp.Claims, tc.dbp.Claims)
|
assert.Equals(t, dbp.X509Template, tc.dbp.X509Template)
|
||||||
assert.Equals(t, dbp.X509Template, tc.dbp.X509Template)
|
assert.Equals(t, dbp.SSHTemplate, tc.dbp.SSHTemplate)
|
||||||
assert.Equals(t, dbp.SSHTemplate, tc.dbp.SSHTemplate)
|
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
|
||||||
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
|
assert.Fatal(t, dbp.DeletedAt.IsZero())
|
||||||
assert.Fatal(t, dbp.DeletedAt.IsZero())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -402,8 +395,8 @@ func TestDB_unmarshalProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{authorityID: admin.DefaultAuthorityID}
|
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||||
if prov, err := db.unmarshalProvisioner(tc.in, provID); err != nil {
|
if prov, err := d.unmarshalProvisioner(tc.in, provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -418,20 +411,18 @@ func TestDB_unmarshalProvisioner(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
assert.Equals(t, prov.Id, provID)
|
||||||
assert.Equals(t, prov.Id, provID)
|
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
|
||||||
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
|
assert.Equals(t, prov.Type, tc.dbp.Type)
|
||||||
assert.Equals(t, prov.Type, tc.dbp.Type)
|
assert.Equals(t, prov.Name, tc.dbp.Name)
|
||||||
assert.Equals(t, prov.Name, tc.dbp.Name)
|
assert.Equals(t, prov.Claims, tc.dbp.Claims)
|
||||||
assert.Equals(t, prov.Claims, tc.dbp.Claims)
|
assert.Equals(t, prov.X509Template, tc.dbp.X509Template)
|
||||||
assert.Equals(t, prov.X509Template, tc.dbp.X509Template)
|
assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
|
||||||
assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
|
|
||||||
|
|
||||||
retDetailsBytes, err := json.Marshal(prov.Details.GetData())
|
retDetailsBytes, err := json.Marshal(prov.Details.GetData())
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, retDetailsBytes, tc.dbp.Details)
|
assert.Equals(t, retDetailsBytes, tc.dbp.Details)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -453,7 +444,7 @@ func TestDB_GetProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -542,8 +533,8 @@ func TestDB_GetProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if prov, err := db.GetProvisioner(context.Background(), provID); err != nil {
|
if prov, err := d.GetProvisioner(context.Background(), provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -558,20 +549,18 @@ func TestDB_GetProvisioner(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
assert.Equals(t, prov.Id, provID)
|
||||||
assert.Equals(t, prov.Id, provID)
|
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
|
||||||
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
|
assert.Equals(t, prov.Type, tc.dbp.Type)
|
||||||
assert.Equals(t, prov.Type, tc.dbp.Type)
|
assert.Equals(t, prov.Name, tc.dbp.Name)
|
||||||
assert.Equals(t, prov.Name, tc.dbp.Name)
|
assert.Equals(t, prov.Claims, tc.dbp.Claims)
|
||||||
assert.Equals(t, prov.Claims, tc.dbp.Claims)
|
assert.Equals(t, prov.X509Template, tc.dbp.X509Template)
|
||||||
assert.Equals(t, prov.X509Template, tc.dbp.X509Template)
|
assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
|
||||||
assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
|
|
||||||
|
|
||||||
retDetailsBytes, err := json.Marshal(prov.Details.GetData())
|
retDetailsBytes, err := json.Marshal(prov.Details.GetData())
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, retDetailsBytes, tc.dbp.Details)
|
assert.Equals(t, retDetailsBytes, tc.dbp.Details)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -592,7 +581,7 @@ func TestDB_DeleteProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -692,8 +681,8 @@ func TestDB_DeleteProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.DeleteProvisioner(context.Background(), provID); err != nil {
|
if err := d.DeleteProvisioner(context.Background(), provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -853,8 +842,8 @@ func TestDB_GetProvisioners(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if provs, err := db.GetProvisioners(context.Background()); err != nil {
|
if provs, err := d.GetProvisioners(context.Background()); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -869,10 +858,8 @@ func TestDB_GetProvisioners(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
tc.verify(t, provs)
|
||||||
tc.verify(t, provs)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -963,8 +950,8 @@ func TestDB_CreateProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.CreateProvisioner(context.Background(), tc.prov); err != nil {
|
if err := d.CreateProvisioner(context.Background(), tc.prov); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -1001,7 +988,7 @@ func TestDB_UpdateProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -1199,8 +1186,8 @@ func TestDB_UpdateProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.UpdateProvisioner(context.Background(), tc.prov); err != nil {
|
if err := d.UpdateProvisioner(context.Background(), tc.prov); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
|
|
@ -55,8 +55,8 @@ type subProv struct {
|
||||||
provisioner string
|
provisioner string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSubProv(subject, provisioner string) subProv {
|
func newSubProv(subject, prov string) subProv {
|
||||||
return subProv{subject, provisioner}
|
return subProv{subject, prov}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadBySubProv a admin by the subject and provisioner name.
|
// LoadBySubProv a admin by the subject and provisioner name.
|
||||||
|
|
|
@ -16,10 +16,10 @@ func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadAdminBySubProv returns an *linkedca.Admin with the given ID.
|
// LoadAdminBySubProv returns an *linkedca.Admin with the given ID.
|
||||||
func (a *Authority) LoadAdminBySubProv(subject, provisioner string) (*linkedca.Admin, bool) {
|
func (a *Authority) LoadAdminBySubProv(subject, prov string) (*linkedca.Admin, bool) {
|
||||||
a.adminMutex.RLock()
|
a.adminMutex.RLock()
|
||||||
defer a.adminMutex.RUnlock()
|
defer a.adminMutex.RUnlock()
|
||||||
return a.admins.LoadBySubProv(subject, provisioner)
|
return a.admins.LoadBySubProv(subject, prov)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdmins returns a map listing each provisioner and the JWK Key Set
|
// GetAdmins returns a map listing each provisioner and the JWK Key Set
|
||||||
|
|
|
@ -7,41 +7,44 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"log"
|
"log"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/cas"
|
|
||||||
"github.com/smallstep/certificates/scep"
|
|
||||||
"go.step.sm/linkedca"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
|
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||||
"github.com/smallstep/certificates/authority/administrator"
|
"github.com/smallstep/certificates/authority/administrator"
|
||||||
"github.com/smallstep/certificates/authority/config"
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/cas"
|
||||||
casapi "github.com/smallstep/certificates/cas/apiv1"
|
casapi "github.com/smallstep/certificates/cas/apiv1"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/certificates/kms"
|
"github.com/smallstep/certificates/kms"
|
||||||
kmsapi "github.com/smallstep/certificates/kms/apiv1"
|
kmsapi "github.com/smallstep/certificates/kms/apiv1"
|
||||||
"github.com/smallstep/certificates/kms/sshagentkms"
|
"github.com/smallstep/certificates/kms/sshagentkms"
|
||||||
|
"github.com/smallstep/certificates/scep"
|
||||||
"github.com/smallstep/certificates/templates"
|
"github.com/smallstep/certificates/templates"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
|
"go.step.sm/linkedca"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authority implements the Certificate Authority internal interface.
|
// Authority implements the Certificate Authority internal interface.
|
||||||
type Authority struct {
|
type Authority struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
keyManager kms.KeyManager
|
keyManager kms.KeyManager
|
||||||
provisioners *provisioner.Collection
|
provisioners *provisioner.Collection
|
||||||
admins *administrator.Collection
|
admins *administrator.Collection
|
||||||
db db.AuthDB
|
db db.AuthDB
|
||||||
adminDB admin.DB
|
adminDB admin.DB
|
||||||
templates *templates.Templates
|
templates *templates.Templates
|
||||||
|
linkedCAToken string
|
||||||
|
|
||||||
// X509 CA
|
// X509 CA
|
||||||
|
password []byte
|
||||||
|
issuerPassword []byte
|
||||||
x509CAService cas.CertificateAuthorityService
|
x509CAService cas.CertificateAuthorityService
|
||||||
rootX509Certs []*x509.Certificate
|
rootX509Certs []*x509.Certificate
|
||||||
rootX509CertPool *x509.CertPool
|
rootX509CertPool *x509.CertPool
|
||||||
|
@ -52,6 +55,8 @@ type Authority struct {
|
||||||
scepService *scep.Service
|
scepService *scep.Service
|
||||||
|
|
||||||
// SSH CA
|
// SSH CA
|
||||||
|
sshHostPassword []byte
|
||||||
|
sshUserPassword []byte
|
||||||
sshCAUserCertSignKey ssh.Signer
|
sshCAUserCertSignKey ssh.Signer
|
||||||
sshCAHostCertSignKey ssh.Signer
|
sshCAHostCertSignKey ssh.Signer
|
||||||
sshCAUserCerts []ssh.PublicKey
|
sshCAUserCerts []ssh.PublicKey
|
||||||
|
@ -73,14 +78,14 @@ type Authority struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates and initiates a new Authority type.
|
// New creates and initiates a new Authority type.
|
||||||
func New(config *config.Config, opts ...Option) (*Authority, error) {
|
func New(cfg *config.Config, opts ...Option) (*Authority, error) {
|
||||||
err := config.Validate()
|
err := cfg.Validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var a = &Authority{
|
var a = &Authority{
|
||||||
config: config,
|
config: cfg,
|
||||||
certificates: new(sync.Map),
|
certificates: new(sync.Map),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,6 +210,26 @@ func (a *Authority) init() error {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
// Set password if they are not set.
|
||||||
|
var configPassword []byte
|
||||||
|
if a.config.Password != "" {
|
||||||
|
configPassword = []byte(a.config.Password)
|
||||||
|
}
|
||||||
|
if configPassword != nil && a.password == nil {
|
||||||
|
a.password = configPassword
|
||||||
|
}
|
||||||
|
if a.sshHostPassword == nil {
|
||||||
|
a.sshHostPassword = a.password
|
||||||
|
}
|
||||||
|
if a.sshUserPassword == nil {
|
||||||
|
a.sshUserPassword = a.password
|
||||||
|
}
|
||||||
|
|
||||||
|
// Automatically enable admin for all linked cas.
|
||||||
|
if a.linkedCAToken != "" {
|
||||||
|
a.config.AuthorityConfig.EnableAdmin = true
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize step-ca Database if it's not already initialized with WithDB.
|
// Initialize step-ca Database if it's not already initialized with WithDB.
|
||||||
// If a.config.DB is nil then a simple, barebones in memory DB will be used.
|
// If a.config.DB is nil then a simple, barebones in memory DB will be used.
|
||||||
if a.db == nil {
|
if a.db == nil {
|
||||||
|
@ -232,6 +257,11 @@ func (a *Authority) init() error {
|
||||||
options = *a.config.AuthorityConfig.Options
|
options = *a.config.AuthorityConfig.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the issuer password if passed in the flags.
|
||||||
|
if options.CertificateIssuer != nil && a.issuerPassword != nil {
|
||||||
|
options.CertificateIssuer.Password = string(a.issuerPassword)
|
||||||
|
}
|
||||||
|
|
||||||
// Read intermediate and create X509 signer for default CAS.
|
// Read intermediate and create X509 signer for default CAS.
|
||||||
if options.Is(casapi.SoftCAS) {
|
if options.Is(casapi.SoftCAS) {
|
||||||
options.CertificateChain, err = pemutil.ReadCertificateBundle(a.config.IntermediateCert)
|
options.CertificateChain, err = pemutil.ReadCertificateBundle(a.config.IntermediateCert)
|
||||||
|
@ -240,7 +270,7 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
||||||
SigningKey: a.config.IntermediateKey,
|
SigningKey: a.config.IntermediateKey,
|
||||||
Password: []byte(a.config.Password),
|
Password: []byte(a.password),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -309,7 +339,7 @@ func (a *Authority) init() error {
|
||||||
if a.config.SSH.HostKey != "" {
|
if a.config.SSH.HostKey != "" {
|
||||||
signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
||||||
SigningKey: a.config.SSH.HostKey,
|
SigningKey: a.config.SSH.HostKey,
|
||||||
Password: []byte(a.config.Password),
|
Password: []byte(a.sshHostPassword),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -335,7 +365,7 @@ func (a *Authority) init() error {
|
||||||
if a.config.SSH.UserKey != "" {
|
if a.config.SSH.UserKey != "" {
|
||||||
signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
||||||
SigningKey: a.config.SSH.UserKey,
|
SigningKey: a.config.SSH.UserKey,
|
||||||
Password: []byte(a.config.Password),
|
Password: []byte(a.sshUserPassword),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -359,33 +389,45 @@ func (a *Authority) init() error {
|
||||||
a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, a.sshCAUserCertSignKey.PublicKey())
|
a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, a.sshCAUserCertSignKey.PublicKey())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append other public keys
|
// Append other public keys and add them to the template variables.
|
||||||
for _, key := range a.config.SSH.Keys {
|
for _, key := range a.config.SSH.Keys {
|
||||||
|
publicKey := key.PublicKey()
|
||||||
switch key.Type {
|
switch key.Type {
|
||||||
case provisioner.SSHHostCert:
|
case provisioner.SSHHostCert:
|
||||||
if key.Federated {
|
if key.Federated {
|
||||||
a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, key.PublicKey())
|
a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, publicKey)
|
||||||
} else {
|
} else {
|
||||||
a.sshCAHostCerts = append(a.sshCAHostCerts, key.PublicKey())
|
a.sshCAHostCerts = append(a.sshCAHostCerts, publicKey)
|
||||||
}
|
}
|
||||||
case provisioner.SSHUserCert:
|
case provisioner.SSHUserCert:
|
||||||
if key.Federated {
|
if key.Federated {
|
||||||
a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, key.PublicKey())
|
a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, publicKey)
|
||||||
} else {
|
} else {
|
||||||
a.sshCAUserCerts = append(a.sshCAUserCerts, key.PublicKey())
|
a.sshCAUserCerts = append(a.sshCAUserCerts, publicKey)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return errors.Errorf("unsupported type %s", key.Type)
|
return errors.Errorf("unsupported type %s", key.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Configure template variables.
|
// Configure template variables. On the template variables HostFederatedKeys
|
||||||
|
// and UserFederatedKeys we will skip the actual CA that will be available
|
||||||
|
// in HostKey and UserKey.
|
||||||
|
//
|
||||||
|
// We cannot do it in the previous blocks because this configuration can be
|
||||||
|
// injected using options.
|
||||||
|
if a.sshCAHostCertSignKey != nil {
|
||||||
tmplVars.SSH.HostKey = a.sshCAHostCertSignKey.PublicKey()
|
tmplVars.SSH.HostKey = a.sshCAHostCertSignKey.PublicKey()
|
||||||
tmplVars.SSH.UserKey = a.sshCAUserCertSignKey.PublicKey()
|
|
||||||
// On the templates we skip the first one because there's a distinction
|
|
||||||
// between the main key and federated keys.
|
|
||||||
tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts[1:]...)
|
tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts[1:]...)
|
||||||
|
} else {
|
||||||
|
tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts...)
|
||||||
|
}
|
||||||
|
if a.sshCAUserCertSignKey != nil {
|
||||||
|
tmplVars.SSH.UserKey = a.sshCAUserCertSignKey.PublicKey()
|
||||||
tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts[1:]...)
|
tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts[1:]...)
|
||||||
|
} else {
|
||||||
|
tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if a KMS with decryption capability is required and available
|
// Check if a KMS with decryption capability is required and available
|
||||||
|
@ -414,7 +456,7 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
|
||||||
SigningKey: a.config.IntermediateKey,
|
SigningKey: a.config.IntermediateKey,
|
||||||
Password: []byte(a.config.Password),
|
Password: []byte(a.password),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -423,7 +465,7 @@ func (a *Authority) init() error {
|
||||||
if km, ok := a.keyManager.(kmsapi.Decrypter); ok {
|
if km, ok := a.keyManager.(kmsapi.Decrypter); ok {
|
||||||
options.Decrypter, err = km.CreateDecrypter(&kmsapi.CreateDecrypterRequest{
|
options.Decrypter, err = km.CreateDecrypter(&kmsapi.CreateDecrypterRequest{
|
||||||
DecryptionKey: a.config.IntermediateKey,
|
DecryptionKey: a.config.IntermediateKey,
|
||||||
Password: []byte(a.config.Password),
|
Password: []byte(a.password),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -442,10 +484,24 @@ func (a *Authority) init() error {
|
||||||
// Initialize step-ca Admin Database if it's not already initialized using
|
// Initialize step-ca Admin Database if it's not already initialized using
|
||||||
// WithAdminDB.
|
// WithAdminDB.
|
||||||
if a.adminDB == nil {
|
if a.adminDB == nil {
|
||||||
// Check if AuthConfig already exists
|
if a.linkedCAToken == "" {
|
||||||
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
|
// Check if AuthConfig already exists
|
||||||
if err != nil {
|
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Use the linkedca client as the admindb.
|
||||||
|
client, err := newLinkedCAClient(a.linkedCAToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// If authorityId is configured make sure it matches the one in the token
|
||||||
|
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) {
|
||||||
|
return errors.New("error initializing linkedca: token authority and configured authority do not match")
|
||||||
|
}
|
||||||
|
client.Run()
|
||||||
|
a.adminDB = client
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -453,9 +509,9 @@ func (a *Authority) init() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
||||||
}
|
}
|
||||||
if len(provs) == 0 {
|
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
|
||||||
// Create First Provisioner
|
// Create First Provisioner
|
||||||
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, a.config.Password)
|
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return admin.WrapErrorISE(err, "error creating first provisioner")
|
return admin.WrapErrorISE(err, "error creating first provisioner")
|
||||||
}
|
}
|
||||||
|
@ -527,6 +583,9 @@ func (a *Authority) CloseForReload() {
|
||||||
if err := a.keyManager.Close(); err != nil {
|
if err := a.keyManager.Close(); err != nil {
|
||||||
log.Printf("error closing the key manager: %v", err)
|
log.Printf("error closing the key manager: %v", err)
|
||||||
}
|
}
|
||||||
|
if client, ok := a.adminDB.(*linkedCaClient); ok {
|
||||||
|
client.Stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// requiresDecrypter returns whether the Authority
|
// requiresDecrypter returns whether the Authority
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
@ -82,6 +83,10 @@ func testAuthority(t *testing.T, opts ...Option) *Authority {
|
||||||
}
|
}
|
||||||
a, err := New(c, opts...)
|
a, err := New(c, opts...)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
// Avoid errors when test tokens are created before the test authority. This
|
||||||
|
// happens in some tests where we re-create the same authority to test
|
||||||
|
// special cases without re-creating the token.
|
||||||
|
a.startTime = a.startTime.Add(-1 * time.Minute)
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -53,7 +54,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
||||||
// key in order to verify the claims and we need the issuer from the claims
|
// key in order to verify the claims and we need the issuer from the claims
|
||||||
// before we can look up the provisioner.
|
// before we can look up the provisioner.
|
||||||
var claims Claims
|
var claims Claims
|
||||||
if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
|
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +77,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
||||||
// Store the token to protect against reuse unless it's skipped.
|
// Store the token to protect against reuse unless it's skipped.
|
||||||
// If we cannot get a token id from the provisioner, just hash the token.
|
// If we cannot get a token id from the provisioner, just hash the token.
|
||||||
if !SkipTokenReuseFromContext(ctx) {
|
if !SkipTokenReuseFromContext(ctx) {
|
||||||
if err = a.UseToken(token, p); err != nil {
|
if err := a.UseToken(token, p); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -111,7 +112,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
|
||||||
// to the public certificate in the `x5c` header of the token.
|
// to the public certificate in the `x5c` header of the token.
|
||||||
// 2. Asserts that the claims are valid - have not been tampered with.
|
// 2. Asserts that the claims are valid - have not been tampered with.
|
||||||
var claims jose.Claims
|
var claims jose.Claims
|
||||||
if err = jwt.Claims(leaf.PublicKey, &claims); err != nil {
|
if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
|
||||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims")
|
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,13 +122,13 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the token has not been used.
|
// Check that the token has not been used.
|
||||||
if err = a.UseToken(token, prov); err != nil {
|
if err := a.UseToken(token, prov); err != nil {
|
||||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token")
|
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||||
// more than a few minutes.
|
// more than a few minutes.
|
||||||
if err = claims.ValidateWithLeeway(jose.Expected{
|
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||||
Issuer: prov.GetName(),
|
Issuer: prov.GetName(),
|
||||||
Time: time.Now().UTC(),
|
Time: time.Now().UTC(),
|
||||||
}, time.Minute); err != nil {
|
}, time.Minute); err != nil {
|
||||||
|
@ -173,6 +174,9 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
|
||||||
}
|
}
|
||||||
|
|
||||||
// UseToken stores the token to protect against reuse.
|
// UseToken stores the token to protect against reuse.
|
||||||
|
//
|
||||||
|
// This method currently ignores any error coming from the GetTokenID, but it
|
||||||
|
// should specifically ignore the error provisioner.ErrAllowTokenReuse.
|
||||||
func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
|
func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
|
||||||
if reuseKey, err := prov.GetTokenID(token); err == nil {
|
if reuseKey, err := prov.GetTokenID(token); err == nil {
|
||||||
if reuseKey == "" {
|
if reuseKey == "" {
|
||||||
|
@ -258,7 +262,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
||||||
}
|
}
|
||||||
if err = p.AuthorizeRevoke(ctx, token); err != nil {
|
if err := p.AuthorizeRevoke(ctx, token); err != nil {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -270,10 +274,19 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
||||||
//
|
//
|
||||||
// TODO(mariano): should we authorize by default?
|
// TODO(mariano): should we authorize by default?
|
||||||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||||
|
var err error
|
||||||
|
var isRevoked bool
|
||||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
|
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
|
||||||
|
|
||||||
// Check the passive revocation table.
|
// Check the passive revocation table.
|
||||||
isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String())
|
serial := cert.SerialNumber.String()
|
||||||
|
if lca, ok := a.adminDB.(interface {
|
||||||
|
IsRevoked(string) (bool, error)
|
||||||
|
}); ok {
|
||||||
|
isRevoked, err = lca.IsRevoked(serial)
|
||||||
|
} else {
|
||||||
|
isRevoked, err = a.db.IsRevoked(serial)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||||
}
|
}
|
||||||
|
@ -291,6 +304,28 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authorizeSSHCertificate returns an error if the given certificate is revoked.
|
||||||
|
func (a *Authority) authorizeSSHCertificate(ctx context.Context, cert *ssh.Certificate) error {
|
||||||
|
var err error
|
||||||
|
var isRevoked bool
|
||||||
|
|
||||||
|
serial := strconv.FormatUint(cert.Serial, 10)
|
||||||
|
if lca, ok := a.adminDB.(interface {
|
||||||
|
IsSSHRevoked(string) (bool, error)
|
||||||
|
}); ok {
|
||||||
|
isRevoked, err = lca.IsSSHRevoked(serial)
|
||||||
|
} else {
|
||||||
|
isRevoked, err = a.db.IsSSHRevoked(serial)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHCertificate", errs.WithKeyVal("serialNumber", serial))
|
||||||
|
}
|
||||||
|
if isRevoked {
|
||||||
|
return errs.Unauthorized("authority.authorizeSSHCertificate: certificate has been revoked", errs.WithKeyVal("serialNumber", serial))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// authorizeSSHSign loads the provisioner from the token, checks that it has not
|
// authorizeSSHSign loads the provisioner from the token, checks that it has not
|
||||||
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a
|
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a
|
||||||
// list of methods to apply to the signing flow.
|
// list of methods to apply to the signing flow.
|
||||||
|
|
|
@ -917,7 +917,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if err = cert.SignCert(rand.Reader, signer); err != nil {
|
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return cert, jwk, nil
|
return cert, jwk, nil
|
||||||
|
|
|
@ -75,6 +75,7 @@ type ASN1DN struct {
|
||||||
Locality string `json:"locality,omitempty"`
|
Locality string `json:"locality,omitempty"`
|
||||||
Province string `json:"province,omitempty"`
|
Province string `json:"province,omitempty"`
|
||||||
StreetAddress string `json:"streetAddress,omitempty"`
|
StreetAddress string `json:"streetAddress,omitempty"`
|
||||||
|
SerialNumber string `json:"serialNumber,omitempty"`
|
||||||
CommonName string `json:"commonName,omitempty"`
|
CommonName string `json:"commonName,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,8 +84,9 @@ type ASN1DN struct {
|
||||||
// cas.Options.
|
// cas.Options.
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
*cas.Options
|
*cas.Options
|
||||||
AuthorityID string `json:"authorityID,omitempty"`
|
AuthorityID string `json:"authorityId,omitempty"`
|
||||||
Provisioners provisioner.List `json:"provisioners"`
|
DeploymentType string `json:"deploymentType,omitempty"`
|
||||||
|
Provisioners provisioner.List `json:"provisioners,omitempty"`
|
||||||
Admins []*linkedca.Admin `json:"-"`
|
Admins []*linkedca.Admin `json:"-"`
|
||||||
Template *ASN1DN `json:"template,omitempty"`
|
Template *ASN1DN `json:"template,omitempty"`
|
||||||
Claims *provisioner.Claims `json:"claims,omitempty"`
|
Claims *provisioner.Claims `json:"claims,omitempty"`
|
||||||
|
@ -188,9 +190,10 @@ func (c *Config) Validate() error {
|
||||||
switch {
|
switch {
|
||||||
case c.Address == "":
|
case c.Address == "":
|
||||||
return errors.New("address cannot be empty")
|
return errors.New("address cannot be empty")
|
||||||
|
|
||||||
case len(c.DNSNames) == 0:
|
case len(c.DNSNames) == 0:
|
||||||
return errors.New("dnsNames cannot be empty")
|
return errors.New("dnsNames cannot be empty")
|
||||||
|
case c.AuthorityConfig == nil:
|
||||||
|
return errors.New("authority cannot be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options holds the RA/CAS configuration.
|
// Options holds the RA/CAS configuration.
|
||||||
|
@ -222,7 +225,7 @@ func (c *Config) Validate() error {
|
||||||
c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion
|
c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion
|
||||||
}
|
}
|
||||||
if c.TLS.MinVersion == 0 {
|
if c.TLS.MinVersion == 0 {
|
||||||
c.TLS.MinVersion = c.TLS.MaxVersion
|
c.TLS.MinVersion = DefaultTLSOptions.MinVersion
|
||||||
}
|
}
|
||||||
if c.TLS.MinVersion > c.TLS.MaxVersion {
|
if c.TLS.MinVersion > c.TLS.MaxVersion {
|
||||||
return errors.New("tls minVersion cannot exceed tls maxVersion")
|
return errors.New("tls minVersion cannot exceed tls maxVersion")
|
||||||
|
|
|
@ -15,8 +15,9 @@ var (
|
||||||
// DefaultTLSRenegotiation default TLS connection renegotiation policy.
|
// DefaultTLSRenegotiation default TLS connection renegotiation policy.
|
||||||
DefaultTLSRenegotiation = false // Never regnegotiate.
|
DefaultTLSRenegotiation = false // Never regnegotiate.
|
||||||
// DefaultTLSCipherSuites specifies default step ciphersuite(s).
|
// DefaultTLSCipherSuites specifies default step ciphersuite(s).
|
||||||
|
// These are TLS 1.0 - 1.2 cipher suites.
|
||||||
DefaultTLSCipherSuites = CipherSuites{
|
DefaultTLSCipherSuites = CipherSuites{
|
||||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
||||||
}
|
}
|
||||||
// ApprovedTLSCipherSuites smallstep approved ciphersuites.
|
// ApprovedTLSCipherSuites smallstep approved ciphersuites.
|
||||||
|
@ -26,25 +27,21 @@ var (
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
|
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
||||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
|
||||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
|
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
|
||||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
|
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
|
||||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
||||||
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
|
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
|
||||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
|
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
|
||||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305",
|
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
|
||||||
}
|
}
|
||||||
// DefaultTLSOptions represents the default TLS version as well as the cipher
|
// DefaultTLSOptions represents the default TLS version as well as the cipher
|
||||||
// suites used in the TLS certificates.
|
// suites used in the TLS certificates.
|
||||||
DefaultTLSOptions = TLSOptions{
|
DefaultTLSOptions = TLSOptions{
|
||||||
CipherSuites: CipherSuites{
|
CipherSuites: DefaultTLSCipherSuites,
|
||||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
MinVersion: DefaultTLSMinVersion,
|
||||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
MaxVersion: DefaultTLSMaxVersion,
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
Renegotiation: DefaultTLSRenegotiation,
|
||||||
},
|
|
||||||
MinVersion: 1.2,
|
|
||||||
MaxVersion: 1.2,
|
|
||||||
Renegotiation: false,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -119,27 +116,38 @@ func (c CipherSuites) Value() []uint16 {
|
||||||
|
|
||||||
// cipherSuites has the list of supported cipher suites.
|
// cipherSuites has the list of supported cipher suites.
|
||||||
var cipherSuites = map[string]uint16{
|
var cipherSuites = map[string]uint16{
|
||||||
"TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
|
// TLS 1.0 - 1.2 cipher suites.
|
||||||
"TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
"TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||||
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
"TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||||
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||||
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||||
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
"TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||||
"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
|
||||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||||
|
|
||||||
|
// TLS 1.3 cipher sutes.
|
||||||
|
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
|
||||||
|
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
|
||||||
|
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||||
|
|
||||||
|
// Legacy names.
|
||||||
|
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||||
|
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSOptions represents the TLS options that can be specified on *tls.Config
|
// TLSOptions represents the TLS options that can be specified on *tls.Config
|
||||||
|
|
|
@ -25,7 +25,7 @@ func (s multiString) HasEmpties() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
for _, ss := range s {
|
for _, ss := range s {
|
||||||
if len(ss) == 0 {
|
if ss == "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
284
authority/export.go
Normal file
284
authority/export.go
Normal file
|
@ -0,0 +1,284 @@
|
||||||
|
package authority
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/url"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"go.step.sm/cli-utils/config"
|
||||||
|
"go.step.sm/linkedca"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Export creates a linkedca configuration form the current ca.json and loaded
|
||||||
|
// authorities.
|
||||||
|
//
|
||||||
|
// Note that export will not export neither the pki password nor the certificate
|
||||||
|
// issuer password.
|
||||||
|
func (a *Authority) Export() (c *linkedca.Configuration, err error) {
|
||||||
|
// Recover from panics
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = r.(error)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
files := make(map[string][]byte)
|
||||||
|
|
||||||
|
// The exported configuration should not include the password in it.
|
||||||
|
c = &linkedca.Configuration{
|
||||||
|
Version: "1.0",
|
||||||
|
Root: mustReadFilesOrURIs(a.config.Root, files),
|
||||||
|
FederatedRoots: mustReadFilesOrURIs(a.config.FederatedRoots, files),
|
||||||
|
Intermediate: mustReadFileOrURI(a.config.IntermediateCert, files),
|
||||||
|
IntermediateKey: mustReadFileOrURI(a.config.IntermediateKey, files),
|
||||||
|
Address: a.config.Address,
|
||||||
|
InsecureAddress: a.config.InsecureAddress,
|
||||||
|
DnsNames: a.config.DNSNames,
|
||||||
|
Db: mustMarshalToStruct(a.config.DB),
|
||||||
|
Logger: mustMarshalToStruct(a.config.Logger),
|
||||||
|
Monitoring: mustMarshalToStruct(a.config.Monitoring),
|
||||||
|
Authority: &linkedca.Authority{
|
||||||
|
Id: a.config.AuthorityConfig.AuthorityID,
|
||||||
|
EnableAdmin: a.config.AuthorityConfig.EnableAdmin,
|
||||||
|
DisableIssuedAtCheck: a.config.AuthorityConfig.DisableIssuedAtCheck,
|
||||||
|
Backdate: mustDuration(a.config.AuthorityConfig.Backdate),
|
||||||
|
DeploymentType: a.config.AuthorityConfig.DeploymentType,
|
||||||
|
},
|
||||||
|
Files: files,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSH
|
||||||
|
if v := a.config.SSH; v != nil {
|
||||||
|
c.Ssh = &linkedca.SSH{
|
||||||
|
HostKey: mustReadFileOrURI(v.HostKey, files),
|
||||||
|
UserKey: mustReadFileOrURI(v.UserKey, files),
|
||||||
|
AddUserPrincipal: v.AddUserPrincipal,
|
||||||
|
AddUserCommand: v.AddUserCommand,
|
||||||
|
}
|
||||||
|
for _, k := range v.Keys {
|
||||||
|
typ, ok := linkedca.SSHPublicKey_Type_value[strings.ToUpper(k.Type)]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("unsupported ssh key type %s", k.Type)
|
||||||
|
}
|
||||||
|
c.Ssh.Keys = append(c.Ssh.Keys, &linkedca.SSHPublicKey{
|
||||||
|
Type: linkedca.SSHPublicKey_Type(typ),
|
||||||
|
Federated: k.Federated,
|
||||||
|
Key: mustMarshalToStruct(k),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if b := v.Bastion; b != nil {
|
||||||
|
c.Ssh.Bastion = &linkedca.Bastion{
|
||||||
|
Hostname: b.Hostname,
|
||||||
|
User: b.User,
|
||||||
|
Port: b.Port,
|
||||||
|
Command: b.Command,
|
||||||
|
Flags: b.Flags,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// KMS
|
||||||
|
if v := a.config.KMS; v != nil {
|
||||||
|
var typ int32
|
||||||
|
var ok bool
|
||||||
|
if v.Type == "" {
|
||||||
|
typ = int32(linkedca.KMS_SOFTKMS)
|
||||||
|
} else {
|
||||||
|
typ, ok = linkedca.KMS_Type_value[strings.ToUpper(v.Type)]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("unsupported kms type %s", v.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Kms = &linkedca.KMS{
|
||||||
|
Type: linkedca.KMS_Type(typ),
|
||||||
|
CredentialsFile: v.CredentialsFile,
|
||||||
|
Uri: v.URI,
|
||||||
|
Pin: v.Pin,
|
||||||
|
ManagementKey: v.ManagementKey,
|
||||||
|
Region: v.Region,
|
||||||
|
Profile: v.Profile,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authority
|
||||||
|
// cas options
|
||||||
|
if v := a.config.AuthorityConfig.Options; v != nil {
|
||||||
|
c.Authority.Type = 0
|
||||||
|
c.Authority.CertificateAuthority = v.CertificateAuthority
|
||||||
|
c.Authority.CertificateAuthorityFingerprint = v.CertificateAuthorityFingerprint
|
||||||
|
c.Authority.CredentialsFile = v.CredentialsFile
|
||||||
|
if iss := v.CertificateIssuer; iss != nil {
|
||||||
|
typ, ok := linkedca.CertificateIssuer_Type_value[strings.ToUpper(iss.Type)]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("unknown certificate issuer type %s", iss.Type)
|
||||||
|
}
|
||||||
|
// The exported certificate issuer should not include the password.
|
||||||
|
c.Authority.CertificateIssuer = &linkedca.CertificateIssuer{
|
||||||
|
Type: linkedca.CertificateIssuer_Type(typ),
|
||||||
|
Provisioner: iss.Provisioner,
|
||||||
|
Certificate: mustReadFileOrURI(iss.Certificate, files),
|
||||||
|
Key: mustReadFileOrURI(iss.Key, files),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// admins
|
||||||
|
for {
|
||||||
|
list, cursor := a.admins.Find("", 100)
|
||||||
|
c.Authority.Admins = append(c.Authority.Admins, list...)
|
||||||
|
if cursor == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// provisioners
|
||||||
|
for {
|
||||||
|
list, cursor := a.provisioners.Find("", 100)
|
||||||
|
for _, p := range list {
|
||||||
|
lp, err := ProvisionerToLinkedca(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.Authority.Provisioners = append(c.Authority.Provisioners, lp)
|
||||||
|
}
|
||||||
|
if cursor == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// global claims
|
||||||
|
c.Authority.Claims = claimsToLinkedca(a.config.AuthorityConfig.Claims)
|
||||||
|
// Distinguished names template
|
||||||
|
if v := a.config.AuthorityConfig.Template; v != nil {
|
||||||
|
c.Authority.Template = &linkedca.DistinguishedName{
|
||||||
|
Country: v.Country,
|
||||||
|
Organization: v.Organization,
|
||||||
|
OrganizationalUnit: v.OrganizationalUnit,
|
||||||
|
Locality: v.Locality,
|
||||||
|
Province: v.Province,
|
||||||
|
StreetAddress: v.StreetAddress,
|
||||||
|
SerialNumber: v.SerialNumber,
|
||||||
|
CommonName: v.CommonName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLS
|
||||||
|
if v := a.config.TLS; v != nil {
|
||||||
|
c.Tls = &linkedca.TLS{
|
||||||
|
MinVersion: v.MinVersion.String(),
|
||||||
|
MaxVersion: v.MaxVersion.String(),
|
||||||
|
Renegotiation: v.Renegotiation,
|
||||||
|
}
|
||||||
|
for _, cs := range v.CipherSuites.Value() {
|
||||||
|
c.Tls.CipherSuites = append(c.Tls.CipherSuites, linkedca.TLS_CiperSuite(cs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Templates
|
||||||
|
if v := a.config.Templates; v != nil {
|
||||||
|
c.Templates = &linkedca.ConfigTemplates{
|
||||||
|
Ssh: &linkedca.SSHConfigTemplate{},
|
||||||
|
Data: mustMarshalToStruct(v.Data),
|
||||||
|
}
|
||||||
|
// Remove automatically loaded vars
|
||||||
|
if c.Templates.Data != nil && c.Templates.Data.Fields != nil {
|
||||||
|
delete(c.Templates.Data.Fields, "Step")
|
||||||
|
}
|
||||||
|
for _, t := range v.SSH.Host {
|
||||||
|
typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("unsupported template type %s", t.Type)
|
||||||
|
}
|
||||||
|
c.Templates.Ssh.Hosts = append(c.Templates.Ssh.Hosts, &linkedca.ConfigTemplate{
|
||||||
|
Type: linkedca.ConfigTemplate_Type(typ),
|
||||||
|
Name: t.Name,
|
||||||
|
Template: mustReadFileOrURI(t.TemplatePath, files),
|
||||||
|
Path: t.Path,
|
||||||
|
Comment: t.Comment,
|
||||||
|
Requires: t.RequiredData,
|
||||||
|
Content: t.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, t := range v.SSH.User {
|
||||||
|
typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("unsupported template type %s", t.Type)
|
||||||
|
}
|
||||||
|
c.Templates.Ssh.Users = append(c.Templates.Ssh.Users, &linkedca.ConfigTemplate{
|
||||||
|
Type: linkedca.ConfigTemplate_Type(typ),
|
||||||
|
Name: t.Name,
|
||||||
|
Template: mustReadFileOrURI(t.TemplatePath, files),
|
||||||
|
Path: t.Path,
|
||||||
|
Comment: t.Comment,
|
||||||
|
Requires: t.RequiredData,
|
||||||
|
Content: t.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustDuration(d *provisioner.Duration) string {
|
||||||
|
if d == nil || d.Duration == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return d.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshalToStruct(v interface{}) *structpb.Struct {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
panic(errors.Wrapf(err, "error marshaling %T", v))
|
||||||
|
}
|
||||||
|
var r *structpb.Struct
|
||||||
|
if err := json.Unmarshal(b, &r); err != nil {
|
||||||
|
panic(errors.Wrapf(err, "error unmarshaling %T", v))
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustReadFileOrURI(fn string, m map[string][]byte) string {
|
||||||
|
if fn == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
stepPath := filepath.ToSlash(config.StepPath())
|
||||||
|
if !strings.HasSuffix(stepPath, "/") {
|
||||||
|
stepPath += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn = strings.TrimPrefix(filepath.ToSlash(fn), stepPath)
|
||||||
|
|
||||||
|
ok, err := isFilename(fn)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
b, err := ioutil.ReadFile(config.StepAbs(fn))
|
||||||
|
if err != nil {
|
||||||
|
panic(errors.Wrapf(err, "error reading %s", fn))
|
||||||
|
}
|
||||||
|
m[fn] = b
|
||||||
|
return fn
|
||||||
|
}
|
||||||
|
return fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustReadFilesOrURIs(fns []string, m map[string][]byte) []string {
|
||||||
|
var result []string
|
||||||
|
for _, fn := range fns {
|
||||||
|
result = append(result, mustReadFileOrURI(fn, m))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func isFilename(fn string) (bool, error) {
|
||||||
|
u, err := url.Parse(fn)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrapf(err, "error parsing %s", fn)
|
||||||
|
}
|
||||||
|
return u.Scheme == "" || u.Scheme == "file", nil
|
||||||
|
}
|
490
authority/linkedca.go
Normal file
490
authority/linkedca.go
Normal file
|
@ -0,0 +1,490 @@
|
||||||
|
package authority
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
"go.step.sm/crypto/keyutil"
|
||||||
|
"go.step.sm/crypto/tlsutil"
|
||||||
|
"go.step.sm/crypto/x509util"
|
||||||
|
"go.step.sm/linkedca"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
|
||||||
|
|
||||||
|
type linkedCaClient struct {
|
||||||
|
renewer *tlsutil.Renewer
|
||||||
|
client linkedca.MajordomoClient
|
||||||
|
authorityID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type linkedCAClaims struct {
|
||||||
|
jose.Claims
|
||||||
|
SANs []string `json:"sans"`
|
||||||
|
SHA string `json:"sha"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLinkedCAClient(token string) (*linkedCaClient, error) {
|
||||||
|
tok, err := jose.ParseSigned(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error parsing token")
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims linkedCAClaims
|
||||||
|
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error parsing token")
|
||||||
|
}
|
||||||
|
// Validate claims
|
||||||
|
if len(claims.Audience) != 1 {
|
||||||
|
return nil, errors.New("error parsing token: invalid aud claim")
|
||||||
|
}
|
||||||
|
if claims.SHA == "" {
|
||||||
|
return nil, errors.New("error parsing token: invalid sha claim")
|
||||||
|
}
|
||||||
|
// Get linkedCA endpoint from audience.
|
||||||
|
u, err := url.Parse(claims.Audience[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("error parsing token: invalid aud claim")
|
||||||
|
}
|
||||||
|
// Get authority from SANs
|
||||||
|
authority, err := getAuthority(claims.SANs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create csr to login with
|
||||||
|
signer, err := keyutil.GenerateDefaultSigner()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
csr, err := x509util.CreateCertificateRequest(claims.Subject, claims.SANs, signer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get and verify root certificate
|
||||||
|
root, err := getRootCertificate(u.Host, claims.SHA)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(root)
|
||||||
|
|
||||||
|
// Login with majordomo and get certificates
|
||||||
|
cert, tlsConfig, err := login(authority, token, csr, signer, u.Host, pool)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start TLS renewer and set the GetClientCertificate callback to it.
|
||||||
|
renewer, err := tlsutil.NewRenewer(cert, tlsConfig, func() (*tls.Certificate, *tls.Config, error) {
|
||||||
|
return login(authority, token, csr, signer, u.Host, pool)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
|
||||||
|
|
||||||
|
// Start mTLS client
|
||||||
|
conn, err := grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error connecting %s", u.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &linkedCaClient{
|
||||||
|
renewer: renewer,
|
||||||
|
client: linkedca.NewMajordomoClient(conn),
|
||||||
|
authorityID: authority,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) Run() {
|
||||||
|
c.renewer.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) Stop() {
|
||||||
|
c.renewer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||||
|
resp, err := c.client.CreateProvisioner(ctx, &linkedca.CreateProvisionerRequest{
|
||||||
|
Type: prov.Type,
|
||||||
|
Name: prov.Name,
|
||||||
|
Details: prov.Details,
|
||||||
|
Claims: prov.Claims,
|
||||||
|
X509Template: prov.X509Template,
|
||||||
|
SshTemplate: prov.SshTemplate,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error creating provisioner")
|
||||||
|
}
|
||||||
|
prov.Id = resp.Id
|
||||||
|
prov.AuthorityId = resp.AuthorityId
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||||
|
resp, err := c.client.GetProvisioner(ctx, &linkedca.GetProvisionerRequest{
|
||||||
|
Id: id,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error getting provisioners")
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
|
||||||
|
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
||||||
|
AuthorityId: c.authorityID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error getting provisioners")
|
||||||
|
}
|
||||||
|
return resp.Provisioners, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||||
|
_, err := c.client.UpdateProvisioner(ctx, &linkedca.UpdateProvisionerRequest{
|
||||||
|
Id: prov.Id,
|
||||||
|
Name: prov.Name,
|
||||||
|
Details: prov.Details,
|
||||||
|
Claims: prov.Claims,
|
||||||
|
X509Template: prov.X509Template,
|
||||||
|
SshTemplate: prov.SshTemplate,
|
||||||
|
})
|
||||||
|
return errors.Wrap(err, "error updating provisioner")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) DeleteProvisioner(ctx context.Context, id string) error {
|
||||||
|
_, err := c.client.DeleteProvisioner(ctx, &linkedca.DeleteProvisionerRequest{
|
||||||
|
Id: id,
|
||||||
|
})
|
||||||
|
return errors.Wrap(err, "error deleting provisioner")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||||
|
resp, err := c.client.CreateAdmin(ctx, &linkedca.CreateAdminRequest{
|
||||||
|
Subject: adm.Subject,
|
||||||
|
ProvisionerId: adm.ProvisionerId,
|
||||||
|
Type: adm.Type,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error creating admin")
|
||||||
|
}
|
||||||
|
adm.Id = resp.Id
|
||||||
|
adm.AuthorityId = resp.AuthorityId
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) {
|
||||||
|
resp, err := c.client.GetAdmin(ctx, &linkedca.GetAdminRequest{
|
||||||
|
Id: id,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error getting admins")
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
|
||||||
|
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
|
||||||
|
AuthorityId: c.authorityID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error getting admins")
|
||||||
|
}
|
||||||
|
return resp.Admins, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error {
|
||||||
|
_, err := c.client.UpdateAdmin(ctx, &linkedca.UpdateAdminRequest{
|
||||||
|
Id: adm.Id,
|
||||||
|
Type: adm.Type,
|
||||||
|
})
|
||||||
|
return errors.Wrap(err, "error updating admin")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
|
||||||
|
_, err := c.client.DeleteAdmin(ctx, &linkedca.DeleteAdminRequest{
|
||||||
|
Id: id,
|
||||||
|
})
|
||||||
|
return errors.Wrap(err, "error deleting admin")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
|
||||||
|
PemCertificate: serializeCertificateChain(fullchain[0]),
|
||||||
|
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
||||||
|
})
|
||||||
|
return errors.Wrap(err, "error posting certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullchain ...*x509.Certificate) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
|
||||||
|
PemCertificate: serializeCertificateChain(fullchain[0]),
|
||||||
|
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
|
||||||
|
PemParentCertificate: serializeCertificateChain(parent),
|
||||||
|
})
|
||||||
|
return errors.Wrap(err, "error posting certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
|
||||||
|
Certificate: string(ssh.MarshalAuthorizedKey(crt)),
|
||||||
|
})
|
||||||
|
return errors.Wrap(err, "error posting ssh certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, err := c.client.RevokeCertificate(ctx, &linkedca.RevokeCertificateRequest{
|
||||||
|
Serial: rci.Serial,
|
||||||
|
PemCertificate: serializeCertificate(crt),
|
||||||
|
Reason: rci.Reason,
|
||||||
|
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
|
||||||
|
Passive: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
return errors.Wrap(err, "error revoking certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) RevokeSSH(cert *ssh.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{
|
||||||
|
Serial: rci.Serial,
|
||||||
|
Certificate: serializeSSHCertificate(cert),
|
||||||
|
Reason: rci.Reason,
|
||||||
|
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
|
||||||
|
Passive: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
return errors.Wrap(err, "error revoking ssh certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) IsRevoked(serial string) (bool, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := c.client.GetCertificateStatus(ctx, &linkedca.GetCertificateStatusRequest{
|
||||||
|
Serial: serial,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "error getting certificate status")
|
||||||
|
}
|
||||||
|
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := c.client.GetSSHCertificateStatus(ctx, &linkedca.GetSSHCertificateStatusRequest{
|
||||||
|
Serial: serial,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "error getting certificate status")
|
||||||
|
}
|
||||||
|
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializeCertificate(crt *x509.Certificate) string {
|
||||||
|
if crt == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: crt.Raw,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializeCertificateChain(fullchain ...*x509.Certificate) string {
|
||||||
|
var chain string
|
||||||
|
for _, crt := range fullchain {
|
||||||
|
chain += string(pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: crt.Raw,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
return chain
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializeSSHCertificate(crt *ssh.Certificate) string {
|
||||||
|
if crt == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(ssh.MarshalAuthorizedKey(crt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAuthority(sans []string) (string, error) {
|
||||||
|
for _, s := range sans {
|
||||||
|
if strings.HasPrefix(s, "urn:smallstep:authority:") {
|
||||||
|
if regexp.MustCompile(uuidPattern).MatchString(s[24:]) {
|
||||||
|
return s[24:], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("error parsing token: invalid sans claim")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRootCertificate creates an insecure majordomo client and returns the
|
||||||
|
// verified root certificate.
|
||||||
|
func getRootCertificate(endpoint, fingerprint string) (*x509.Certificate, error) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error connecting %s", endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client := linkedca.NewMajordomoClient(conn)
|
||||||
|
resp, err := client.GetRootCertificate(ctx, &linkedca.GetRootCertificateRequest{
|
||||||
|
Fingerprint: fingerprint,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting root certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var block *pem.Block
|
||||||
|
b := []byte(resp.PemCertificate)
|
||||||
|
for len(b) > 0 {
|
||||||
|
block, b = pem.Decode(b)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the sha256
|
||||||
|
sum := sha256.Sum256(cert.Raw)
|
||||||
|
if !strings.EqualFold(fingerprint, hex.EncodeToString(sum[:])) {
|
||||||
|
return nil, fmt.Errorf("error verifying certificate: SHA256 fingerprint does not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("error getting root certificate: certificate not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// login creates a new majordomo client with just the root ca pool and returns
|
||||||
|
// the signed certificate and tls configuration.
|
||||||
|
func login(authority, token string, csr *x509.CertificateRequest, signer crypto.PrivateKey, endpoint string, rootCAs *x509.CertPool) (*tls.Certificate, *tls.Config, error) {
|
||||||
|
// Connect to majordomo
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
})))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrapf(err, "error connecting %s", endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login to get the signed certificate
|
||||||
|
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client := linkedca.NewMajordomoClient(conn)
|
||||||
|
resp, err := client.Login(ctx, &linkedca.LoginRequest{
|
||||||
|
AuthorityId: authority,
|
||||||
|
Token: token,
|
||||||
|
PemCertificateRequest: string(pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE REQUEST",
|
||||||
|
Bytes: csr.Raw,
|
||||||
|
})),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrapf(err, "error logging in %s", endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse login response
|
||||||
|
var block *pem.Block
|
||||||
|
var bundle []*x509.Certificate
|
||||||
|
rest := []byte(resp.PemCertificateChain)
|
||||||
|
for {
|
||||||
|
block, rest = pem.Decode(rest)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if block.Type != "CERTIFICATE" {
|
||||||
|
return nil, nil, errors.New("error decoding login response: pemCertificateChain is not a certificate bundle")
|
||||||
|
}
|
||||||
|
crt, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "error parsing login response")
|
||||||
|
}
|
||||||
|
bundle = append(bundle, crt)
|
||||||
|
}
|
||||||
|
if len(bundle) == 0 {
|
||||||
|
return nil, nil, errors.New("error decoding login response: pemCertificateChain should not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build tls.Certificate with PemCertificate and intermediates in the
|
||||||
|
// PemCertificateChain
|
||||||
|
cert := &tls.Certificate{
|
||||||
|
PrivateKey: signer,
|
||||||
|
}
|
||||||
|
rest = []byte(resp.PemCertificate)
|
||||||
|
for {
|
||||||
|
block, rest = pem.Decode(rest)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if block.Type == "CERTIFICATE" {
|
||||||
|
leaf, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "error parsing pemCertificate")
|
||||||
|
}
|
||||||
|
cert.Certificate = append(cert.Certificate, block.Bytes)
|
||||||
|
cert.Leaf = leaf
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add intermediates to the tls.Certificate
|
||||||
|
last := len(bundle) - 1
|
||||||
|
for i := 0; i < last; i++ {
|
||||||
|
cert.Certificate = append(cert.Certificate, bundle[i].Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add root to the pool if it's not there yet
|
||||||
|
rootCAs.AddCert(bundle[last])
|
||||||
|
|
||||||
|
return cert, &tls.Config{
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -22,9 +22,9 @@ type Option func(*Authority) error
|
||||||
|
|
||||||
// WithConfig replaces the current config with the given one. No validation is
|
// WithConfig replaces the current config with the given one. No validation is
|
||||||
// performed in the given value.
|
// performed in the given value.
|
||||||
func WithConfig(config *config.Config) Option {
|
func WithConfig(cfg *config.Config) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
a.config = config
|
a.config = cfg
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,11 +38,47 @@ func WithConfigFile(filename string) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPassword set the password to decrypt the intermediate key as well as the
|
||||||
|
// ssh host and user keys if they are not overridden by other options.
|
||||||
|
func WithPassword(password []byte) Option {
|
||||||
|
return func(a *Authority) (err error) {
|
||||||
|
a.password = password
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSSHHostPassword set the password to decrypt the key used to sign SSH host
|
||||||
|
// certificates.
|
||||||
|
func WithSSHHostPassword(password []byte) Option {
|
||||||
|
return func(a *Authority) (err error) {
|
||||||
|
a.sshHostPassword = password
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSSHUserPassword set the password to decrypt the key used to sign SSH user
|
||||||
|
// certificates.
|
||||||
|
func WithSSHUserPassword(password []byte) Option {
|
||||||
|
return func(a *Authority) (err error) {
|
||||||
|
a.sshUserPassword = password
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithIssuerPassword set the password to decrypt the certificate issuer private
|
||||||
|
// key used in RA mode.
|
||||||
|
func WithIssuerPassword(password []byte) Option {
|
||||||
|
return func(a *Authority) (err error) {
|
||||||
|
a.issuerPassword = password
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithDatabase sets an already initialized authority database to a new
|
// WithDatabase sets an already initialized authority database to a new
|
||||||
// authority. This option is intended to be use on graceful reloads.
|
// authority. This option is intended to be use on graceful reloads.
|
||||||
func WithDatabase(db db.AuthDB) Option {
|
func WithDatabase(d db.AuthDB) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
a.db = db
|
a.db = d
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -189,9 +225,18 @@ func WithX509FederatedBundle(pemCerts []byte) Option {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithAdminDB is an option to set the database backing the admin APIs.
|
// WithAdminDB is an option to set the database backing the admin APIs.
|
||||||
func WithAdminDB(db admin.DB) Option {
|
func WithAdminDB(d admin.DB) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
a.adminDB = db
|
a.adminDB = d
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLinkedCAToken is an option to set the authentication token used to enable
|
||||||
|
// linked ca.
|
||||||
|
func WithLinkedCAToken(token string) Option {
|
||||||
|
return func(a *Authority) error {
|
||||||
|
a.linkedCAToken = token
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -312,7 +312,7 @@ func (p *AWS) GetType() Type {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEncryptedKey is not available in an AWS provisioner.
|
// GetEncryptedKey is not available in an AWS provisioner.
|
||||||
func (p *AWS) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (p *AWS) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -449,13 +449,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
// There's no way to trust them other than TOFU.
|
// There's no way to trust them other than TOFU.
|
||||||
var so []SignOption
|
var so []SignOption
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region)
|
dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region)
|
||||||
so = append(so, dnsNamesValidator([]string{dnsName}))
|
so = append(so,
|
||||||
so = append(so, ipAddressesValidator([]net.IP{
|
dnsNamesValidator([]string{dnsName}),
|
||||||
net.ParseIP(doc.PrivateIP),
|
ipAddressesValidator([]net.IP{
|
||||||
}))
|
net.ParseIP(doc.PrivateIP),
|
||||||
so = append(so, emailAddressesValidator(nil))
|
}),
|
||||||
so = append(so, urisValidator(nil))
|
emailAddressesValidator(nil),
|
||||||
|
urisValidator(nil),
|
||||||
|
)
|
||||||
|
|
||||||
// Template options
|
// Template options
|
||||||
data.SetSANs([]string{dnsName, doc.PrivateIP})
|
data.SetSANs([]string{dnsName, doc.PrivateIP})
|
||||||
|
@ -515,6 +517,11 @@ func (p *AWS) readURL(url string) ([]byte, error) {
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
// Initialize IMDS versions when this is called from the cli.
|
||||||
|
if len(p.IMDSVersions) == 0 {
|
||||||
|
p.IMDSVersions = []string{"v2", "v1"}
|
||||||
|
}
|
||||||
|
|
||||||
for _, v := range p.IMDSVersions {
|
for _, v := range p.IMDSVersions {
|
||||||
switch v {
|
switch v {
|
||||||
case "v1":
|
case "v1":
|
||||||
|
@ -664,7 +671,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
if payload.Subject != doc.InstanceID &&
|
if payload.Subject != doc.InstanceID &&
|
||||||
payload.Subject != doc.PrivateIP &&
|
payload.Subject != doc.PrivateIP &&
|
||||||
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
|
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region) {
|
||||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
|
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -715,7 +722,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
||||||
// Validated principals.
|
// Validated principals.
|
||||||
principals := []string{
|
principals := []string{
|
||||||
doc.PrivateIP,
|
doc.PrivateIP,
|
||||||
fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region),
|
fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only enforce known principals if disable custom sans is true.
|
// Only enforce known principals if disable custom sans is true.
|
||||||
|
|
|
@ -141,6 +141,12 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
||||||
p7.config.signatureURL = p1.config.signatureURL
|
p7.config.signatureURL = p1.config.signatureURL
|
||||||
p7.config.tokenURL = p1.config.tokenURL
|
p7.config.tokenURL = p1.config.tokenURL
|
||||||
|
|
||||||
|
p8, err := generateAWS()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
p8.IMDSVersions = nil
|
||||||
|
p8.Accounts = p1.Accounts
|
||||||
|
p8.config = p1.config
|
||||||
|
|
||||||
caURL := "https://ca.smallstep.com"
|
caURL := "https://ca.smallstep.com"
|
||||||
u, err := url.Parse(caURL)
|
u, err := url.Parse(caURL)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
@ -156,6 +162,7 @@ func TestAWS_GetIdentityToken(t *testing.T) {
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{"foo.local", caURL}, false},
|
{"ok", p1, args{"foo.local", caURL}, false},
|
||||||
|
{"ok no imds", p8, args{"foo.local", caURL}, false},
|
||||||
{"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true},
|
{"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true},
|
||||||
{"fail identityURL", p2, args{"foo.local", caURL}, true},
|
{"fail identityURL", p2, args{"foo.local", caURL}, true},
|
||||||
{"fail signatureURL", p3, args{"foo.local", caURL}, true},
|
{"fail signatureURL", p3, args{"foo.local", caURL}, true},
|
||||||
|
@ -656,15 +663,15 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
|
switch got, err := tt.aws.AuthorizeSign(ctx, tt.args.token); {
|
||||||
if (err != nil) != tt.wantErr {
|
case (err != nil) != tt.wantErr:
|
||||||
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
case err != nil:
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(errs.StatusCoder)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
} else {
|
default:
|
||||||
assert.Len(t, tt.wantLen, got)
|
assert.Len(t, tt.wantLen, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
|
|
@ -131,9 +131,10 @@ func (p *Azure) GetTokenID(token string) (string, error) {
|
||||||
return "", errors.Wrap(err, "error verifying claims")
|
return "", errors.Wrap(err, "error verifying claims")
|
||||||
}
|
}
|
||||||
|
|
||||||
// If TOFU is disabled create return the token kid
|
// If TOFU is disabled then allow token re-use. Azure caches the token for
|
||||||
|
// 24h and without allowing the re-use we cannot use it twice.
|
||||||
if p.DisableTrustOnFirstUse {
|
if p.DisableTrustOnFirstUse {
|
||||||
return claims.ID, nil
|
return "", ErrAllowTokenReuse
|
||||||
}
|
}
|
||||||
|
|
||||||
sum := sha256.Sum256([]byte(claims.XMSMirID))
|
sum := sha256.Sum256([]byte(claims.XMSMirID))
|
||||||
|
@ -151,7 +152,7 @@ func (p *Azure) GetType() Type {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEncryptedKey is not available in an Azure provisioner.
|
// GetEncryptedKey is not available in an Azure provisioner.
|
||||||
func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (p *Azure) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -302,11 +303,13 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
||||||
var so []SignOption
|
var so []SignOption
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
// name will work only inside the virtual network
|
// name will work only inside the virtual network
|
||||||
so = append(so, commonNameValidator(name))
|
so = append(so,
|
||||||
so = append(so, dnsNamesValidator([]string{name}))
|
commonNameValidator(name),
|
||||||
so = append(so, ipAddressesValidator(nil))
|
dnsNamesValidator([]string{name}),
|
||||||
so = append(so, emailAddressesValidator(nil))
|
ipAddressesValidator(nil),
|
||||||
so = append(so, urisValidator(nil))
|
emailAddressesValidator(nil),
|
||||||
|
urisValidator(nil),
|
||||||
|
)
|
||||||
|
|
||||||
// Enforce SANs in the template.
|
// Enforce SANs in the template.
|
||||||
data.SetSANs([]string{name})
|
data.SetSANs([]string{name})
|
||||||
|
|
|
@ -72,7 +72,7 @@ func TestAzure_GetTokenID(t *testing.T) {
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", p1, args{t1}, w1, false},
|
{"ok", p1, args{t1}, w1, false},
|
||||||
{"ok no TOFU", p2, args{t2}, "the-jti", false},
|
{"ok no TOFU", p2, args{t2}, "", true},
|
||||||
{"fail token", p1, args{"bad-token"}, "", true},
|
{"fail token", p1, args{"bad-token"}, "", true},
|
||||||
{"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true},
|
{"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true},
|
||||||
}
|
}
|
||||||
|
@ -446,15 +446,15 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token)
|
switch got, err := tt.azure.AuthorizeSign(ctx, tt.args.token); {
|
||||||
if (err != nil) != tt.wantErr {
|
case (err != nil) != tt.wantErr:
|
||||||
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
case err != nil:
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(errs.StatusCoder)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
} else {
|
default:
|
||||||
assert.Len(t, tt.wantLen, got)
|
assert.Len(t, tt.wantLen, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
|
|
@ -37,8 +37,9 @@ func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||||
// provisioner.
|
// provisioner.
|
||||||
type loadByTokenPayload struct {
|
type loadByTokenPayload struct {
|
||||||
jose.Claims
|
jose.Claims
|
||||||
AuthorizedParty string `json:"azp"` // OIDC client id
|
Email string `json:"email"` // OIDC email
|
||||||
TenantID string `json:"tid"` // Microsoft Azure tenant id
|
AuthorizedParty string `json:"azp"` // OIDC client id
|
||||||
|
TenantID string `json:"tid"` // Microsoft Azure tenant id
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collection is a memory map of provisioners.
|
// Collection is a memory map of provisioners.
|
||||||
|
@ -129,12 +130,20 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
|
||||||
return p, ok
|
return p, ok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Try with tid (Azure)
|
// Try with tid (Azure, Azure OIDC)
|
||||||
if payload.TenantID != "" {
|
if payload.TenantID != "" {
|
||||||
|
// Try to load an OIDC provisioner first.
|
||||||
|
if payload.Email != "" {
|
||||||
|
if p, ok := c.LoadByTokenID(payload.Audience[0]); ok {
|
||||||
|
return p, ok
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try to load an Azure provisioner.
|
||||||
if p, ok := c.LoadByTokenID(payload.TenantID); ok {
|
if p, ok := c.LoadByTokenID(payload.TenantID); ok {
|
||||||
return p, ok
|
return p, ok
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to aud
|
// Fallback to aud
|
||||||
return c.LoadByTokenID(payload.Audience[0])
|
return c.LoadByTokenID(payload.Audience[0])
|
||||||
}
|
}
|
||||||
|
@ -220,14 +229,15 @@ func (c *Collection) Remove(id string) error {
|
||||||
|
|
||||||
var found bool
|
var found bool
|
||||||
for i, elem := range c.sorted {
|
for i, elem := range c.sorted {
|
||||||
if elem.provisioner.GetID() == id {
|
if elem.provisioner.GetID() != id {
|
||||||
// Remove index in sorted list
|
continue
|
||||||
copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index.
|
|
||||||
c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value).
|
|
||||||
c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice.
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
// Remove index in sorted list
|
||||||
|
copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index.
|
||||||
|
c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value).
|
||||||
|
c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice.
|
||||||
|
found = true
|
||||||
|
break
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName())
|
return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName())
|
||||||
|
|
|
@ -150,7 +150,7 @@ func (p *GCP) GetType() Type {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEncryptedKey is not available in a GCP provisioner.
|
// GetEncryptedKey is not available in a GCP provisioner.
|
||||||
func (p *GCP) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (p *GCP) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,15 +244,17 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID)
|
dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID)
|
||||||
dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID)
|
dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID)
|
||||||
so = append(so, commonNameSliceValidator([]string{
|
so = append(so,
|
||||||
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
|
commonNameSliceValidator([]string{
|
||||||
}))
|
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
|
||||||
so = append(so, dnsNamesValidator([]string{
|
}),
|
||||||
dnsName1, dnsName2,
|
dnsNamesValidator([]string{
|
||||||
}))
|
dnsName1, dnsName2,
|
||||||
so = append(so, ipAddressesValidator(nil))
|
}),
|
||||||
so = append(so, emailAddressesValidator(nil))
|
ipAddressesValidator(nil),
|
||||||
so = append(so, urisValidator(nil))
|
emailAddressesValidator(nil),
|
||||||
|
urisValidator(nil),
|
||||||
|
)
|
||||||
|
|
||||||
// Template SANs
|
// Template SANs
|
||||||
data.SetSANs([]string{dnsName1, dnsName2})
|
data.SetSANs([]string{dnsName1, dnsName2})
|
||||||
|
|
|
@ -535,15 +535,15 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token)
|
switch got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token); {
|
||||||
if (err != nil) != tt.wantErr {
|
case (err != nil) != tt.wantErr:
|
||||||
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
case err != nil:
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(errs.StatusCoder)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
} else {
|
default:
|
||||||
assert.Len(t, tt.wantLen, got)
|
assert.Len(t, tt.wantLen, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
|
|
@ -18,7 +18,7 @@ const (
|
||||||
defaultCacheJitter = 1 * time.Hour
|
defaultCacheJitter = 1 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)")
|
var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)
|
||||||
|
|
||||||
type keyStore struct {
|
type keyStore struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
|
@ -29,7 +29,7 @@ func (p *noop) GetType() Type {
|
||||||
return noopType
|
return noopType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (p *noop) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,29 @@ type openIDPayload struct {
|
||||||
Groups []string `json:"groups"`
|
Groups []string `json:"groups"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *openIDPayload) IsAdmin(admins []string) bool {
|
||||||
|
if o.Email != "" {
|
||||||
|
email := sanitizeEmail(o.Email)
|
||||||
|
for _, e := range admins {
|
||||||
|
if email == sanitizeEmail(e) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The groups and emails can be in the same array for now, but consider
|
||||||
|
// making a specialized option later.
|
||||||
|
for _, name := range o.Groups {
|
||||||
|
for _, admin := range admins {
|
||||||
|
if name == admin {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// OIDC represents an OAuth 2.0 OpenID Connect provider.
|
// OIDC represents an OAuth 2.0 OpenID Connect provider.
|
||||||
//
|
//
|
||||||
// ClientSecret is mandatory, but it can be an empty string.
|
// ClientSecret is mandatory, but it can be an empty string.
|
||||||
|
@ -73,35 +96,6 @@ type OIDC struct {
|
||||||
getIdentityFunc GetIdentityFunc
|
getIdentityFunc GetIdentityFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAdmin returns true if the given email is in the Admins allowlist, false
|
|
||||||
// otherwise.
|
|
||||||
func (o *OIDC) IsAdmin(email string) bool {
|
|
||||||
if email != "" {
|
|
||||||
email = sanitizeEmail(email)
|
|
||||||
for _, e := range o.Admins {
|
|
||||||
if email == sanitizeEmail(e) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsAdminGroup returns true if the one group in the given list is in the Admins
|
|
||||||
// allowlist, false otherwise.
|
|
||||||
func (o *OIDC) IsAdminGroup(groups []string) bool {
|
|
||||||
for _, g := range groups {
|
|
||||||
// The groups and emails can be in the same array for now, but consider
|
|
||||||
// making a specialized option later.
|
|
||||||
for _, gadmin := range o.Admins {
|
|
||||||
if g == gadmin {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func sanitizeEmail(email string) string {
|
func sanitizeEmail(email string) string {
|
||||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||||
email = email[:i] + strings.ToLower(email[i:])
|
email = email[:i] + strings.ToLower(email[i:])
|
||||||
|
@ -154,7 +148,7 @@ func (o *OIDC) GetType() Type {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEncryptedKey is not available in an OIDC provisioner.
|
// GetEncryptedKey is not available in an OIDC provisioner.
|
||||||
func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (o *OIDC) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,7 +193,7 @@ func (o *OIDC) Init(config Config) (err error) {
|
||||||
}
|
}
|
||||||
// Replace {tenantid} with the configured one
|
// Replace {tenantid} with the configured one
|
||||||
if o.TenantID != "" {
|
if o.TenantID != "" {
|
||||||
o.configuration.Issuer = strings.Replace(o.configuration.Issuer, "{tenantid}", o.TenantID, -1)
|
o.configuration.Issuer = strings.ReplaceAll(o.configuration.Issuer, "{tenantid}", o.TenantID)
|
||||||
}
|
}
|
||||||
// Get JWK key set
|
// Get JWK key set
|
||||||
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
|
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
|
||||||
|
@ -234,7 +228,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate domains (case-insensitive)
|
// Validate domains (case-insensitive)
|
||||||
if p.Email != "" && len(o.Domains) > 0 && !o.IsAdmin(p.Email) {
|
if p.Email != "" && len(o.Domains) > 0 && !p.IsAdmin(o.Admins) {
|
||||||
email := sanitizeEmail(p.Email)
|
email := sanitizeEmail(p.Email)
|
||||||
var found bool
|
var found bool
|
||||||
for _, d := range o.Domains {
|
for _, d := range o.Domains {
|
||||||
|
@ -313,9 +307,10 @@ func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only admins can revoke certificates.
|
// Only admins can revoke certificates.
|
||||||
if o.IsAdmin(claims.Email) {
|
if claims.IsAdmin(o.Admins) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
|
return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -351,7 +346,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
||||||
// Use the default template unless no-templates are configured and email is
|
// Use the default template unless no-templates are configured and email is
|
||||||
// an admin, in that case we will use the CR template.
|
// an admin, in that case we will use the CR template.
|
||||||
defaultTemplate := x509util.DefaultLeafTemplate
|
defaultTemplate := x509util.DefaultLeafTemplate
|
||||||
if !o.Options.GetX509Options().HasTemplate() && o.IsAdmin(claims.Email) {
|
if !o.Options.GetX509Options().HasTemplate() && claims.IsAdmin(o.Admins) {
|
||||||
defaultTemplate = x509util.DefaultAdminLeafTemplate
|
defaultTemplate = x509util.DefaultAdminLeafTemplate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -420,10 +415,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
||||||
|
|
||||||
// Use the default template unless no-templates are configured and email is
|
// Use the default template unless no-templates are configured and email is
|
||||||
// an admin, in that case we will use the parameters in the request.
|
// an admin, in that case we will use the parameters in the request.
|
||||||
isAdmin := o.IsAdmin(claims.Email)
|
isAdmin := claims.IsAdmin(o.Admins)
|
||||||
if !isAdmin && len(claims.Groups) > 0 {
|
|
||||||
isAdmin = o.IsAdminGroup(claims.Groups)
|
|
||||||
}
|
|
||||||
defaultTemplate := sshutil.DefaultTemplate
|
defaultTemplate := sshutil.DefaultTemplate
|
||||||
if isAdmin && !o.Options.GetSSHOptions().HasTemplate() {
|
if isAdmin && !o.Options.GetSSHOptions().HasTemplate() {
|
||||||
defaultTemplate = sshutil.DefaultAdminTemplate
|
defaultTemplate = sshutil.DefaultAdminTemplate
|
||||||
|
@ -471,10 +463,11 @@ func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only admins can revoke certificates.
|
// Only admins can revoke certificates.
|
||||||
if !o.IsAdmin(claims.Email) {
|
if claims.IsAdmin(o.Admins) {
|
||||||
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAndDecode(uri string, v interface{}) error {
|
func getAndDecode(uri string, v interface{}) error {
|
||||||
|
|
|
@ -321,32 +321,26 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
} else {
|
} else if assert.NotNil(t, got) {
|
||||||
if assert.NotNil(t, got) {
|
assert.Len(t, 5, got)
|
||||||
if tt.name == "admin" {
|
for _, o := range got {
|
||||||
assert.Len(t, 5, got)
|
switch v := o.(type) {
|
||||||
} else {
|
case certificateOptionsFunc:
|
||||||
assert.Len(t, 5, got)
|
case *provisionerExtensionOption:
|
||||||
}
|
assert.Equals(t, v.Type, int(TypeOIDC))
|
||||||
for _, o := range got {
|
assert.Equals(t, v.Name, tt.prov.GetName())
|
||||||
switch v := o.(type) {
|
assert.Equals(t, v.CredentialID, tt.prov.ClientID)
|
||||||
case certificateOptionsFunc:
|
assert.Len(t, 0, v.KeyValuePairs)
|
||||||
case *provisionerExtensionOption:
|
case profileDefaultDuration:
|
||||||
assert.Equals(t, v.Type, int(TypeOIDC))
|
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
|
||||||
assert.Equals(t, v.Name, tt.prov.GetName())
|
case defaultPublicKeyValidator:
|
||||||
assert.Equals(t, v.CredentialID, tt.prov.ClientID)
|
case *validityValidator:
|
||||||
assert.Len(t, 0, v.KeyValuePairs)
|
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
|
||||||
case profileDefaultDuration:
|
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
|
||||||
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
|
case emailOnlyIdentity:
|
||||||
case defaultPublicKeyValidator:
|
assert.Equals(t, string(v), "name@smallstep.com")
|
||||||
case *validityValidator:
|
default:
|
||||||
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
|
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||||
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
|
|
||||||
case emailOnlyIdentity:
|
|
||||||
assert.Equals(t, string(v), "name@smallstep.com")
|
|
||||||
default:
|
|
||||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -698,3 +692,39 @@ func Test_sanitizeEmail(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_openIDPayload_IsAdmin(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Email string
|
||||||
|
Groups []string
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
admins []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"ok email", fields{"admin@smallstep.com", nil}, args{[]string{"admin@smallstep.com"}}, true},
|
||||||
|
{"ok email multiple", fields{"admin@smallstep.com", []string{"admin", "eng"}}, args{[]string{"eng@smallstep.com", "admin@smallstep.com"}}, true},
|
||||||
|
{"ok email sanitized", fields{"admin@Smallstep.com", nil}, args{[]string{"admin@smallStep.com"}}, true},
|
||||||
|
{"ok group", fields{"", []string{"admin"}}, args{[]string{"admin"}}, true},
|
||||||
|
{"ok group multiple", fields{"admin@smallstep.com", []string{"engineering", "admin"}}, args{[]string{"admin"}}, true},
|
||||||
|
{"fail missing", fields{"eng@smallstep.com", []string{"admin"}}, args{[]string{"admin@smallstep.com"}}, false},
|
||||||
|
{"fail email letter case", fields{"Admin@smallstep.com", []string{}}, args{[]string{"admin@smallstep.com"}}, false},
|
||||||
|
{"fail group letter case", fields{"", []string{"Admin"}}, args{[]string{"admin"}}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
o := &openIDPayload{
|
||||||
|
Email: tt.fields.Email,
|
||||||
|
Groups: tt.fields.Groups,
|
||||||
|
}
|
||||||
|
if got := o.IsAdmin(tt.args.admins); got != tt.want {
|
||||||
|
t.Errorf("openIDPayload.IsAdmin() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -138,7 +138,7 @@ func unsafeParseSigned(s string) (map[string]interface{}, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claims := make(map[string]interface{})
|
claims := make(map[string]interface{})
|
||||||
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return claims, nil
|
return claims, nil
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
stderrors "errors"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -32,6 +33,17 @@ type Interface interface {
|
||||||
AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error)
|
AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrAllowTokenReuse is an error that is returned by provisioners that allows
|
||||||
|
// the reuse of tokens.
|
||||||
|
//
|
||||||
|
// This is, for example, returned by the Azure provisioner when
|
||||||
|
// DisableTrustOnFirstUse is set to true. Azure caches tokens for up to 24hr and
|
||||||
|
// has no mechanism for getting a different token - this can be an issue when
|
||||||
|
// rebooting a VM. In contrast, AWS and GCP have facilities for requesting a new
|
||||||
|
// token. Therefore, for the Azure provisioner we are enabling token reuse, with
|
||||||
|
// the understanding that we are not following security best practices
|
||||||
|
var ErrAllowTokenReuse = stderrors.New("allow token reuse")
|
||||||
|
|
||||||
// Audiences stores all supported audiences by request type.
|
// Audiences stores all supported audiences by request type.
|
||||||
type Audiences struct {
|
type Audiences struct {
|
||||||
Sign []string
|
Sign []string
|
||||||
|
@ -111,7 +123,7 @@ func (a Audiences) WithFragment(fragment string) Audiences {
|
||||||
|
|
||||||
// generateSignAudience generates a sign audience with the format
|
// generateSignAudience generates a sign audience with the format
|
||||||
// https://<host>/1.0/sign#provisionerID
|
// https://<host>/1.0/sign#provisionerID
|
||||||
func generateSignAudience(caURL string, provisionerID string) (string, error) {
|
func generateSignAudience(caURL, provisionerID string) (string, error) {
|
||||||
u, err := url.Parse(caURL)
|
u, err := url.Parse(caURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrapf(err, "error parsing %s", caURL)
|
return "", errors.Wrapf(err, "error parsing %s", caURL)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -455,10 +456,10 @@ func containsAllMembers(group, subgroup []string) bool {
|
||||||
}
|
}
|
||||||
visit := make(map[string]struct{}, lg)
|
visit := make(map[string]struct{}, lg)
|
||||||
for i := 0; i < lg; i++ {
|
for i := 0; i < lg; i++ {
|
||||||
visit[group[i]] = struct{}{}
|
visit[strings.ToLower(group[i])] = struct{}{}
|
||||||
}
|
}
|
||||||
for i := 0; i < lsg; i++ {
|
for i := 0; i < lsg; i++ {
|
||||||
if _, ok := visit[subgroup[i]]; !ok {
|
if _, ok := visit[strings.ToLower(subgroup[i])]; !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,7 @@ func TestSSHOptions_Modify(t *testing.T) {
|
||||||
valid func(*ssh.Certificate)
|
valid func(*ssh.Certificate)
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"fail/unexpected-cert-type": func() test {
|
"fail/unexpected-cert-type": func() test {
|
||||||
return test{
|
return test{
|
||||||
so: SignSSHOptions{CertType: "foo"},
|
so: SignSSHOptions{CertType: "foo"},
|
||||||
|
@ -117,7 +117,7 @@ func TestSSHOptions_Match(t *testing.T) {
|
||||||
cmp SignSSHOptions
|
cmp SignSSHOptions
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"fail/cert-type": func() test {
|
"fail/cert-type": func() test {
|
||||||
return test{
|
return test{
|
||||||
so: SignSSHOptions{CertType: "foo"},
|
so: SignSSHOptions{CertType: "foo"},
|
||||||
|
@ -208,7 +208,7 @@ func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
expected []string
|
expected []string
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok": func() test {
|
"ok": func() test {
|
||||||
a := []string{"foo", "bar"}
|
a := []string{"foo", "bar"}
|
||||||
return test{
|
return test{
|
||||||
|
@ -234,7 +234,7 @@ func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
expected string
|
expected string
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok": func() test {
|
"ok": func() test {
|
||||||
a := "foo"
|
a := "foo"
|
||||||
return test{
|
return test{
|
||||||
|
@ -260,7 +260,7 @@ func Test_sshCertTypeModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
expected uint32
|
expected uint32
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok/user": func() test {
|
"ok/user": func() test {
|
||||||
return test{
|
return test{
|
||||||
modifier: sshCertTypeModifier("user"),
|
modifier: sshCertTypeModifier("user"),
|
||||||
|
@ -299,7 +299,7 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
expected uint64
|
expected uint64
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok": func() test {
|
"ok": func() test {
|
||||||
return test{
|
return test{
|
||||||
modifier: sshCertValidAfterModifier(15),
|
modifier: sshCertValidAfterModifier(15),
|
||||||
|
@ -324,7 +324,7 @@ func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
valid func(*ssh.Certificate)
|
valid func(*ssh.Certificate)
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok/changes": func() test {
|
"ok/changes": func() test {
|
||||||
n := time.Now()
|
n := time.Now()
|
||||||
va := NewTimeDuration(n.Add(1 * time.Minute))
|
va := NewTimeDuration(n.Add(1 * time.Minute))
|
||||||
|
@ -388,7 +388,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
|
||||||
valid func(*ssh.Certificate)
|
valid func(*ssh.Certificate)
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"fail/unexpected-cert-type": func() test {
|
"fail/unexpected-cert-type": func() test {
|
||||||
cert := &ssh.Certificate{CertType: 3}
|
cert := &ssh.Certificate{CertType: 3}
|
||||||
return test{
|
return test{
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
@ -30,7 +29,6 @@ type SSHPOP struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Claims *Claims `json:"claims,omitempty"`
|
Claims *Claims `json:"claims,omitempty"`
|
||||||
db db.AuthDB
|
|
||||||
claimer *Claimer
|
claimer *Claimer
|
||||||
audiences Audiences
|
audiences Audiences
|
||||||
sshPubKeys *SSHKeys
|
sshPubKeys *SSHKeys
|
||||||
|
@ -102,7 +100,6 @@ func (p *SSHPOP) Init(config Config) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||||
p.db = config.DB
|
|
||||||
p.sshPubKeys = config.SSHKeys
|
p.sshPubKeys = config.SSHKeys
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -110,6 +107,8 @@ func (p *SSHPOP) Init(config Config) error {
|
||||||
// authorizeToken performs common jwt authorization actions and returns the
|
// authorizeToken performs common jwt authorization actions and returns the
|
||||||
// claims for case specific downstream parsing.
|
// claims for case specific downstream parsing.
|
||||||
// e.g. a Sign request will auth/validate different fields than a Revoke request.
|
// e.g. a Sign request will auth/validate different fields than a Revoke request.
|
||||||
|
//
|
||||||
|
// Checking for certificate revocation has been moved to the authority package.
|
||||||
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
|
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
|
||||||
sshCert, jwt, err := ExtractSSHPOPCert(token)
|
sshCert, jwt, err := ExtractSSHPOPCert(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -117,14 +116,6 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
||||||
"sshpop.authorizeToken; error extracting sshpop header from token")
|
"sshpop.authorizeToken; error extracting sshpop header from token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for revocation.
|
|
||||||
if isRevoked, err := p.db.IsSSHRevoked(strconv.FormatUint(sshCert.Serial, 10)); err != nil {
|
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err,
|
|
||||||
"sshpop.authorizeToken; error checking checking sshpop cert revocation")
|
|
||||||
} else if isRevoked {
|
|
||||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check validity period of the certificate.
|
// Check validity period of the certificate.
|
||||||
n := time.Now()
|
n := time.Now()
|
||||||
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
|
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/db"
|
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
|
@ -47,7 +46,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if err = cert.SignCert(rand.Reader, signer); err != nil {
|
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return cert, jwk, nil
|
return cert, jwk, nil
|
||||||
|
@ -83,52 +82,9 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
|
err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/error-revoked-db-check": func(t *testing.T) test {
|
|
||||||
p, err := generateSSHPOP()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, errors.New("force")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
p: p,
|
|
||||||
token: tok,
|
|
||||||
code: http.StatusInternalServerError,
|
|
||||||
err: errors.New("sshpop.authorizeToken; error checking checking sshpop cert revocation: force"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/cert-already-revoked": func(t *testing.T) test {
|
|
||||||
p, err := generateSSHPOP()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return true, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
return test{
|
|
||||||
p: p,
|
|
||||||
token: tok,
|
|
||||||
code: http.StatusUnauthorized,
|
|
||||||
err: errors.New("sshpop.authorizeToken; sshpop certificate is revoked"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/cert-not-yet-valid": func(t *testing.T) test {
|
"fail/cert-not-yet-valid": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{
|
cert, jwk, err := createSSHCert(&ssh.Certificate{
|
||||||
CertType: ssh.UserCert,
|
CertType: ssh.UserCert,
|
||||||
ValidAfter: uint64(time.Now().Add(time.Minute).Unix()),
|
ValidAfter: uint64(time.Now().Add(time.Minute).Unix()),
|
||||||
|
@ -146,11 +102,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
"fail/cert-past-validity": func(t *testing.T) test {
|
"fail/cert-past-validity": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{
|
cert, jwk, err := createSSHCert(&ssh.Certificate{
|
||||||
CertType: ssh.UserCert,
|
CertType: ssh.UserCert,
|
||||||
ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()),
|
ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()),
|
||||||
|
@ -168,11 +119,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
"fail/no-signer-found": func(t *testing.T) test {
|
"fail/no-signer-found": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||||
|
@ -187,11 +133,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
"fail/error-parsing-claims-bad-sig": func(t *testing.T) test {
|
"fail/error-parsing-claims-bad-sig": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
@ -208,11 +149,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
"fail/invalid-claims-issuer": func(t *testing.T) test {
|
"fail/invalid-claims-issuer": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "",
|
tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "",
|
||||||
|
@ -228,11 +164,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
"fail/invalid-audience": func(t *testing.T) test {
|
"fail/invalid-audience": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateToken("foo", p.GetName(), "invalid-aud", "",
|
tok, err := generateToken("foo", p.GetName(), "invalid-aud", "",
|
||||||
|
@ -248,11 +179,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
"fail/empty-subject": func(t *testing.T) test {
|
"fail/empty-subject": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
|
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
|
||||||
|
@ -268,11 +194,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateSSHPOPToken(p, cert, jwk)
|
tok, err := generateSSHPOPToken(p, cert, jwk)
|
||||||
|
@ -293,10 +214,8 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.NotNil(t, claims)
|
||||||
assert.NotNil(t, claims)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -330,11 +249,6 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
|
||||||
"fail/subject-not-equal-serial": func(t *testing.T) test {
|
"fail/subject-not-equal-serial": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "",
|
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "",
|
||||||
|
@ -350,11 +264,6 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "",
|
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "",
|
||||||
|
@ -419,11 +328,6 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
|
||||||
"fail/not-host-cert": func(t *testing.T) test {
|
"fail/not-host-cert": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "",
|
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "",
|
||||||
|
@ -439,11 +343,6 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "",
|
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "",
|
||||||
|
@ -511,11 +410,6 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
||||||
"fail/not-host-cert": func(t *testing.T) test {
|
"fail/not-host-cert": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "",
|
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "",
|
||||||
|
@ -531,11 +425,6 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
p, err := generateSSHPOP()
|
p, err := generateSSHPOP()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
p.db = &db.MockAuthDB{
|
|
||||||
MIsSSHRevoked: func(sn string) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
|
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "",
|
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "",
|
||||||
|
|
|
@ -732,7 +732,7 @@ func withSSHPOPFile(cert *ssh.Certificate) tokOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
|
func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
|
||||||
so := new(jose.SignerOptions)
|
so := new(jose.SignerOptions)
|
||||||
so.WithType("JWT")
|
so.WithType("JWT")
|
||||||
so.WithHeader("kid", jwk.KeyID)
|
so.WithHeader("kid", jwk.KeyID)
|
||||||
|
@ -773,7 +773,7 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
|
||||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateOIDCToken(sub, iss, aud string, email string, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
|
func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
|
||||||
so := new(jose.SignerOptions)
|
so := new(jose.SignerOptions)
|
||||||
so.WithType("JWT")
|
so.WithType("JWT")
|
||||||
so.WithHeader("kid", jwk.KeyID)
|
so.WithHeader("kid", jwk.KeyID)
|
||||||
|
|
|
@ -4,12 +4,17 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
"github.com/smallstep/certificates/authority/config"
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
|
step "go.step.sm/cli-utils/config"
|
||||||
|
"go.step.sm/cli-utils/ui"
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/linkedca"
|
"go.step.sm/linkedca"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
@ -234,6 +239,14 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
|
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
|
||||||
|
if password == "" {
|
||||||
|
pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
password = string(pass)
|
||||||
|
}
|
||||||
|
|
||||||
jwk, jwe, err := jose.GenerateDefaultKeyPair([]byte(password))
|
jwk, jwe, err := jose.GenerateDefaultKeyPair([]byte(password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, admin.WrapErrorISE(err, "error generating JWK key pair")
|
return nil, admin.WrapErrorISE(err, "error generating JWK key pair")
|
||||||
|
@ -398,6 +411,13 @@ func durationsToCertificates(d *linkedca.Durations) (min, max, def *provisioner.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func durationsToLinkedca(d *provisioner.Duration) string {
|
||||||
|
if d == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return d.Duration.String()
|
||||||
|
}
|
||||||
|
|
||||||
// claimsToCertificates converts the linkedca provisioner claims type to the
|
// claimsToCertificates converts the linkedca provisioner claims type to the
|
||||||
// certifictes claims type.
|
// certifictes claims type.
|
||||||
func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
|
func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
|
||||||
|
@ -438,6 +458,109 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
|
||||||
return pc, nil
|
return pc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
disableRenewal := config.DefaultDisableRenewal
|
||||||
|
if c.DisableRenewal != nil {
|
||||||
|
disableRenewal = *c.DisableRenewal
|
||||||
|
}
|
||||||
|
|
||||||
|
lc := &linkedca.Claims{
|
||||||
|
DisableRenewal: disableRenewal,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {
|
||||||
|
lc.X509 = &linkedca.X509Claims{
|
||||||
|
Enabled: true,
|
||||||
|
Durations: &linkedca.Durations{
|
||||||
|
Default: durationsToLinkedca(c.DefaultTLSDur),
|
||||||
|
Min: durationsToLinkedca(c.MinTLSDur),
|
||||||
|
Max: durationsToLinkedca(c.MaxTLSDur),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.EnableSSHCA != nil && *c.EnableSSHCA {
|
||||||
|
lc.Ssh = &linkedca.SSHClaims{
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
if c.DefaultUserSSHDur != nil || c.MinUserSSHDur != nil || c.MaxUserSSHDur != nil {
|
||||||
|
lc.Ssh.UserDurations = &linkedca.Durations{
|
||||||
|
Default: durationsToLinkedca(c.DefaultUserSSHDur),
|
||||||
|
Min: durationsToLinkedca(c.MinUserSSHDur),
|
||||||
|
Max: durationsToLinkedca(c.MaxUserSSHDur),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.DefaultHostSSHDur != nil || c.MinHostSSHDur != nil || c.MaxHostSSHDur != nil {
|
||||||
|
lc.Ssh.HostDurations = &linkedca.Durations{
|
||||||
|
Default: durationsToLinkedca(c.DefaultHostSSHDur),
|
||||||
|
Min: durationsToLinkedca(c.MinHostSSHDur),
|
||||||
|
Max: durationsToLinkedca(c.MaxHostSSHDur),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lc
|
||||||
|
}
|
||||||
|
|
||||||
|
func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *linkedca.Template, error) {
|
||||||
|
var err error
|
||||||
|
var x509Template, sshTemplate *linkedca.Template
|
||||||
|
|
||||||
|
if p == nil {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.X509 != nil && p.X509.HasTemplate() {
|
||||||
|
x509Template = &linkedca.Template{
|
||||||
|
Template: nil,
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.X509.Template != "" {
|
||||||
|
x509Template.Template = []byte(p.SSH.Template)
|
||||||
|
} else if p.X509.TemplateFile != "" {
|
||||||
|
filename := step.StepAbs(p.X509.TemplateFile)
|
||||||
|
if x509Template.Template, err = ioutil.ReadFile(filename); err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "error reading x509 template")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.SSH != nil && p.SSH.HasTemplate() {
|
||||||
|
sshTemplate = &linkedca.Template{
|
||||||
|
Template: nil,
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.SSH.Template != "" {
|
||||||
|
sshTemplate.Template = []byte(p.SSH.Template)
|
||||||
|
} else if p.SSH.TemplateFile != "" {
|
||||||
|
filename := step.StepAbs(p.SSH.TemplateFile)
|
||||||
|
if sshTemplate.Template, err = ioutil.ReadFile(filename); err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "error reading ssh template")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return x509Template, sshTemplate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func provisionerPEMToLinkedca(b []byte) [][]byte {
|
||||||
|
var roots [][]byte
|
||||||
|
var block *pem.Block
|
||||||
|
for {
|
||||||
|
if block, b = pem.Decode(b); block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
roots = append(roots, pem.EncodeToMemory(block))
|
||||||
|
}
|
||||||
|
return roots
|
||||||
|
}
|
||||||
|
|
||||||
// ProvisionerToCertificates converts the linkedca provisioner type to the certificates provisioner
|
// ProvisionerToCertificates converts the linkedca provisioner type to the certificates provisioner
|
||||||
// interface.
|
// interface.
|
||||||
func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, error) {
|
func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, error) {
|
||||||
|
@ -448,7 +571,7 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
|
||||||
|
|
||||||
details := p.Details.GetData()
|
details := p.Details.GetData()
|
||||||
if details == nil {
|
if details == nil {
|
||||||
return nil, fmt.Errorf("provisioner does not have any details")
|
return nil, errors.New("provisioner does not have any details")
|
||||||
}
|
}
|
||||||
|
|
||||||
options := optionsToCertificates(p)
|
options := optionsToCertificates(p)
|
||||||
|
@ -457,7 +580,7 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
|
||||||
case *linkedca.ProvisionerDetails_JWK:
|
case *linkedca.ProvisionerDetails_JWK:
|
||||||
jwk := new(jose.JSONWebKey)
|
jwk := new(jose.JSONWebKey)
|
||||||
if err := json.Unmarshal(d.JWK.PublicKey, &jwk); err != nil {
|
if err := json.Unmarshal(d.JWK.PublicKey, &jwk); err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "error unmarshaling public key")
|
||||||
}
|
}
|
||||||
return &provisioner.JWK{
|
return &provisioner.JWK{
|
||||||
ID: p.Id,
|
ID: p.Id,
|
||||||
|
@ -588,6 +711,233 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvisionerToLinkedca converts a provisioner.Interface to a
|
||||||
|
// linkedca.Provisioner type.
|
||||||
|
func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, error) {
|
||||||
|
switch p := p.(type) {
|
||||||
|
case *provisioner.JWK:
|
||||||
|
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
publicKey, err := json.Marshal(p.Key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error marshaling key")
|
||||||
|
}
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_JWK,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_JWK{
|
||||||
|
JWK: &linkedca.JWKProvisioner{
|
||||||
|
PublicKey: publicKey,
|
||||||
|
EncryptedPrivateKey: []byte(p.EncryptedKey),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
X509Template: x509Template,
|
||||||
|
SshTemplate: sshTemplate,
|
||||||
|
}, nil
|
||||||
|
case *provisioner.OIDC:
|
||||||
|
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_OIDC,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_OIDC{
|
||||||
|
OIDC: &linkedca.OIDCProvisioner{
|
||||||
|
ClientId: p.ClientID,
|
||||||
|
ClientSecret: p.ClientSecret,
|
||||||
|
ConfigurationEndpoint: p.ConfigurationEndpoint,
|
||||||
|
Admins: p.Admins,
|
||||||
|
Domains: p.Domains,
|
||||||
|
Groups: p.Groups,
|
||||||
|
ListenAddress: p.ListenAddress,
|
||||||
|
TenantId: p.TenantID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
X509Template: x509Template,
|
||||||
|
SshTemplate: sshTemplate,
|
||||||
|
}, nil
|
||||||
|
case *provisioner.GCP:
|
||||||
|
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_GCP,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_GCP{
|
||||||
|
GCP: &linkedca.GCPProvisioner{
|
||||||
|
ServiceAccounts: p.ServiceAccounts,
|
||||||
|
ProjectIds: p.ProjectIDs,
|
||||||
|
DisableCustomSans: p.DisableCustomSANs,
|
||||||
|
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,
|
||||||
|
InstanceAge: p.InstanceAge.String(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
X509Template: x509Template,
|
||||||
|
SshTemplate: sshTemplate,
|
||||||
|
}, nil
|
||||||
|
case *provisioner.AWS:
|
||||||
|
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_AWS,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_AWS{
|
||||||
|
AWS: &linkedca.AWSProvisioner{
|
||||||
|
Accounts: p.Accounts,
|
||||||
|
DisableCustomSans: p.DisableCustomSANs,
|
||||||
|
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,
|
||||||
|
InstanceAge: p.InstanceAge.String(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
X509Template: x509Template,
|
||||||
|
SshTemplate: sshTemplate,
|
||||||
|
}, nil
|
||||||
|
case *provisioner.Azure:
|
||||||
|
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_AZURE,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_Azure{
|
||||||
|
Azure: &linkedca.AzureProvisioner{
|
||||||
|
TenantId: p.TenantID,
|
||||||
|
ResourceGroups: p.ResourceGroups,
|
||||||
|
Audience: p.Audience,
|
||||||
|
DisableCustomSans: p.DisableCustomSANs,
|
||||||
|
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
X509Template: x509Template,
|
||||||
|
SshTemplate: sshTemplate,
|
||||||
|
}, nil
|
||||||
|
case *provisioner.ACME:
|
||||||
|
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_ACME,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_ACME{
|
||||||
|
ACME: &linkedca.ACMEProvisioner{
|
||||||
|
ForceCn: p.ForceCN,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
X509Template: x509Template,
|
||||||
|
SshTemplate: sshTemplate,
|
||||||
|
}, nil
|
||||||
|
case *provisioner.X5C:
|
||||||
|
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_X5C,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_X5C{
|
||||||
|
X5C: &linkedca.X5CProvisioner{
|
||||||
|
Roots: provisionerPEMToLinkedca(p.Roots),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
X509Template: x509Template,
|
||||||
|
SshTemplate: sshTemplate,
|
||||||
|
}, nil
|
||||||
|
case *provisioner.K8sSA:
|
||||||
|
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_K8SSA,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_K8SSA{
|
||||||
|
K8SSA: &linkedca.K8SSAProvisioner{
|
||||||
|
PublicKeys: provisionerPEMToLinkedca(p.PubKeys),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
X509Template: x509Template,
|
||||||
|
SshTemplate: sshTemplate,
|
||||||
|
}, nil
|
||||||
|
case *provisioner.SSHPOP:
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_SSHPOP,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_SSHPOP{
|
||||||
|
SSHPOP: &linkedca.SSHPOPProvisioner{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
}, nil
|
||||||
|
case *provisioner.SCEP:
|
||||||
|
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: p.ID,
|
||||||
|
Type: linkedca.Provisioner_SCEP,
|
||||||
|
Name: p.GetName(),
|
||||||
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
Data: &linkedca.ProvisionerDetails_SCEP{
|
||||||
|
SCEP: &linkedca.SCEPProvisioner{
|
||||||
|
ForceCn: p.ForceCN,
|
||||||
|
Challenge: p.GetChallengePassword(),
|
||||||
|
Capabilities: p.Capabilities,
|
||||||
|
MinimumPublicKeyLength: int32(p.MinimumPublicKeyLength),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Claims: claimsToLinkedca(p.Claims),
|
||||||
|
X509Template: x509Template,
|
||||||
|
SshTemplate: sshTemplate,
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("provisioner %s not implemented", p.GetType())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func parseInstanceAge(age string) (provisioner.Duration, error) {
|
func parseInstanceAge(age string) (provisioner.Duration, error) {
|
||||||
var instanceAge provisioner.Duration
|
var instanceAge provisioner.Duration
|
||||||
if age != "" {
|
if age != "" {
|
||||||
|
|
|
@ -108,7 +108,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
|
||||||
|
|
||||||
// GetSSHBastion returns the bastion configuration, for the given pair user,
|
// GetSSHBastion returns the bastion configuration, for the given pair user,
|
||||||
// hostname.
|
// hostname.
|
||||||
func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname string) (*config.Bastion, error) {
|
func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*config.Bastion, error) {
|
||||||
if a.sshBastionFunc != nil {
|
if a.sshBastionFunc != nil {
|
||||||
bs, err := a.sshBastionFunc(ctx, user, hostname)
|
bs, err := a.sshBastionFunc(ctx, user, hostname)
|
||||||
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
|
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
|
||||||
|
@ -239,7 +239,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -249,7 +249,11 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
|
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
|
||||||
func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
|
func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
|
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
|
||||||
return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period")
|
return nil, errs.BadRequest("renewSSH: cannot renew certificate without validity period")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
backdate := a.config.AuthorityConfig.Backdate.Duration
|
backdate := a.config.AuthorityConfig.Backdate.Duration
|
||||||
|
@ -294,7 +298,7 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,6 +323,10 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
|
||||||
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
|
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
backdate := a.config.AuthorityConfig.Backdate.Duration
|
backdate := a.config.AuthorityConfig.Backdate.Duration
|
||||||
duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second
|
duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
@ -369,13 +377,23 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db")
|
||||||
}
|
}
|
||||||
|
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Authority) storeSSHCertificate(cert *ssh.Certificate) error {
|
||||||
|
type sshCertificateStorer interface {
|
||||||
|
StoreSSHCertificate(crt *ssh.Certificate) error
|
||||||
|
}
|
||||||
|
if s, ok := a.adminDB.(sshCertificateStorer); ok {
|
||||||
|
return s.StoreSSHCertificate(cert)
|
||||||
|
}
|
||||||
|
return a.db.StoreSSHCertificate(cert)
|
||||||
|
}
|
||||||
|
|
||||||
// IsValidForAddUser checks if a user provisioner certificate can be issued to
|
// IsValidForAddUser checks if a user provisioner certificate can be issued to
|
||||||
// the given certificate.
|
// the given certificate.
|
||||||
func IsValidForAddUser(cert *ssh.Certificate) error {
|
func IsValidForAddUser(cert *ssh.Certificate) error {
|
||||||
|
@ -451,7 +469,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
|
||||||
}
|
}
|
||||||
cert.Signature = sig
|
cert.Signature = sig
|
||||||
|
|
||||||
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -459,7 +477,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckSSHHost checks the given principal has been registered before.
|
// CheckSSHHost checks the given principal has been registered before.
|
||||||
func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) {
|
func (a *Authority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) {
|
||||||
if a.sshCheckHostFunc != nil {
|
if a.sshCheckHostFunc != nil {
|
||||||
exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates())
|
exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -513,5 +531,5 @@ func (a *Authority) getAddUserCommand(principal string) string {
|
||||||
} else {
|
} else {
|
||||||
cmd = a.config.SSH.AddUserCommand
|
cmd = a.config.SSH.AddUserCommand
|
||||||
}
|
}
|
||||||
return strings.Replace(cmd, "<principal>", principal, -1)
|
return strings.ReplaceAll(cmd, "<principal>", principal)
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,6 +87,52 @@ func (m sshTestOptionsModifier) Modify(cert *ssh.Certificate, opts provisioner.S
|
||||||
return fmt.Errorf(string(m))
|
return fmt.Errorf(string(m))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthority_initHostOnly(t *testing.T) {
|
||||||
|
auth := testAuthority(t, func(a *Authority) error {
|
||||||
|
a.config.SSH.UserKey = ""
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check keys
|
||||||
|
keys, err := auth.GetSSHRoots(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, 1, keys.HostKeys)
|
||||||
|
assert.Len(t, 0, keys.UserKeys)
|
||||||
|
|
||||||
|
// Check templates, user templates should work fine.
|
||||||
|
_, err = auth.GetSSHConfig(context.Background(), "user", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = auth.GetSSHConfig(context.Background(), "host", map[string]string{
|
||||||
|
"Certificate": "ssh_host_ecdsa_key-cert.pub",
|
||||||
|
"Key": "ssh_host_ecdsa_key",
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthority_initUserOnly(t *testing.T) {
|
||||||
|
auth := testAuthority(t, func(a *Authority) error {
|
||||||
|
a.config.SSH.HostKey = ""
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check keys
|
||||||
|
keys, err := auth.GetSSHRoots(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, 0, keys.HostKeys)
|
||||||
|
assert.Len(t, 1, keys.UserKeys)
|
||||||
|
|
||||||
|
// Check templates, host templates should work fine.
|
||||||
|
_, err = auth.GetSSHConfig(context.Background(), "host", map[string]string{
|
||||||
|
"Certificate": "ssh_host_ecdsa_key-cert.pub",
|
||||||
|
"Key": "ssh_host_ecdsa_key",
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = auth.GetSSHConfig(context.Background(), "user", nil)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAuthority_SignSSH(t *testing.T) {
|
func TestAuthority_SignSSH(t *testing.T) {
|
||||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
@ -153,6 +199,8 @@ func TestAuthority_SignSSH(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"ok-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false},
|
{"ok-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false},
|
||||||
{"ok-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false},
|
{"ok-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false},
|
||||||
|
{"ok-user-only", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false},
|
||||||
|
{"ok-host-only", fields{nil, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false},
|
||||||
{"ok-opts-type-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false},
|
{"ok-opts-type-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false},
|
||||||
{"ok-opts-type-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false},
|
{"ok-opts-type-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false},
|
||||||
{"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
|
{"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
|
||||||
|
@ -750,6 +798,11 @@ func TestAuthority_RekeySSH(t *testing.T) {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
|
||||||
a := testAuthority(t)
|
a := testAuthority(t)
|
||||||
|
a.db = &db.MockAuthDB{
|
||||||
|
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
auth *Authority
|
auth *Authority
|
||||||
|
@ -763,6 +816,56 @@ func TestAuthority_RekeySSH(t *testing.T) {
|
||||||
code int
|
code int
|
||||||
}
|
}
|
||||||
tests := map[string]func(t *testing.T) *test{
|
tests := map[string]func(t *testing.T) *test{
|
||||||
|
"fail/is-revoked": func(t *testing.T) *test {
|
||||||
|
auth := testAuthority(t)
|
||||||
|
auth.db = &db.MockAuthDB{
|
||||||
|
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return &test{
|
||||||
|
auth: auth,
|
||||||
|
userSigner: signer,
|
||||||
|
hostSigner: signer,
|
||||||
|
cert: &ssh.Certificate{
|
||||||
|
Serial: 1234567890,
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
ValidPrincipals: []string{"foo", "bar"},
|
||||||
|
KeyId: "foo",
|
||||||
|
},
|
||||||
|
key: pub,
|
||||||
|
signOpts: []provisioner.SignOption{},
|
||||||
|
err: errors.New("authority.authorizeSSHCertificate: certificate has been revoked"),
|
||||||
|
code: http.StatusUnauthorized,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/is-revoked-error": func(t *testing.T) *test {
|
||||||
|
auth := testAuthority(t)
|
||||||
|
auth.db = &db.MockAuthDB{
|
||||||
|
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||||
|
return false, errors.New("an error")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return &test{
|
||||||
|
auth: auth,
|
||||||
|
userSigner: signer,
|
||||||
|
hostSigner: signer,
|
||||||
|
cert: &ssh.Certificate{
|
||||||
|
Serial: 1234567890,
|
||||||
|
ValidAfter: uint64(now.Unix()),
|
||||||
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
ValidPrincipals: []string{"foo", "bar"},
|
||||||
|
KeyId: "foo",
|
||||||
|
},
|
||||||
|
key: pub,
|
||||||
|
signOpts: []provisioner.SignOption{},
|
||||||
|
err: errors.New("authority.authorizeSSHCertificate: an error"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
},
|
||||||
"fail/opts-type": func(t *testing.T) *test {
|
"fail/opts-type": func(t *testing.T) *test {
|
||||||
return &test{
|
return &test{
|
||||||
userSigner: signer,
|
userSigner: signer,
|
||||||
|
@ -831,6 +934,9 @@ func TestAuthority_RekeySSH(t *testing.T) {
|
||||||
"fail/db-store": func(t *testing.T) *test {
|
"fail/db-store": func(t *testing.T) *test {
|
||||||
return &test{
|
return &test{
|
||||||
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
|
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
|
||||||
|
MIsSSHRevoked: func(sn string) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
},
|
||||||
MStoreSSHCertificate: func(cert *ssh.Certificate) error {
|
MStoreSSHCertificate: func(cert *ssh.Certificate) error {
|
||||||
return errors.New("force")
|
return errors.New("force")
|
||||||
},
|
},
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"go.step.sm/crypto/keyutil"
|
"go.step.sm/crypto/keyutil"
|
||||||
"go.step.sm/crypto/pemutil"
|
"go.step.sm/crypto/pemutil"
|
||||||
"go.step.sm/crypto/x509util"
|
"go.step.sm/crypto/x509util"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetTLSOptions returns the tls options configured.
|
// GetTLSOptions returns the tls options configured.
|
||||||
|
@ -36,7 +37,6 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
|
||||||
if def == nil {
|
if def == nil {
|
||||||
return errors.New("default ASN1DN template cannot be nil")
|
return errors.New("default ASN1DN template cannot be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(crt.Subject.Country) == 0 && def.Country != "" {
|
if len(crt.Subject.Country) == 0 && def.Country != "" {
|
||||||
crt.Subject.Country = append(crt.Subject.Country, def.Country)
|
crt.Subject.Country = append(crt.Subject.Country, def.Country)
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,12 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
|
||||||
if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" {
|
if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" {
|
||||||
crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress)
|
crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress)
|
||||||
}
|
}
|
||||||
|
if crt.Subject.SerialNumber == "" && def.SerialNumber != "" {
|
||||||
|
crt.Subject.SerialNumber = def.SerialNumber
|
||||||
|
}
|
||||||
|
if crt.Subject.CommonName == "" && def.CommonName != "" {
|
||||||
|
crt.Subject.CommonName = def.CommonName
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -280,9 +285,15 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
|
||||||
// `StoreCertificate(...*x509.Certificate) error` instead of just
|
// `StoreCertificate(...*x509.Certificate) error` instead of just
|
||||||
// `StoreCertificate(*x509.Certificate) error`.
|
// `StoreCertificate(*x509.Certificate) error`.
|
||||||
func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
|
func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
|
||||||
if s, ok := a.db.(interface {
|
type certificateChainStorer interface {
|
||||||
StoreCertificateChain(...*x509.Certificate) error
|
StoreCertificateChain(...*x509.Certificate) error
|
||||||
}); ok {
|
}
|
||||||
|
// Store certificate in linkedca
|
||||||
|
if s, ok := a.adminDB.(certificateChainStorer); ok {
|
||||||
|
return s.StoreCertificateChain(fullchain...)
|
||||||
|
}
|
||||||
|
// Store certificate in local db
|
||||||
|
if s, ok := a.db.(certificateChainStorer); ok {
|
||||||
return s.StoreCertificateChain(fullchain...)
|
return s.StoreCertificateChain(fullchain...)
|
||||||
}
|
}
|
||||||
return a.db.StoreCertificate(fullchain[0])
|
return a.db.StoreCertificate(fullchain[0])
|
||||||
|
@ -293,9 +304,15 @@ func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
|
||||||
//
|
//
|
||||||
// TODO: at some point we should implement this in the standard implementation.
|
// TODO: at some point we should implement this in the standard implementation.
|
||||||
func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error {
|
func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error {
|
||||||
if s, ok := a.db.(interface {
|
type renewedCertificateChainStorer interface {
|
||||||
StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error
|
StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error
|
||||||
}); ok {
|
}
|
||||||
|
// Store certificate in linkedca
|
||||||
|
if s, ok := a.adminDB.(renewedCertificateChainStorer); ok {
|
||||||
|
return s.StoreRenewedCertificate(oldCert, fullchain...)
|
||||||
|
}
|
||||||
|
// Store certificate in local db
|
||||||
|
if s, ok := a.db.(renewedCertificateChainStorer); ok {
|
||||||
return s.StoreRenewedCertificate(oldCert, fullchain...)
|
return s.StoreRenewedCertificate(oldCert, fullchain...)
|
||||||
}
|
}
|
||||||
return a.db.StoreCertificate(fullchain[0])
|
return a.db.StoreCertificate(fullchain[0])
|
||||||
|
@ -368,22 +385,22 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
||||||
}
|
}
|
||||||
rci.ProvisionerID = p.GetID()
|
rci.ProvisionerID = p.GetID()
|
||||||
rci.TokenID, err = p.GetTokenID(revokeOpts.OTT)
|
rci.TokenID, err = p.GetTokenID(revokeOpts.OTT)
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, provisioner.ErrAllowTokenReuse) {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err,
|
return errs.Wrap(http.StatusInternalServerError, err,
|
||||||
"authority.Revoke; could not get ID for token")
|
"authority.Revoke; could not get ID for token")
|
||||||
}
|
}
|
||||||
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
|
opts = append(opts,
|
||||||
opts = append(opts, errs.WithKeyVal("tokenID", rci.TokenID))
|
errs.WithKeyVal("provisionerID", rci.ProvisionerID),
|
||||||
} else {
|
errs.WithKeyVal("tokenID", rci.TokenID),
|
||||||
|
)
|
||||||
|
} else if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil {
|
||||||
// Load the Certificate provisioner if one exists.
|
// Load the Certificate provisioner if one exists.
|
||||||
if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil {
|
rci.ProvisionerID = p.GetID()
|
||||||
rci.ProvisionerID = p.GetID()
|
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
|
||||||
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod {
|
if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod {
|
||||||
err = a.db.RevokeSSH(rci)
|
err = a.revokeSSH(nil, rci)
|
||||||
} else {
|
} else {
|
||||||
// Revoke an X.509 certificate using CAS. If the certificate is not
|
// Revoke an X.509 certificate using CAS. If the certificate is not
|
||||||
// provided we will try to read it from the db. If the read fails we
|
// provided we will try to read it from the db. If the read fails we
|
||||||
|
@ -410,7 +427,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save as revoked in the Db.
|
// Save as revoked in the Db.
|
||||||
err = a.db.Revoke(rci)
|
err = a.revoke(revokedCert, rci)
|
||||||
}
|
}
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
|
@ -425,6 +442,24 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Authority) revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||||
|
if lca, ok := a.adminDB.(interface {
|
||||||
|
Revoke(*x509.Certificate, *db.RevokedCertificateInfo) error
|
||||||
|
}); ok {
|
||||||
|
return lca.Revoke(crt, rci)
|
||||||
|
}
|
||||||
|
return a.db.Revoke(rci)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) revokeSSH(crt *ssh.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||||
|
if lca, ok := a.adminDB.(interface {
|
||||||
|
RevokeSSH(*ssh.Certificate, *db.RevokedCertificateInfo) error
|
||||||
|
}); ok {
|
||||||
|
return lca.RevokeSSH(crt, rci)
|
||||||
|
}
|
||||||
|
return a.db.Revoke(rci)
|
||||||
|
}
|
||||||
|
|
||||||
// GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server.
|
// GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server.
|
||||||
func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
|
func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
|
||||||
fatal := func(err error) (*tls.Certificate, error) {
|
fatal := func(err error) (*tls.Certificate, error) {
|
||||||
|
|
|
@ -426,6 +426,7 @@ ZYtQ9Ot36qc=
|
||||||
{Id: stepOIDProvisioner, Value: []byte("foo")},
|
{Id: stepOIDProvisioner, Value: []byte("foo")},
|
||||||
{Id: []int{1, 1, 1}, Value: []byte("bar")}}))
|
{Id: []int{1, 1, 1}, Value: []byte("bar")}}))
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
// nolint:gocritic
|
||||||
enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{
|
enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{
|
||||||
NotBefore: now,
|
NotBefore: now,
|
||||||
NotAfter: now.Add(365 * 24 * time.Hour),
|
NotAfter: now.Add(365 * 24 * time.Hour),
|
||||||
|
|
|
@ -345,7 +345,7 @@ func readACMEError(r io.ReadCloser) error {
|
||||||
ae := new(acme.Error)
|
ae := new(acme.Error)
|
||||||
err = json.Unmarshal(b, &ae)
|
err = json.Unmarshal(b, &ae)
|
||||||
// If we successfully marshaled to an ACMEError then return the ACMEError.
|
// If we successfully marshaled to an ACMEError then return the ACMEError.
|
||||||
if err != nil || len(ae.Error()) == 0 {
|
if err != nil || ae.Error() == "" {
|
||||||
fmt.Printf("b = %s\n", b)
|
fmt.Printf("b = %s\n", b)
|
||||||
// Throw up our hands.
|
// Throw up our hands.
|
||||||
return errors.Errorf("%s", b)
|
return errors.Errorf("%s", b)
|
||||||
|
|
|
@ -1247,6 +1247,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
|
||||||
Type: "Certificate",
|
Type: "Certificate",
|
||||||
Bytes: leaf.Raw,
|
Bytes: leaf.Raw,
|
||||||
})
|
})
|
||||||
|
// nolint:gocritic
|
||||||
certBytes := append(leafb, leafb...)
|
certBytes := append(leafb, leafb...)
|
||||||
certBytes = append(certBytes, leafb...)
|
certBytes = append(certBytes, leafb...)
|
||||||
ac := &ACMEClient{
|
ac := &ACMEClient{
|
||||||
|
|
|
@ -70,7 +70,7 @@ func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error)
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AdminClient) generateAdminToken(path string) (string, error) {
|
func (c *AdminClient) generateAdminToken(urlPath string) (string, error) {
|
||||||
// A random jwt id will be used to identify duplicated tokens
|
// A random jwt id will be used to identify duplicated tokens
|
||||||
jwtID, err := randutil.Hex(64) // 256 bits
|
jwtID, err := randutil.Hex(64) // 256 bits
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -82,7 +82,7 @@ func (c *AdminClient) generateAdminToken(path string) (string, error) {
|
||||||
token.WithJWTID(jwtID),
|
token.WithJWTID(jwtID),
|
||||||
token.WithKid(c.x5cJWK.KeyID),
|
token.WithKid(c.x5cJWK.KeyID),
|
||||||
token.WithIssuer(c.x5cIssuer),
|
token.WithIssuer(c.x5cIssuer),
|
||||||
token.WithAudience(path),
|
token.WithAudience(urlPath),
|
||||||
token.WithValidity(now, now.Add(token.DefaultValidity)),
|
token.WithValidity(now, now.Add(token.DefaultValidity)),
|
||||||
token.WithX5CCerts(c.x5cCertStrs),
|
token.WithX5CCerts(c.x5cCertStrs),
|
||||||
}
|
}
|
||||||
|
@ -348,14 +348,15 @@ func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provi
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var u *url.URL
|
var u *url.URL
|
||||||
if len(o.id) > 0 {
|
switch {
|
||||||
|
case len(o.id) > 0:
|
||||||
u = c.endpoint.ResolveReference(&url.URL{
|
u = c.endpoint.ResolveReference(&url.URL{
|
||||||
Path: "/admin/provisioners/id",
|
Path: "/admin/provisioners/id",
|
||||||
RawQuery: o.rawQuery(),
|
RawQuery: o.rawQuery(),
|
||||||
})
|
})
|
||||||
} else if len(o.name) > 0 {
|
case len(o.name) > 0:
|
||||||
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
|
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
|
||||||
} else {
|
default:
|
||||||
return nil, errors.New("must set either name or id in method options")
|
return nil, errors.New("must set either name or id in method options")
|
||||||
}
|
}
|
||||||
tok, err := c.generateAdminToken(u.Path)
|
tok, err := c.generateAdminToken(u.Path)
|
||||||
|
@ -456,14 +457,15 @@ func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(o.id) > 0 {
|
switch {
|
||||||
|
case len(o.id) > 0:
|
||||||
u = c.endpoint.ResolveReference(&url.URL{
|
u = c.endpoint.ResolveReference(&url.URL{
|
||||||
Path: path.Join(adminURLPrefix, "provisioners/id"),
|
Path: path.Join(adminURLPrefix, "provisioners/id"),
|
||||||
RawQuery: o.rawQuery(),
|
RawQuery: o.rawQuery(),
|
||||||
})
|
})
|
||||||
} else if len(o.name) > 0 {
|
case len(o.name) > 0:
|
||||||
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
|
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
|
||||||
} else {
|
default:
|
||||||
return errors.New("must set either name or id in method options")
|
return errors.New("must set either name or id in method options")
|
||||||
}
|
}
|
||||||
tok, err := c.generateAdminToken(u.Path)
|
tok, err := c.generateAdminToken(u.Path)
|
||||||
|
|
|
@ -30,7 +30,7 @@ func Bootstrap(token string) (*Client, error) {
|
||||||
|
|
||||||
// Validate bootstrap token
|
// Validate bootstrap token
|
||||||
switch {
|
switch {
|
||||||
case len(claims.SHA) == 0:
|
case claims.SHA == "":
|
||||||
return nil, errors.New("invalid bootstrap token: sha claim is not present")
|
return nil, errors.New("invalid bootstrap token: sha claim is not present")
|
||||||
case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"):
|
case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"):
|
||||||
return nil, errors.New("invalid bootstrap token: aud claim is not a url")
|
return nil, errors.New("invalid bootstrap token: aud claim is not a url")
|
||||||
|
|
110
ca/ca.go
110
ca/ca.go
|
@ -29,10 +29,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
configFile string
|
configFile string
|
||||||
password []byte
|
linkedCAToken string
|
||||||
issuerPassword []byte
|
password []byte
|
||||||
database db.AuthDB
|
issuerPassword []byte
|
||||||
|
sshHostPassword []byte
|
||||||
|
sshUserPassword []byte
|
||||||
|
database db.AuthDB
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *options) apply(opts []Option) {
|
func (o *options) apply(opts []Option) {
|
||||||
|
@ -60,6 +63,22 @@ func WithPassword(password []byte) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithSSHHostPassword sets the given password to decrypt the key used to sign
|
||||||
|
// ssh host certificates.
|
||||||
|
func WithSSHHostPassword(password []byte) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.sshHostPassword = password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSSHUserPassword sets the given password to decrypt the key used to sign
|
||||||
|
// ssh user certificates.
|
||||||
|
func WithSSHUserPassword(password []byte) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.sshUserPassword = password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithIssuerPassword sets the given password as the configured certificate
|
// WithIssuerPassword sets the given password as the configured certificate
|
||||||
// issuer password in the CA options.
|
// issuer password in the CA options.
|
||||||
func WithIssuerPassword(password []byte) Option {
|
func WithIssuerPassword(password []byte) Option {
|
||||||
|
@ -69,9 +88,16 @@ func WithIssuerPassword(password []byte) Option {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDatabase sets the given authority database to the CA options.
|
// WithDatabase sets the given authority database to the CA options.
|
||||||
func WithDatabase(db db.AuthDB) Option {
|
func WithDatabase(d db.AuthDB) Option {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.database = db
|
o.database = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLinkedCAToken sets the token used to authenticate with the linkedca.
|
||||||
|
func WithLinkedCAToken(token string) Option {
|
||||||
|
return func(o *options) {
|
||||||
|
o.linkedCAToken = token
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,35 +113,34 @@ type CA struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates and initializes the CA with the given configuration and options.
|
// New creates and initializes the CA with the given configuration and options.
|
||||||
func New(config *config.Config, opts ...Option) (*CA, error) {
|
func New(cfg *config.Config, opts ...Option) (*CA, error) {
|
||||||
ca := &CA{
|
ca := &CA{
|
||||||
config: config,
|
config: cfg,
|
||||||
opts: new(options),
|
opts: new(options),
|
||||||
}
|
}
|
||||||
ca.opts.apply(opts)
|
ca.opts.apply(opts)
|
||||||
return ca.Init(config)
|
return ca.Init(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the CA with the given configuration.
|
// Init initializes the CA with the given configuration.
|
||||||
func (ca *CA) Init(config *config.Config) (*CA, error) {
|
func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||||
// Intermediate Password.
|
// Set password, it's ok to set nil password, the ca will prompt for them if
|
||||||
if len(ca.opts.password) > 0 {
|
// they are required.
|
||||||
ca.config.Password = string(ca.opts.password)
|
opts := []authority.Option{
|
||||||
|
authority.WithPassword(ca.opts.password),
|
||||||
|
authority.WithSSHHostPassword(ca.opts.sshHostPassword),
|
||||||
|
authority.WithSSHUserPassword(ca.opts.sshUserPassword),
|
||||||
|
authority.WithIssuerPassword(ca.opts.issuerPassword),
|
||||||
|
}
|
||||||
|
if ca.opts.linkedCAToken != "" {
|
||||||
|
opts = append(opts, authority.WithLinkedCAToken(ca.opts.linkedCAToken))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Certificate issuer password for RA mode.
|
|
||||||
if len(ca.opts.issuerPassword) > 0 {
|
|
||||||
if ca.config.AuthorityConfig != nil && ca.config.AuthorityConfig.CertificateIssuer != nil {
|
|
||||||
ca.config.AuthorityConfig.CertificateIssuer.Password = string(ca.opts.issuerPassword)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var opts []authority.Option
|
|
||||||
if ca.opts.database != nil {
|
if ca.opts.database != nil {
|
||||||
opts = append(opts, authority.WithDatabase(ca.opts.database))
|
opts = append(opts, authority.WithDatabase(ca.opts.database))
|
||||||
}
|
}
|
||||||
|
|
||||||
auth, err := authority.New(config, opts...)
|
auth, err := authority.New(cfg, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -141,8 +166,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
})
|
})
|
||||||
|
|
||||||
//Add ACME api endpoints in /acme and /1.0/acme
|
//Add ACME api endpoints in /acme and /1.0/acme
|
||||||
dns := config.DNSNames[0]
|
dns := cfg.DNSNames[0]
|
||||||
u, err := url.Parse("https://" + config.Address)
|
u, err := url.Parse("https://" + cfg.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -154,7 +179,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
// ACME Router
|
// ACME Router
|
||||||
prefix := "acme"
|
prefix := "acme"
|
||||||
var acmeDB acme.DB
|
var acmeDB acme.DB
|
||||||
if config.DB == nil {
|
if cfg.DB == nil {
|
||||||
acmeDB = nil
|
acmeDB = nil
|
||||||
} else {
|
} else {
|
||||||
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
|
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
|
||||||
|
@ -163,7 +188,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
|
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
|
||||||
Backdate: *config.AuthorityConfig.Backdate,
|
Backdate: *cfg.AuthorityConfig.Backdate,
|
||||||
DB: acmeDB,
|
DB: acmeDB,
|
||||||
DNS: dns,
|
DNS: dns,
|
||||||
Prefix: prefix,
|
Prefix: prefix,
|
||||||
|
@ -179,7 +204,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Admin API Router
|
// Admin API Router
|
||||||
if config.AuthorityConfig.EnableAdmin {
|
if cfg.AuthorityConfig.EnableAdmin {
|
||||||
adminDB := auth.GetAdminDatabase()
|
adminDB := auth.GetAdminDatabase()
|
||||||
if adminDB != nil {
|
if adminDB != nil {
|
||||||
adminHandler := adminAPI.NewHandler(auth)
|
adminHandler := adminAPI.NewHandler(auth)
|
||||||
|
@ -223,8 +248,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
//dumpRoutes(mux)
|
//dumpRoutes(mux)
|
||||||
|
|
||||||
// Add monitoring if configured
|
// Add monitoring if configured
|
||||||
if len(config.Monitoring) > 0 {
|
if len(cfg.Monitoring) > 0 {
|
||||||
m, err := monitoring.New(config.Monitoring)
|
m, err := monitoring.New(cfg.Monitoring)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -233,8 +258,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add logger if configured
|
// Add logger if configured
|
||||||
if len(config.Logger) > 0 {
|
if len(cfg.Logger) > 0 {
|
||||||
logger, err := logging.New("ca", config.Logger)
|
logger, err := logging.New("ca", cfg.Logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -242,16 +267,16 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
insecureHandler = logger.Middleware(insecureHandler)
|
insecureHandler = logger.Middleware(insecureHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
ca.srv = server.New(config.Address, handler, tlsConfig)
|
ca.srv = server.New(cfg.Address, handler, tlsConfig)
|
||||||
|
|
||||||
// only start the insecure server if the insecure address is configured
|
// only start the insecure server if the insecure address is configured
|
||||||
// and, currently, also only when it should serve SCEP endpoints.
|
// and, currently, also only when it should serve SCEP endpoints.
|
||||||
if ca.shouldServeSCEPEndpoints() && config.InsecureAddress != "" {
|
if ca.shouldServeSCEPEndpoints() && cfg.InsecureAddress != "" {
|
||||||
// TODO: instead opt for having a single server.Server but two
|
// TODO: instead opt for having a single server.Server but two
|
||||||
// http.Servers handling the HTTP and HTTPS handler? The latter
|
// http.Servers handling the HTTP and HTTPS handler? The latter
|
||||||
// will probably introduce more complexity in terms of graceful
|
// will probably introduce more complexity in terms of graceful
|
||||||
// reload.
|
// reload.
|
||||||
ca.insecureSrv = server.New(config.InsecureAddress, insecureHandler, nil)
|
ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ca, nil
|
return ca, nil
|
||||||
|
@ -260,24 +285,24 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
// Run starts the CA calling to the server ListenAndServe method.
|
// Run starts the CA calling to the server ListenAndServe method.
|
||||||
func (ca *CA) Run() error {
|
func (ca *CA) Run() error {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
errors := make(chan error, 1)
|
errs := make(chan error, 1)
|
||||||
|
|
||||||
if ca.insecureSrv != nil {
|
if ca.insecureSrv != nil {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
errors <- ca.insecureSrv.ListenAndServe()
|
errs <- ca.insecureSrv.ListenAndServe()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
errors <- ca.srv.ListenAndServe()
|
errs <- ca.srv.ListenAndServe()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// wait till error occurs; ensures the servers keep listening
|
// wait till error occurs; ensures the servers keep listening
|
||||||
err := <-errors
|
err := <-errs
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
|
@ -306,7 +331,7 @@ func (ca *CA) Stop() error {
|
||||||
// Reload reloads the configuration of the CA and calls to the server Reload
|
// Reload reloads the configuration of the CA and calls to the server Reload
|
||||||
// method.
|
// method.
|
||||||
func (ca *CA) Reload() error {
|
func (ca *CA) Reload() error {
|
||||||
config, err := config.LoadConfiguration(ca.opts.configFile)
|
cfg, err := config.LoadConfiguration(ca.opts.configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "error reloading ca configuration")
|
return errors.Wrap(err, "error reloading ca configuration")
|
||||||
}
|
}
|
||||||
|
@ -318,14 +343,17 @@ func (ca *CA) Reload() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do not allow reload if the database configuration has changed.
|
// Do not allow reload if the database configuration has changed.
|
||||||
if !reflect.DeepEqual(ca.config.DB, config.DB) {
|
if !reflect.DeepEqual(ca.config.DB, cfg.DB) {
|
||||||
logContinue("Reload failed because the database configuration has changed.")
|
logContinue("Reload failed because the database configuration has changed.")
|
||||||
return errors.New("error reloading ca: database configuration cannot change")
|
return errors.New("error reloading ca: database configuration cannot change")
|
||||||
}
|
}
|
||||||
|
|
||||||
newCA, err := New(config,
|
newCA, err := New(cfg,
|
||||||
WithPassword(ca.opts.password),
|
WithPassword(ca.opts.password),
|
||||||
|
WithSSHHostPassword(ca.opts.sshHostPassword),
|
||||||
|
WithSSHUserPassword(ca.opts.sshUserPassword),
|
||||||
WithIssuerPassword(ca.opts.issuerPassword),
|
WithIssuerPassword(ca.opts.issuerPassword),
|
||||||
|
WithLinkedCAToken(ca.opts.linkedCAToken),
|
||||||
WithConfigFile(ca.opts.configFile),
|
WithConfigFile(ca.opts.configFile),
|
||||||
WithDatabase(ca.auth.GetDatabase()),
|
WithDatabase(ca.auth.GetDatabase()),
|
||||||
)
|
)
|
||||||
|
|
|
@ -322,7 +322,7 @@ ZEp7knvU2psWRw==
|
||||||
assert.Equals(t, intermediate, realIntermediate)
|
assert.Equals(t, intermediate, realIntermediate)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
@ -375,7 +375,7 @@ func TestCAProvisioners(t *testing.T) {
|
||||||
assert.Equals(t, a, b)
|
assert.Equals(t, a, b)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
@ -436,7 +436,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
|
||||||
assert.Equals(t, ek.Key, tc.expectedKey)
|
assert.Equals(t, ek.Key, tc.expectedKey)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
@ -497,7 +497,7 @@ func TestCARoot(t *testing.T) {
|
||||||
assert.Equals(t, root.RootPEM.Certificate, rootCrt)
|
assert.Equals(t, root.RootPEM.Certificate, rootCrt)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
@ -665,7 +665,7 @@ func TestCARenew(t *testing.T) {
|
||||||
assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions)
|
assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
|
32
ca/client.go
32
ca/client.go
|
@ -74,17 +74,17 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) {
|
||||||
c.Client.Transport = tr
|
c.Client.Transport = tr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *uaClient) Get(url string) (*http.Response, error) {
|
func (c *uaClient) Get(u string) (*http.Response, error) {
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
req, err := http.NewRequest("GET", u, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "new request GET %s failed", url)
|
return nil, errors.Wrapf(err, "new request GET %s failed", u)
|
||||||
}
|
}
|
||||||
req.Header.Set("User-Agent", UserAgent)
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
return c.Client.Do(req)
|
return c.Client.Do(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response, error) {
|
func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) {
|
||||||
req, err := http.NewRequest("POST", url, body)
|
req, err := http.NewRequest("POST", u, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -305,7 +305,7 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin
|
||||||
err error
|
err error
|
||||||
opts []jose.Option
|
opts []jose.Option
|
||||||
)
|
)
|
||||||
if len(passwordFile) != 0 {
|
if passwordFile != "" {
|
||||||
opts = append(opts, jose.WithPasswordFile(passwordFile))
|
opts = append(opts, jose.WithPasswordFile(passwordFile))
|
||||||
}
|
}
|
||||||
blk, err := pemutil.Serialize(key)
|
blk, err := pemutil.Serialize(key)
|
||||||
|
@ -326,14 +326,14 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin
|
||||||
|
|
||||||
for _, e := range o.x5cCert.Extensions {
|
for _, e := range o.x5cCert.Extensions {
|
||||||
if e.Id.Equal(stepOIDProvisioner) {
|
if e.Id.Equal(stepOIDProvisioner) {
|
||||||
var provisioner stepProvisionerASN1
|
var prov stepProvisionerASN1
|
||||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
if _, err := asn1.Unmarshal(e.Value, &prov); err != nil {
|
||||||
return errors.Wrap(err, "error unmarshaling provisioner OID from certificate")
|
return errors.Wrap(err, "error unmarshaling provisioner OID from certificate")
|
||||||
}
|
}
|
||||||
o.x5cIssuer = string(provisioner.Name)
|
o.x5cIssuer = string(prov.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(o.x5cIssuer) == 0 {
|
if o.x5cIssuer == "" {
|
||||||
return errors.New("provisioner extension not found in certificate")
|
return errors.New("provisioner extension not found in certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -631,7 +631,7 @@ retry:
|
||||||
// do not match.
|
// do not match.
|
||||||
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
||||||
var retried bool
|
var retried bool
|
||||||
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
|
sha256Sum = strings.ToLower(strings.ReplaceAll(sha256Sum, "-", ""))
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
|
||||||
retry:
|
retry:
|
||||||
resp, err := newInsecureClient().Get(u.String())
|
resp, err := newInsecureClient().Get(u.String())
|
||||||
|
@ -651,7 +651,7 @@ retry:
|
||||||
}
|
}
|
||||||
// verify the sha256
|
// verify the sha256
|
||||||
sum := sha256.Sum256(root.RootPEM.Raw)
|
sum := sha256.Sum256(root.RootPEM.Raw)
|
||||||
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) {
|
if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) {
|
||||||
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
|
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
|
||||||
}
|
}
|
||||||
return &root, nil
|
return &root, nil
|
||||||
|
@ -1066,16 +1066,16 @@ retry:
|
||||||
}
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var config api.SSHConfigResponse
|
var cfg api.SSHConfigResponse
|
||||||
if err := readJSON(resp.Body, &config); err != nil {
|
if err := readJSON(resp.Body, &cfg); err != nil {
|
||||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||||
}
|
}
|
||||||
return &config, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the
|
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the
|
||||||
// given principal.
|
// given principal.
|
||||||
func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrincipalResponse, error) {
|
func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalResponse, error) {
|
||||||
var retried bool
|
var retried bool
|
||||||
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
|
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
|
||||||
Type: provisioner.SSHHostCert,
|
Type: provisioner.SSHHostCert,
|
||||||
|
|
|
@ -135,7 +135,7 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
|
||||||
return csr
|
return csr
|
||||||
}
|
}
|
||||||
|
|
||||||
func equalJSON(t *testing.T, a interface{}, b interface{}) bool {
|
func equalJSON(t *testing.T, a, b interface{}) bool {
|
||||||
if reflect.DeepEqual(a, b) {
|
if reflect.DeepEqual(a, b) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -187,11 +187,12 @@ func TestLoadClient(t *testing.T) {
|
||||||
} else {
|
} else {
|
||||||
gotTransport := got.Client.Transport.(*http.Transport)
|
gotTransport := got.Client.Transport.(*http.Transport)
|
||||||
wantTransport := tt.want.Client.Transport.(*http.Transport)
|
wantTransport := tt.want.Client.Transport.(*http.Transport)
|
||||||
if gotTransport.TLSClientConfig.GetClientCertificate == nil {
|
switch {
|
||||||
|
case gotTransport.TLSClientConfig.GetClientCertificate == nil:
|
||||||
t.Error("LoadClient() transport does not define GetClientCertificate")
|
t.Error("LoadClient() transport does not define GetClientCertificate")
|
||||||
} else if !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()) {
|
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()):
|
||||||
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
|
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
|
||||||
} else {
|
default:
|
||||||
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
|
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("LoadClient() GetClientCertificate error = %v", err)
|
t.Errorf("LoadClient() GetClientCertificate error = %v", err)
|
||||||
|
|
7
ca/testdata/ca.json
vendored
7
ca/testdata/ca.json
vendored
|
@ -9,12 +9,11 @@
|
||||||
"logger": {"format": "text"},
|
"logger": {"format": "text"},
|
||||||
"tls": {
|
"tls": {
|
||||||
"minVersion": 1.2,
|
"minVersion": 1.2,
|
||||||
"maxVersion": 1.2,
|
"maxVersion": 1.3,
|
||||||
"renegotiation": false,
|
"renegotiation": false,
|
||||||
"cipherSuites": [
|
"cipherSuites": [
|
||||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305",
|
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
|
||||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
|
||||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"authority": {
|
"authority": {
|
||||||
|
|
12
ca/tls.go
12
ca/tls.go
|
@ -105,7 +105,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
|
|
||||||
tr := getDefaultTransport(tlsConfig)
|
tr := getDefaultTransport(tlsConfig)
|
||||||
// Use mutable tls.Config on renew
|
// Use mutable tls.Config on renew
|
||||||
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic
|
||||||
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
// Update renew function with transport
|
// Update renew function with transport
|
||||||
tr := getDefaultTransport(tlsConfig)
|
tr := getDefaultTransport(tlsConfig)
|
||||||
// Use mutable tls.Config on renew
|
// Use mutable tls.Config on renew
|
||||||
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic
|
||||||
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
|
@ -195,7 +195,7 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport.
|
// buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport.
|
||||||
// nolint:unused
|
// nolint:unused,gocritic
|
||||||
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
d := getDefaultDialer()
|
d := getDefaultDialer()
|
||||||
|
@ -253,6 +253,8 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:gocritic
|
||||||
|
// using a new variable for clarity
|
||||||
chain := append(certPEM, caPEM...)
|
chain := append(certPEM, caPEM...)
|
||||||
cert, err := tls.X509KeyPair(chain, keyPEM)
|
cert, err := tls.X509KeyPair(chain, keyPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -277,9 +279,9 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
|
||||||
|
|
||||||
// getDefaultDialer returns a new dialer with the default configuration.
|
// getDefaultDialer returns a new dialer with the default configuration.
|
||||||
func getDefaultDialer() *net.Dialer {
|
func getDefaultDialer() *net.Dialer {
|
||||||
|
// With the KeepAlive parameter set to 0, it will be use Golang's default.
|
||||||
return &net.Dialer{
|
return &net.Dialer{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,10 +38,17 @@ type Options struct {
|
||||||
CertificateChain []*x509.Certificate `json:"-"`
|
CertificateChain []*x509.Certificate `json:"-"`
|
||||||
Signer crypto.Signer `json:"-"`
|
Signer crypto.Signer `json:"-"`
|
||||||
|
|
||||||
// IsCreator is set to true when we're creating a certificate authority. Is
|
// IsCreator is set to true when we're creating a certificate authority. It
|
||||||
// used to skip some validations when initializing a CertificateAuthority.
|
// is used to skip some validations when initializing a
|
||||||
|
// CertificateAuthority. This option is used on SoftCAS and CloudCAS.
|
||||||
IsCreator bool `json:"-"`
|
IsCreator bool `json:"-"`
|
||||||
|
|
||||||
|
// IsCAGetter is set to true when we're just using the
|
||||||
|
// CertificateAuthorityGetter interface to retrieve the root certificate. It
|
||||||
|
// is used to skip some validations when initializing a
|
||||||
|
// CertificateAuthority. This option is used on StepCAS.
|
||||||
|
IsCAGetter bool `json:"-"`
|
||||||
|
|
||||||
// KeyManager is the KMS used to generate keys in SoftCAS.
|
// KeyManager is the KMS used to generate keys in SoftCAS.
|
||||||
KeyManager kms.KeyManager `json:"-"`
|
KeyManager kms.KeyManager `json:"-"`
|
||||||
|
|
||||||
|
|
|
@ -108,6 +108,9 @@ type GetCertificateAuthorityResponse struct {
|
||||||
RootCertificate *x509.Certificate
|
RootCertificate *x509.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateKeyRequest is the request used to generate a new key using a KMS.
|
||||||
|
type CreateKeyRequest = apiv1.CreateKeyRequest
|
||||||
|
|
||||||
// CreateCertificateAuthorityRequest is the request used to generate a root or
|
// CreateCertificateAuthorityRequest is the request used to generate a root or
|
||||||
// intermediate certificate.
|
// intermediate certificate.
|
||||||
type CreateCertificateAuthorityRequest struct {
|
type CreateCertificateAuthorityRequest struct {
|
||||||
|
@ -126,7 +129,7 @@ type CreateCertificateAuthorityRequest struct {
|
||||||
// CreateKey defines the KMS CreateKeyRequest to use when creating a new
|
// CreateKey defines the KMS CreateKeyRequest to use when creating a new
|
||||||
// CertificateAuthority. If CreateKey is nil, a default algorithm will be
|
// CertificateAuthority. If CreateKey is nil, a default algorithm will be
|
||||||
// used.
|
// used.
|
||||||
CreateKey *apiv1.CreateKeyRequest
|
CreateKey *CreateKeyRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateCertificateAuthorityResponse is the response for
|
// CreateCertificateAuthorityResponse is the response for
|
||||||
|
@ -136,6 +139,7 @@ type CreateCertificateAuthorityResponse struct {
|
||||||
Name string
|
Name string
|
||||||
Certificate *x509.Certificate
|
Certificate *x509.Certificate
|
||||||
CertificateChain []*x509.Certificate
|
CertificateChain []*x509.Certificate
|
||||||
|
KeyName string
|
||||||
PublicKey crypto.PublicKey
|
PublicKey crypto.PublicKey
|
||||||
PrivateKey crypto.PrivateKey
|
PrivateKey crypto.PrivateKey
|
||||||
Signer crypto.Signer
|
Signer crypto.Signer
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package apiv1
|
package apiv1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/x509"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -26,6 +27,12 @@ type CertificateAuthorityCreator interface {
|
||||||
CreateCertificateAuthority(req *CreateCertificateAuthorityRequest) (*CreateCertificateAuthorityResponse, error)
|
CreateCertificateAuthority(req *CreateCertificateAuthorityRequest) (*CreateCertificateAuthorityResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignatureAlgorithmGetter is an optional implementation in a crypto.Signer
|
||||||
|
// that returns the SignatureAlgorithm to use.
|
||||||
|
type SignatureAlgorithmGetter interface {
|
||||||
|
SignatureAlgorithm() x509.SignatureAlgorithm
|
||||||
|
}
|
||||||
|
|
||||||
// Type represents the CAS type used.
|
// Type represents the CAS type used.
|
||||||
type Type string
|
type Type string
|
||||||
|
|
||||||
|
|
|
@ -29,9 +29,7 @@ func init() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var now = func() time.Time {
|
var now = time.Now
|
||||||
return time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
// The actual regular expression that matches a certificate authority is:
|
// The actual regular expression that matches a certificate authority is:
|
||||||
// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$
|
// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -103,7 +102,7 @@ MHcCAQEEIN51Rgg6YcQVLeCRzumdw4pjM3VWqFIdCbnsV3Up1e/goAoGCCqGSM49
|
||||||
AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk
|
AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk
|
||||||
Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg==
|
Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg==
|
||||||
-----END EC PRIVATE KEY-----`
|
-----END EC PRIVATE KEY-----`
|
||||||
// nolint:unused,deadcode
|
// nolint:unused,deadcode,gocritic
|
||||||
testIntermediateKey = `-----BEGIN EC PRIVATE KEY-----
|
testIntermediateKey = `-----BEGIN EC PRIVATE KEY-----
|
||||||
MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49
|
MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49
|
||||||
AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6
|
AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6
|
||||||
|
@ -190,7 +189,7 @@ func (b *badSigner) Public() crypto.PublicKey {
|
||||||
return b.pub
|
return b.pub
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
func (b *badSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("💥")
|
return nil, fmt.Errorf("💥")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -730,7 +729,7 @@ func TestCloudCAS_RevokeCertificate(t *testing.T) {
|
||||||
func Test_createCertificateID(t *testing.T) {
|
func Test_createCertificateID(t *testing.T) {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
setTeeReader(t, buf)
|
setTeeReader(t, buf)
|
||||||
uuid, err := uuid.NewRandomFromReader(rand.Reader)
|
id, err := uuid.NewRandomFromReader(rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -741,7 +740,7 @@ func Test_createCertificateID(t *testing.T) {
|
||||||
want string
|
want string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", uuid.String(), false},
|
{"ok", id.String(), false},
|
||||||
{"fail", "", true},
|
{"fail", "", true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -858,7 +857,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) {
|
||||||
return lis.Dial()
|
return lis.Dial()
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn))
|
client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn))
|
||||||
|
|
|
@ -19,9 +19,7 @@ func init() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var now = func() time.Time {
|
var now = time.Now
|
||||||
return time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoftCAS implements a Certificate Authority Service using Golang or KMS
|
// SoftCAS implements a Certificate Authority Service using Golang or KMS
|
||||||
// crypto. This is the default CAS used in step-ca.
|
// crypto. This is the default CAS used in step-ca.
|
||||||
|
@ -68,7 +66,7 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
|
||||||
}
|
}
|
||||||
req.Template.Issuer = c.CertificateChain[0].Subject
|
req.Template.Issuer = c.CertificateChain[0].Subject
|
||||||
|
|
||||||
cert, err := x509util.CreateCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer)
|
cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -93,7 +91,7 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R
|
||||||
req.Template.NotAfter = t.Add(req.Lifetime)
|
req.Template.NotAfter = t.Add(req.Lifetime)
|
||||||
req.Template.Issuer = c.CertificateChain[0].Subject
|
req.Template.Issuer = c.CertificateChain[0].Subject
|
||||||
|
|
||||||
cert, err := x509util.CreateCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer)
|
cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -150,12 +148,12 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori
|
||||||
var cert *x509.Certificate
|
var cert *x509.Certificate
|
||||||
switch req.Type {
|
switch req.Type {
|
||||||
case apiv1.RootCA:
|
case apiv1.RootCA:
|
||||||
cert, err = x509util.CreateCertificate(req.Template, req.Template, signer.Public(), signer)
|
cert, err = createCertificate(req.Template, req.Template, signer.Public(), signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case apiv1.IntermediateCA:
|
case apiv1.IntermediateCA:
|
||||||
cert, err = x509util.CreateCertificate(req.Template, req.Parent.Certificate, signer.Public(), req.Parent.Signer)
|
cert, err = createCertificate(req.Template, req.Parent.Certificate, signer.Public(), req.Parent.Signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -174,6 +172,7 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori
|
||||||
Name: cert.Subject.CommonName,
|
Name: cert.Subject.CommonName,
|
||||||
Certificate: cert,
|
Certificate: cert,
|
||||||
CertificateChain: chain,
|
CertificateChain: chain,
|
||||||
|
KeyName: key.Name,
|
||||||
PublicKey: key.PublicKey,
|
PublicKey: key.PublicKey,
|
||||||
PrivateKey: key.PrivateKey,
|
PrivateKey: key.PrivateKey,
|
||||||
Signer: signer,
|
Signer: signer,
|
||||||
|
@ -210,3 +209,16 @@ func (c *SoftCAS) createSigner(req *kmsapi.CreateSignerRequest) (crypto.Signer,
|
||||||
}
|
}
|
||||||
return c.KeyManager.CreateSigner(req)
|
return c.KeyManager.CreateSigner(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createCertificate sets the SignatureAlgorithm of the template if necessary
|
||||||
|
// and calls x509util.CreateCertificate.
|
||||||
|
func createCertificate(template, parent *x509.Certificate, pub crypto.PublicKey, signer crypto.Signer) (*x509.Certificate, error) {
|
||||||
|
// Signers can specify the signature algorithm. This is especially important
|
||||||
|
// when x509.CreateCertificate attempts to validate a RSAPSS signature.
|
||||||
|
if template.SignatureAlgorithm == 0 {
|
||||||
|
if sa, ok := signer.(apiv1.SignatureAlgorithmGetter); ok {
|
||||||
|
template.SignatureAlgorithm = sa.SignatureAlgorithm()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return x509util.CreateCertificate(template, parent, pub, signer)
|
||||||
|
}
|
||||||
|
|
|
@ -75,6 +75,15 @@ var (
|
||||||
testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour))
|
testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type signatureAlgorithmSigner struct {
|
||||||
|
crypto.Signer
|
||||||
|
algorithm x509.SignatureAlgorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *signatureAlgorithmSigner) SignatureAlgorithm() x509.SignatureAlgorithm {
|
||||||
|
return s.algorithm
|
||||||
|
}
|
||||||
|
|
||||||
type mockKeyManager struct {
|
type mockKeyManager struct {
|
||||||
signer crypto.Signer
|
signer crypto.Signer
|
||||||
errGetPublicKey error
|
errGetPublicKey error
|
||||||
|
@ -97,6 +106,7 @@ func (m *mockKeyManager) CreateKey(req *kmsapi.CreateKeyRequest) (*kmsapi.Create
|
||||||
signer = m.signer
|
signer = m.signer
|
||||||
}
|
}
|
||||||
return &kmsapi.CreateKeyResponse{
|
return &kmsapi.CreateKeyResponse{
|
||||||
|
Name: req.Name,
|
||||||
PrivateKey: signer,
|
PrivateKey: signer,
|
||||||
PublicKey: signer.Public(),
|
PublicKey: signer.Public(),
|
||||||
}, m.errCreateKey
|
}, m.errCreateKey
|
||||||
|
@ -124,7 +134,7 @@ func (b *badSigner) Public() crypto.PublicKey {
|
||||||
return testSigner.Public()
|
return testSigner.Public()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
func (b *badSigner) Sign(_ io.Reader, _ []byte, _ crypto.SignerOpts) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("💥")
|
return nil, fmt.Errorf("💥")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,6 +257,13 @@ func TestSoftCAS_CreateCertificate(t *testing.T) {
|
||||||
tmplNoSerial := *testTemplate
|
tmplNoSerial := *testTemplate
|
||||||
tmplNoSerial.SerialNumber = nil
|
tmplNoSerial.SerialNumber = nil
|
||||||
|
|
||||||
|
saTemplate := *testSignedTemplate
|
||||||
|
saTemplate.SignatureAlgorithm = 0
|
||||||
|
saSigner := &signatureAlgorithmSigner{
|
||||||
|
Signer: testSigner,
|
||||||
|
algorithm: x509.PureEd25519,
|
||||||
|
}
|
||||||
|
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Issuer *x509.Certificate
|
Issuer *x509.Certificate
|
||||||
Signer crypto.Signer
|
Signer crypto.Signer
|
||||||
|
@ -267,6 +284,12 @@ func TestSoftCAS_CreateCertificate(t *testing.T) {
|
||||||
Certificate: testSignedTemplate,
|
Certificate: testSignedTemplate,
|
||||||
CertificateChain: []*x509.Certificate{testIssuer},
|
CertificateChain: []*x509.Certificate{testIssuer},
|
||||||
}, false},
|
}, false},
|
||||||
|
{"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.CreateCertificateRequest{
|
||||||
|
Template: &saTemplate, Lifetime: 24 * time.Hour,
|
||||||
|
}}, &apiv1.CreateCertificateResponse{
|
||||||
|
Certificate: testSignedTemplate,
|
||||||
|
CertificateChain: []*x509.Certificate{testIssuer},
|
||||||
|
}, false},
|
||||||
{"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{
|
{"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{
|
||||||
Template: &tmplNotBefore, Lifetime: 24 * time.Hour,
|
Template: &tmplNotBefore, Lifetime: 24 * time.Hour,
|
||||||
}}, &apiv1.CreateCertificateResponse{
|
}}, &apiv1.CreateCertificateResponse{
|
||||||
|
@ -316,6 +339,11 @@ func TestSoftCAS_RenewCertificate(t *testing.T) {
|
||||||
tmplNoSerial := *testTemplate
|
tmplNoSerial := *testTemplate
|
||||||
tmplNoSerial.SerialNumber = nil
|
tmplNoSerial.SerialNumber = nil
|
||||||
|
|
||||||
|
saSigner := &signatureAlgorithmSigner{
|
||||||
|
Signer: testSigner,
|
||||||
|
algorithm: x509.PureEd25519,
|
||||||
|
}
|
||||||
|
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Issuer *x509.Certificate
|
Issuer *x509.Certificate
|
||||||
Signer crypto.Signer
|
Signer crypto.Signer
|
||||||
|
@ -336,6 +364,12 @@ func TestSoftCAS_RenewCertificate(t *testing.T) {
|
||||||
Certificate: testSignedTemplate,
|
Certificate: testSignedTemplate,
|
||||||
CertificateChain: []*x509.Certificate{testIssuer},
|
CertificateChain: []*x509.Certificate{testIssuer},
|
||||||
}, false},
|
}, false},
|
||||||
|
{"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.RenewCertificateRequest{
|
||||||
|
Template: testTemplate, Lifetime: 24 * time.Hour,
|
||||||
|
}}, &apiv1.RenewCertificateResponse{
|
||||||
|
Certificate: testSignedTemplate,
|
||||||
|
CertificateChain: []*x509.Certificate{testIssuer},
|
||||||
|
}, false},
|
||||||
{"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true},
|
{"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true},
|
||||||
{"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true},
|
{"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true},
|
||||||
{"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{
|
{"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{
|
||||||
|
@ -425,6 +459,11 @@ func Test_now(t *testing.T) {
|
||||||
func TestSoftCAS_CreateCertificateAuthority(t *testing.T) {
|
func TestSoftCAS_CreateCertificateAuthority(t *testing.T) {
|
||||||
mockNow(t)
|
mockNow(t)
|
||||||
|
|
||||||
|
saSigner := &signatureAlgorithmSigner{
|
||||||
|
Signer: testSigner,
|
||||||
|
algorithm: x509.PureEd25519,
|
||||||
|
}
|
||||||
|
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Issuer *x509.Certificate
|
Issuer *x509.Certificate
|
||||||
Signer crypto.Signer
|
Signer crypto.Signer
|
||||||
|
@ -467,6 +506,33 @@ func TestSoftCAS_CreateCertificateAuthority(t *testing.T) {
|
||||||
PrivateKey: testSigner,
|
PrivateKey: testSigner,
|
||||||
Signer: testSigner,
|
Signer: testSigner,
|
||||||
}, false},
|
}, false},
|
||||||
|
{"ok signature algorithm", fields{nil, nil, &mockKeyManager{signer: saSigner}}, args{&apiv1.CreateCertificateAuthorityRequest{
|
||||||
|
Type: apiv1.RootCA,
|
||||||
|
Template: testRootTemplate,
|
||||||
|
Lifetime: 24 * time.Hour,
|
||||||
|
}}, &apiv1.CreateCertificateAuthorityResponse{
|
||||||
|
Name: "Test Root CA",
|
||||||
|
Certificate: testSignedRootTemplate,
|
||||||
|
PublicKey: testSignedRootTemplate.PublicKey,
|
||||||
|
PrivateKey: saSigner,
|
||||||
|
Signer: saSigner,
|
||||||
|
}, false},
|
||||||
|
{"ok createKey", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{
|
||||||
|
Type: apiv1.RootCA,
|
||||||
|
Template: testRootTemplate,
|
||||||
|
Lifetime: 24 * time.Hour,
|
||||||
|
CreateKey: &kmsapi.CreateKeyRequest{
|
||||||
|
Name: "root_ca.crt",
|
||||||
|
SignatureAlgorithm: kmsapi.ECDSAWithSHA256,
|
||||||
|
},
|
||||||
|
}}, &apiv1.CreateCertificateAuthorityResponse{
|
||||||
|
Name: "Test Root CA",
|
||||||
|
Certificate: testSignedRootTemplate,
|
||||||
|
PublicKey: testSignedRootTemplate.PublicKey,
|
||||||
|
KeyName: "root_ca.crt",
|
||||||
|
PrivateKey: testSigner,
|
||||||
|
Signer: testSigner,
|
||||||
|
}, false},
|
||||||
{"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{
|
{"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{
|
||||||
Type: apiv1.RootCA,
|
Type: apiv1.RootCA,
|
||||||
Lifetime: 24 * time.Hour,
|
Lifetime: 24 * time.Hour,
|
||||||
|
|
|
@ -47,10 +47,13 @@ func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create configured issuer
|
var iss stepIssuer
|
||||||
iss, err := newStepIssuer(caURL, client, opts.CertificateIssuer)
|
// Create configured issuer unless we only want to use GetCertificateAuthority.
|
||||||
if err != nil {
|
// This avoid the request for the password if not provided.
|
||||||
return nil, err
|
if !opts.IsCAGetter {
|
||||||
|
if iss, err = newStepIssuer(caURL, client, opts.CertificateIssuer); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &StepCAS{
|
return &StepCAS{
|
||||||
|
@ -87,9 +90,9 @@ func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R
|
||||||
return nil, apiv1.ErrNotImplemented{Message: "stepCAS does not support mTLS renewals"}
|
return nil, apiv1.ErrNotImplemented{Message: "stepCAS does not support mTLS renewals"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RevokeCertificate revokes a certificate.
|
||||||
func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
|
func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
|
||||||
switch {
|
if req.SerialNumber == "" && req.Certificate == nil {
|
||||||
case req.SerialNumber == "" && req.Certificate == nil:
|
|
||||||
return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required")
|
return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -411,6 +411,19 @@ func TestNew(t *testing.T) {
|
||||||
client: client,
|
client: client,
|
||||||
fingerprint: testRootFingerprint,
|
fingerprint: testRootFingerprint,
|
||||||
}, false},
|
}, false},
|
||||||
|
{"ok ca getter", args{context.TODO(), apiv1.Options{
|
||||||
|
IsCAGetter: true,
|
||||||
|
CertificateAuthority: caURL.String(),
|
||||||
|
CertificateAuthorityFingerprint: testRootFingerprint,
|
||||||
|
CertificateIssuer: &apiv1.CertificateIssuer{
|
||||||
|
Type: "jwk",
|
||||||
|
Provisioner: "ra@doe.org",
|
||||||
|
},
|
||||||
|
}}, &StepCAS{
|
||||||
|
iss: nil,
|
||||||
|
client: client,
|
||||||
|
fingerprint: testRootFingerprint,
|
||||||
|
}, false},
|
||||||
{"fail authority", args{context.TODO(), apiv1.Options{
|
{"fail authority", args{context.TODO(), apiv1.Options{
|
||||||
CertificateAuthority: "",
|
CertificateAuthority: "",
|
||||||
CertificateAuthorityFingerprint: testRootFingerprint,
|
CertificateAuthorityFingerprint: testRootFingerprint,
|
||||||
|
|
|
@ -19,9 +19,7 @@ const defaultValidity = 5 * time.Minute
|
||||||
|
|
||||||
// timeNow returns the current time.
|
// timeNow returns the current time.
|
||||||
// This method is used for unit testing purposes.
|
// This method is used for unit testing purposes.
|
||||||
var timeNow = func() time.Time {
|
var timeNow = time.Now
|
||||||
return time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
type x5cIssuer struct {
|
type x5cIssuer struct {
|
||||||
caURL *url.URL
|
caURL *url.URL
|
||||||
|
|
|
@ -22,7 +22,7 @@ func (b noneSigner) Public() crypto.PublicKey {
|
||||||
return []byte(b)
|
return []byte(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b noneSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
|
func (b noneSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
|
||||||
return digest, nil
|
return digest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,13 +24,16 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var credentialsFile, region string
|
var credentialsFile, region string
|
||||||
var ssh bool
|
var enableSSH bool
|
||||||
flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the AWS KMS credentials.")
|
flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the AWS KMS credentials.")
|
||||||
flag.StringVar(®ion, "region", "", "AWS KMS region name.")
|
flag.StringVar(®ion, "region", "", "AWS KMS region name.")
|
||||||
flag.BoolVar(&ssh, "ssh", false, "Create SSH keys.")
|
flag.BoolVar(&enableSSH, "ssh", false, "Create SSH keys.")
|
||||||
flag.Usage = usage
|
flag.Usage = usage
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
// Initialize windows terminal
|
||||||
|
ui.Init()
|
||||||
|
|
||||||
c, err := awskms.New(context.Background(), apiv1.Options{
|
c, err := awskms.New(context.Background(), apiv1.Options{
|
||||||
Type: string(apiv1.AmazonKMS),
|
Type: string(apiv1.AmazonKMS),
|
||||||
Region: region,
|
Region: region,
|
||||||
|
@ -44,16 +47,20 @@ func main() {
|
||||||
fatal(err)
|
fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ssh {
|
if enableSSH {
|
||||||
ui.Println()
|
ui.Println()
|
||||||
if err := createSSH(c); err != nil {
|
if err := createSSH(c); err != nil {
|
||||||
fatal(err)
|
fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset windows terminal
|
||||||
|
ui.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
func fatal(err error) {
|
func fatal(err error) {
|
||||||
fmt.Fprintln(os.Stderr, err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
ui.Reset()
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,7 +120,7 @@ func createX509(c *awskms.KMS) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
@ -156,7 +163,7 @@ func createX509(c *awskms.KMS) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
@ -186,7 +193,7 @@ func createSSH(c *awskms.KMS) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
if err := fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,7 +214,7 @@ func createSSH(c *awskms.KMS) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
if err := fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,12 @@ import (
|
||||||
"go.step.sm/cli-utils/command"
|
"go.step.sm/cli-utils/command"
|
||||||
"go.step.sm/cli-utils/command/version"
|
"go.step.sm/cli-utils/command/version"
|
||||||
"go.step.sm/cli-utils/config"
|
"go.step.sm/cli-utils/config"
|
||||||
|
"go.step.sm/cli-utils/ui"
|
||||||
"go.step.sm/cli-utils/usage"
|
"go.step.sm/cli-utils/usage"
|
||||||
|
|
||||||
// Enabled kms interfaces.
|
// Enabled kms interfaces.
|
||||||
_ "github.com/smallstep/certificates/kms/awskms"
|
_ "github.com/smallstep/certificates/kms/awskms"
|
||||||
|
_ "github.com/smallstep/certificates/kms/azurekms"
|
||||||
_ "github.com/smallstep/certificates/kms/cloudkms"
|
_ "github.com/smallstep/certificates/kms/cloudkms"
|
||||||
_ "github.com/smallstep/certificates/kms/softkms"
|
_ "github.com/smallstep/certificates/kms/softkms"
|
||||||
_ "github.com/smallstep/certificates/kms/sshagentkms"
|
_ "github.com/smallstep/certificates/kms/sshagentkms"
|
||||||
|
@ -52,6 +54,11 @@ func init() {
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func exit(code int) {
|
||||||
|
ui.Reset()
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
// appHelpTemplate contains the modified template for the main app
|
// appHelpTemplate contains the modified template for the main app
|
||||||
var appHelpTemplate = `## NAME
|
var appHelpTemplate = `## NAME
|
||||||
**{{.HelpName}}** -- {{.Usage}}
|
**{{.HelpName}}** -- {{.Usage}}
|
||||||
|
@ -90,6 +97,9 @@ Please send us a sentence or two, good or bad: **feedback@smallstep.com** or htt
|
||||||
`
|
`
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
// Initialize windows terminal
|
||||||
|
ui.Init()
|
||||||
|
|
||||||
// Override global framework components
|
// Override global framework components
|
||||||
cli.VersionPrinter = func(c *cli.Context) {
|
cli.VersionPrinter = func(c *cli.Context) {
|
||||||
version.Command(c)
|
version.Command(c)
|
||||||
|
@ -107,7 +117,9 @@ func main() {
|
||||||
app.HelpName = "step-ca"
|
app.HelpName = "step-ca"
|
||||||
app.Version = config.Version()
|
app.Version = config.Version()
|
||||||
app.Usage = "an online certificate authority for secure automated certificate management"
|
app.Usage = "an online certificate authority for secure automated certificate management"
|
||||||
app.UsageText = `**step-ca** <config> [**--password-file**=<file>] [**--issuer-password-file**=<file>] [**--resolver**=<addr>] [**--help**] [**--version**]`
|
app.UsageText = `**step-ca** <config> [**--password-file**=<file>]
|
||||||
|
[**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>]
|
||||||
|
[**--issuer-password-file**=<file>] [**--resolver**=<addr>] [**--help**] [**--version**]`
|
||||||
app.Description = `**step-ca** runs the Step Online Certificate Authority
|
app.Description = `**step-ca** runs the Step Online Certificate Authority
|
||||||
(Step CA) using the given configuration.
|
(Step CA) using the given configuration.
|
||||||
See the README.md for more detailed configuration documentation.
|
See the README.md for more detailed configuration documentation.
|
||||||
|
@ -162,8 +174,10 @@ $ step-ca $STEPPATH/config/ca.json --password-file ./password.txt
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintln(os.Stderr, err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
}
|
}
|
||||||
os.Exit(1)
|
exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func flagValue(f cli.Flag) reflect.Value {
|
func flagValue(f cli.Flag) reflect.Value {
|
||||||
|
@ -178,8 +192,8 @@ var placeholderString = regexp.MustCompile(`<.*?>`)
|
||||||
|
|
||||||
func stringifyFlag(f cli.Flag) string {
|
func stringifyFlag(f cli.Flag) string {
|
||||||
fv := flagValue(f)
|
fv := flagValue(f)
|
||||||
usage := fv.FieldByName("Usage").String()
|
usg := fv.FieldByName("Usage").String()
|
||||||
placeholder := placeholderString.FindString(usage)
|
placeholder := placeholderString.FindString(usg)
|
||||||
if placeholder == "" {
|
if placeholder == "" {
|
||||||
switch f.(type) {
|
switch f.(type) {
|
||||||
case cli.BoolFlag, cli.BoolTFlag:
|
case cli.BoolFlag, cli.BoolTFlag:
|
||||||
|
@ -187,5 +201,5 @@ func stringifyFlag(f cli.Flag) string {
|
||||||
placeholder = "<value>"
|
placeholder = "<value>"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage
|
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usg
|
||||||
}
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue