Merge branch 'master' into hs/acme-revocation

This commit is contained in:
Herman Slatman 2021-10-30 15:41:29 +02:00
commit 3151255a25
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
153 changed files with 6603 additions and 1745 deletions

View file

@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.15', '1.16' ] go: [ '1.15', '1.16', '1.17' ]
outputs: outputs:
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
steps: steps:
@ -62,8 +62,15 @@ jobs:
needs: test needs: test
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
outputs: outputs:
debversion: ${{ steps.extract-tag.outputs.DEB_VERSION }}
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
steps: steps:
-
name: Extract Tag Names
id: extract-tag
run: |
DEB_VERSION=$(echo ${GITHUB_REF#refs/tags/v} | sed 's/-/./')
echo "::set-output name=DEB_VERSION::${DEB_VERSION}"
- -
name: Is Pre-release name: Is Pre-release
id: is_prerelease id: is_prerelease
@ -99,62 +106,71 @@ jobs:
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.16 go-version: 1.17
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@56f5b77f7fa4a8fe068bf22b732ec036cc9bc13f # v2.4.1
with:
version: latest
args: release --rm-dist
env:
GITHUB_TOKEN: ${{ secrets.PAT }}
release_deb:
name: Build & Upload Debian Package To Github
runs-on: ubuntu-20.04
needs: create_release
steps:
-
name: Checkout
uses: actions/checkout@v2
with:
fetch-depth: 0
-
name: Set up Go
uses: actions/setup-go@v2
with:
go-version: '1.16'
- -
name: APT Install name: APT Install
id: aptInstall id: aptInstall
run: sudo apt-get -y install build-essential debhelper fakeroot run: sudo apt-get -y install build-essential debhelper fakeroot
- -
name: Build Debian package name: Build Debian package
id: build id: make_debian
run: | run: |
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
make debian make debian
# need to restore the git state otherwise goreleaser fails due to dirty state
git restore debian/changelog
git clean -fd
- -
name: Upload Debian Package name: Install cosign
id: upload_deb uses: sigstore/cosign-installer@v1.1.0
with:
cosign-release: 'v1.1.0'
-
name: Write cosign key to disk
id: write_key
run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
-
name: Get Release Date
id: release_date
run: | run: |
tag_name="${GITHUB_REF##*/}" RELEASE_DATE=$(date +"%y-%m-%d")
hub release edit $(find ./.releases -type f -printf "-a %p ") -m "" "$tag_name" echo "::set-output name=RELEASE_DATE::${RELEASE_DATE}"
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0
with:
version: latest
args: release --rm-dist
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.PAT }}
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}
DEB_VERSION: ${{ needs.create_release.outputs.debversion }}
RELEASE_DATE: ${{ steps.release_date.outputs.RELEASE_DATE }}
build_upload_docker: build_upload_docker:
name: Build & Upload Docker Images name: Build & Upload Docker Images
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
needs: test needs: test
steps: steps:
- name: Checkout -
name: Checkout
uses: actions/checkout@v2 uses: actions/checkout@v2
- name: Setup Go -
name: Setup Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: '1.16' go-version: '1.17'
- name: Build -
name: Install cosign
uses: sigstore/cosign-installer@v1.1.0
with:
cosign-release: 'v1.1.0'
-
name: Write cosign key to disk
id: write_key
run: echo "${{ secrets.COSIGN_KEY }}" > "/tmp/cosign.key"
-
name: Build
id: build id: build
run: | run: |
PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin PATH=$PATH:/usr/local/go/bin:/home/admin/go/bin
@ -162,3 +178,4 @@ jobs:
env: env:
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }}
COSIGN_PWD: ${{ secrets.COSIGN_PWD }}

View file

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.15', '1.16' ] go: [ '1.15', '1.16', '1.17' ]
steps: steps:
- -
name: Checkout name: Checkout

4
.gitignore vendored
View file

@ -14,8 +14,8 @@
# Others # Others
*.swp *.swp
.travis-releases .releases
coverage.txt coverage.txt
vendor
output output
vendor
.idea .idea

View file

@ -36,22 +36,30 @@ linters-settings:
- performance - performance
- style - style
- experimental - experimental
- diagnostic
disabled-checks: disabled-checks:
- wrapperFunc - commentFormatting
- dupImport # https://github.com/go-critic/go-critic/issues/845 - commentedOutCode
- evalOrder
- hugeParam
- octalLiteral
- rangeValCopy
- tooManyResultsChecker
- unnamedResult
linters: linters:
disable-all: true disable-all: true
enable: enable:
- gofmt
- revive
- govet
- misspell
- ineffassign
- deadcode - deadcode
- gocritic
- gofmt
- gosimple
- govet
- ineffassign
- misspell
- revive
- staticcheck - staticcheck
- unused - unused
- gosimple
run: run:
skip-dirs: skip-dirs:

View file

@ -1,34 +1,27 @@
# This is an example .goreleaser.yml file with some sane defaults. # This is an example .goreleaser.yml file with some sane defaults.
# Make sure to check the documentation at http://goreleaser.com # Make sure to check the documentation at http://goreleaser.com
project_name: step-ca project_name: step-ca
before: before:
hooks: hooks:
# You may remove this if you don't use go modules. # You may remove this if you don't use go modules.
- go mod download - go mod download
builds: builds:
- -
id: step-ca id: step-ca
env: env:
- CGO_ENABLED=0 - CGO_ENABLED=0
goos: targets:
- linux - darwin_amd64
- darwin - darwin_arm64
- windows - freebsd_amd64
goarch: - linux_386
- amd64 - linux_amd64
- arm - linux_arm64
- arm64 - linux_arm_6
- 386 - linux_arm_7
goarm: - windows_amd64
- 6
- 7
ignore:
- goos: windows
goarch: 386
- goos: windows
goarm: 6
- goos: windows
goarm: 7
flags: flags:
- -trimpath - -trimpath
main: ./cmd/step-ca/main.go main: ./cmd/step-ca/main.go
@ -39,25 +32,16 @@ builds:
id: step-cloudkms-init id: step-cloudkms-init
env: env:
- CGO_ENABLED=0 - CGO_ENABLED=0
goos: targets:
- linux - darwin_amd64
- darwin - darwin_arm64
- windows - freebsd_amd64
goarch: - linux_386
- amd64 - linux_amd64
- arm - linux_arm64
- arm64 - linux_arm_6
- 386 - linux_arm_7
goarm: - windows_amd64
- 6
- 7
ignore:
- goos: windows
goarch: 386
- goos: windows
goarm: 6
- goos: windows
goarm: 7
flags: flags:
- -trimpath - -trimpath
main: ./cmd/step-cloudkms-init/main.go main: ./cmd/step-cloudkms-init/main.go
@ -68,31 +52,23 @@ builds:
id: step-awskms-init id: step-awskms-init
env: env:
- CGO_ENABLED=0 - CGO_ENABLED=0
goos: targets:
- linux - darwin_amd64
- darwin - darwin_arm64
- windows - freebsd_amd64
goarch: - linux_386
- amd64 - linux_amd64
- arm - linux_arm64
- arm64 - linux_arm_6
- 386 - linux_arm_7
goarm: - windows_amd64
- 6
- 7
ignore:
- goos: windows
goarch: 386
- goos: windows
goarm: 6
- goos: windows
goarm: 7
flags: flags:
- -trimpath - -trimpath
main: ./cmd/step-awskms-init/main.go main: ./cmd/step-awskms-init/main.go
binary: bin/step-awskms-init binary: bin/step-awskms-init
ldflags: ldflags:
- -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}} - -w -X main.Version={{.Version}} -X main.BuildTime={{.Date}}
archives: archives:
- -
# Can be used to change the archive formats for specific GOOSs. # Can be used to change the archive formats for specific GOOSs.
@ -106,13 +82,25 @@ archives:
files: files:
- README.md - README.md
- LICENSE - LICENSE
source: source:
enabled: true enabled: true
name_template: '{{ .ProjectName }}_{{ .Version }}' name_template: '{{ .ProjectName }}_{{ .Version }}'
checksum: checksum:
name_template: 'checksums.txt' name_template: 'checksums.txt'
extra_files:
- glob: ./.releases/*
signs:
- cmd: cosign
stdin: '{{ .Env.COSIGN_PWD }}'
args: ["sign-blob", "-key=/tmp/cosign.key", "-output=${signature}", "${artifact}"]
artifacts: all
snapshot: snapshot:
name_template: "{{ .Tag }}-next" name_template: "{{ .Tag }}-next"
release: release:
# Repo in which the release will be created. # Repo in which the release will be created.
# Default is extracted from the origin remote URL or empty if its private hosted. # Default is extracted from the origin remote URL or empty if its private hosted.
@ -139,7 +127,55 @@ release:
# You can change the name of the release. # You can change the name of the release.
# Default is `{{.Tag}}` # Default is `{{.Tag}}`
#name_template: "{{.ProjectName}}-v{{.Version}} {{.Env.USER}}" name_template: "Step CA {{ .Tag }} ({{ .Env.RELEASE_DATE }})"
# Header template for the release body.
# Defaults to empty.
header: |
## Official Release Artifacts
#### Linux
- 📦 [step-ca_linux_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_linux_{{ .Version }}_amd64.tar.gz)
- 📦 [step-ca_{{ .Env.DEB_VERSION }}_amd64.deb](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_{{ .Env.DEB_VERSION }}_amd64.deb)
#### OSX Darwin
- 📦 [step-ca_darwin_{{ .Version }}_amd64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_amd64.tar.gz)
- 📦 [step-ca_darwin_{{ .Version }}_arm64.tar.gz](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_darwin_{{ .Version }}_arm64.tar.gz)
#### Windows
- 📦 [step-ca_windows_{{ .Version }}_arm64.zip](https://dl.step.sm/gh-release/certificates/gh-release-header/{{ .Tag }}/step-ca_windows_{{ .Version }}_amd64.zip)
For more builds across platforms and architectures, see the `Assets` section below.
And for packaged versions (Docker, k8s, Homebrew), see our [installation docs](https://smallstep.com/docs/step-ca/installation).
Don't see the artifact you need? Open an issue [here](https://github.com/smallstep/certificates/issues/new/choose).
## Signatures and Checksums
`step-ca` uses [sigstore/cosign](https://github.com/sigstore/cosign) for signing and verifying release artifacts.
Below is an example using `cosign` to verify a release artifact:
```
cosign verify-blob \
-key https://raw.githubusercontent.com/smallstep/certificates/master/cosign.pub \
-signature ~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz.sig
~/Downloads/step-ca_darwin_{{ .Version }}_amd64.tar.gz
```
The `checksums.txt` file (in the `Assets` section below) contains a checksum for every artifact in the release.
# Footer template for the release body.
# Defaults to empty.
footer: |
## Thanks!
Those were the changes on {{ .Tag }}!
Come join us on [Discord](https://discord.gg/X2RKGwEbV9) to ask questions, chat about PKI, or get a sneak peak at the freshest PKI memes.
# You can disable this pipe in order to not upload any artifacts. # You can disable this pipe in order to not upload any artifacts.
# Defaults to false. # Defaults to false.
@ -149,6 +185,8 @@ release:
# The filename on the release will be the last part of the path (base). If # The filename on the release will be the last part of the path (base). If
# another file with the same name exists, the latest one found will be used. # another file with the same name exists, the latest one found will be used.
# Defaults to empty. # Defaults to empty.
extra_files:
- glob: ./.releases/*
#extra_files: #extra_files:
# - glob: ./path/to/file.txt # - glob: ./path/to/file.txt
# - glob: ./glob/**/to/**/file/**/* # - glob: ./glob/**/to/**/file/**/*

View file

@ -4,10 +4,62 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased - 0.0.1] - DATE ## [Unreleased - 0.17.7] - DATE
### Added ### Added
- Support for generate extractable keys and certificates on a pkcs#11 module.
### Changed ### Changed
### Deprecated ### Deprecated
### Removed ### Removed
### Fixed ### Fixed
### Security ### Security
## [0.17.6] - 2021-10-20
### Notes
- 0.17.5 failed in CI/CD
## [0.17.5] - 2021-10-20
### Added
- Support for Azure Key Vault as a KMS.
- Adapt `pki` package to support key managers.
- gocritic linter
### Fixed
- gocritic warnings
## [0.17.4] - 2021-09-28
### Fixed
- Support host-only or user-only SSH CA.
## [0.17.3] - 2021-09-24
### Added
- go 1.17 to github action test matrix
- Support for CloudKMS RSA-PSS signers without using templates.
- Add flags to support individual passwords for the intermediate and SSH keys.
- Global support for group admins in the OIDC provisioner.
### Changed
- Using go 1.17 for binaries
### Fixed
- Upgrade go-jose.v2 to fix a bug in the JWK fingerprint of Ed25519 keys.
### Security
- Use cosign to sign and upload signatures for multi-arch Docker container.
- Add debian checksum
## [0.17.2] - 2021-08-30
### Added
- Additional way to distinguish Azure IID and Azure OIDC tokens.
### Security
- Sign over all goreleaser github artifacts using cosign
## [0.17.1] - 2021-08-26
## [0.17.0] - 2021-08-25
### Added
- Add support for Linked CAs using protocol buffers and gRPC
- `step-ca init` adds support for
- configuring a StepCAS RA
- configuring a Linked CA
- congifuring a `step-ca` using Helm
### Changed
- Update badger driver to use v2 by default
- Update TLS cipher suites to include 1.3
### Security
- Fix key version when SHA512WithRSA is used. There was a typo creating RSA keys with SHA256 digests instead of SHA512.

View file

@ -29,7 +29,7 @@ ci: testcgo build
bootstra%: bootstra%:
# Using a released version of golangci-lint to take into account custom replacements in their go.mod # Using a released version of golangci-lint to take into account custom replacements in their go.mod
$Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v1.39.0 $Q curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(shell go env GOPATH)/bin v1.42.0
.PHONY: bootstra% .PHONY: bootstra%
@ -68,7 +68,7 @@ PUSHTYPE := branch
endif endif
VERSION := $(shell echo $(VERSION) | sed 's/^v//') VERSION := $(shell echo $(VERSION) | sed 's/^v//')
DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/~/g') DEB_VERSION := $(shell echo $(VERSION) | sed 's/-/./g')
ifdef V ifdef V
$(info TRAVIS_TAG is $(TRAVIS_TAG)) $(info TRAVIS_TAG is $(TRAVIS_TAG))
@ -154,7 +154,7 @@ fmt:
$Q gofmt -l -w $(SRC) $Q gofmt -l -w $(SRC)
lint: lint:
$Q $(GOFLAGS) LOG_LEVEL=error golangci-lint run --timeout=30m $Q golangci-lint run --timeout=30m
lintcgo: lintcgo:
$Q LOG_LEVEL=error golangci-lint run --timeout=30m $Q LOG_LEVEL=error golangci-lint run --timeout=30m

View file

@ -18,7 +18,14 @@ You can use it to:
Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [safe, sane defaults](https://smallstep.com/docs/step-ca/certificate-authority-server-production#sane-cryptographic-defaults). Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [safe, sane defaults](https://smallstep.com/docs/step-ca/certificate-authority-server-production#sane-cryptographic-defaults).
**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions).** ---
**Don't want to run your own CA?**
To get up and running quickly, or as an alternative to running your own `step-ca` server, consider creating a [free hosted smallstep Certificate Manager authority](https://info.smallstep.com/certificate-manager-early-access-mvp/).
---
**Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions) or [Join our Discord](https://u.step.sm/discord).**
[Website](https://smallstep.com/certificates) | [Website](https://smallstep.com/certificates) |
[Documentation](https://smallstep.com/docs) | [Documentation](https://smallstep.com/docs) |
@ -27,7 +34,6 @@ Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [
[Contributor's Guide](./docs/CONTRIBUTING.md) [Contributor's Guide](./docs/CONTRIBUTING.md)
[![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest) [![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest)
[![CA Image](https://images.microbadger.com/badges/image/smallstep/step-ca.svg)](https://microbadger.com/images/smallstep/step-ca)
[![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates) [![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates)
[![Build Status](https://travis-ci.com/smallstep/certificates.svg?branch=master)](https://travis-ci.com/smallstep/certificates) [![Build Status](https://travis-ci.com/smallstep/certificates.svg?branch=master)](https://travis-ci.com/smallstep/certificates)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
@ -58,10 +64,11 @@ You can issue certificates in exchange for:
- ID tokens from Okta, GSuite, Azure AD, Auth0. - ID tokens from Okta, GSuite, Azure AD, Auth0.
- ID tokens from an OAuth OIDC service that you host, like [Keycloak](https://www.keycloak.org/) or [Dex](https://github.com/dexidp/dex) - ID tokens from an OAuth OIDC service that you host, like [Keycloak](https://www.keycloak.org/) or [Dex](https://github.com/dexidp/dex)
- [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure - [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure
- [Single-use, short-lived JWK tokens]() issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc. - [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
- A trusted X.509 certificate (X5C provisioner) - A trusted X.509 certificate (X5C provisioner)
- Expiring SSH host certificates needing rotation (the SSHPOP provisioner) - A SCEP challenge (SCEP provisioner)
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/configuration#jwk) - An SSH host certificates needing renewal (the SSHPOP provisioner)
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners)
### 🏔 Your own private ACME server ### 🏔 Your own private ACME server
@ -74,16 +81,17 @@ ACME is the protocol used by Let's Encrypt to automate the issuance of HTTPS cer
- For `tls-alpn-01`, respond to the challenge at the TLS layer ([as Caddy does](https://caddy.community/t/caddy-supports-the-acme-tls-alpn-challenge/4860)) to prove that you control the web server - For `tls-alpn-01`, respond to the challenge at the TLS layer ([as Caddy does](https://caddy.community/t/caddy-supports-the-acme-tls-alpn-challenge/4860)) to prove that you control the web server
- Works with any ACME client. We've written examples for: - Works with any ACME client. We've written examples for:
- [certbot](https://smallstep.com/blog/private-acme-server/#certbotuploadsacme-certbotpng-certbot-example) - [certbot](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#certbot)
- [acme.sh](https://smallstep.com/blog/private-acme-server/#acmeshuploadsacme-acme-shpng-acmesh-example) - [acme.sh](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#acmesh)
- [Caddy](https://smallstep.com/blog/private-acme-server/#caddyuploadsacme-caddypng-caddy-example) - [win-acme](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#win-acme)
- [Traefik](https://smallstep.com/blog/private-acme-server/#traefikuploadsacme-traefikpng-traefik-example) - [Caddy](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#caddy-v2)
- [Apache](https://smallstep.com/blog/private-acme-server/#apacheuploadsacme-apachepng-apache-example) - [Traefik](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#traefik)
- [nginx](https://smallstep.com/blog/private-acme-server/#nginxuploadsacme-nginxpng-nginx-example) - [Apache](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#apache)
- [nginx](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#nginx)
- Get certificates programmatically using ACME, using these libraries: - Get certificates programmatically using ACME, using these libraries:
- [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/blog/private-acme-server/#golanguploadsacme-golangpng-go-example)) - [`lego`](https://github.com/go-acme/lego) for Golang ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#golang))
- certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/blog/private-acme-server/#pythonuploadsacme-pythonpng-python-example)) - certbot's [`acme` module](https://github.com/certbot/certbot/tree/master/acme) for Python ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#python))
- [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/blog/private-acme-server/#nodejsuploadsacme-node-jspng-nodejs-example)) - [`acme-client`](https://github.com/publishlab/node-acme-client) for Node.js ([example usage](https://smallstep.com/docs/tutorials/acme-protocol-acme-clients#node))
- Our own [`step` CLI tool](https://github.com/smallstep/cli) is also an ACME client! - Our own [`step` CLI tool](https://github.com/smallstep/cli) is also an ACME client!
- See our [ACME tutorial](https://smallstep.com/docs/tutorials/acme-challenge) for more - See our [ACME tutorial](https://smallstep.com/docs/tutorials/acme-challenge) for more

View file

@ -19,7 +19,7 @@ type NewAccountRequest struct {
func validateContacts(cs []string) error { func validateContacts(cs []string) error {
for _, c := range cs { for _, c := range cs {
if len(c) == 0 { if c == "" {
return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string") return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
} }
} }

View file

@ -178,7 +178,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) u := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
oids := []string{"foo", "bar"} oids := []string{"foo", "bar"}
oidURLs := []string{ oidURLs := []string{
@ -255,7 +255,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrdersByAccountID(w, req) h.GetOrdersByAccountID(w, req)

View file

@ -64,8 +64,14 @@ type HandlerOptions struct {
// NewHandler returns a new ACME API handler. // NewHandler returns a new ACME API handler.
func NewHandler(ops HandlerOptions) api.RouterHandler { func NewHandler(ops HandlerOptions) api.RouterHandler {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
client := http.Client{ client := http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
Transport: transport,
} }
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,

View file

@ -148,7 +148,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("authzID", az.ID) chiCtx.URLParams.Add("authzID", az.ID)
url := fmt.Sprintf("%s/acme/%s/authz/%s", u := fmt.Sprintf("%s/acme/%s/authz/%s",
baseURL.String(), provName, az.ID) baseURL.String(), provName, az.ID)
type test struct { type test struct {
@ -280,7 +280,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
expB, err := json.Marshal(az) expB, err := json.Marshal(az)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })
@ -314,7 +314,7 @@ func TestHandler_GetCertificate(t *testing.T) {
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("certID", certID) chiCtx.URLParams.Add("certID", certID)
url := fmt.Sprintf("%s/acme/%s/certificate/%s", u := fmt.Sprintf("%s/acme/%s/certificate/%s",
baseURL.String(), provName, certID) baseURL.String(), provName, certID)
type test struct { type test struct {
@ -396,7 +396,7 @@ func TestHandler_GetCertificate(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetCertificate(w, req) h.GetCertificate(w, req)
@ -434,7 +434,7 @@ func TestHandler_GetChallenge(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s", u := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
baseURL.String(), provName, "authzID", "chID") baseURL.String(), provName, "authzID", "chID")
type test struct { type test struct {
@ -635,7 +635,7 @@ func TestHandler_GetChallenge(t *testing.T) {
AuthorizationID: "authzID", AuthorizationID: "authzID",
Type: acme.HTTP01, Type: acme.HTTP01,
AccountID: "accID", AccountID: "accID",
URL: url, URL: u,
Error: acme.NewError(acme.ErrorConnectionType, "force"), Error: acme.NewError(acme.ErrorConnectionType, "force"),
}, },
vco: &acme.ValidateChallengeOptions{ vco: &acme.ValidateChallengeOptions{
@ -652,7 +652,7 @@ func TestHandler_GetChallenge(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetChallenge(w, req) h.GetChallenge(w, req)
@ -678,7 +678,7 @@ func TestHandler_GetChallenge(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")}) assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")})
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })

View file

@ -223,7 +223,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return return
} }
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 { if hdr.JSONWebKey == nil && hdr.KeyID == "" {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return return
} }
@ -399,7 +399,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return return
} }

View file

@ -108,7 +108,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) {
} }
func TestHandler_addNonce(t *testing.T) { func TestHandler_addNonce(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-nonce" u := "https://ca.smallstep.com/acme/new-nonce"
type test struct { type test struct {
db acme.DB db acme.DB
err *acme.Error err *acme.Error
@ -141,7 +141,7 @@ func TestHandler_addNonce(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.addNonce(testNext)(w, req) h.addNonce(testNext)(w, req)
res := w.Result() res := w.Result()
@ -230,7 +230,7 @@ func TestHandler_verifyContentType(t *testing.T) {
prov := newProv() prov := newProv()
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
type test struct { type test struct {
h Handler h Handler
ctx context.Context ctx context.Context
@ -245,7 +245,7 @@ func TestHandler_verifyContentType(t *testing.T) {
h: Handler{ h: Handler{
linker: NewLinker("dns", "acme"), linker: NewLinker("dns", "acme"),
}, },
url: url, url: u,
ctx: context.Background(), ctx: context.Background(),
contentType: "foo", contentType: "foo",
statusCode: 500, statusCode: 500,
@ -257,7 +257,7 @@ func TestHandler_verifyContentType(t *testing.T) {
h: Handler{ h: Handler{
linker: NewLinker("dns", "acme"), linker: NewLinker("dns", "acme"),
}, },
url: url, url: u,
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
@ -319,11 +319,11 @@ func TestHandler_verifyContentType(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
_url := url _u := u
if tc.url != "" { if tc.url != "" {
_url = tc.url _u = tc.url
} }
req := httptest.NewRequest("GET", _url, nil) req := httptest.NewRequest("GET", _u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
req.Header.Add("Content-Type", tc.contentType) req.Header.Add("Content-Type", tc.contentType)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -353,7 +353,7 @@ func TestHandler_verifyContentType(t *testing.T) {
} }
func TestHandler_isPostAsGet(t *testing.T) { func TestHandler_isPostAsGet(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-account" u := "https://ca.smallstep.com/acme/new-account"
type test struct { type test struct {
ctx context.Context ctx context.Context
err *acme.Error err *acme.Error
@ -392,7 +392,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.isPostAsGet(testNext)(w, req) h.isPostAsGet(testNext)(w, req)
@ -430,7 +430,7 @@ func (errReader) Close() error {
} }
func TestHandler_parseJWS(t *testing.T) { func TestHandler_parseJWS(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-account" u := "https://ca.smallstep.com/acme/new-account"
type test struct { type test struct {
next nextHTTP next nextHTTP
body io.Reader body io.Reader
@ -483,7 +483,7 @@ func TestHandler_parseJWS(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest("GET", url, tc.body) req := httptest.NewRequest("GET", u, tc.body)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.parseJWS(tc.next)(w, req) h.parseJWS(tc.next)(w, req)
res := w.Result() res := w.Result()
@ -528,7 +528,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
url := "https://ca.smallstep.com/acme/account/1234" u := "https://ca.smallstep.com/acme/account/1234"
type test struct { type test struct {
ctx context.Context ctx context.Context
next func(http.ResponseWriter, *http.Request) next func(http.ResponseWriter, *http.Request)
@ -681,7 +681,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.verifyAndExtractJWSPayload(tc.next)(w, req) h.verifyAndExtractJWSPayload(tc.next)(w, req)
@ -713,7 +713,7 @@ func TestHandler_lookupJWK(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/account/1234", u := fmt.Sprintf("%s/acme/%s/account/1234",
baseURL, provName) baseURL, provName)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -883,7 +883,7 @@ func TestHandler_lookupJWK(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker} h := &Handler{db: tc.db, linker: tc.linker}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.lookupJWK(tc.next)(w, req) h.lookupJWK(tc.next)(w, req)
@ -934,7 +934,7 @@ func TestHandler_extractJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", u := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
provName) provName)
type test struct { type test struct {
db acme.DB db acme.DB
@ -1079,7 +1079,7 @@ func TestHandler_extractJWK(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.extractJWK(tc.next)(w, req) h.extractJWK(tc.next)(w, req)
@ -1108,7 +1108,7 @@ func TestHandler_extractJWK(t *testing.T) {
} }
func TestHandler_validateJWS(t *testing.T) { func TestHandler_validateJWS(t *testing.T) {
url := "https://ca.smallstep.com/acme/account/1234" u := "https://ca.smallstep.com/acme/account/1234"
type test struct { type test struct {
db acme.DB db acme.DB
ctx context.Context ctx context.Context
@ -1198,7 +1198,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.RS256, Algorithm: jose.RS256,
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1226,7 +1226,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.RS256, Algorithm: jose.RS256,
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1298,7 +1298,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", url), err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", u),
} }
}, },
"fail/both-jwk-kid": func(t *testing.T) test { "fail/both-jwk-kid": func(t *testing.T) test {
@ -1313,7 +1313,7 @@ func TestHandler_validateJWS(t *testing.T) {
KeyID: "bar", KeyID: "bar",
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1337,7 +1337,7 @@ func TestHandler_validateJWS(t *testing.T) {
Protected: jose.Header{ Protected: jose.Header{
Algorithm: jose.ES256, Algorithm: jose.ES256,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1362,7 +1362,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.ES256, Algorithm: jose.ES256,
KeyID: "bar", KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1392,7 +1392,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.ES256, Algorithm: jose.ES256,
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1422,7 +1422,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.RS256, Algorithm: jose.RS256,
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1446,7 +1446,7 @@ func TestHandler_validateJWS(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.validateJWS(tc.next)(w, req) h.validateJWS(tc.next)(w, req)

View file

@ -264,7 +264,7 @@ func TestHandler_GetOrder(t *testing.T) {
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID) chiCtx.URLParams.Add("ordID", o.ID)
url := fmt.Sprintf("%s/acme/%s/order/%s", u := fmt.Sprintf("%s/acme/%s/order/%s",
baseURL.String(), escProvName, o.ID) baseURL.String(), escProvName, o.ID)
type test struct { type test struct {
@ -422,7 +422,7 @@ func TestHandler_GetOrder(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrder(w, req) h.GetOrder(w, req)
@ -448,7 +448,7 @@ func TestHandler_GetOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })
@ -663,7 +663,7 @@ func TestHandler_NewOrder(t *testing.T) {
prov := newProv() prov := newProv()
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/order/ordID", u := fmt.Sprintf("%s/acme/%s/order/ordID",
baseURL.String(), escProvName) baseURL.String(), escProvName)
type test struct { type test struct {
@ -1335,7 +1335,7 @@ func TestHandler_NewOrder(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.NewOrder(w, req) h.NewOrder(w, req)
@ -1363,7 +1363,7 @@ func TestHandler_NewOrder(t *testing.T) {
tc.vr(t, ro) tc.vr(t, ro)
} }
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })
@ -1406,7 +1406,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID) chiCtx.URLParams.Add("ordID", o.ID)
url := fmt.Sprintf("%s/acme/%s/order/%s", u := fmt.Sprintf("%s/acme/%s/order/%s",
baseURL.String(), escProvName, o.ID) baseURL.String(), escProvName, o.ID)
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
@ -1625,7 +1625,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.FinalizeOrder(w, req) h.FinalizeOrder(w, req)
@ -1654,7 +1654,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
assert.FatalError(t, json.Unmarshal(body, ro)) assert.FatalError(t, json.Unmarshal(body, ro))
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })

View file

@ -10,11 +10,13 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"strings" "strings"
"time" "time"
@ -74,23 +76,23 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
} }
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(url.String()) resp, err := vo.HTTPGet(u.String())
if err != nil { if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
"error doing http GET for url %s", url)) "error doing http GET for url %s", u))
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return storeError(ctx, db, ch, false, NewError(ErrorConnectionType, return storeError(ctx, db, ch, false, NewError(ErrorConnectionType,
"error doing http GET for url %s with status code %d", url, resp.StatusCode)) "error doing http GET for url %s with status code %d", u, resp.StatusCode))
} }
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return WrapErrorISE(err, "error reading "+ return WrapErrorISE(err, "error reading "+
"response body for url %s", url) "response body for url %s", u)
} }
keyAuth := strings.TrimSpace(string(body)) keyAuth := strings.TrimSpace(string(body))
@ -114,6 +116,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
return nil return nil
} }
func tlsAlert(err error) uint8 {
var opErr *net.OpError
if errors.As(err, &opErr) {
v := reflect.ValueOf(opErr.Err)
if v.Kind() == reflect.Uint8 {
return uint8(v.Uint())
}
}
return 0
}
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
config := &tls.Config{ config := &tls.Config{
NextProtos: []string{"acme-tls/1"}, NextProtos: []string{"acme-tls/1"},
@ -129,6 +142,14 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
conn, err := vo.TLSDial("tcp", hostPort, config) conn, err := vo.TLSDial("tcp", hostPort, config)
if err != nil { if err != nil {
// With Go 1.17+ tls.Dial fails if there's no overlap between configured
// client and server protocols. When this happens the connection is
// closed with the error no_application_protocol(120) as required by
// RFC7301. See https://golang.org/doc/go1.17#ALPN
if tlsAlert(err) == 120 {
return storeError(ctx, db, ch, true, NewError(ErrorRejectedIdentifierType,
"cannot negotiate ALPN acme-tls/1 protocol for tls-alpn-01 challenge"))
}
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
"error doing TLS dial for %s", hostPort)) "error doing TLS dial for %s", hostPort))
} }

View file

@ -1276,7 +1276,7 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na
oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
} }
keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:]) keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash)
certTemplate.ExtraExtensions = []pkix.Extension{ certTemplate.ExtraExtensions = []pkix.Extension{
{ {
@ -1395,7 +1395,7 @@ func TestTLSALPN01Validate(t *testing.T) {
assert.Equals(t, updch.Type, ch.Type) assert.Equals(t, updch.Type, ch.Type)
assert.Equals(t, updch.Value, ch.Value) assert.Equals(t, updch.Value, ch.Value)
err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443: tls: DialWithDialer timed out", ch.Value) err := NewError(ErrorConnectionType, "error doing TLS dial for %v:443:", ch.Value)
assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error())
assert.Equals(t, updch.Error.Type, err.Type) assert.Equals(t, updch.Error.Type, err.Type)

View file

@ -93,8 +93,8 @@ func TestDB_getDBAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil { if dbacc, err := d.getDBAccount(context.Background(), accID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -109,15 +109,13 @@ func TestDB_getDBAccount(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, dbacc.ID, tc.dbacc.ID)
assert.Equals(t, dbacc.ID, tc.dbacc.ID) assert.Equals(t, dbacc.Status, tc.dbacc.Status)
assert.Equals(t, dbacc.Status, tc.dbacc.Status) assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt) assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt) assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact) assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
}
} }
}) })
} }
@ -174,8 +172,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil { if retAccID, err := d.getAccountIDByKeyID(context.Background(), kid); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -190,10 +188,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, retAccID, accID)
assert.Equals(t, retAccID, accID)
}
} }
}) })
} }
@ -250,8 +246,8 @@ func TestDB_GetAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if acc, err := db.GetAccount(context.Background(), accID); err != nil { if acc, err := d.GetAccount(context.Background(), accID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -266,13 +262,11 @@ func TestDB_GetAccount(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, acc.ID, tc.dbacc.ID)
assert.Equals(t, acc.ID, tc.dbacc.ID) assert.Equals(t, acc.Status, tc.dbacc.Status)
assert.Equals(t, acc.Status, tc.dbacc.Status) assert.Equals(t, acc.Contact, tc.dbacc.Contact)
assert.Equals(t, acc.Contact, tc.dbacc.Contact) assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
}
} }
}) })
} }
@ -358,8 +352,8 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil { if acc, err := d.GetAccountByKeyID(context.Background(), kid); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -374,13 +368,11 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, acc.ID, tc.dbacc.ID)
assert.Equals(t, acc.ID, tc.dbacc.ID) assert.Equals(t, acc.Status, tc.dbacc.Status)
assert.Equals(t, acc.Status, tc.dbacc.Status) assert.Equals(t, acc.Contact, tc.dbacc.Contact)
assert.Equals(t, acc.Contact, tc.dbacc.Contact) assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
}
} }
}) })
} }
@ -527,8 +519,8 @@ func TestDB_CreateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateAccount(context.Background(), tc.acc); err != nil { if err := d.CreateAccount(context.Background(), tc.acc); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -688,8 +680,8 @@ func TestDB_UpdateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.UpdateAccount(context.Background(), tc.acc); err != nil { if err := d.UpdateAccount(context.Background(), tc.acc); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -97,8 +97,8 @@ func TestDB_getDBAuthz(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil { if dbaz, err := d.getDBAuthz(context.Background(), azID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -113,18 +113,16 @@ func TestDB_getDBAuthz(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, dbaz.ID, tc.dbaz.ID)
assert.Equals(t, dbaz.ID, tc.dbaz.ID) assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID) assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier) assert.Equals(t, dbaz.Status, tc.dbaz.Status)
assert.Equals(t, dbaz.Status, tc.dbaz.Status) assert.Equals(t, dbaz.Token, tc.dbaz.Token)
assert.Equals(t, dbaz.Token, tc.dbaz.Token) assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt) assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt) assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error()) assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
}
} }
}) })
} }
@ -293,8 +291,8 @@ func TestDB_GetAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if az, err := db.GetAuthorization(context.Background(), azID); err != nil { if az, err := d.GetAuthorization(context.Background(), azID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -309,21 +307,19 @@ func TestDB_GetAuthorization(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, az.ID, tc.dbaz.ID)
assert.Equals(t, az.ID, tc.dbaz.ID) assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
assert.Equals(t, az.AccountID, tc.dbaz.AccountID) assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
assert.Equals(t, az.Identifier, tc.dbaz.Identifier) assert.Equals(t, az.Status, tc.dbaz.Status)
assert.Equals(t, az.Status, tc.dbaz.Status) assert.Equals(t, az.Token, tc.dbaz.Token)
assert.Equals(t, az.Token, tc.dbaz.Token) assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard) assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt) assert.Equals(t, az.Challenges, []*acme.Challenge{
assert.Equals(t, az.Challenges, []*acme.Challenge{ {ID: "foo"},
{ID: "foo"}, {ID: "bar"},
{ID: "bar"}, })
}) assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
}
} }
}) })
} }
@ -445,8 +441,8 @@ func TestDB_CreateAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateAuthorization(context.Background(), tc.az); err != nil { if err := d.CreateAuthorization(context.Background(), tc.az); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -594,8 +590,8 @@ func TestDB_UpdateAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil { if err := d.UpdateAuthorization(context.Background(), tc.az); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -116,8 +116,8 @@ func TestDB_CreateCertificate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateCertificate(context.Background(), tc.cert); err != nil { if err := d.CreateCertificate(context.Background(), tc.cert); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -246,8 +246,8 @@ func TestDB_GetCertificate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
cert, err := db.GetCertificate(context.Background(), certID) cert, err := d.GetCertificate(context.Background(), certID)
if err != nil { if err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
@ -263,14 +263,12 @@ func TestDB_GetCertificate(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, cert.ID, certID)
assert.Equals(t, cert.ID, certID) assert.Equals(t, cert.AccountID, "accountID")
assert.Equals(t, cert.AccountID, "accountID") assert.Equals(t, cert.OrderID, "orderID")
assert.Equals(t, cert.OrderID, "orderID") assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Leaf, leaf) assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
}
} }
}) })
} }

View file

@ -92,8 +92,8 @@ func TestDB_getDBChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil { if ch, err := d.getDBChallenge(context.Background(), chID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -108,17 +108,15 @@ func TestDB_getDBChallenge(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, ch.ID, tc.dbc.ID)
assert.Equals(t, ch.ID, tc.dbc.ID) assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
assert.Equals(t, ch.AccountID, tc.dbc.AccountID) assert.Equals(t, ch.Type, tc.dbc.Type)
assert.Equals(t, ch.Type, tc.dbc.Type) assert.Equals(t, ch.Status, tc.dbc.Status)
assert.Equals(t, ch.Status, tc.dbc.Status) assert.Equals(t, ch.Token, tc.dbc.Token)
assert.Equals(t, ch.Token, tc.dbc.Token) assert.Equals(t, ch.Value, tc.dbc.Value)
assert.Equals(t, ch.Value, tc.dbc.Value) assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
}
} }
}) })
} }
@ -206,8 +204,8 @@ func TestDB_CreateChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil { if err := d.CreateChallenge(context.Background(), tc.ch); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -286,8 +284,8 @@ func TestDB_GetChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil { if ch, err := d.GetChallenge(context.Background(), chID, azID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -302,17 +300,15 @@ func TestDB_GetChallenge(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, ch.ID, tc.dbc.ID)
assert.Equals(t, ch.ID, tc.dbc.ID) assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
assert.Equals(t, ch.AccountID, tc.dbc.AccountID) assert.Equals(t, ch.Type, tc.dbc.Type)
assert.Equals(t, ch.Type, tc.dbc.Type) assert.Equals(t, ch.Status, tc.dbc.Status)
assert.Equals(t, ch.Status, tc.dbc.Status) assert.Equals(t, ch.Token, tc.dbc.Token)
assert.Equals(t, ch.Token, tc.dbc.Token) assert.Equals(t, ch.Value, tc.dbc.Value)
assert.Equals(t, ch.Value, tc.dbc.Value) assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
}
} }
}) })
} }
@ -442,8 +438,8 @@ func TestDB_UpdateChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil { if err := d.UpdateChallenge(context.Background(), tc.ch); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -31,7 +31,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
ID: id, ID: id,
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
} }
if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil { if err := db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
return "", err return "", err
} }
return acme.Nonce(id), nil return acme.Nonce(id), nil

View file

@ -67,8 +67,8 @@ func TestDB_CreateNonce(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if n, err := db.CreateNonce(context.Background()); err != nil { if n, err := d.CreateNonce(context.Background()); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -144,8 +144,8 @@ func TestDB_DeleteNonce(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil { if err := d.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {

View file

@ -42,7 +42,7 @@ func New(db nosqlDB.DB) (*DB, error) {
// save writes the new data to the database, overwriting the old data if it // save writes the new data to the database, overwriting the old data if it
// existed. // existed.
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
var ( var (
err error err error
newB []byte newB []byte

View file

@ -126,8 +126,8 @@ func TestDB_save(t *testing.T) {
} }
for name, tc := range tests { for name, tc := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := &DB{db: tc.db} d := &DB{db: tc.db}
if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil { if err := d.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -124,10 +124,8 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st
ordersByAccountMux.Lock() ordersByAccountMux.Lock()
defer ordersByAccountMux.Unlock() defer ordersByAccountMux.Unlock()
var oldOids []string
b, err := db.db.Get(ordersByAccountIDTable, []byte(accID)) b, err := db.db.Get(ordersByAccountIDTable, []byte(accID))
var (
oldOids []string
)
if err != nil { if err != nil {
if !nosql.IsErrNotFound(err) { if !nosql.IsErrNotFound(err) {
return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID)

View file

@ -12,7 +12,7 @@ import (
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
) )
func TestDB_getDBOrder(t *testing.T) { func TestDB_getDBOrder(t *testing.T) {
@ -31,7 +31,7 @@ func TestDB_getDBOrder(t *testing.T) {
assert.Equals(t, bucket, orderTable) assert.Equals(t, bucket, orderTable)
assert.Equals(t, string(key), orderID) assert.Equals(t, string(key), orderID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
@ -100,8 +100,8 @@ func TestDB_getDBOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if dbo, err := db.getDBOrder(context.Background(), orderID); err != nil { if dbo, err := d.getDBOrder(context.Background(), orderID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -116,20 +116,18 @@ func TestDB_getDBOrder(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, dbo.ID, tc.dbo.ID)
assert.Equals(t, dbo.ID, tc.dbo.ID) assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID)
assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID) assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID)
assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID) assert.Equals(t, dbo.Status, tc.dbo.Status)
assert.Equals(t, dbo.Status, tc.dbo.Status) assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt)
assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt) assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt)
assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt) assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore)
assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore) assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter)
assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter) assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers)
assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers) assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs)
assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs) assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
}
} }
}) })
} }
@ -164,7 +162,7 @@ func TestDB_GetOrder(t *testing.T) {
assert.Equals(t, bucket, orderTable) assert.Equals(t, bucket, orderTable)
assert.Equals(t, string(key), orderID) assert.Equals(t, string(key), orderID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
@ -206,8 +204,8 @@ func TestDB_GetOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if o, err := db.GetOrder(context.Background(), orderID); err != nil { if o, err := d.GetOrder(context.Background(), orderID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -222,20 +220,18 @@ func TestDB_GetOrder(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, o.ID, tc.dbo.ID)
assert.Equals(t, o.ID, tc.dbo.ID) assert.Equals(t, o.AccountID, tc.dbo.AccountID)
assert.Equals(t, o.AccountID, tc.dbo.AccountID) assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID)
assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID) assert.Equals(t, o.CertificateID, tc.dbo.CertificateID)
assert.Equals(t, o.CertificateID, tc.dbo.CertificateID) assert.Equals(t, o.Status, tc.dbo.Status)
assert.Equals(t, o.Status, tc.dbo.Status) assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt)
assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt) assert.Equals(t, o.NotBefore, tc.dbo.NotBefore)
assert.Equals(t, o.NotBefore, tc.dbo.NotBefore) assert.Equals(t, o.NotAfter, tc.dbo.NotAfter)
assert.Equals(t, o.NotAfter, tc.dbo.NotAfter) assert.Equals(t, o.Identifiers, tc.dbo.Identifiers)
assert.Equals(t, o.Identifiers, tc.dbo.Identifiers) assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs)
assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs) assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
}
} }
}) })
} }
@ -366,8 +362,8 @@ func TestDB_UpdateOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.UpdateOrder(context.Background(), tc.o); err != nil { if err := d.UpdateOrder(context.Background(), tc.o); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -511,7 +507,7 @@ func TestDB_CreateOrder(t *testing.T) {
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, string(bucket), string(ordersByAccountIDTable)) assert.Equals(t, string(bucket), string(ordersByAccountIDTable))
assert.Equals(t, string(key), o.AccountID) assert.Equals(t, string(key), o.AccountID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
switch string(bucket) { switch string(bucket) {
@ -557,8 +553,8 @@ func TestDB_CreateOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateOrder(context.Background(), tc.o); err != nil { if err := d.CreateOrder(context.Background(), tc.o); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -680,7 +676,7 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte(accID)) assert.Equals(t, key, []byte(accID))
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, bucket, ordersByAccountIDTable)
@ -710,6 +706,34 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
err: errors.Errorf("error saving orderIDs index for account %s", accID), err: errors.Errorf("error saving orderIDs index for account %s", accID),
} }
}, },
"ok/no-old": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
switch string(bucket) {
case string(ordersByAccountIDTable):
return nil, database.ErrNotFound
default:
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
return nil, errors.New("force")
}
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
switch string(bucket) {
case string(ordersByAccountIDTable):
assert.Equals(t, key, []byte(accID))
assert.Equals(t, old, nil)
assert.Equals(t, nu, nil)
return nil, true, nil
default:
assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket)))
return nil, false, errors.New("force")
}
},
},
res: []string{},
}
},
"ok/all-old-not-pending": func(t *testing.T) test { "ok/all-old-not-pending": func(t *testing.T) test {
oldOids := []string{"foo", "bar"} oldOids := []string{"foo", "bar"}
bOldOids, err := json.Marshal(oldOids) bOldOids, err := json.Marshal(oldOids)
@ -967,15 +991,15 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
var ( var (
res []string res []string
err error err error
) )
if tc.addOids == nil { if tc.addOids == nil {
res, err = db.updateAddOrderIDs(context.Background(), accID) res, err = d.updateAddOrderIDs(context.Background(), accID)
} else { } else {
res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...) res, err = d.updateAddOrderIDs(context.Background(), accID, tc.addOids...)
} }
if err != nil { if err != nil {
@ -993,10 +1017,8 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.True(t, reflect.DeepEqual(res, tc.res))
assert.True(t, reflect.DeepEqual(res, tc.res))
}
} }
}) })
} }

View file

@ -289,6 +289,7 @@ func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.Certificate
// name or in an extensionRequest attribute [RFC2985] requesting a // name or in an extensionRequest attribute [RFC2985] requesting a
// subjectAltName extension, or both. // subjectAltName extension, or both.
if csr.Subject.CommonName != "" { if csr.Subject.CommonName != "" {
// nolint:gocritic
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
} }
canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames) canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames)

View file

@ -240,9 +240,9 @@ type caHandler struct {
} }
// New creates a new RouterHandler with the CA endpoints. // New creates a new RouterHandler with the CA endpoints.
func New(authority Authority) RouterHandler { func New(auth Authority) RouterHandler {
return &caHandler{ return &caHandler{
Authority: authority, Authority: auth,
} }
} }
@ -295,7 +295,7 @@ func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
// certificate for the given SHA256. // certificate for the given SHA256.
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
sha := chi.URLParam(r, "sha") sha := chi.URLParam(r, "sha")
sum := strings.ToLower(strings.Replace(sha, "-", "", -1)) sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
// Load root certificate with the // Load root certificate with the
cert, err := h.Authority.Root(sum) cert, err := h.Authority.Root(sum)
if err != nil { if err != nil {
@ -409,19 +409,20 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
"certificate": base64.StdEncoding.EncodeToString(cert.Raw), "certificate": base64.StdEncoding.EncodeToString(cert.Raw),
} }
for _, ext := range cert.Extensions { for _, ext := range cert.Extensions {
if ext.Id.Equal(oidStepProvisioner) { if !ext.Id.Equal(oidStepProvisioner) {
val := &stepProvisioner{} continue
rest, err := asn1.Unmarshal(ext.Value, val) }
if err != nil || len(rest) > 0 { val := &stepProvisioner{}
break rest, err := asn1.Unmarshal(ext.Value, val)
} if err != nil || len(rest) > 0 {
if len(val.CredentialID) > 0 {
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
} else {
m["provisioner"] = string(val.Name)
}
break break
} }
if len(val.CredentialID) > 0 {
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
} else {
m["provisioner"] = string(val.Name)
}
break
} }
rl.WithFields(m) rl.WithFields(m)
} }

View file

@ -186,8 +186,8 @@ func TestCertificate_MarshalJSON(t *testing.T) {
}{ }{
{"nil", fields{Certificate: nil}, []byte("null"), false}, {"nil", fields{Certificate: nil}, []byte("null"), false},
{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false}, {"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"`), false}, {"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"`), false},
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`), false}, {"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`), false},
} }
for _, tt := range tests { for _, tt := range tests {
@ -219,11 +219,11 @@ func TestCertificate_UnmarshalJSON(t *testing.T) {
{"invalid string", []byte(`"foobar"`), false, true}, {"invalid string", []byte(`"foobar"`), false, true},
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true},
{"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false, true}, {"invalid type", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), false, true},
{"empty string", []byte(`""`), false, false}, {"empty string", []byte(`""`), false, false},
{"json null", []byte(`null`), false, false}, {"json null", []byte(`null`), false, false},
{"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true, false}, {"valid root", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), true, false},
{"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), true, false}, {"valid cert", []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"`), true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -251,7 +251,7 @@ func TestCertificate_UnmarshalJSON_json(t *testing.T) {
{"empty crt (null)", `{"crt":null}`, false, false}, {"empty crt (null)", `{"crt":null}`, false, false},
{"empty crt (string)", `{"crt":""}`, false, false}, {"empty crt (string)", `{"crt":""}`, false, false},
{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true}, {"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true},
{"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, true, false}, {"valid crt", `{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"}`, true, false},
} }
type request struct { type request struct {
@ -297,7 +297,7 @@ func TestCertificateRequest_MarshalJSON(t *testing.T) {
}{ }{
{"nil", fields{CertificateRequest: nil}, []byte("null"), false}, {"nil", fields{CertificateRequest: nil}, []byte("null"), false},
{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false}, {"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `\n"`), false}, {"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `\n"`), false},
} }
for _, tt := range tests { for _, tt := range tests {
@ -329,10 +329,10 @@ func TestCertificateRequest_UnmarshalJSON(t *testing.T) {
{"invalid string", []byte(`"foobar"`), false, true}, {"invalid string", []byte(`"foobar"`), false, true},
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true},
{"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false, true}, {"invalid type", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), false, true},
{"empty string", []byte(`""`), false, false}, {"empty string", []byte(`""`), false, false},
{"json null", []byte(`null`), false, false}, {"json null", []byte(`null`), false, false},
{"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true, false}, {"valid csr", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -360,7 +360,7 @@ func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) {
{"empty csr (null)", `{"csr":null}`, false, false}, {"empty csr (null)", `{"csr":null}`, false, false},
{"empty csr (string)", `{"csr":""}`, false, false}, {"empty csr (string)", `{"csr":""}`, false, false},
{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true}, {"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true},
{"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, true, false}, {"valid csr", `{"csr":"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"}`, true, false},
} }
type request struct { type request struct {
@ -739,7 +739,7 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin
return m.ret1.(bool), m.err return m.ret1.(bool), m.err
} }
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) { func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
if m.getSSHBastion != nil { if m.getSSHBastion != nil {
return m.getSSHBastion(ctx, user, hostname) return m.getSSHBastion(ctx, user, hostname)
} }
@ -816,7 +816,7 @@ func Test_caHandler_Root(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil) req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil)
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
expected := []byte(`{"ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`) expected := []byte(`{"ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -860,8 +860,8 @@ func Test_caHandler_Sign(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
expected1 := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
expected2 := []byte(`{"crt":"` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
tests := []struct { tests := []struct {
name string name string
@ -934,7 +934,7 @@ func Test_caHandler_Renew(t *testing.T) {
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
} }
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -995,7 +995,7 @@ func Test_caHandler_Rekey(t *testing.T) {
{"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest}, {"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest},
} }
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -1210,7 +1210,7 @@ func Test_caHandler_Roots(t *testing.T) {
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
} }
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -1256,7 +1256,7 @@ func Test_caHandler_Federation(t *testing.T) {
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
} }
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -50,12 +50,10 @@ func WriteError(w http.ResponseWriter, err error) {
rl.WithFields(map[string]interface{}{ rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e), "stack-trace": fmt.Sprintf("%+v", e),
}) })
} else { } else if e, ok := cause.(errs.StackTracer); ok {
if e, ok := cause.(errs.StackTracer); ok { rl.WithFields(map[string]interface{}{
rl.WithFields(map[string]interface{}{ "stack-trace": fmt.Sprintf("%+v", e),
"stack-trace": fmt.Sprintf("%+v", e), })
})
}
} }
} }
} }

View file

@ -52,7 +52,7 @@ func (s *SSHSignRequest) Validate() error {
return errors.Errorf("unknown certType %s", s.CertType) return errors.Errorf("unknown certType %s", s.CertType)
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty publicKey") return errors.New("missing or empty publicKey")
case len(s.OTT) == 0: case s.OTT == "":
return errors.New("missing or empty ott") return errors.New("missing or empty ott")
default: default:
// Validate identity signature if provided // Validate identity signature if provided
@ -408,18 +408,18 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
var config SSHConfigResponse var cfg SSHConfigResponse
switch body.Type { switch body.Type {
case provisioner.SSHUserCert: case provisioner.SSHUserCert:
config.UserTemplates = ts cfg.UserTemplates = ts
case provisioner.SSHHostCert: case provisioner.SSHHostCert:
config.HostTemplates = ts cfg.HostTemplates = ts
default: default:
WriteError(w, errs.InternalServer("it should hot get here")) WriteError(w, errs.InternalServer("it should hot get here"))
return return
} }
JSON(w, config) JSON(w, cfg)
} }
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.

View file

@ -2,6 +2,7 @@ package api
import ( import (
"net/http" "net/http"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -18,7 +19,7 @@ type SSHRekeyRequest struct {
// Validate validates the SSHSignRekey. // Validate validates the SSHSignRekey.
func (s *SSHRekeyRequest) Validate() error { func (s *SSHRekeyRequest) Validate() error {
switch { switch {
case len(s.OTT) == 0: case s.OTT == "":
return errors.New("missing or empty ott") return errors.New("missing or empty ott")
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty public key") return errors.New("missing or empty public key")
@ -72,7 +73,11 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
return return
} }
identity, err := h.renewIdentityCertificate(r) // Match identity cert with the SSH cert
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err))
return return

View file

@ -1,7 +1,9 @@
package api package api
import ( import (
"crypto/x509"
"net/http" "net/http"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -16,7 +18,7 @@ type SSHRenewRequest struct {
// Validate validates the SSHSignRequest. // Validate validates the SSHSignRequest.
func (s *SSHRenewRequest) Validate() error { func (s *SSHRenewRequest) Validate() error {
switch { switch {
case len(s.OTT) == 0: case s.OTT == "":
return errors.New("missing or empty ott") return errors.New("missing or empty ott")
default: default:
return nil return nil
@ -62,7 +64,11 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
return return
} }
identity, err := h.renewIdentityCertificate(r) // Match identity cert with the SSH cert
notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err))
return return
@ -74,13 +80,28 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
}, http.StatusCreated) }, http.StatusCreated)
} }
// renewIdentityCertificate request the client TLS certificate if present. // renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
func (h *caHandler) renewIdentityCertificate(r *http.Request) ([]Certificate, error) { func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
return nil, nil return nil, nil
} }
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) // Clone the certificate as we can modify it.
cert, err := x509.ParseCertificate(r.TLS.PeerCertificates[0].Raw)
if err != nil {
return nil, errors.Wrap(err, "error parsing client certificate")
}
// Enforce the cert to match another certificate, for example an ssh
// certificate.
if !notBefore.IsZero() {
cert.NotBefore = notBefore
}
if !notAfter.IsZero() {
cert.NotAfter = notAfter
}
certChain, err := h.Authority.Renew(cert)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -36,7 +36,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
if !r.Passive { if !r.Passive {
return errs.NotImplemented("non-passive revocation not implemented") return errs.NotImplemented("non-passive revocation not implemented")
} }
if len(r.OTT) == 0 { if r.OTT == "" {
return errs.BadRequest("missing ott") return errs.BadRequest("missing ott")
} }
return return

View file

@ -284,7 +284,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
identityCerts := []*x509.Certificate{ identityCerts := []*x509.Certificate{
parseCertificate(certPEM), parseCertificate(certPEM),
} }
identityCertsPEM := []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`) identityCertsPEM := []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`)
tests := []struct { tests := []struct {
name string name string

View file

@ -27,7 +27,7 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization") tok := r.Header.Get("Authorization")
if len(tok) == 0 { if tok == "" {
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
"missing authorization header token")) "missing authorization header token"))
return return

View file

@ -54,7 +54,7 @@ func UnmarshalProvisionerDetails(typ linkedca.Provisioner_Type, data []byte) (*l
return &linkedca.ProvisionerDetails{Data: v.Data}, nil return &linkedca.ProvisionerDetails{Data: v.Data}, nil
} }
// DB is the DB interface expected by the step-ca ACME API. // DB is the DB interface expected by the step-ca Admin API.
type DB interface { type DB interface {
CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error
GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error)

View file

@ -12,7 +12,6 @@ import (
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
nosqldb "github.com/smallstep/nosql/database"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
) )
@ -32,7 +31,7 @@ func TestDB_getDBAdminBytes(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -67,8 +66,8 @@ func TestDB_getDBAdminBytes(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if b, err := db.getDBAdminBytes(context.Background(), adminID); err != nil { if b, err := d.getDBAdminBytes(context.Background(), adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -83,10 +82,8 @@ func TestDB_getDBAdminBytes(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, string(b), "foo")
assert.Equals(t, string(b), "foo")
}
} }
}) })
} }
@ -108,7 +105,7 @@ func TestDB_getDBAdmin(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -193,8 +190,8 @@ func TestDB_getDBAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if dba, err := db.getDBAdmin(context.Background(), adminID); err != nil { if dba, err := d.getDBAdmin(context.Background(), adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -209,16 +206,14 @@ func TestDB_getDBAdmin(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dba.ID, adminID)
assert.Equals(t, dba.ID, adminID) assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID) assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID) assert.Equals(t, dba.Subject, tc.dba.Subject)
assert.Equals(t, dba.Subject, tc.dba.Subject) assert.Equals(t, dba.Type, tc.dba.Type)
assert.Equals(t, dba.Type, tc.dba.Type) assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt) assert.Fatal(t, dba.DeletedAt.IsZero())
assert.Fatal(t, dba.DeletedAt.IsZero())
}
} }
}) })
} }
@ -283,8 +278,8 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{authorityID: admin.DefaultAuthorityID} d := DB{authorityID: admin.DefaultAuthorityID}
if dba, err := db.unmarshalDBAdmin(tc.in, adminID); err != nil { if dba, err := d.unmarshalDBAdmin(tc.in, adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -299,16 +294,14 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dba.ID, adminID)
assert.Equals(t, dba.ID, adminID) assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID) assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID) assert.Equals(t, dba.Subject, tc.dba.Subject)
assert.Equals(t, dba.Subject, tc.dba.Subject) assert.Equals(t, dba.Type, tc.dba.Type)
assert.Equals(t, dba.Type, tc.dba.Type) assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt) assert.Fatal(t, dba.DeletedAt.IsZero())
assert.Fatal(t, dba.DeletedAt.IsZero())
}
} }
}) })
} }
@ -360,8 +353,8 @@ func TestDB_unmarshalAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{authorityID: admin.DefaultAuthorityID} d := DB{authorityID: admin.DefaultAuthorityID}
if adm, err := db.unmarshalAdmin(tc.in, adminID); err != nil { if adm, err := d.unmarshalAdmin(tc.in, adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -376,16 +369,14 @@ func TestDB_unmarshalAdmin(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, adm.Id, adminID)
assert.Equals(t, adm.Id, adminID) assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID) assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID) assert.Equals(t, adm.Subject, tc.dba.Subject)
assert.Equals(t, adm.Subject, tc.dba.Subject) assert.Equals(t, adm.Type, tc.dba.Type)
assert.Equals(t, adm.Type, tc.dba.Type) assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt)) assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
}
} }
}) })
} }
@ -407,7 +398,7 @@ func TestDB_GetAdmin(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -516,8 +507,8 @@ func TestDB_GetAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if adm, err := db.GetAdmin(context.Background(), adminID); err != nil { if adm, err := d.GetAdmin(context.Background(), adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -532,16 +523,14 @@ func TestDB_GetAdmin(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, adm.Id, adminID)
assert.Equals(t, adm.Id, adminID) assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID) assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID) assert.Equals(t, adm.Subject, tc.dba.Subject)
assert.Equals(t, adm.Subject, tc.dba.Subject) assert.Equals(t, adm.Type, tc.dba.Type)
assert.Equals(t, adm.Type, tc.dba.Type) assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt)) assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
}
} }
}) })
} }
@ -562,7 +551,7 @@ func TestDB_DeleteAdmin(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -670,8 +659,8 @@ func TestDB_DeleteAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.DeleteAdmin(context.Background(), adminID); err != nil { if err := d.DeleteAdmin(context.Background(), adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -708,7 +697,7 @@ func TestDB_UpdateAdmin(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -821,8 +810,8 @@ func TestDB_UpdateAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.UpdateAdmin(context.Background(), tc.adm); err != nil { if err := d.UpdateAdmin(context.Background(), tc.adm); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -919,8 +908,8 @@ func TestDB_CreateAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.CreateAdmin(context.Background(), tc.adm); err != nil { if err := d.CreateAdmin(context.Background(), tc.adm); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -1095,8 +1084,8 @@ func TestDB_GetAdmins(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if admins, err := db.GetAdmins(context.Background()); err != nil { if admins, err := d.GetAdmins(context.Background()); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -1111,10 +1100,8 @@ func TestDB_GetAdmins(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { tc.verify(t, admins)
tc.verify(t, admins)
}
} }
}) })
} }

View file

@ -35,7 +35,7 @@ func New(db nosqlDB.DB, authorityID string) (*DB, error) {
// save writes the new data to the database, overwriting the old data if it // save writes the new data to the database, overwriting the old data if it
// existed. // existed.
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
var ( var (
err error err error
newB []byte newB []byte

View file

@ -12,7 +12,6 @@ import (
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
nosqldb "github.com/smallstep/nosql/database"
"go.step.sm/linkedca" "go.step.sm/linkedca"
) )
@ -31,7 +30,7 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -66,8 +65,8 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if b, err := db.getDBProvisionerBytes(context.Background(), provID); err != nil { if b, err := d.getDBProvisionerBytes(context.Background(), provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -82,10 +81,8 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, string(b), "foo")
assert.Equals(t, string(b), "foo")
}
} }
}) })
} }
@ -107,7 +104,7 @@ func TestDB_getDBProvisioner(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -190,8 +187,8 @@ func TestDB_getDBProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if dbp, err := db.getDBProvisioner(context.Background(), provID); err != nil { if dbp, err := d.getDBProvisioner(context.Background(), provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -206,15 +203,13 @@ func TestDB_getDBProvisioner(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dbp.ID, provID)
assert.Equals(t, dbp.ID, provID) assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID) assert.Equals(t, dbp.Type, tc.dbp.Type)
assert.Equals(t, dbp.Type, tc.dbp.Type) assert.Equals(t, dbp.Name, tc.dbp.Name)
assert.Equals(t, dbp.Name, tc.dbp.Name) assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt) assert.Fatal(t, dbp.DeletedAt.IsZero())
assert.Fatal(t, dbp.DeletedAt.IsZero())
}
} }
}) })
} }
@ -278,8 +273,8 @@ func TestDB_unmarshalDBProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{authorityID: admin.DefaultAuthorityID} d := DB{authorityID: admin.DefaultAuthorityID}
if dbp, err := db.unmarshalDBProvisioner(tc.in, provID); err != nil { if dbp, err := d.unmarshalDBProvisioner(tc.in, provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -294,19 +289,17 @@ func TestDB_unmarshalDBProvisioner(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dbp.ID, provID)
assert.Equals(t, dbp.ID, provID) assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID) assert.Equals(t, dbp.Type, tc.dbp.Type)
assert.Equals(t, dbp.Type, tc.dbp.Type) assert.Equals(t, dbp.Name, tc.dbp.Name)
assert.Equals(t, dbp.Name, tc.dbp.Name) assert.Equals(t, dbp.Details, tc.dbp.Details)
assert.Equals(t, dbp.Details, tc.dbp.Details) assert.Equals(t, dbp.Claims, tc.dbp.Claims)
assert.Equals(t, dbp.Claims, tc.dbp.Claims) assert.Equals(t, dbp.X509Template, tc.dbp.X509Template)
assert.Equals(t, dbp.X509Template, tc.dbp.X509Template) assert.Equals(t, dbp.SSHTemplate, tc.dbp.SSHTemplate)
assert.Equals(t, dbp.SSHTemplate, tc.dbp.SSHTemplate) assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt) assert.Fatal(t, dbp.DeletedAt.IsZero())
assert.Fatal(t, dbp.DeletedAt.IsZero())
}
} }
}) })
} }
@ -402,8 +395,8 @@ func TestDB_unmarshalProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{authorityID: admin.DefaultAuthorityID} d := DB{authorityID: admin.DefaultAuthorityID}
if prov, err := db.unmarshalProvisioner(tc.in, provID); err != nil { if prov, err := d.unmarshalProvisioner(tc.in, provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -418,20 +411,18 @@ func TestDB_unmarshalProvisioner(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, prov.Id, provID)
assert.Equals(t, prov.Id, provID) assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID) assert.Equals(t, prov.Type, tc.dbp.Type)
assert.Equals(t, prov.Type, tc.dbp.Type) assert.Equals(t, prov.Name, tc.dbp.Name)
assert.Equals(t, prov.Name, tc.dbp.Name) assert.Equals(t, prov.Claims, tc.dbp.Claims)
assert.Equals(t, prov.Claims, tc.dbp.Claims) assert.Equals(t, prov.X509Template, tc.dbp.X509Template)
assert.Equals(t, prov.X509Template, tc.dbp.X509Template) assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
retDetailsBytes, err := json.Marshal(prov.Details.GetData()) retDetailsBytes, err := json.Marshal(prov.Details.GetData())
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, retDetailsBytes, tc.dbp.Details) assert.Equals(t, retDetailsBytes, tc.dbp.Details)
}
} }
}) })
} }
@ -453,7 +444,7 @@ func TestDB_GetProvisioner(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -542,8 +533,8 @@ func TestDB_GetProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if prov, err := db.GetProvisioner(context.Background(), provID); err != nil { if prov, err := d.GetProvisioner(context.Background(), provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -558,20 +549,18 @@ func TestDB_GetProvisioner(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, prov.Id, provID)
assert.Equals(t, prov.Id, provID) assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID) assert.Equals(t, prov.Type, tc.dbp.Type)
assert.Equals(t, prov.Type, tc.dbp.Type) assert.Equals(t, prov.Name, tc.dbp.Name)
assert.Equals(t, prov.Name, tc.dbp.Name) assert.Equals(t, prov.Claims, tc.dbp.Claims)
assert.Equals(t, prov.Claims, tc.dbp.Claims) assert.Equals(t, prov.X509Template, tc.dbp.X509Template)
assert.Equals(t, prov.X509Template, tc.dbp.X509Template) assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
retDetailsBytes, err := json.Marshal(prov.Details.GetData()) retDetailsBytes, err := json.Marshal(prov.Details.GetData())
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, retDetailsBytes, tc.dbp.Details) assert.Equals(t, retDetailsBytes, tc.dbp.Details)
}
} }
}) })
} }
@ -592,7 +581,7 @@ func TestDB_DeleteProvisioner(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -692,8 +681,8 @@ func TestDB_DeleteProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.DeleteProvisioner(context.Background(), provID); err != nil { if err := d.DeleteProvisioner(context.Background(), provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -853,8 +842,8 @@ func TestDB_GetProvisioners(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if provs, err := db.GetProvisioners(context.Background()); err != nil { if provs, err := d.GetProvisioners(context.Background()); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -869,10 +858,8 @@ func TestDB_GetProvisioners(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { tc.verify(t, provs)
tc.verify(t, provs)
}
} }
}) })
} }
@ -963,8 +950,8 @@ func TestDB_CreateProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.CreateProvisioner(context.Background(), tc.prov); err != nil { if err := d.CreateProvisioner(context.Background(), tc.prov); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -1001,7 +988,7 @@ func TestDB_UpdateProvisioner(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -1199,8 +1186,8 @@ func TestDB_UpdateProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.UpdateProvisioner(context.Background(), tc.prov); err != nil { if err := d.UpdateProvisioner(context.Background(), tc.prov); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {

View file

@ -55,8 +55,8 @@ type subProv struct {
provisioner string provisioner string
} }
func newSubProv(subject, provisioner string) subProv { func newSubProv(subject, prov string) subProv {
return subProv{subject, provisioner} return subProv{subject, prov}
} }
// LoadBySubProv a admin by the subject and provisioner name. // LoadBySubProv a admin by the subject and provisioner name.

View file

@ -16,10 +16,10 @@ func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
} }
// LoadAdminBySubProv returns an *linkedca.Admin with the given ID. // LoadAdminBySubProv returns an *linkedca.Admin with the given ID.
func (a *Authority) LoadAdminBySubProv(subject, provisioner string) (*linkedca.Admin, bool) { func (a *Authority) LoadAdminBySubProv(subject, prov string) (*linkedca.Admin, bool) {
a.adminMutex.RLock() a.adminMutex.RLock()
defer a.adminMutex.RUnlock() defer a.adminMutex.RUnlock()
return a.admins.LoadBySubProv(subject, provisioner) return a.admins.LoadBySubProv(subject, prov)
} }
// GetAdmins returns a map listing each provisioner and the JWK Key Set // GetAdmins returns a map listing each provisioner and the JWK Key Set

View file

@ -7,41 +7,44 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"log" "log"
"strings"
"sync" "sync"
"time" "time"
"github.com/smallstep/certificates/cas"
"github.com/smallstep/certificates/scep"
"go.step.sm/linkedca"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql" adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/administrator" "github.com/smallstep/certificates/authority/administrator"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas"
casapi "github.com/smallstep/certificates/cas/apiv1" casapi "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/kms" "github.com/smallstep/certificates/kms"
kmsapi "github.com/smallstep/certificates/kms/apiv1" kmsapi "github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/kms/sshagentkms" "github.com/smallstep/certificates/kms/sshagentkms"
"github.com/smallstep/certificates/scep"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// Authority implements the Certificate Authority internal interface. // Authority implements the Certificate Authority internal interface.
type Authority struct { type Authority struct {
config *config.Config config *config.Config
keyManager kms.KeyManager keyManager kms.KeyManager
provisioners *provisioner.Collection provisioners *provisioner.Collection
admins *administrator.Collection admins *administrator.Collection
db db.AuthDB db db.AuthDB
adminDB admin.DB adminDB admin.DB
templates *templates.Templates templates *templates.Templates
linkedCAToken string
// X509 CA // X509 CA
password []byte
issuerPassword []byte
x509CAService cas.CertificateAuthorityService x509CAService cas.CertificateAuthorityService
rootX509Certs []*x509.Certificate rootX509Certs []*x509.Certificate
rootX509CertPool *x509.CertPool rootX509CertPool *x509.CertPool
@ -52,6 +55,8 @@ type Authority struct {
scepService *scep.Service scepService *scep.Service
// SSH CA // SSH CA
sshHostPassword []byte
sshUserPassword []byte
sshCAUserCertSignKey ssh.Signer sshCAUserCertSignKey ssh.Signer
sshCAHostCertSignKey ssh.Signer sshCAHostCertSignKey ssh.Signer
sshCAUserCerts []ssh.PublicKey sshCAUserCerts []ssh.PublicKey
@ -73,14 +78,14 @@ type Authority struct {
} }
// New creates and initiates a new Authority type. // New creates and initiates a new Authority type.
func New(config *config.Config, opts ...Option) (*Authority, error) { func New(cfg *config.Config, opts ...Option) (*Authority, error) {
err := config.Validate() err := cfg.Validate()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var a = &Authority{ var a = &Authority{
config: config, config: cfg,
certificates: new(sync.Map), certificates: new(sync.Map),
} }
@ -205,6 +210,26 @@ func (a *Authority) init() error {
var err error var err error
// Set password if they are not set.
var configPassword []byte
if a.config.Password != "" {
configPassword = []byte(a.config.Password)
}
if configPassword != nil && a.password == nil {
a.password = configPassword
}
if a.sshHostPassword == nil {
a.sshHostPassword = a.password
}
if a.sshUserPassword == nil {
a.sshUserPassword = a.password
}
// Automatically enable admin for all linked cas.
if a.linkedCAToken != "" {
a.config.AuthorityConfig.EnableAdmin = true
}
// Initialize step-ca Database if it's not already initialized with WithDB. // Initialize step-ca Database if it's not already initialized with WithDB.
// If a.config.DB is nil then a simple, barebones in memory DB will be used. // If a.config.DB is nil then a simple, barebones in memory DB will be used.
if a.db == nil { if a.db == nil {
@ -232,6 +257,11 @@ func (a *Authority) init() error {
options = *a.config.AuthorityConfig.Options options = *a.config.AuthorityConfig.Options
} }
// Set the issuer password if passed in the flags.
if options.CertificateIssuer != nil && a.issuerPassword != nil {
options.CertificateIssuer.Password = string(a.issuerPassword)
}
// Read intermediate and create X509 signer for default CAS. // Read intermediate and create X509 signer for default CAS.
if options.Is(casapi.SoftCAS) { if options.Is(casapi.SoftCAS) {
options.CertificateChain, err = pemutil.ReadCertificateBundle(a.config.IntermediateCert) options.CertificateChain, err = pemutil.ReadCertificateBundle(a.config.IntermediateCert)
@ -240,7 +270,7 @@ func (a *Authority) init() error {
} }
options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
SigningKey: a.config.IntermediateKey, SigningKey: a.config.IntermediateKey,
Password: []byte(a.config.Password), Password: []byte(a.password),
}) })
if err != nil { if err != nil {
return err return err
@ -309,7 +339,7 @@ func (a *Authority) init() error {
if a.config.SSH.HostKey != "" { if a.config.SSH.HostKey != "" {
signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
SigningKey: a.config.SSH.HostKey, SigningKey: a.config.SSH.HostKey,
Password: []byte(a.config.Password), Password: []byte(a.sshHostPassword),
}) })
if err != nil { if err != nil {
return err return err
@ -335,7 +365,7 @@ func (a *Authority) init() error {
if a.config.SSH.UserKey != "" { if a.config.SSH.UserKey != "" {
signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ signer, err := a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
SigningKey: a.config.SSH.UserKey, SigningKey: a.config.SSH.UserKey,
Password: []byte(a.config.Password), Password: []byte(a.sshUserPassword),
}) })
if err != nil { if err != nil {
return err return err
@ -359,33 +389,45 @@ func (a *Authority) init() error {
a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, a.sshCAUserCertSignKey.PublicKey()) a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, a.sshCAUserCertSignKey.PublicKey())
} }
// Append other public keys // Append other public keys and add them to the template variables.
for _, key := range a.config.SSH.Keys { for _, key := range a.config.SSH.Keys {
publicKey := key.PublicKey()
switch key.Type { switch key.Type {
case provisioner.SSHHostCert: case provisioner.SSHHostCert:
if key.Federated { if key.Federated {
a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, key.PublicKey()) a.sshCAHostFederatedCerts = append(a.sshCAHostFederatedCerts, publicKey)
} else { } else {
a.sshCAHostCerts = append(a.sshCAHostCerts, key.PublicKey()) a.sshCAHostCerts = append(a.sshCAHostCerts, publicKey)
} }
case provisioner.SSHUserCert: case provisioner.SSHUserCert:
if key.Federated { if key.Federated {
a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, key.PublicKey()) a.sshCAUserFederatedCerts = append(a.sshCAUserFederatedCerts, publicKey)
} else { } else {
a.sshCAUserCerts = append(a.sshCAUserCerts, key.PublicKey()) a.sshCAUserCerts = append(a.sshCAUserCerts, publicKey)
} }
default: default:
return errors.Errorf("unsupported type %s", key.Type) return errors.Errorf("unsupported type %s", key.Type)
} }
} }
}
// Configure template variables. // Configure template variables. On the template variables HostFederatedKeys
// and UserFederatedKeys we will skip the actual CA that will be available
// in HostKey and UserKey.
//
// We cannot do it in the previous blocks because this configuration can be
// injected using options.
if a.sshCAHostCertSignKey != nil {
tmplVars.SSH.HostKey = a.sshCAHostCertSignKey.PublicKey() tmplVars.SSH.HostKey = a.sshCAHostCertSignKey.PublicKey()
tmplVars.SSH.UserKey = a.sshCAUserCertSignKey.PublicKey()
// On the templates we skip the first one because there's a distinction
// between the main key and federated keys.
tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts[1:]...) tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts[1:]...)
} else {
tmplVars.SSH.HostFederatedKeys = append(tmplVars.SSH.HostFederatedKeys, a.sshCAHostFederatedCerts...)
}
if a.sshCAUserCertSignKey != nil {
tmplVars.SSH.UserKey = a.sshCAUserCertSignKey.PublicKey()
tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts[1:]...) tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts[1:]...)
} else {
tmplVars.SSH.UserFederatedKeys = append(tmplVars.SSH.UserFederatedKeys, a.sshCAUserFederatedCerts...)
} }
// Check if a KMS with decryption capability is required and available // Check if a KMS with decryption capability is required and available
@ -414,7 +456,7 @@ func (a *Authority) init() error {
} }
options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{ options.Signer, err = a.keyManager.CreateSigner(&kmsapi.CreateSignerRequest{
SigningKey: a.config.IntermediateKey, SigningKey: a.config.IntermediateKey,
Password: []byte(a.config.Password), Password: []byte(a.password),
}) })
if err != nil { if err != nil {
return err return err
@ -423,7 +465,7 @@ func (a *Authority) init() error {
if km, ok := a.keyManager.(kmsapi.Decrypter); ok { if km, ok := a.keyManager.(kmsapi.Decrypter); ok {
options.Decrypter, err = km.CreateDecrypter(&kmsapi.CreateDecrypterRequest{ options.Decrypter, err = km.CreateDecrypter(&kmsapi.CreateDecrypterRequest{
DecryptionKey: a.config.IntermediateKey, DecryptionKey: a.config.IntermediateKey,
Password: []byte(a.config.Password), Password: []byte(a.password),
}) })
if err != nil { if err != nil {
return err return err
@ -442,10 +484,24 @@ func (a *Authority) init() error {
// Initialize step-ca Admin Database if it's not already initialized using // Initialize step-ca Admin Database if it's not already initialized using
// WithAdminDB. // WithAdminDB.
if a.adminDB == nil { if a.adminDB == nil {
// Check if AuthConfig already exists if a.linkedCAToken == "" {
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID) // Check if AuthConfig already exists
if err != nil { a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
return err if err != nil {
return err
}
} else {
// Use the linkedca client as the admindb.
client, err := newLinkedCAClient(a.linkedCAToken)
if err != nil {
return err
}
// If authorityId is configured make sure it matches the one in the token
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) {
return errors.New("error initializing linkedca: token authority and configured authority do not match")
}
client.Run()
a.adminDB = client
} }
} }
@ -453,9 +509,9 @@ func (a *Authority) init() error {
if err != nil { if err != nil {
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority") return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
} }
if len(provs) == 0 { if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
// Create First Provisioner // Create First Provisioner
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, a.config.Password) prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password))
if err != nil { if err != nil {
return admin.WrapErrorISE(err, "error creating first provisioner") return admin.WrapErrorISE(err, "error creating first provisioner")
} }
@ -527,6 +583,9 @@ func (a *Authority) CloseForReload() {
if err := a.keyManager.Close(); err != nil { if err := a.keyManager.Close(); err != nil {
log.Printf("error closing the key manager: %v", err) log.Printf("error closing the key manager: %v", err)
} }
if client, ok := a.adminDB.(*linkedCaClient); ok {
client.Stop()
}
} }
// requiresDecrypter returns whether the Authority // requiresDecrypter returns whether the Authority

View file

@ -11,6 +11,7 @@ import (
"net" "net"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
@ -82,6 +83,10 @@ func testAuthority(t *testing.T, opts ...Option) *Authority {
} }
a, err := New(c, opts...) a, err := New(c, opts...)
assert.FatalError(t, err) assert.FatalError(t, err)
// Avoid errors when test tokens are created before the test authority. This
// happens in some tests where we re-create the same authority to test
// special cases without re-creating the token.
a.startTime = a.startTime.Add(-1 * time.Minute)
return a return a
} }

View file

@ -6,6 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
@ -53,7 +54,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// key in order to verify the claims and we need the issuer from the claims // key in order to verify the claims and we need the issuer from the claims
// before we can look up the provisioner. // before we can look up the provisioner.
var claims Claims var claims Claims
if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil { if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken") return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
} }
@ -76,7 +77,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// Store the token to protect against reuse unless it's skipped. // Store the token to protect against reuse unless it's skipped.
// If we cannot get a token id from the provisioner, just hash the token. // If we cannot get a token id from the provisioner, just hash the token.
if !SkipTokenReuseFromContext(ctx) { if !SkipTokenReuseFromContext(ctx) {
if err = a.UseToken(token, p); err != nil { if err := a.UseToken(token, p); err != nil {
return nil, err return nil, err
} }
} }
@ -111,7 +112,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
// to the public certificate in the `x5c` header of the token. // to the public certificate in the `x5c` header of the token.
// 2. Asserts that the claims are valid - have not been tampered with. // 2. Asserts that the claims are valid - have not been tampered with.
var claims jose.Claims var claims jose.Claims
if err = jwt.Claims(leaf.PublicKey, &claims); err != nil { if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims") return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims")
} }
@ -121,13 +122,13 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
} }
// Check that the token has not been used. // Check that the token has not been used.
if err = a.UseToken(token, prov); err != nil { if err := a.UseToken(token, prov); err != nil {
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token") return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token")
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes. // more than a few minutes.
if err = claims.ValidateWithLeeway(jose.Expected{ if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: prov.GetName(), Issuer: prov.GetName(),
Time: time.Now().UTC(), Time: time.Now().UTC(),
}, time.Minute); err != nil { }, time.Minute); err != nil {
@ -173,6 +174,9 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
} }
// UseToken stores the token to protect against reuse. // UseToken stores the token to protect against reuse.
//
// This method currently ignores any error coming from the GetTokenID, but it
// should specifically ignore the error provisioner.ErrAllowTokenReuse.
func (a *Authority) UseToken(token string, prov provisioner.Interface) error { func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
if reuseKey, err := prov.GetTokenID(token); err == nil { if reuseKey, err := prov.GetTokenID(token); err == nil {
if reuseKey == "" { if reuseKey == "" {
@ -258,7 +262,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
} }
if err = p.AuthorizeRevoke(ctx, token); err != nil { if err := p.AuthorizeRevoke(ctx, token); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
} }
return nil return nil
@ -270,10 +274,19 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
// //
// TODO(mariano): should we authorize by default? // TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenew(cert *x509.Certificate) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
var err error
var isRevoked bool
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())} var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
// Check the passive revocation table. // Check the passive revocation table.
isRevoked, err := a.db.IsRevoked(cert.SerialNumber.String()) serial := cert.SerialNumber.String()
if lca, ok := a.adminDB.(interface {
IsRevoked(string) (bool, error)
}); ok {
isRevoked, err = lca.IsRevoked(serial)
} else {
isRevoked, err = a.db.IsRevoked(serial)
}
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
} }
@ -291,6 +304,28 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
return nil return nil
} }
// authorizeSSHCertificate returns an error if the given certificate is revoked.
func (a *Authority) authorizeSSHCertificate(ctx context.Context, cert *ssh.Certificate) error {
var err error
var isRevoked bool
serial := strconv.FormatUint(cert.Serial, 10)
if lca, ok := a.adminDB.(interface {
IsSSHRevoked(string) (bool, error)
}); ok {
isRevoked, err = lca.IsSSHRevoked(serial)
} else {
isRevoked, err = a.db.IsSSHRevoked(serial)
}
if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeSSHCertificate", errs.WithKeyVal("serialNumber", serial))
}
if isRevoked {
return errs.Unauthorized("authority.authorizeSSHCertificate: certificate has been revoked", errs.WithKeyVal("serialNumber", serial))
}
return nil
}
// authorizeSSHSign loads the provisioner from the token, checks that it has not // authorizeSSHSign loads the provisioner from the token, checks that it has not
// been used again and calls the provisioner AuthorizeSSHSign method. Returns a // been used again and calls the provisioner AuthorizeSSHSign method. Returns a
// list of methods to apply to the signing flow. // list of methods to apply to the signing flow.

View file

@ -917,7 +917,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if err = cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
return cert, jwk, nil return cert, jwk, nil

View file

@ -75,6 +75,7 @@ type ASN1DN struct {
Locality string `json:"locality,omitempty"` Locality string `json:"locality,omitempty"`
Province string `json:"province,omitempty"` Province string `json:"province,omitempty"`
StreetAddress string `json:"streetAddress,omitempty"` StreetAddress string `json:"streetAddress,omitempty"`
SerialNumber string `json:"serialNumber,omitempty"`
CommonName string `json:"commonName,omitempty"` CommonName string `json:"commonName,omitempty"`
} }
@ -83,8 +84,9 @@ type ASN1DN struct {
// cas.Options. // cas.Options.
type AuthConfig struct { type AuthConfig struct {
*cas.Options *cas.Options
AuthorityID string `json:"authorityID,omitempty"` AuthorityID string `json:"authorityId,omitempty"`
Provisioners provisioner.List `json:"provisioners"` DeploymentType string `json:"deploymentType,omitempty"`
Provisioners provisioner.List `json:"provisioners,omitempty"`
Admins []*linkedca.Admin `json:"-"` Admins []*linkedca.Admin `json:"-"`
Template *ASN1DN `json:"template,omitempty"` Template *ASN1DN `json:"template,omitempty"`
Claims *provisioner.Claims `json:"claims,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"`
@ -188,9 +190,10 @@ func (c *Config) Validate() error {
switch { switch {
case c.Address == "": case c.Address == "":
return errors.New("address cannot be empty") return errors.New("address cannot be empty")
case len(c.DNSNames) == 0: case len(c.DNSNames) == 0:
return errors.New("dnsNames cannot be empty") return errors.New("dnsNames cannot be empty")
case c.AuthorityConfig == nil:
return errors.New("authority cannot be nil")
} }
// Options holds the RA/CAS configuration. // Options holds the RA/CAS configuration.
@ -222,7 +225,7 @@ func (c *Config) Validate() error {
c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion c.TLS.MaxVersion = DefaultTLSOptions.MaxVersion
} }
if c.TLS.MinVersion == 0 { if c.TLS.MinVersion == 0 {
c.TLS.MinVersion = c.TLS.MaxVersion c.TLS.MinVersion = DefaultTLSOptions.MinVersion
} }
if c.TLS.MinVersion > c.TLS.MaxVersion { if c.TLS.MinVersion > c.TLS.MaxVersion {
return errors.New("tls minVersion cannot exceed tls maxVersion") return errors.New("tls minVersion cannot exceed tls maxVersion")

View file

@ -15,8 +15,9 @@ var (
// DefaultTLSRenegotiation default TLS connection renegotiation policy. // DefaultTLSRenegotiation default TLS connection renegotiation policy.
DefaultTLSRenegotiation = false // Never regnegotiate. DefaultTLSRenegotiation = false // Never regnegotiate.
// DefaultTLSCipherSuites specifies default step ciphersuite(s). // DefaultTLSCipherSuites specifies default step ciphersuite(s).
// These are TLS 1.0 - 1.2 cipher suites.
DefaultTLSCipherSuites = CipherSuites{ DefaultTLSCipherSuites = CipherSuites{
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
} }
// ApprovedTLSCipherSuites smallstep approved ciphersuites. // ApprovedTLSCipherSuites smallstep approved ciphersuites.
@ -26,25 +27,21 @@ var (
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
} }
// DefaultTLSOptions represents the default TLS version as well as the cipher // DefaultTLSOptions represents the default TLS version as well as the cipher
// suites used in the TLS certificates. // suites used in the TLS certificates.
DefaultTLSOptions = TLSOptions{ DefaultTLSOptions = TLSOptions{
CipherSuites: CipherSuites{ CipherSuites: DefaultTLSCipherSuites,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", MinVersion: DefaultTLSMinVersion,
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", MaxVersion: DefaultTLSMaxVersion,
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", Renegotiation: DefaultTLSRenegotiation,
},
MinVersion: 1.2,
MaxVersion: 1.2,
Renegotiation: false,
} }
) )
@ -119,27 +116,38 @@ func (c CipherSuites) Value() []uint16 {
// cipherSuites has the list of supported cipher suites. // cipherSuites has the list of supported cipher suites.
var cipherSuites = map[string]uint16{ var cipherSuites = map[string]uint16{
"TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA, // TLS 1.0 - 1.2 cipher suites.
"TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA,
"TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
"TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
"TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
// TLS 1.3 cipher sutes.
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
// Legacy names.
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
} }
// TLSOptions represents the TLS options that can be specified on *tls.Config // TLSOptions represents the TLS options that can be specified on *tls.Config

View file

@ -25,7 +25,7 @@ func (s multiString) HasEmpties() bool {
return true return true
} }
for _, ss := range s { for _, ss := range s {
if len(ss) == 0 { if ss == "" {
return true return true
} }
} }

284
authority/export.go Normal file
View file

@ -0,0 +1,284 @@
package authority
import (
"encoding/json"
"io/ioutil"
"net/url"
"path/filepath"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/cli-utils/config"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/structpb"
)
// Export creates a linkedca configuration form the current ca.json and loaded
// authorities.
//
// Note that export will not export neither the pki password nor the certificate
// issuer password.
func (a *Authority) Export() (c *linkedca.Configuration, err error) {
// Recover from panics
defer func() {
if r := recover(); r != nil {
err = r.(error)
}
}()
files := make(map[string][]byte)
// The exported configuration should not include the password in it.
c = &linkedca.Configuration{
Version: "1.0",
Root: mustReadFilesOrURIs(a.config.Root, files),
FederatedRoots: mustReadFilesOrURIs(a.config.FederatedRoots, files),
Intermediate: mustReadFileOrURI(a.config.IntermediateCert, files),
IntermediateKey: mustReadFileOrURI(a.config.IntermediateKey, files),
Address: a.config.Address,
InsecureAddress: a.config.InsecureAddress,
DnsNames: a.config.DNSNames,
Db: mustMarshalToStruct(a.config.DB),
Logger: mustMarshalToStruct(a.config.Logger),
Monitoring: mustMarshalToStruct(a.config.Monitoring),
Authority: &linkedca.Authority{
Id: a.config.AuthorityConfig.AuthorityID,
EnableAdmin: a.config.AuthorityConfig.EnableAdmin,
DisableIssuedAtCheck: a.config.AuthorityConfig.DisableIssuedAtCheck,
Backdate: mustDuration(a.config.AuthorityConfig.Backdate),
DeploymentType: a.config.AuthorityConfig.DeploymentType,
},
Files: files,
}
// SSH
if v := a.config.SSH; v != nil {
c.Ssh = &linkedca.SSH{
HostKey: mustReadFileOrURI(v.HostKey, files),
UserKey: mustReadFileOrURI(v.UserKey, files),
AddUserPrincipal: v.AddUserPrincipal,
AddUserCommand: v.AddUserCommand,
}
for _, k := range v.Keys {
typ, ok := linkedca.SSHPublicKey_Type_value[strings.ToUpper(k.Type)]
if !ok {
return nil, errors.Errorf("unsupported ssh key type %s", k.Type)
}
c.Ssh.Keys = append(c.Ssh.Keys, &linkedca.SSHPublicKey{
Type: linkedca.SSHPublicKey_Type(typ),
Federated: k.Federated,
Key: mustMarshalToStruct(k),
})
}
if b := v.Bastion; b != nil {
c.Ssh.Bastion = &linkedca.Bastion{
Hostname: b.Hostname,
User: b.User,
Port: b.Port,
Command: b.Command,
Flags: b.Flags,
}
}
}
// KMS
if v := a.config.KMS; v != nil {
var typ int32
var ok bool
if v.Type == "" {
typ = int32(linkedca.KMS_SOFTKMS)
} else {
typ, ok = linkedca.KMS_Type_value[strings.ToUpper(v.Type)]
if !ok {
return nil, errors.Errorf("unsupported kms type %s", v.Type)
}
}
c.Kms = &linkedca.KMS{
Type: linkedca.KMS_Type(typ),
CredentialsFile: v.CredentialsFile,
Uri: v.URI,
Pin: v.Pin,
ManagementKey: v.ManagementKey,
Region: v.Region,
Profile: v.Profile,
}
}
// Authority
// cas options
if v := a.config.AuthorityConfig.Options; v != nil {
c.Authority.Type = 0
c.Authority.CertificateAuthority = v.CertificateAuthority
c.Authority.CertificateAuthorityFingerprint = v.CertificateAuthorityFingerprint
c.Authority.CredentialsFile = v.CredentialsFile
if iss := v.CertificateIssuer; iss != nil {
typ, ok := linkedca.CertificateIssuer_Type_value[strings.ToUpper(iss.Type)]
if !ok {
return nil, errors.Errorf("unknown certificate issuer type %s", iss.Type)
}
// The exported certificate issuer should not include the password.
c.Authority.CertificateIssuer = &linkedca.CertificateIssuer{
Type: linkedca.CertificateIssuer_Type(typ),
Provisioner: iss.Provisioner,
Certificate: mustReadFileOrURI(iss.Certificate, files),
Key: mustReadFileOrURI(iss.Key, files),
}
}
}
// admins
for {
list, cursor := a.admins.Find("", 100)
c.Authority.Admins = append(c.Authority.Admins, list...)
if cursor == "" {
break
}
}
// provisioners
for {
list, cursor := a.provisioners.Find("", 100)
for _, p := range list {
lp, err := ProvisionerToLinkedca(p)
if err != nil {
return nil, err
}
c.Authority.Provisioners = append(c.Authority.Provisioners, lp)
}
if cursor == "" {
break
}
}
// global claims
c.Authority.Claims = claimsToLinkedca(a.config.AuthorityConfig.Claims)
// Distinguished names template
if v := a.config.AuthorityConfig.Template; v != nil {
c.Authority.Template = &linkedca.DistinguishedName{
Country: v.Country,
Organization: v.Organization,
OrganizationalUnit: v.OrganizationalUnit,
Locality: v.Locality,
Province: v.Province,
StreetAddress: v.StreetAddress,
SerialNumber: v.SerialNumber,
CommonName: v.CommonName,
}
}
// TLS
if v := a.config.TLS; v != nil {
c.Tls = &linkedca.TLS{
MinVersion: v.MinVersion.String(),
MaxVersion: v.MaxVersion.String(),
Renegotiation: v.Renegotiation,
}
for _, cs := range v.CipherSuites.Value() {
c.Tls.CipherSuites = append(c.Tls.CipherSuites, linkedca.TLS_CiperSuite(cs))
}
}
// Templates
if v := a.config.Templates; v != nil {
c.Templates = &linkedca.ConfigTemplates{
Ssh: &linkedca.SSHConfigTemplate{},
Data: mustMarshalToStruct(v.Data),
}
// Remove automatically loaded vars
if c.Templates.Data != nil && c.Templates.Data.Fields != nil {
delete(c.Templates.Data.Fields, "Step")
}
for _, t := range v.SSH.Host {
typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))]
if !ok {
return nil, errors.Errorf("unsupported template type %s", t.Type)
}
c.Templates.Ssh.Hosts = append(c.Templates.Ssh.Hosts, &linkedca.ConfigTemplate{
Type: linkedca.ConfigTemplate_Type(typ),
Name: t.Name,
Template: mustReadFileOrURI(t.TemplatePath, files),
Path: t.Path,
Comment: t.Comment,
Requires: t.RequiredData,
Content: t.Content,
})
}
for _, t := range v.SSH.User {
typ, ok := linkedca.ConfigTemplate_Type_value[strings.ToUpper(string(t.Type))]
if !ok {
return nil, errors.Errorf("unsupported template type %s", t.Type)
}
c.Templates.Ssh.Users = append(c.Templates.Ssh.Users, &linkedca.ConfigTemplate{
Type: linkedca.ConfigTemplate_Type(typ),
Name: t.Name,
Template: mustReadFileOrURI(t.TemplatePath, files),
Path: t.Path,
Comment: t.Comment,
Requires: t.RequiredData,
Content: t.Content,
})
}
}
return c, nil
}
func mustDuration(d *provisioner.Duration) string {
if d == nil || d.Duration == 0 {
return ""
}
return d.String()
}
func mustMarshalToStruct(v interface{}) *structpb.Struct {
b, err := json.Marshal(v)
if err != nil {
panic(errors.Wrapf(err, "error marshaling %T", v))
}
var r *structpb.Struct
if err := json.Unmarshal(b, &r); err != nil {
panic(errors.Wrapf(err, "error unmarshaling %T", v))
}
return r
}
func mustReadFileOrURI(fn string, m map[string][]byte) string {
if fn == "" {
return ""
}
stepPath := filepath.ToSlash(config.StepPath())
if !strings.HasSuffix(stepPath, "/") {
stepPath += "/"
}
fn = strings.TrimPrefix(filepath.ToSlash(fn), stepPath)
ok, err := isFilename(fn)
if err != nil {
panic(err)
}
if ok {
b, err := ioutil.ReadFile(config.StepAbs(fn))
if err != nil {
panic(errors.Wrapf(err, "error reading %s", fn))
}
m[fn] = b
return fn
}
return fn
}
func mustReadFilesOrURIs(fns []string, m map[string][]byte) []string {
var result []string
for _, fn := range fns {
result = append(result, mustReadFileOrURI(fn, m))
}
return result
}
func isFilename(fn string) (bool, error) {
u, err := url.Parse(fn)
if err != nil {
return false, errors.Wrapf(err, "error parsing %s", fn)
}
return u.Scheme == "" || u.Scheme == "file", nil
}

490
authority/linkedca.go Normal file
View file

@ -0,0 +1,490 @@
package authority
import (
"context"
"crypto"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"fmt"
"net/url"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/tlsutil"
"go.step.sm/crypto/x509util"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
type linkedCaClient struct {
renewer *tlsutil.Renewer
client linkedca.MajordomoClient
authorityID string
}
type linkedCAClaims struct {
jose.Claims
SANs []string `json:"sans"`
SHA string `json:"sha"`
}
func newLinkedCAClient(token string) (*linkedCaClient, error) {
tok, err := jose.ParseSigned(token)
if err != nil {
return nil, errors.Wrap(err, "error parsing token")
}
var claims linkedCAClaims
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errors.Wrap(err, "error parsing token")
}
// Validate claims
if len(claims.Audience) != 1 {
return nil, errors.New("error parsing token: invalid aud claim")
}
if claims.SHA == "" {
return nil, errors.New("error parsing token: invalid sha claim")
}
// Get linkedCA endpoint from audience.
u, err := url.Parse(claims.Audience[0])
if err != nil {
return nil, errors.New("error parsing token: invalid aud claim")
}
// Get authority from SANs
authority, err := getAuthority(claims.SANs)
if err != nil {
return nil, err
}
// Create csr to login with
signer, err := keyutil.GenerateDefaultSigner()
if err != nil {
return nil, err
}
csr, err := x509util.CreateCertificateRequest(claims.Subject, claims.SANs, signer)
if err != nil {
return nil, err
}
// Get and verify root certificate
root, err := getRootCertificate(u.Host, claims.SHA)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
pool.AddCert(root)
// Login with majordomo and get certificates
cert, tlsConfig, err := login(authority, token, csr, signer, u.Host, pool)
if err != nil {
return nil, err
}
// Start TLS renewer and set the GetClientCertificate callback to it.
renewer, err := tlsutil.NewRenewer(cert, tlsConfig, func() (*tls.Certificate, *tls.Config, error) {
return login(authority, token, csr, signer, u.Host, pool)
})
if err != nil {
return nil, err
}
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
// Start mTLS client
conn, err := grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
if err != nil {
return nil, errors.Wrapf(err, "error connecting %s", u.Host)
}
return &linkedCaClient{
renewer: renewer,
client: linkedca.NewMajordomoClient(conn),
authorityID: authority,
}, nil
}
func (c *linkedCaClient) Run() {
c.renewer.Run()
}
func (c *linkedCaClient) Stop() {
c.renewer.Stop()
}
func (c *linkedCaClient) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
resp, err := c.client.CreateProvisioner(ctx, &linkedca.CreateProvisionerRequest{
Type: prov.Type,
Name: prov.Name,
Details: prov.Details,
Claims: prov.Claims,
X509Template: prov.X509Template,
SshTemplate: prov.SshTemplate,
})
if err != nil {
return errors.Wrap(err, "error creating provisioner")
}
prov.Id = resp.Id
prov.AuthorityId = resp.AuthorityId
return nil
}
func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linkedca.Provisioner, error) {
resp, err := c.client.GetProvisioner(ctx, &linkedca.GetProvisionerRequest{
Id: id,
})
if err != nil {
return nil, errors.Wrap(err, "error getting provisioners")
}
return resp, nil
}
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
AuthorityId: c.authorityID,
})
if err != nil {
return nil, errors.Wrap(err, "error getting provisioners")
}
return resp.Provisioners, nil
}
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
_, err := c.client.UpdateProvisioner(ctx, &linkedca.UpdateProvisionerRequest{
Id: prov.Id,
Name: prov.Name,
Details: prov.Details,
Claims: prov.Claims,
X509Template: prov.X509Template,
SshTemplate: prov.SshTemplate,
})
return errors.Wrap(err, "error updating provisioner")
}
func (c *linkedCaClient) DeleteProvisioner(ctx context.Context, id string) error {
_, err := c.client.DeleteProvisioner(ctx, &linkedca.DeleteProvisionerRequest{
Id: id,
})
return errors.Wrap(err, "error deleting provisioner")
}
func (c *linkedCaClient) CreateAdmin(ctx context.Context, adm *linkedca.Admin) error {
resp, err := c.client.CreateAdmin(ctx, &linkedca.CreateAdminRequest{
Subject: adm.Subject,
ProvisionerId: adm.ProvisionerId,
Type: adm.Type,
})
if err != nil {
return errors.Wrap(err, "error creating admin")
}
adm.Id = resp.Id
adm.AuthorityId = resp.AuthorityId
return nil
}
func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Admin, error) {
resp, err := c.client.GetAdmin(ctx, &linkedca.GetAdminRequest{
Id: id,
})
if err != nil {
return nil, errors.Wrap(err, "error getting admins")
}
return resp, nil
}
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
AuthorityId: c.authorityID,
})
if err != nil {
return nil, errors.Wrap(err, "error getting admins")
}
return resp.Admins, nil
}
func (c *linkedCaClient) UpdateAdmin(ctx context.Context, adm *linkedca.Admin) error {
_, err := c.client.UpdateAdmin(ctx, &linkedca.UpdateAdminRequest{
Id: adm.Id,
Type: adm.Type,
})
return errors.Wrap(err, "error updating admin")
}
func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
_, err := c.client.DeleteAdmin(ctx, &linkedca.DeleteAdminRequest{
Id: id,
})
return errors.Wrap(err, "error deleting admin")
}
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
PemCertificate: serializeCertificateChain(fullchain[0]),
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
})
return errors.Wrap(err, "error posting certificate")
}
func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullchain ...*x509.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
PemCertificate: serializeCertificateChain(fullchain[0]),
PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
PemParentCertificate: serializeCertificateChain(parent),
})
return errors.Wrap(err, "error posting certificate")
}
func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
Certificate: string(ssh.MarshalAuthorizedKey(crt)),
})
return errors.Wrap(err, "error posting ssh certificate")
}
func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.RevokeCertificate(ctx, &linkedca.RevokeCertificateRequest{
Serial: rci.Serial,
PemCertificate: serializeCertificate(crt),
Reason: rci.Reason,
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
Passive: true,
})
return errors.Wrap(err, "error revoking certificate")
}
func (c *linkedCaClient) RevokeSSH(cert *ssh.Certificate, rci *db.RevokedCertificateInfo) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{
Serial: rci.Serial,
Certificate: serializeSSHCertificate(cert),
Reason: rci.Reason,
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
Passive: true,
})
return errors.Wrap(err, "error revoking ssh certificate")
}
func (c *linkedCaClient) IsRevoked(serial string) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
resp, err := c.client.GetCertificateStatus(ctx, &linkedca.GetCertificateStatusRequest{
Serial: serial,
})
if err != nil {
return false, errors.Wrap(err, "error getting certificate status")
}
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
}
func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
resp, err := c.client.GetSSHCertificateStatus(ctx, &linkedca.GetSSHCertificateStatusRequest{
Serial: serial,
})
if err != nil {
return false, errors.Wrap(err, "error getting certificate status")
}
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
}
func serializeCertificate(crt *x509.Certificate) string {
if crt == nil {
return ""
}
return string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: crt.Raw,
}))
}
func serializeCertificateChain(fullchain ...*x509.Certificate) string {
var chain string
for _, crt := range fullchain {
chain += string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: crt.Raw,
}))
}
return chain
}
func serializeSSHCertificate(crt *ssh.Certificate) string {
if crt == nil {
return ""
}
return string(ssh.MarshalAuthorizedKey(crt))
}
func getAuthority(sans []string) (string, error) {
for _, s := range sans {
if strings.HasPrefix(s, "urn:smallstep:authority:") {
if regexp.MustCompile(uuidPattern).MatchString(s[24:]) {
return s[24:], nil
}
}
}
return "", fmt.Errorf("error parsing token: invalid sans claim")
}
// getRootCertificate creates an insecure majordomo client and returns the
// verified root certificate.
func getRootCertificate(endpoint, fingerprint string) (*x509.Certificate, error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
})))
if err != nil {
return nil, errors.Wrapf(err, "error connecting %s", endpoint)
}
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
client := linkedca.NewMajordomoClient(conn)
resp, err := client.GetRootCertificate(ctx, &linkedca.GetRootCertificateRequest{
Fingerprint: fingerprint,
})
if err != nil {
return nil, fmt.Errorf("error getting root certificate: %w", err)
}
var block *pem.Block
b := []byte(resp.PemCertificate)
for len(b) > 0 {
block, b = pem.Decode(b)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("error parsing certificate: %w", err)
}
// verify the sha256
sum := sha256.Sum256(cert.Raw)
if !strings.EqualFold(fingerprint, hex.EncodeToString(sum[:])) {
return nil, fmt.Errorf("error verifying certificate: SHA256 fingerprint does not match")
}
return cert, nil
}
return nil, fmt.Errorf("error getting root certificate: certificate not found")
}
// login creates a new majordomo client with just the root ca pool and returns
// the signed certificate and tls configuration.
func login(authority, token string, csr *x509.CertificateRequest, signer crypto.PrivateKey, endpoint string, rootCAs *x509.CertPool) (*tls.Certificate, *tls.Config, error) {
// Connect to majordomo
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
conn, err := grpc.DialContext(ctx, endpoint, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: rootCAs,
})))
if err != nil {
return nil, nil, errors.Wrapf(err, "error connecting %s", endpoint)
}
// Login to get the signed certificate
ctx, cancel = context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
client := linkedca.NewMajordomoClient(conn)
resp, err := client.Login(ctx, &linkedca.LoginRequest{
AuthorityId: authority,
Token: token,
PemCertificateRequest: string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csr.Raw,
})),
})
if err != nil {
return nil, nil, errors.Wrapf(err, "error logging in %s", endpoint)
}
// Parse login response
var block *pem.Block
var bundle []*x509.Certificate
rest := []byte(resp.PemCertificateChain)
for {
block, rest = pem.Decode(rest)
if block == nil {
break
}
if block.Type != "CERTIFICATE" {
return nil, nil, errors.New("error decoding login response: pemCertificateChain is not a certificate bundle")
}
crt, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing login response")
}
bundle = append(bundle, crt)
}
if len(bundle) == 0 {
return nil, nil, errors.New("error decoding login response: pemCertificateChain should not be empty")
}
// Build tls.Certificate with PemCertificate and intermediates in the
// PemCertificateChain
cert := &tls.Certificate{
PrivateKey: signer,
}
rest = []byte(resp.PemCertificate)
for {
block, rest = pem.Decode(rest)
if block == nil {
break
}
if block.Type == "CERTIFICATE" {
leaf, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, nil, errors.Wrap(err, "error parsing pemCertificate")
}
cert.Certificate = append(cert.Certificate, block.Bytes)
cert.Leaf = leaf
}
}
// Add intermediates to the tls.Certificate
last := len(bundle) - 1
for i := 0; i < last; i++ {
cert.Certificate = append(cert.Certificate, bundle[i].Raw)
}
// Add root to the pool if it's not there yet
rootCAs.AddCert(bundle[last])
return cert, &tls.Config{
RootCAs: rootCAs,
}, nil
}

View file

@ -22,9 +22,9 @@ type Option func(*Authority) error
// WithConfig replaces the current config with the given one. No validation is // WithConfig replaces the current config with the given one. No validation is
// performed in the given value. // performed in the given value.
func WithConfig(config *config.Config) Option { func WithConfig(cfg *config.Config) Option {
return func(a *Authority) error { return func(a *Authority) error {
a.config = config a.config = cfg
return nil return nil
} }
} }
@ -38,11 +38,47 @@ func WithConfigFile(filename string) Option {
} }
} }
// WithPassword set the password to decrypt the intermediate key as well as the
// ssh host and user keys if they are not overridden by other options.
func WithPassword(password []byte) Option {
return func(a *Authority) (err error) {
a.password = password
return
}
}
// WithSSHHostPassword set the password to decrypt the key used to sign SSH host
// certificates.
func WithSSHHostPassword(password []byte) Option {
return func(a *Authority) (err error) {
a.sshHostPassword = password
return
}
}
// WithSSHUserPassword set the password to decrypt the key used to sign SSH user
// certificates.
func WithSSHUserPassword(password []byte) Option {
return func(a *Authority) (err error) {
a.sshUserPassword = password
return
}
}
// WithIssuerPassword set the password to decrypt the certificate issuer private
// key used in RA mode.
func WithIssuerPassword(password []byte) Option {
return func(a *Authority) (err error) {
a.issuerPassword = password
return
}
}
// WithDatabase sets an already initialized authority database to a new // WithDatabase sets an already initialized authority database to a new
// authority. This option is intended to be use on graceful reloads. // authority. This option is intended to be use on graceful reloads.
func WithDatabase(db db.AuthDB) Option { func WithDatabase(d db.AuthDB) Option {
return func(a *Authority) error { return func(a *Authority) error {
a.db = db a.db = d
return nil return nil
} }
} }
@ -189,9 +225,18 @@ func WithX509FederatedBundle(pemCerts []byte) Option {
} }
// WithAdminDB is an option to set the database backing the admin APIs. // WithAdminDB is an option to set the database backing the admin APIs.
func WithAdminDB(db admin.DB) Option { func WithAdminDB(d admin.DB) Option {
return func(a *Authority) error { return func(a *Authority) error {
a.adminDB = db a.adminDB = d
return nil
}
}
// WithLinkedCAToken is an option to set the authentication token used to enable
// linked ca.
func WithLinkedCAToken(token string) Option {
return func(a *Authority) error {
a.linkedCAToken = token
return nil return nil
} }
} }

View file

@ -312,7 +312,7 @@ func (p *AWS) GetType() Type {
} }
// GetEncryptedKey is not available in an AWS provisioner. // GetEncryptedKey is not available in an AWS provisioner.
func (p *AWS) GetEncryptedKey() (kid string, key string, ok bool) { func (p *AWS) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }
@ -449,13 +449,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// There's no way to trust them other than TOFU. // There's no way to trust them other than TOFU.
var so []SignOption var so []SignOption
if p.DisableCustomSANs { if p.DisableCustomSANs {
dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region)
so = append(so, dnsNamesValidator([]string{dnsName})) so = append(so,
so = append(so, ipAddressesValidator([]net.IP{ dnsNamesValidator([]string{dnsName}),
net.ParseIP(doc.PrivateIP), ipAddressesValidator([]net.IP{
})) net.ParseIP(doc.PrivateIP),
so = append(so, emailAddressesValidator(nil)) }),
so = append(so, urisValidator(nil)) emailAddressesValidator(nil),
urisValidator(nil),
)
// Template options // Template options
data.SetSANs([]string{dnsName, doc.PrivateIP}) data.SetSANs([]string{dnsName, doc.PrivateIP})
@ -515,6 +517,11 @@ func (p *AWS) readURL(url string) ([]byte, error) {
var resp *http.Response var resp *http.Response
var err error var err error
// Initialize IMDS versions when this is called from the cli.
if len(p.IMDSVersions) == 0 {
p.IMDSVersions = []string{"v2", "v1"}
}
for _, v := range p.IMDSVersions { for _, v := range p.IMDSVersions {
switch v { switch v {
case "v1": case "v1":
@ -664,7 +671,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
if p.DisableCustomSANs { if p.DisableCustomSANs {
if payload.Subject != doc.InstanceID && if payload.Subject != doc.InstanceID &&
payload.Subject != doc.PrivateIP && payload.Subject != doc.PrivateIP &&
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) { payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region) {
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)") return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
} }
} }
@ -715,7 +722,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validated principals. // Validated principals.
principals := []string{ principals := []string{
doc.PrivateIP, doc.PrivateIP,
fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region), fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region),
} }
// Only enforce known principals if disable custom sans is true. // Only enforce known principals if disable custom sans is true.

View file

@ -141,6 +141,12 @@ func TestAWS_GetIdentityToken(t *testing.T) {
p7.config.signatureURL = p1.config.signatureURL p7.config.signatureURL = p1.config.signatureURL
p7.config.tokenURL = p1.config.tokenURL p7.config.tokenURL = p1.config.tokenURL
p8, err := generateAWS()
assert.FatalError(t, err)
p8.IMDSVersions = nil
p8.Accounts = p1.Accounts
p8.config = p1.config
caURL := "https://ca.smallstep.com" caURL := "https://ca.smallstep.com"
u, err := url.Parse(caURL) u, err := url.Parse(caURL)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -156,6 +162,7 @@ func TestAWS_GetIdentityToken(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{"foo.local", caURL}, false}, {"ok", p1, args{"foo.local", caURL}, false},
{"ok no imds", p8, args{"foo.local", caURL}, false},
{"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true}, {"fail ca url", p1, args{"foo.local", "://ca.smallstep.com"}, true},
{"fail identityURL", p2, args{"foo.local", caURL}, true}, {"fail identityURL", p2, args{"foo.local", caURL}, true},
{"fail signatureURL", p3, args{"foo.local", caURL}, true}, {"fail signatureURL", p3, args{"foo.local", caURL}, true},
@ -656,15 +663,15 @@ func TestAWS_AuthorizeSign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token) switch got, err := tt.aws.AuthorizeSign(ctx, tt.args.token); {
if (err != nil) != tt.wantErr { case (err != nil) != tt.wantErr:
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} else { default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {

View file

@ -131,9 +131,10 @@ func (p *Azure) GetTokenID(token string) (string, error) {
return "", errors.Wrap(err, "error verifying claims") return "", errors.Wrap(err, "error verifying claims")
} }
// If TOFU is disabled create return the token kid // If TOFU is disabled then allow token re-use. Azure caches the token for
// 24h and without allowing the re-use we cannot use it twice.
if p.DisableTrustOnFirstUse { if p.DisableTrustOnFirstUse {
return claims.ID, nil return "", ErrAllowTokenReuse
} }
sum := sha256.Sum256([]byte(claims.XMSMirID)) sum := sha256.Sum256([]byte(claims.XMSMirID))
@ -151,7 +152,7 @@ func (p *Azure) GetType() Type {
} }
// GetEncryptedKey is not available in an Azure provisioner. // GetEncryptedKey is not available in an Azure provisioner.
func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) { func (p *Azure) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }
@ -302,11 +303,13 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
var so []SignOption var so []SignOption
if p.DisableCustomSANs { if p.DisableCustomSANs {
// name will work only inside the virtual network // name will work only inside the virtual network
so = append(so, commonNameValidator(name)) so = append(so,
so = append(so, dnsNamesValidator([]string{name})) commonNameValidator(name),
so = append(so, ipAddressesValidator(nil)) dnsNamesValidator([]string{name}),
so = append(so, emailAddressesValidator(nil)) ipAddressesValidator(nil),
so = append(so, urisValidator(nil)) emailAddressesValidator(nil),
urisValidator(nil),
)
// Enforce SANs in the template. // Enforce SANs in the template.
data.SetSANs([]string{name}) data.SetSANs([]string{name})

View file

@ -72,7 +72,7 @@ func TestAzure_GetTokenID(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1}, w1, false}, {"ok", p1, args{t1}, w1, false},
{"ok no TOFU", p2, args{t2}, "the-jti", false}, {"ok no TOFU", p2, args{t2}, "", true},
{"fail token", p1, args{"bad-token"}, "", true}, {"fail token", p1, args{"bad-token"}, "", true},
{"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true}, {"fail claims", p1, args{"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.ey.fooo"}, "", true},
} }
@ -446,15 +446,15 @@ func TestAzure_AuthorizeSign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token) switch got, err := tt.azure.AuthorizeSign(ctx, tt.args.token); {
if (err != nil) != tt.wantErr { case (err != nil) != tt.wantErr:
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} else { default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {

View file

@ -37,8 +37,9 @@ func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// provisioner. // provisioner.
type loadByTokenPayload struct { type loadByTokenPayload struct {
jose.Claims jose.Claims
AuthorizedParty string `json:"azp"` // OIDC client id Email string `json:"email"` // OIDC email
TenantID string `json:"tid"` // Microsoft Azure tenant id AuthorizedParty string `json:"azp"` // OIDC client id
TenantID string `json:"tid"` // Microsoft Azure tenant id
} }
// Collection is a memory map of provisioners. // Collection is a memory map of provisioners.
@ -129,12 +130,20 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
return p, ok return p, ok
} }
} }
// Try with tid (Azure) // Try with tid (Azure, Azure OIDC)
if payload.TenantID != "" { if payload.TenantID != "" {
// Try to load an OIDC provisioner first.
if payload.Email != "" {
if p, ok := c.LoadByTokenID(payload.Audience[0]); ok {
return p, ok
}
}
// Try to load an Azure provisioner.
if p, ok := c.LoadByTokenID(payload.TenantID); ok { if p, ok := c.LoadByTokenID(payload.TenantID); ok {
return p, ok return p, ok
} }
} }
// Fallback to aud // Fallback to aud
return c.LoadByTokenID(payload.Audience[0]) return c.LoadByTokenID(payload.Audience[0])
} }
@ -220,14 +229,15 @@ func (c *Collection) Remove(id string) error {
var found bool var found bool
for i, elem := range c.sorted { for i, elem := range c.sorted {
if elem.provisioner.GetID() == id { if elem.provisioner.GetID() != id {
// Remove index in sorted list continue
copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index.
c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value).
c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice.
found = true
break
} }
// Remove index in sorted list
copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index.
c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value).
c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice.
found = true
break
} }
if !found { if !found {
return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName()) return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName())

View file

@ -150,7 +150,7 @@ func (p *GCP) GetType() Type {
} }
// GetEncryptedKey is not available in a GCP provisioner. // GetEncryptedKey is not available in a GCP provisioner.
func (p *GCP) GetEncryptedKey() (kid string, key string, ok bool) { func (p *GCP) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }
@ -244,15 +244,17 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
if p.DisableCustomSANs { if p.DisableCustomSANs {
dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID) dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID)
dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID) dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID)
so = append(so, commonNameSliceValidator([]string{ so = append(so,
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2, commonNameSliceValidator([]string{
})) ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
so = append(so, dnsNamesValidator([]string{ }),
dnsName1, dnsName2, dnsNamesValidator([]string{
})) dnsName1, dnsName2,
so = append(so, ipAddressesValidator(nil)) }),
so = append(so, emailAddressesValidator(nil)) ipAddressesValidator(nil),
so = append(so, urisValidator(nil)) emailAddressesValidator(nil),
urisValidator(nil),
)
// Template SANs // Template SANs
data.SetSANs([]string{dnsName1, dnsName2}) data.SetSANs([]string{dnsName1, dnsName2})

View file

@ -535,15 +535,15 @@ func TestGCP_AuthorizeSign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token) switch got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token); {
if (err != nil) != tt.wantErr { case (err != nil) != tt.wantErr:
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} else { default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {

View file

@ -18,7 +18,7 @@ const (
defaultCacheJitter = 1 * time.Hour defaultCacheJitter = 1 * time.Hour
) )
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)") var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)
type keyStore struct { type keyStore struct {
sync.RWMutex sync.RWMutex

View file

@ -29,7 +29,7 @@ func (p *noop) GetType() Type {
return noopType return noopType
} }
func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) { func (p *noop) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }

View file

@ -49,6 +49,29 @@ type openIDPayload struct {
Groups []string `json:"groups"` Groups []string `json:"groups"`
} }
func (o *openIDPayload) IsAdmin(admins []string) bool {
if o.Email != "" {
email := sanitizeEmail(o.Email)
for _, e := range admins {
if email == sanitizeEmail(e) {
return true
}
}
}
// The groups and emails can be in the same array for now, but consider
// making a specialized option later.
for _, name := range o.Groups {
for _, admin := range admins {
if name == admin {
return true
}
}
}
return false
}
// OIDC represents an OAuth 2.0 OpenID Connect provider. // OIDC represents an OAuth 2.0 OpenID Connect provider.
// //
// ClientSecret is mandatory, but it can be an empty string. // ClientSecret is mandatory, but it can be an empty string.
@ -73,35 +96,6 @@ type OIDC struct {
getIdentityFunc GetIdentityFunc getIdentityFunc GetIdentityFunc
} }
// IsAdmin returns true if the given email is in the Admins allowlist, false
// otherwise.
func (o *OIDC) IsAdmin(email string) bool {
if email != "" {
email = sanitizeEmail(email)
for _, e := range o.Admins {
if email == sanitizeEmail(e) {
return true
}
}
}
return false
}
// IsAdminGroup returns true if the one group in the given list is in the Admins
// allowlist, false otherwise.
func (o *OIDC) IsAdminGroup(groups []string) bool {
for _, g := range groups {
// The groups and emails can be in the same array for now, but consider
// making a specialized option later.
for _, gadmin := range o.Admins {
if g == gadmin {
return true
}
}
}
return false
}
func sanitizeEmail(email string) string { func sanitizeEmail(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 { if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i] + strings.ToLower(email[i:]) email = email[:i] + strings.ToLower(email[i:])
@ -154,7 +148,7 @@ func (o *OIDC) GetType() Type {
} }
// GetEncryptedKey is not available in an OIDC provisioner. // GetEncryptedKey is not available in an OIDC provisioner.
func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) { func (o *OIDC) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }
@ -199,7 +193,7 @@ func (o *OIDC) Init(config Config) (err error) {
} }
// Replace {tenantid} with the configured one // Replace {tenantid} with the configured one
if o.TenantID != "" { if o.TenantID != "" {
o.configuration.Issuer = strings.Replace(o.configuration.Issuer, "{tenantid}", o.TenantID, -1) o.configuration.Issuer = strings.ReplaceAll(o.configuration.Issuer, "{tenantid}", o.TenantID)
} }
// Get JWK key set // Get JWK key set
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI) o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
@ -234,7 +228,7 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
} }
// Validate domains (case-insensitive) // Validate domains (case-insensitive)
if p.Email != "" && len(o.Domains) > 0 && !o.IsAdmin(p.Email) { if p.Email != "" && len(o.Domains) > 0 && !p.IsAdmin(o.Admins) {
email := sanitizeEmail(p.Email) email := sanitizeEmail(p.Email)
var found bool var found bool
for _, d := range o.Domains { for _, d := range o.Domains {
@ -313,9 +307,10 @@ func (o *OIDC) AuthorizeRevoke(ctx context.Context, token string) error {
} }
// Only admins can revoke certificates. // Only admins can revoke certificates.
if o.IsAdmin(claims.Email) { if claims.IsAdmin(o.Admins) {
return nil return nil
} }
return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token") return errs.Unauthorized("oidc.AuthorizeRevoke; cannot revoke with non-admin oidc token")
} }
@ -351,7 +346,7 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// Use the default template unless no-templates are configured and email is // Use the default template unless no-templates are configured and email is
// an admin, in that case we will use the CR template. // an admin, in that case we will use the CR template.
defaultTemplate := x509util.DefaultLeafTemplate defaultTemplate := x509util.DefaultLeafTemplate
if !o.Options.GetX509Options().HasTemplate() && o.IsAdmin(claims.Email) { if !o.Options.GetX509Options().HasTemplate() && claims.IsAdmin(o.Admins) {
defaultTemplate = x509util.DefaultAdminLeafTemplate defaultTemplate = x509util.DefaultAdminLeafTemplate
} }
@ -420,10 +415,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Use the default template unless no-templates are configured and email is // Use the default template unless no-templates are configured and email is
// an admin, in that case we will use the parameters in the request. // an admin, in that case we will use the parameters in the request.
isAdmin := o.IsAdmin(claims.Email) isAdmin := claims.IsAdmin(o.Admins)
if !isAdmin && len(claims.Groups) > 0 {
isAdmin = o.IsAdminGroup(claims.Groups)
}
defaultTemplate := sshutil.DefaultTemplate defaultTemplate := sshutil.DefaultTemplate
if isAdmin && !o.Options.GetSSHOptions().HasTemplate() { if isAdmin && !o.Options.GetSSHOptions().HasTemplate() {
defaultTemplate = sshutil.DefaultAdminTemplate defaultTemplate = sshutil.DefaultAdminTemplate
@ -471,10 +463,11 @@ func (o *OIDC) AuthorizeSSHRevoke(ctx context.Context, token string) error {
} }
// Only admins can revoke certificates. // Only admins can revoke certificates.
if !o.IsAdmin(claims.Email) { if claims.IsAdmin(o.Admins) {
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token") return nil
} }
return nil
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
} }
func getAndDecode(uri string, v interface{}) error { func getAndDecode(uri string, v interface{}) error {

View file

@ -321,32 +321,26 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else { } else if assert.NotNil(t, got) {
if assert.NotNil(t, got) { assert.Len(t, 5, got)
if tt.name == "admin" { for _, o := range got {
assert.Len(t, 5, got) switch v := o.(type) {
} else { case certificateOptionsFunc:
assert.Len(t, 5, got) case *provisionerExtensionOption:
} assert.Equals(t, v.Type, int(TypeOIDC))
for _, o := range got { assert.Equals(t, v.Name, tt.prov.GetName())
switch v := o.(type) { assert.Equals(t, v.CredentialID, tt.prov.ClientID)
case certificateOptionsFunc: assert.Len(t, 0, v.KeyValuePairs)
case *provisionerExtensionOption: case profileDefaultDuration:
assert.Equals(t, v.Type, int(TypeOIDC)) assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
assert.Equals(t, v.Name, tt.prov.GetName()) case defaultPublicKeyValidator:
assert.Equals(t, v.CredentialID, tt.prov.ClientID) case *validityValidator:
assert.Len(t, 0, v.KeyValuePairs) assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
case profileDefaultDuration: assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) case emailOnlyIdentity:
case defaultPublicKeyValidator: assert.Equals(t, string(v), "name@smallstep.com")
case *validityValidator: default:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
case emailOnlyIdentity:
assert.Equals(t, string(v), "name@smallstep.com")
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
} }
} }
} }
@ -698,3 +692,39 @@ func Test_sanitizeEmail(t *testing.T) {
}) })
} }
} }
func Test_openIDPayload_IsAdmin(t *testing.T) {
type fields struct {
Email string
Groups []string
}
type args struct {
admins []string
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{"ok email", fields{"admin@smallstep.com", nil}, args{[]string{"admin@smallstep.com"}}, true},
{"ok email multiple", fields{"admin@smallstep.com", []string{"admin", "eng"}}, args{[]string{"eng@smallstep.com", "admin@smallstep.com"}}, true},
{"ok email sanitized", fields{"admin@Smallstep.com", nil}, args{[]string{"admin@smallStep.com"}}, true},
{"ok group", fields{"", []string{"admin"}}, args{[]string{"admin"}}, true},
{"ok group multiple", fields{"admin@smallstep.com", []string{"engineering", "admin"}}, args{[]string{"admin"}}, true},
{"fail missing", fields{"eng@smallstep.com", []string{"admin"}}, args{[]string{"admin@smallstep.com"}}, false},
{"fail email letter case", fields{"Admin@smallstep.com", []string{}}, args{[]string{"admin@smallstep.com"}}, false},
{"fail group letter case", fields{"", []string{"Admin"}}, args{[]string{"admin"}}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := &openIDPayload{
Email: tt.fields.Email,
Groups: tt.fields.Groups,
}
if got := o.IsAdmin(tt.args.admins); got != tt.want {
t.Errorf("openIDPayload.IsAdmin() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -138,7 +138,7 @@ func unsafeParseSigned(s string) (map[string]interface{}, error) {
return nil, err return nil, err
} }
claims := make(map[string]interface{}) claims := make(map[string]interface{})
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, err return nil, err
} }
return claims, nil return claims, nil

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
stderrors "errors"
"net/url" "net/url"
"regexp" "regexp"
"strings" "strings"
@ -32,6 +33,17 @@ type Interface interface {
AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error)
} }
// ErrAllowTokenReuse is an error that is returned by provisioners that allows
// the reuse of tokens.
//
// This is, for example, returned by the Azure provisioner when
// DisableTrustOnFirstUse is set to true. Azure caches tokens for up to 24hr and
// has no mechanism for getting a different token - this can be an issue when
// rebooting a VM. In contrast, AWS and GCP have facilities for requesting a new
// token. Therefore, for the Azure provisioner we are enabling token reuse, with
// the understanding that we are not following security best practices
var ErrAllowTokenReuse = stderrors.New("allow token reuse")
// Audiences stores all supported audiences by request type. // Audiences stores all supported audiences by request type.
type Audiences struct { type Audiences struct {
Sign []string Sign []string
@ -111,7 +123,7 @@ func (a Audiences) WithFragment(fragment string) Audiences {
// generateSignAudience generates a sign audience with the format // generateSignAudience generates a sign audience with the format
// https://<host>/1.0/sign#provisionerID // https://<host>/1.0/sign#provisionerID
func generateSignAudience(caURL string, provisionerID string) (string, error) { func generateSignAudience(caURL, provisionerID string) (string, error) {
u, err := url.Parse(caURL) u, err := url.Parse(caURL)
if err != nil { if err != nil {
return "", errors.Wrapf(err, "error parsing %s", caURL) return "", errors.Wrapf(err, "error parsing %s", caURL)

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"math/big" "math/big"
"strings"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -455,10 +456,10 @@ func containsAllMembers(group, subgroup []string) bool {
} }
visit := make(map[string]struct{}, lg) visit := make(map[string]struct{}, lg)
for i := 0; i < lg; i++ { for i := 0; i < lg; i++ {
visit[group[i]] = struct{}{} visit[strings.ToLower(group[i])] = struct{}{}
} }
for i := 0; i < lsg; i++ { for i := 0; i < lsg; i++ {
if _, ok := visit[subgroup[i]]; !ok { if _, ok := visit[strings.ToLower(subgroup[i])]; !ok {
return false return false
} }
} }

View file

@ -44,7 +44,7 @@ func TestSSHOptions_Modify(t *testing.T) {
valid func(*ssh.Certificate) valid func(*ssh.Certificate)
err error err error
} }
tests := map[string](func() test){ tests := map[string]func() test{
"fail/unexpected-cert-type": func() test { "fail/unexpected-cert-type": func() test {
return test{ return test{
so: SignSSHOptions{CertType: "foo"}, so: SignSSHOptions{CertType: "foo"},
@ -117,7 +117,7 @@ func TestSSHOptions_Match(t *testing.T) {
cmp SignSSHOptions cmp SignSSHOptions
err error err error
} }
tests := map[string](func() test){ tests := map[string]func() test{
"fail/cert-type": func() test { "fail/cert-type": func() test {
return test{ return test{
so: SignSSHOptions{CertType: "foo"}, so: SignSSHOptions{CertType: "foo"},
@ -208,7 +208,7 @@ func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
expected []string expected []string
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok": func() test { "ok": func() test {
a := []string{"foo", "bar"} a := []string{"foo", "bar"}
return test{ return test{
@ -234,7 +234,7 @@ func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
expected string expected string
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok": func() test { "ok": func() test {
a := "foo" a := "foo"
return test{ return test{
@ -260,7 +260,7 @@ func Test_sshCertTypeModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
expected uint32 expected uint32
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok/user": func() test { "ok/user": func() test {
return test{ return test{
modifier: sshCertTypeModifier("user"), modifier: sshCertTypeModifier("user"),
@ -299,7 +299,7 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
expected uint64 expected uint64
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok": func() test { "ok": func() test {
return test{ return test{
modifier: sshCertValidAfterModifier(15), modifier: sshCertValidAfterModifier(15),
@ -324,7 +324,7 @@ func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
valid func(*ssh.Certificate) valid func(*ssh.Certificate)
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok/changes": func() test { "ok/changes": func() test {
n := time.Now() n := time.Now()
va := NewTimeDuration(n.Add(1 * time.Minute)) va := NewTimeDuration(n.Add(1 * time.Minute))
@ -388,7 +388,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
valid func(*ssh.Certificate) valid func(*ssh.Certificate)
err error err error
} }
tests := map[string](func() test){ tests := map[string]func() test{
"fail/unexpected-cert-type": func() test { "fail/unexpected-cert-type": func() test {
cert := &ssh.Certificate{CertType: 3} cert := &ssh.Certificate{CertType: 3}
return test{ return test{

View file

@ -8,7 +8,6 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -30,7 +29,6 @@ type SSHPOP struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
db db.AuthDB
claimer *Claimer claimer *Claimer
audiences Audiences audiences Audiences
sshPubKeys *SSHKeys sshPubKeys *SSHKeys
@ -102,7 +100,6 @@ func (p *SSHPOP) Init(config Config) error {
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.db = config.DB
p.sshPubKeys = config.SSHKeys p.sshPubKeys = config.SSHKeys
return nil return nil
} }
@ -110,6 +107,8 @@ func (p *SSHPOP) Init(config Config) error {
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
// claims for case specific downstream parsing. // claims for case specific downstream parsing.
// e.g. a Sign request will auth/validate different fields than a Revoke request. // e.g. a Sign request will auth/validate different fields than a Revoke request.
//
// Checking for certificate revocation has been moved to the authority package.
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
sshCert, jwt, err := ExtractSSHPOPCert(token) sshCert, jwt, err := ExtractSSHPOPCert(token)
if err != nil { if err != nil {
@ -117,14 +116,6 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
"sshpop.authorizeToken; error extracting sshpop header from token") "sshpop.authorizeToken; error extracting sshpop header from token")
} }
// Check for revocation.
if isRevoked, err := p.db.IsSSHRevoked(strconv.FormatUint(sshCert.Serial, 10)); err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err,
"sshpop.authorizeToken; error checking checking sshpop cert revocation")
} else if isRevoked {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate is revoked")
}
// Check validity period of the certificate. // Check validity period of the certificate.
n := time.Now() n := time.Now()
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {

View file

@ -11,7 +11,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -47,7 +46,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if err = cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
return cert, jwk, nil return cert, jwk, nil
@ -83,52 +82,9 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "), err: errors.New("sshpop.authorizeToken; error extracting sshpop header from token: extractSSHPOPCert; error parsing token: "),
} }
}, },
"fail/error-revoked-db-check": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, errors.New("force")
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusInternalServerError,
err: errors.New("sshpop.authorizeToken; error checking checking sshpop cert revocation: force"),
}
},
"fail/cert-already-revoked": func(t *testing.T) test {
p, err := generateSSHPOP()
assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return true, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
code: http.StatusUnauthorized,
err: errors.New("sshpop.authorizeToken; sshpop certificate is revoked"),
}
},
"fail/cert-not-yet-valid": func(t *testing.T) test { "fail/cert-not-yet-valid": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{ cert, jwk, err := createSSHCert(&ssh.Certificate{
CertType: ssh.UserCert, CertType: ssh.UserCert,
ValidAfter: uint64(time.Now().Add(time.Minute).Unix()), ValidAfter: uint64(time.Now().Add(time.Minute).Unix()),
@ -146,11 +102,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
"fail/cert-past-validity": func(t *testing.T) test { "fail/cert-past-validity": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{ cert, jwk, err := createSSHCert(&ssh.Certificate{
CertType: ssh.UserCert, CertType: ssh.UserCert,
ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()), ValidBefore: uint64(time.Now().Add(-time.Minute).Unix()),
@ -168,11 +119,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
"fail/no-signer-found": func(t *testing.T) test { "fail/no-signer-found": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk) tok, err := generateSSHPOPToken(p, cert, jwk)
@ -187,11 +133,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
"fail/error-parsing-claims-bad-sig": func(t *testing.T) test { "fail/error-parsing-claims-bad-sig": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) cert, _, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) otherJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
@ -208,11 +149,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
"fail/invalid-claims-issuer": func(t *testing.T) test { "fail/invalid-claims-issuer": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "", tok, err := generateToken("foo", "bar", testAudiences.Sign[0], "",
@ -228,11 +164,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
"fail/invalid-audience": func(t *testing.T) test { "fail/invalid-audience": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), "invalid-aud", "", tok, err := generateToken("foo", p.GetName(), "invalid-aud", "",
@ -248,11 +179,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
"fail/empty-subject": func(t *testing.T) test { "fail/empty-subject": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "", tok, err := generateToken("", p.GetName(), testAudiences.Sign[0], "",
@ -268,11 +194,6 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateSSHPOPToken(p, cert, jwk) tok, err := generateSSHPOPToken(p, cert, jwk)
@ -293,10 +214,8 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.NotNil(t, claims)
assert.NotNil(t, claims)
}
} }
}) })
} }
@ -330,11 +249,6 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
"fail/subject-not-equal-serial": func(t *testing.T) test { "fail/subject-not-equal-serial": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRevoke[0], "",
@ -350,11 +264,6 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.UserCert}, sshSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "", tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRevoke[0], "",
@ -419,11 +328,6 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
"fail/not-host-cert": func(t *testing.T) test { "fail/not-host-cert": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0], "",
@ -439,11 +343,6 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "", tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRenew[0], "",
@ -511,11 +410,6 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
"fail/not-host-cert": func(t *testing.T) test { "fail/not-host-cert": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.UserCert}, sshUserSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "", tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRekey[0], "",
@ -531,11 +425,6 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateSSHPOP() p, err := generateSSHPOP()
assert.FatalError(t, err) assert.FatalError(t, err)
p.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner) cert, jwk, err := createSSHCert(&ssh.Certificate{Serial: 123455, CertType: ssh.HostCert}, sshHostSigner)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "", tok, err := generateToken("123455", p.GetName(), testAudiences.SSHRekey[0], "",

View file

@ -732,7 +732,7 @@ func withSSHPOPFile(cert *ssh.Certificate) tokOption {
} }
} }
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions) so := new(jose.SignerOptions)
so.WithType("JWT") so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID) so.WithHeader("kid", jwk.KeyID)
@ -773,7 +773,7 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateOIDCToken(sub, iss, aud string, email string, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions) so := new(jose.SignerOptions)
so.WithType("JWT") so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID) so.WithHeader("kid", jwk.KeyID)

View file

@ -4,12 +4,17 @@ import (
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem"
"fmt" "fmt"
"io/ioutil"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
step "go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
@ -234,6 +239,14 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error {
} }
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) { func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
if password == "" {
pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one")
if err != nil {
return nil, err
}
password = string(pass)
}
jwk, jwe, err := jose.GenerateDefaultKeyPair([]byte(password)) jwk, jwe, err := jose.GenerateDefaultKeyPair([]byte(password))
if err != nil { if err != nil {
return nil, admin.WrapErrorISE(err, "error generating JWK key pair") return nil, admin.WrapErrorISE(err, "error generating JWK key pair")
@ -398,6 +411,13 @@ func durationsToCertificates(d *linkedca.Durations) (min, max, def *provisioner.
return return
} }
func durationsToLinkedca(d *provisioner.Duration) string {
if d == nil {
return ""
}
return d.Duration.String()
}
// claimsToCertificates converts the linkedca provisioner claims type to the // claimsToCertificates converts the linkedca provisioner claims type to the
// certifictes claims type. // certifictes claims type.
func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) { func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
@ -438,6 +458,109 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
return pc, nil return pc, nil
} }
func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims {
if c == nil {
return nil
}
disableRenewal := config.DefaultDisableRenewal
if c.DisableRenewal != nil {
disableRenewal = *c.DisableRenewal
}
lc := &linkedca.Claims{
DisableRenewal: disableRenewal,
}
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {
lc.X509 = &linkedca.X509Claims{
Enabled: true,
Durations: &linkedca.Durations{
Default: durationsToLinkedca(c.DefaultTLSDur),
Min: durationsToLinkedca(c.MinTLSDur),
Max: durationsToLinkedca(c.MaxTLSDur),
},
}
}
if c.EnableSSHCA != nil && *c.EnableSSHCA {
lc.Ssh = &linkedca.SSHClaims{
Enabled: true,
}
if c.DefaultUserSSHDur != nil || c.MinUserSSHDur != nil || c.MaxUserSSHDur != nil {
lc.Ssh.UserDurations = &linkedca.Durations{
Default: durationsToLinkedca(c.DefaultUserSSHDur),
Min: durationsToLinkedca(c.MinUserSSHDur),
Max: durationsToLinkedca(c.MaxUserSSHDur),
}
}
if c.DefaultHostSSHDur != nil || c.MinHostSSHDur != nil || c.MaxHostSSHDur != nil {
lc.Ssh.HostDurations = &linkedca.Durations{
Default: durationsToLinkedca(c.DefaultHostSSHDur),
Min: durationsToLinkedca(c.MinHostSSHDur),
Max: durationsToLinkedca(c.MaxHostSSHDur),
}
}
}
return lc
}
func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *linkedca.Template, error) {
var err error
var x509Template, sshTemplate *linkedca.Template
if p == nil {
return nil, nil, nil
}
if p.X509 != nil && p.X509.HasTemplate() {
x509Template = &linkedca.Template{
Template: nil,
Data: nil,
}
if p.X509.Template != "" {
x509Template.Template = []byte(p.SSH.Template)
} else if p.X509.TemplateFile != "" {
filename := step.StepAbs(p.X509.TemplateFile)
if x509Template.Template, err = ioutil.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading x509 template")
}
}
}
if p.SSH != nil && p.SSH.HasTemplate() {
sshTemplate = &linkedca.Template{
Template: nil,
Data: nil,
}
if p.SSH.Template != "" {
sshTemplate.Template = []byte(p.SSH.Template)
} else if p.SSH.TemplateFile != "" {
filename := step.StepAbs(p.SSH.TemplateFile)
if sshTemplate.Template, err = ioutil.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading ssh template")
}
}
}
return x509Template, sshTemplate, nil
}
func provisionerPEMToLinkedca(b []byte) [][]byte {
var roots [][]byte
var block *pem.Block
for {
if block, b = pem.Decode(b); block == nil {
break
}
roots = append(roots, pem.EncodeToMemory(block))
}
return roots
}
// ProvisionerToCertificates converts the linkedca provisioner type to the certificates provisioner // ProvisionerToCertificates converts the linkedca provisioner type to the certificates provisioner
// interface. // interface.
func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, error) { func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, error) {
@ -448,7 +571,7 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
details := p.Details.GetData() details := p.Details.GetData()
if details == nil { if details == nil {
return nil, fmt.Errorf("provisioner does not have any details") return nil, errors.New("provisioner does not have any details")
} }
options := optionsToCertificates(p) options := optionsToCertificates(p)
@ -457,7 +580,7 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
case *linkedca.ProvisionerDetails_JWK: case *linkedca.ProvisionerDetails_JWK:
jwk := new(jose.JSONWebKey) jwk := new(jose.JSONWebKey)
if err := json.Unmarshal(d.JWK.PublicKey, &jwk); err != nil { if err := json.Unmarshal(d.JWK.PublicKey, &jwk); err != nil {
return nil, err return nil, errors.Wrap(err, "error unmarshaling public key")
} }
return &provisioner.JWK{ return &provisioner.JWK{
ID: p.Id, ID: p.Id,
@ -588,6 +711,233 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface,
} }
} }
// ProvisionerToLinkedca converts a provisioner.Interface to a
// linkedca.Provisioner type.
func ProvisionerToLinkedca(p provisioner.Interface) (*linkedca.Provisioner, error) {
switch p := p.(type) {
case *provisioner.JWK:
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
if err != nil {
return nil, err
}
publicKey, err := json.Marshal(p.Key)
if err != nil {
return nil, errors.Wrap(err, "error marshaling key")
}
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_JWK,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_JWK{
JWK: &linkedca.JWKProvisioner{
PublicKey: publicKey,
EncryptedPrivateKey: []byte(p.EncryptedKey),
},
},
},
Claims: claimsToLinkedca(p.Claims),
X509Template: x509Template,
SshTemplate: sshTemplate,
}, nil
case *provisioner.OIDC:
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
if err != nil {
return nil, err
}
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_OIDC,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_OIDC{
OIDC: &linkedca.OIDCProvisioner{
ClientId: p.ClientID,
ClientSecret: p.ClientSecret,
ConfigurationEndpoint: p.ConfigurationEndpoint,
Admins: p.Admins,
Domains: p.Domains,
Groups: p.Groups,
ListenAddress: p.ListenAddress,
TenantId: p.TenantID,
},
},
},
Claims: claimsToLinkedca(p.Claims),
X509Template: x509Template,
SshTemplate: sshTemplate,
}, nil
case *provisioner.GCP:
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
if err != nil {
return nil, err
}
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_GCP,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_GCP{
GCP: &linkedca.GCPProvisioner{
ServiceAccounts: p.ServiceAccounts,
ProjectIds: p.ProjectIDs,
DisableCustomSans: p.DisableCustomSANs,
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,
InstanceAge: p.InstanceAge.String(),
},
},
},
Claims: claimsToLinkedca(p.Claims),
X509Template: x509Template,
SshTemplate: sshTemplate,
}, nil
case *provisioner.AWS:
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
if err != nil {
return nil, err
}
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_AWS,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_AWS{
AWS: &linkedca.AWSProvisioner{
Accounts: p.Accounts,
DisableCustomSans: p.DisableCustomSANs,
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,
InstanceAge: p.InstanceAge.String(),
},
},
},
Claims: claimsToLinkedca(p.Claims),
X509Template: x509Template,
SshTemplate: sshTemplate,
}, nil
case *provisioner.Azure:
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
if err != nil {
return nil, err
}
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_AZURE,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_Azure{
Azure: &linkedca.AzureProvisioner{
TenantId: p.TenantID,
ResourceGroups: p.ResourceGroups,
Audience: p.Audience,
DisableCustomSans: p.DisableCustomSANs,
DisableTrustOnFirstUse: p.DisableTrustOnFirstUse,
},
},
},
Claims: claimsToLinkedca(p.Claims),
X509Template: x509Template,
SshTemplate: sshTemplate,
}, nil
case *provisioner.ACME:
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
if err != nil {
return nil, err
}
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_ACME,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
ForceCn: p.ForceCN,
},
},
},
Claims: claimsToLinkedca(p.Claims),
X509Template: x509Template,
SshTemplate: sshTemplate,
}, nil
case *provisioner.X5C:
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
if err != nil {
return nil, err
}
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_X5C,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_X5C{
X5C: &linkedca.X5CProvisioner{
Roots: provisionerPEMToLinkedca(p.Roots),
},
},
},
Claims: claimsToLinkedca(p.Claims),
X509Template: x509Template,
SshTemplate: sshTemplate,
}, nil
case *provisioner.K8sSA:
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
if err != nil {
return nil, err
}
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_K8SSA,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_K8SSA{
K8SSA: &linkedca.K8SSAProvisioner{
PublicKeys: provisionerPEMToLinkedca(p.PubKeys),
},
},
},
Claims: claimsToLinkedca(p.Claims),
X509Template: x509Template,
SshTemplate: sshTemplate,
}, nil
case *provisioner.SSHPOP:
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_SSHPOP,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_SSHPOP{
SSHPOP: &linkedca.SSHPOPProvisioner{},
},
},
Claims: claimsToLinkedca(p.Claims),
}, nil
case *provisioner.SCEP:
x509Template, sshTemplate, err := provisionerOptionsToLinkedca(p.Options)
if err != nil {
return nil, err
}
return &linkedca.Provisioner{
Id: p.ID,
Type: linkedca.Provisioner_SCEP,
Name: p.GetName(),
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_SCEP{
SCEP: &linkedca.SCEPProvisioner{
ForceCn: p.ForceCN,
Challenge: p.GetChallengePassword(),
Capabilities: p.Capabilities,
MinimumPublicKeyLength: int32(p.MinimumPublicKeyLength),
},
},
},
Claims: claimsToLinkedca(p.Claims),
X509Template: x509Template,
SshTemplate: sshTemplate,
}, nil
default:
return nil, fmt.Errorf("provisioner %s not implemented", p.GetType())
}
}
func parseInstanceAge(age string) (provisioner.Duration, error) { func parseInstanceAge(age string) (provisioner.Duration, error) {
var instanceAge provisioner.Duration var instanceAge provisioner.Duration
if age != "" { if age != "" {

View file

@ -108,7 +108,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
// GetSSHBastion returns the bastion configuration, for the given pair user, // GetSSHBastion returns the bastion configuration, for the given pair user,
// hostname. // hostname.
func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname string) (*config.Bastion, error) { func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*config.Bastion, error) {
if a.sshBastionFunc != nil { if a.sshBastionFunc != nil {
bs, err := a.sshBastionFunc(ctx, user, hostname) bs, err := a.sshBastionFunc(ctx, user, hostname)
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion") return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
@ -239,7 +239,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
} }
} }
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db") return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db")
} }
@ -249,7 +249,11 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest("rewnewSSH: cannot renew certificate without validity period") return nil, errs.BadRequest("renewSSH: cannot renew certificate without validity period")
}
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
return nil, err
} }
backdate := a.config.AuthorityConfig.Backdate.Duration backdate := a.config.AuthorityConfig.Backdate.Duration
@ -294,7 +298,7 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate")
} }
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db") return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db")
} }
@ -319,6 +323,10 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period") return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
} }
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
return nil, err
}
backdate := a.config.AuthorityConfig.Backdate.Duration backdate := a.config.AuthorityConfig.Backdate.Duration
duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second
now := time.Now() now := time.Now()
@ -369,13 +377,23 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
} }
} }
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db") return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db")
} }
return cert, nil return cert, nil
} }
func (a *Authority) storeSSHCertificate(cert *ssh.Certificate) error {
type sshCertificateStorer interface {
StoreSSHCertificate(crt *ssh.Certificate) error
}
if s, ok := a.adminDB.(sshCertificateStorer); ok {
return s.StoreSSHCertificate(cert)
}
return a.db.StoreSSHCertificate(cert)
}
// IsValidForAddUser checks if a user provisioner certificate can be issued to // IsValidForAddUser checks if a user provisioner certificate can be issued to
// the given certificate. // the given certificate.
func IsValidForAddUser(cert *ssh.Certificate) error { func IsValidForAddUser(cert *ssh.Certificate) error {
@ -451,7 +469,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
} }
cert.Signature = sig cert.Signature = sig
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { if err = a.storeSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db") return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db")
} }
@ -459,7 +477,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
} }
// CheckSSHHost checks the given principal has been registered before. // CheckSSHHost checks the given principal has been registered before.
func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) { func (a *Authority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) {
if a.sshCheckHostFunc != nil { if a.sshCheckHostFunc != nil {
exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates()) exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates())
if err != nil { if err != nil {
@ -513,5 +531,5 @@ func (a *Authority) getAddUserCommand(principal string) string {
} else { } else {
cmd = a.config.SSH.AddUserCommand cmd = a.config.SSH.AddUserCommand
} }
return strings.Replace(cmd, "<principal>", principal, -1) return strings.ReplaceAll(cmd, "<principal>", principal)
} }

View file

@ -87,6 +87,52 @@ func (m sshTestOptionsModifier) Modify(cert *ssh.Certificate, opts provisioner.S
return fmt.Errorf(string(m)) return fmt.Errorf(string(m))
} }
func TestAuthority_initHostOnly(t *testing.T) {
auth := testAuthority(t, func(a *Authority) error {
a.config.SSH.UserKey = ""
return nil
})
// Check keys
keys, err := auth.GetSSHRoots(context.Background())
assert.NoError(t, err)
assert.Len(t, 1, keys.HostKeys)
assert.Len(t, 0, keys.UserKeys)
// Check templates, user templates should work fine.
_, err = auth.GetSSHConfig(context.Background(), "user", nil)
assert.NoError(t, err)
_, err = auth.GetSSHConfig(context.Background(), "host", map[string]string{
"Certificate": "ssh_host_ecdsa_key-cert.pub",
"Key": "ssh_host_ecdsa_key",
})
assert.Error(t, err)
}
func TestAuthority_initUserOnly(t *testing.T) {
auth := testAuthority(t, func(a *Authority) error {
a.config.SSH.HostKey = ""
return nil
})
// Check keys
keys, err := auth.GetSSHRoots(context.Background())
assert.NoError(t, err)
assert.Len(t, 0, keys.HostKeys)
assert.Len(t, 1, keys.UserKeys)
// Check templates, host templates should work fine.
_, err = auth.GetSSHConfig(context.Background(), "host", map[string]string{
"Certificate": "ssh_host_ecdsa_key-cert.pub",
"Key": "ssh_host_ecdsa_key",
})
assert.NoError(t, err)
_, err = auth.GetSSHConfig(context.Background(), "user", nil)
assert.Error(t, err)
}
func TestAuthority_SignSSH(t *testing.T) { func TestAuthority_SignSSH(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -153,6 +199,8 @@ func TestAuthority_SignSSH(t *testing.T) {
}{ }{
{"ok-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false}, {"ok-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false},
{"ok-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false}, {"ok-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false},
{"ok-user-only", fields{signer, nil}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{userTemplate, userOptions}}, want{CertType: ssh.UserCert}, false},
{"ok-host-only", fields{nil, signer}, args{pub, provisioner.SignSSHOptions{}, []provisioner.SignOption{hostTemplate, hostOptions}}, want{CertType: ssh.HostCert}, false},
{"ok-opts-type-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false}, {"ok-opts-type-user", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user"}, []provisioner.SignOption{userTemplate}}, want{CertType: ssh.UserCert}, false},
{"ok-opts-type-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false}, {"ok-opts-type-host", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "host"}, []provisioner.SignOption{hostTemplate}}, want{CertType: ssh.HostCert}, false},
{"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false}, {"ok-opts-principals", fields{signer, signer}, args{pub, provisioner.SignSSHOptions{CertType: "user", Principals: []string{"user"}}, []provisioner.SignOption{userTemplateWithUser}}, want{CertType: ssh.UserCert, Principals: []string{"user"}}, false},
@ -750,6 +798,11 @@ func TestAuthority_RekeySSH(t *testing.T) {
now := time.Now().UTC() now := time.Now().UTC()
a := testAuthority(t) a := testAuthority(t)
a.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
}
type test struct { type test struct {
auth *Authority auth *Authority
@ -763,6 +816,56 @@ func TestAuthority_RekeySSH(t *testing.T) {
code int code int
} }
tests := map[string]func(t *testing.T) *test{ tests := map[string]func(t *testing.T) *test{
"fail/is-revoked": func(t *testing.T) *test {
auth := testAuthority(t)
auth.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return true, nil
},
}
return &test{
auth: auth,
userSigner: signer,
hostSigner: signer,
cert: &ssh.Certificate{
Serial: 1234567890,
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
CertType: ssh.UserCert,
ValidPrincipals: []string{"foo", "bar"},
KeyId: "foo",
},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("authority.authorizeSSHCertificate: certificate has been revoked"),
code: http.StatusUnauthorized,
}
},
"fail/is-revoked-error": func(t *testing.T) *test {
auth := testAuthority(t)
auth.db = &db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, errors.New("an error")
},
}
return &test{
auth: auth,
userSigner: signer,
hostSigner: signer,
cert: &ssh.Certificate{
Serial: 1234567890,
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
CertType: ssh.UserCert,
ValidPrincipals: []string{"foo", "bar"},
KeyId: "foo",
},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("authority.authorizeSSHCertificate: an error"),
code: http.StatusInternalServerError,
}
},
"fail/opts-type": func(t *testing.T) *test { "fail/opts-type": func(t *testing.T) *test {
return &test{ return &test{
userSigner: signer, userSigner: signer,
@ -831,6 +934,9 @@ func TestAuthority_RekeySSH(t *testing.T) {
"fail/db-store": func(t *testing.T) *test { "fail/db-store": func(t *testing.T) *test {
return &test{ return &test{
auth: testAuthority(t, WithDatabase(&db.MockAuthDB{ auth: testAuthority(t, WithDatabase(&db.MockAuthDB{
MIsSSHRevoked: func(sn string) (bool, error) {
return false, nil
},
MStoreSSHCertificate: func(cert *ssh.Certificate) error { MStoreSSHCertificate: func(cert *ssh.Certificate) error {
return errors.New("force") return errors.New("force")
}, },

View file

@ -21,6 +21,7 @@ import (
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
) )
// GetTLSOptions returns the tls options configured. // GetTLSOptions returns the tls options configured.
@ -36,7 +37,6 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
if def == nil { if def == nil {
return errors.New("default ASN1DN template cannot be nil") return errors.New("default ASN1DN template cannot be nil")
} }
if len(crt.Subject.Country) == 0 && def.Country != "" { if len(crt.Subject.Country) == 0 && def.Country != "" {
crt.Subject.Country = append(crt.Subject.Country, def.Country) crt.Subject.Country = append(crt.Subject.Country, def.Country)
} }
@ -55,7 +55,12 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" { if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" {
crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress) crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress)
} }
if crt.Subject.SerialNumber == "" && def.SerialNumber != "" {
crt.Subject.SerialNumber = def.SerialNumber
}
if crt.Subject.CommonName == "" && def.CommonName != "" {
crt.Subject.CommonName = def.CommonName
}
return nil return nil
} }
} }
@ -280,9 +285,15 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
// `StoreCertificate(...*x509.Certificate) error` instead of just // `StoreCertificate(...*x509.Certificate) error` instead of just
// `StoreCertificate(*x509.Certificate) error`. // `StoreCertificate(*x509.Certificate) error`.
func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
if s, ok := a.db.(interface { type certificateChainStorer interface {
StoreCertificateChain(...*x509.Certificate) error StoreCertificateChain(...*x509.Certificate) error
}); ok { }
// Store certificate in linkedca
if s, ok := a.adminDB.(certificateChainStorer); ok {
return s.StoreCertificateChain(fullchain...)
}
// Store certificate in local db
if s, ok := a.db.(certificateChainStorer); ok {
return s.StoreCertificateChain(fullchain...) return s.StoreCertificateChain(fullchain...)
} }
return a.db.StoreCertificate(fullchain[0]) return a.db.StoreCertificate(fullchain[0])
@ -293,9 +304,15 @@ func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
// //
// TODO: at some point we should implement this in the standard implementation. // TODO: at some point we should implement this in the standard implementation.
func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error { func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error {
if s, ok := a.db.(interface { type renewedCertificateChainStorer interface {
StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error
}); ok { }
// Store certificate in linkedca
if s, ok := a.adminDB.(renewedCertificateChainStorer); ok {
return s.StoreRenewedCertificate(oldCert, fullchain...)
}
// Store certificate in local db
if s, ok := a.db.(renewedCertificateChainStorer); ok {
return s.StoreRenewedCertificate(oldCert, fullchain...) return s.StoreRenewedCertificate(oldCert, fullchain...)
} }
return a.db.StoreCertificate(fullchain[0]) return a.db.StoreCertificate(fullchain[0])
@ -368,22 +385,22 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
} }
rci.ProvisionerID = p.GetID() rci.ProvisionerID = p.GetID()
rci.TokenID, err = p.GetTokenID(revokeOpts.OTT) rci.TokenID, err = p.GetTokenID(revokeOpts.OTT)
if err != nil { if err != nil && !errors.Is(err, provisioner.ErrAllowTokenReuse) {
return errs.Wrap(http.StatusInternalServerError, err, return errs.Wrap(http.StatusInternalServerError, err,
"authority.Revoke; could not get ID for token") "authority.Revoke; could not get ID for token")
} }
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID)) opts = append(opts,
opts = append(opts, errs.WithKeyVal("tokenID", rci.TokenID)) errs.WithKeyVal("provisionerID", rci.ProvisionerID),
} else { errs.WithKeyVal("tokenID", rci.TokenID),
)
} else if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil {
// Load the Certificate provisioner if one exists. // Load the Certificate provisioner if one exists.
if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil { rci.ProvisionerID = p.GetID()
rci.ProvisionerID = p.GetID() opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
}
} }
if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod { if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod {
err = a.db.RevokeSSH(rci) err = a.revokeSSH(nil, rci)
} else { } else {
// Revoke an X.509 certificate using CAS. If the certificate is not // Revoke an X.509 certificate using CAS. If the certificate is not
// provided we will try to read it from the db. If the read fails we // provided we will try to read it from the db. If the read fails we
@ -410,7 +427,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
} }
// Save as revoked in the Db. // Save as revoked in the Db.
err = a.db.Revoke(rci) err = a.revoke(revokedCert, rci)
} }
switch err { switch err {
case nil: case nil:
@ -425,6 +442,24 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
} }
} }
func (a *Authority) revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error {
if lca, ok := a.adminDB.(interface {
Revoke(*x509.Certificate, *db.RevokedCertificateInfo) error
}); ok {
return lca.Revoke(crt, rci)
}
return a.db.Revoke(rci)
}
func (a *Authority) revokeSSH(crt *ssh.Certificate, rci *db.RevokedCertificateInfo) error {
if lca, ok := a.adminDB.(interface {
RevokeSSH(*ssh.Certificate, *db.RevokedCertificateInfo) error
}); ok {
return lca.RevokeSSH(crt, rci)
}
return a.db.Revoke(rci)
}
// GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server. // GetTLSCertificate creates a new leaf certificate to be used by the CA HTTPS server.
func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) { func (a *Authority) GetTLSCertificate() (*tls.Certificate, error) {
fatal := func(err error) (*tls.Certificate, error) { fatal := func(err error) (*tls.Certificate, error) {

View file

@ -426,6 +426,7 @@ ZYtQ9Ot36qc=
{Id: stepOIDProvisioner, Value: []byte("foo")}, {Id: stepOIDProvisioner, Value: []byte("foo")},
{Id: []int{1, 1, 1}, Value: []byte("bar")}})) {Id: []int{1, 1, 1}, Value: []byte("bar")}}))
now := time.Now().UTC() now := time.Now().UTC()
// nolint:gocritic
enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{ enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{
NotBefore: now, NotBefore: now,
NotAfter: now.Add(365 * 24 * time.Hour), NotAfter: now.Add(365 * 24 * time.Hour),

View file

@ -345,7 +345,7 @@ func readACMEError(r io.ReadCloser) error {
ae := new(acme.Error) ae := new(acme.Error)
err = json.Unmarshal(b, &ae) err = json.Unmarshal(b, &ae)
// If we successfully marshaled to an ACMEError then return the ACMEError. // If we successfully marshaled to an ACMEError then return the ACMEError.
if err != nil || len(ae.Error()) == 0 { if err != nil || ae.Error() == "" {
fmt.Printf("b = %s\n", b) fmt.Printf("b = %s\n", b)
// Throw up our hands. // Throw up our hands.
return errors.Errorf("%s", b) return errors.Errorf("%s", b)

View file

@ -1247,6 +1247,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
Type: "Certificate", Type: "Certificate",
Bytes: leaf.Raw, Bytes: leaf.Raw,
}) })
// nolint:gocritic
certBytes := append(leafb, leafb...) certBytes := append(leafb, leafb...)
certBytes = append(certBytes, leafb...) certBytes = append(certBytes, leafb...)
ac := &ACMEClient{ ac := &ACMEClient{

View file

@ -70,7 +70,7 @@ func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error)
}, nil }, nil
} }
func (c *AdminClient) generateAdminToken(path string) (string, error) { func (c *AdminClient) generateAdminToken(urlPath string) (string, error) {
// A random jwt id will be used to identify duplicated tokens // A random jwt id will be used to identify duplicated tokens
jwtID, err := randutil.Hex(64) // 256 bits jwtID, err := randutil.Hex(64) // 256 bits
if err != nil { if err != nil {
@ -82,7 +82,7 @@ func (c *AdminClient) generateAdminToken(path string) (string, error) {
token.WithJWTID(jwtID), token.WithJWTID(jwtID),
token.WithKid(c.x5cJWK.KeyID), token.WithKid(c.x5cJWK.KeyID),
token.WithIssuer(c.x5cIssuer), token.WithIssuer(c.x5cIssuer),
token.WithAudience(path), token.WithAudience(urlPath),
token.WithValidity(now, now.Add(token.DefaultValidity)), token.WithValidity(now, now.Add(token.DefaultValidity)),
token.WithX5CCerts(c.x5cCertStrs), token.WithX5CCerts(c.x5cCertStrs),
} }
@ -348,14 +348,15 @@ func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provi
return nil, err return nil, err
} }
var u *url.URL var u *url.URL
if len(o.id) > 0 { switch {
case len(o.id) > 0:
u = c.endpoint.ResolveReference(&url.URL{ u = c.endpoint.ResolveReference(&url.URL{
Path: "/admin/provisioners/id", Path: "/admin/provisioners/id",
RawQuery: o.rawQuery(), RawQuery: o.rawQuery(),
}) })
} else if len(o.name) > 0 { case len(o.name) > 0:
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)}) u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
} else { default:
return nil, errors.New("must set either name or id in method options") return nil, errors.New("must set either name or id in method options")
} }
tok, err := c.generateAdminToken(u.Path) tok, err := c.generateAdminToken(u.Path)
@ -456,14 +457,15 @@ func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error {
return err return err
} }
if len(o.id) > 0 { switch {
case len(o.id) > 0:
u = c.endpoint.ResolveReference(&url.URL{ u = c.endpoint.ResolveReference(&url.URL{
Path: path.Join(adminURLPrefix, "provisioners/id"), Path: path.Join(adminURLPrefix, "provisioners/id"),
RawQuery: o.rawQuery(), RawQuery: o.rawQuery(),
}) })
} else if len(o.name) > 0 { case len(o.name) > 0:
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)}) u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
} else { default:
return errors.New("must set either name or id in method options") return errors.New("must set either name or id in method options")
} }
tok, err := c.generateAdminToken(u.Path) tok, err := c.generateAdminToken(u.Path)

View file

@ -30,7 +30,7 @@ func Bootstrap(token string) (*Client, error) {
// Validate bootstrap token // Validate bootstrap token
switch { switch {
case len(claims.SHA) == 0: case claims.SHA == "":
return nil, errors.New("invalid bootstrap token: sha claim is not present") return nil, errors.New("invalid bootstrap token: sha claim is not present")
case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"): case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"):
return nil, errors.New("invalid bootstrap token: aud claim is not a url") return nil, errors.New("invalid bootstrap token: aud claim is not a url")

110
ca/ca.go
View file

@ -29,10 +29,13 @@ import (
) )
type options struct { type options struct {
configFile string configFile string
password []byte linkedCAToken string
issuerPassword []byte password []byte
database db.AuthDB issuerPassword []byte
sshHostPassword []byte
sshUserPassword []byte
database db.AuthDB
} }
func (o *options) apply(opts []Option) { func (o *options) apply(opts []Option) {
@ -60,6 +63,22 @@ func WithPassword(password []byte) Option {
} }
} }
// WithSSHHostPassword sets the given password to decrypt the key used to sign
// ssh host certificates.
func WithSSHHostPassword(password []byte) Option {
return func(o *options) {
o.sshHostPassword = password
}
}
// WithSSHUserPassword sets the given password to decrypt the key used to sign
// ssh user certificates.
func WithSSHUserPassword(password []byte) Option {
return func(o *options) {
o.sshUserPassword = password
}
}
// WithIssuerPassword sets the given password as the configured certificate // WithIssuerPassword sets the given password as the configured certificate
// issuer password in the CA options. // issuer password in the CA options.
func WithIssuerPassword(password []byte) Option { func WithIssuerPassword(password []byte) Option {
@ -69,9 +88,16 @@ func WithIssuerPassword(password []byte) Option {
} }
// WithDatabase sets the given authority database to the CA options. // WithDatabase sets the given authority database to the CA options.
func WithDatabase(db db.AuthDB) Option { func WithDatabase(d db.AuthDB) Option {
return func(o *options) { return func(o *options) {
o.database = db o.database = d
}
}
// WithLinkedCAToken sets the token used to authenticate with the linkedca.
func WithLinkedCAToken(token string) Option {
return func(o *options) {
o.linkedCAToken = token
} }
} }
@ -87,35 +113,34 @@ type CA struct {
} }
// New creates and initializes the CA with the given configuration and options. // New creates and initializes the CA with the given configuration and options.
func New(config *config.Config, opts ...Option) (*CA, error) { func New(cfg *config.Config, opts ...Option) (*CA, error) {
ca := &CA{ ca := &CA{
config: config, config: cfg,
opts: new(options), opts: new(options),
} }
ca.opts.apply(opts) ca.opts.apply(opts)
return ca.Init(config) return ca.Init(cfg)
} }
// Init initializes the CA with the given configuration. // Init initializes the CA with the given configuration.
func (ca *CA) Init(config *config.Config) (*CA, error) { func (ca *CA) Init(cfg *config.Config) (*CA, error) {
// Intermediate Password. // Set password, it's ok to set nil password, the ca will prompt for them if
if len(ca.opts.password) > 0 { // they are required.
ca.config.Password = string(ca.opts.password) opts := []authority.Option{
authority.WithPassword(ca.opts.password),
authority.WithSSHHostPassword(ca.opts.sshHostPassword),
authority.WithSSHUserPassword(ca.opts.sshUserPassword),
authority.WithIssuerPassword(ca.opts.issuerPassword),
}
if ca.opts.linkedCAToken != "" {
opts = append(opts, authority.WithLinkedCAToken(ca.opts.linkedCAToken))
} }
// Certificate issuer password for RA mode.
if len(ca.opts.issuerPassword) > 0 {
if ca.config.AuthorityConfig != nil && ca.config.AuthorityConfig.CertificateIssuer != nil {
ca.config.AuthorityConfig.CertificateIssuer.Password = string(ca.opts.issuerPassword)
}
}
var opts []authority.Option
if ca.opts.database != nil { if ca.opts.database != nil {
opts = append(opts, authority.WithDatabase(ca.opts.database)) opts = append(opts, authority.WithDatabase(ca.opts.database))
} }
auth, err := authority.New(config, opts...) auth, err := authority.New(cfg, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -141,8 +166,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
}) })
//Add ACME api endpoints in /acme and /1.0/acme //Add ACME api endpoints in /acme and /1.0/acme
dns := config.DNSNames[0] dns := cfg.DNSNames[0]
u, err := url.Parse("https://" + config.Address) u, err := url.Parse("https://" + cfg.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -154,7 +179,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
// ACME Router // ACME Router
prefix := "acme" prefix := "acme"
var acmeDB acme.DB var acmeDB acme.DB
if config.DB == nil { if cfg.DB == nil {
acmeDB = nil acmeDB = nil
} else { } else {
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
@ -163,7 +188,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
} }
} }
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
Backdate: *config.AuthorityConfig.Backdate, Backdate: *cfg.AuthorityConfig.Backdate,
DB: acmeDB, DB: acmeDB,
DNS: dns, DNS: dns,
Prefix: prefix, Prefix: prefix,
@ -179,7 +204,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
}) })
// Admin API Router // Admin API Router
if config.AuthorityConfig.EnableAdmin { if cfg.AuthorityConfig.EnableAdmin {
adminDB := auth.GetAdminDatabase() adminDB := auth.GetAdminDatabase()
if adminDB != nil { if adminDB != nil {
adminHandler := adminAPI.NewHandler(auth) adminHandler := adminAPI.NewHandler(auth)
@ -223,8 +248,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
//dumpRoutes(mux) //dumpRoutes(mux)
// Add monitoring if configured // Add monitoring if configured
if len(config.Monitoring) > 0 { if len(cfg.Monitoring) > 0 {
m, err := monitoring.New(config.Monitoring) m, err := monitoring.New(cfg.Monitoring)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -233,8 +258,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
} }
// Add logger if configured // Add logger if configured
if len(config.Logger) > 0 { if len(cfg.Logger) > 0 {
logger, err := logging.New("ca", config.Logger) logger, err := logging.New("ca", cfg.Logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,16 +267,16 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
insecureHandler = logger.Middleware(insecureHandler) insecureHandler = logger.Middleware(insecureHandler)
} }
ca.srv = server.New(config.Address, handler, tlsConfig) ca.srv = server.New(cfg.Address, handler, tlsConfig)
// only start the insecure server if the insecure address is configured // only start the insecure server if the insecure address is configured
// and, currently, also only when it should serve SCEP endpoints. // and, currently, also only when it should serve SCEP endpoints.
if ca.shouldServeSCEPEndpoints() && config.InsecureAddress != "" { if ca.shouldServeSCEPEndpoints() && cfg.InsecureAddress != "" {
// TODO: instead opt for having a single server.Server but two // TODO: instead opt for having a single server.Server but two
// http.Servers handling the HTTP and HTTPS handler? The latter // http.Servers handling the HTTP and HTTPS handler? The latter
// will probably introduce more complexity in terms of graceful // will probably introduce more complexity in terms of graceful
// reload. // reload.
ca.insecureSrv = server.New(config.InsecureAddress, insecureHandler, nil) ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil)
} }
return ca, nil return ca, nil
@ -260,24 +285,24 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
// Run starts the CA calling to the server ListenAndServe method. // Run starts the CA calling to the server ListenAndServe method.
func (ca *CA) Run() error { func (ca *CA) Run() error {
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 1) errs := make(chan error, 1)
if ca.insecureSrv != nil { if ca.insecureSrv != nil {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errors <- ca.insecureSrv.ListenAndServe() errs <- ca.insecureSrv.ListenAndServe()
}() }()
} }
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errors <- ca.srv.ListenAndServe() errs <- ca.srv.ListenAndServe()
}() }()
// wait till error occurs; ensures the servers keep listening // wait till error occurs; ensures the servers keep listening
err := <-errors err := <-errs
wg.Wait() wg.Wait()
@ -306,7 +331,7 @@ func (ca *CA) Stop() error {
// Reload reloads the configuration of the CA and calls to the server Reload // Reload reloads the configuration of the CA and calls to the server Reload
// method. // method.
func (ca *CA) Reload() error { func (ca *CA) Reload() error {
config, err := config.LoadConfiguration(ca.opts.configFile) cfg, err := config.LoadConfiguration(ca.opts.configFile)
if err != nil { if err != nil {
return errors.Wrap(err, "error reloading ca configuration") return errors.Wrap(err, "error reloading ca configuration")
} }
@ -318,14 +343,17 @@ func (ca *CA) Reload() error {
} }
// Do not allow reload if the database configuration has changed. // Do not allow reload if the database configuration has changed.
if !reflect.DeepEqual(ca.config.DB, config.DB) { if !reflect.DeepEqual(ca.config.DB, cfg.DB) {
logContinue("Reload failed because the database configuration has changed.") logContinue("Reload failed because the database configuration has changed.")
return errors.New("error reloading ca: database configuration cannot change") return errors.New("error reloading ca: database configuration cannot change")
} }
newCA, err := New(config, newCA, err := New(cfg,
WithPassword(ca.opts.password), WithPassword(ca.opts.password),
WithSSHHostPassword(ca.opts.sshHostPassword),
WithSSHUserPassword(ca.opts.sshUserPassword),
WithIssuerPassword(ca.opts.issuerPassword), WithIssuerPassword(ca.opts.issuerPassword),
WithLinkedCAToken(ca.opts.linkedCAToken),
WithConfigFile(ca.opts.configFile), WithConfigFile(ca.opts.configFile),
WithDatabase(ca.auth.GetDatabase()), WithDatabase(ca.auth.GetDatabase()),
) )

View file

@ -322,7 +322,7 @@ ZEp7knvU2psWRw==
assert.Equals(t, intermediate, realIntermediate) assert.Equals(t, intermediate, realIntermediate)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)
@ -375,7 +375,7 @@ func TestCAProvisioners(t *testing.T) {
assert.Equals(t, a, b) assert.Equals(t, a, b)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)
@ -436,7 +436,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
assert.Equals(t, ek.Key, tc.expectedKey) assert.Equals(t, ek.Key, tc.expectedKey)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)
@ -497,7 +497,7 @@ func TestCARoot(t *testing.T) {
assert.Equals(t, root.RootPEM.Certificate, rootCrt) assert.Equals(t, root.RootPEM.Certificate, rootCrt)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)
@ -665,7 +665,7 @@ func TestCARenew(t *testing.T) {
assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions) assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)

View file

@ -74,17 +74,17 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) {
c.Client.Transport = tr c.Client.Transport = tr
} }
func (c *uaClient) Get(url string) (*http.Response, error) { func (c *uaClient) Get(u string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", u, nil)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "new request GET %s failed", url) return nil, errors.Wrapf(err, "new request GET %s failed", u)
} }
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
return c.Client.Do(req) return c.Client.Do(req)
} }
func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response, error) { func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", url, body) req, err := http.NewRequest("POST", u, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -305,7 +305,7 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin
err error err error
opts []jose.Option opts []jose.Option
) )
if len(passwordFile) != 0 { if passwordFile != "" {
opts = append(opts, jose.WithPasswordFile(passwordFile)) opts = append(opts, jose.WithPasswordFile(passwordFile))
} }
blk, err := pemutil.Serialize(key) blk, err := pemutil.Serialize(key)
@ -326,14 +326,14 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin
for _, e := range o.x5cCert.Extensions { for _, e := range o.x5cCert.Extensions {
if e.Id.Equal(stepOIDProvisioner) { if e.Id.Equal(stepOIDProvisioner) {
var provisioner stepProvisionerASN1 var prov stepProvisionerASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { if _, err := asn1.Unmarshal(e.Value, &prov); err != nil {
return errors.Wrap(err, "error unmarshaling provisioner OID from certificate") return errors.Wrap(err, "error unmarshaling provisioner OID from certificate")
} }
o.x5cIssuer = string(provisioner.Name) o.x5cIssuer = string(prov.Name)
} }
} }
if len(o.x5cIssuer) == 0 { if o.x5cIssuer == "" {
return errors.New("provisioner extension not found in certificate") return errors.New("provisioner extension not found in certificate")
} }
@ -631,7 +631,7 @@ retry:
// do not match. // do not match.
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
var retried bool var retried bool
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1)) sha256Sum = strings.ToLower(strings.ReplaceAll(sha256Sum, "-", ""))
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum}) u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
retry: retry:
resp, err := newInsecureClient().Get(u.String()) resp, err := newInsecureClient().Get(u.String())
@ -651,7 +651,7 @@ retry:
} }
// verify the sha256 // verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw) sum := sha256.Sum256(root.RootPEM.Raw)
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) { if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) {
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match") return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
} }
return &root, nil return &root, nil
@ -1066,16 +1066,16 @@ retry:
} }
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var config api.SSHConfigResponse var cfg api.SSHConfigResponse
if err := readJSON(resp.Body, &config); err != nil { if err := readJSON(resp.Body, &cfg); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errors.Wrapf(err, "error reading %s", u)
} }
return &config, nil return &cfg, nil
} }
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the // SSHCheckHost performs the POST /ssh/check-host request to the CA with the
// given principal. // given principal.
func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrincipalResponse, error) { func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalResponse, error) {
var retried bool var retried bool
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
Type: provisioner.SSHHostCert, Type: provisioner.SSHHostCert,

View file

@ -135,7 +135,7 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
return csr return csr
} }
func equalJSON(t *testing.T, a interface{}, b interface{}) bool { func equalJSON(t *testing.T, a, b interface{}) bool {
if reflect.DeepEqual(a, b) { if reflect.DeepEqual(a, b) {
return true return true
} }

View file

@ -187,11 +187,12 @@ func TestLoadClient(t *testing.T) {
} else { } else {
gotTransport := got.Client.Transport.(*http.Transport) gotTransport := got.Client.Transport.(*http.Transport)
wantTransport := tt.want.Client.Transport.(*http.Transport) wantTransport := tt.want.Client.Transport.(*http.Transport)
if gotTransport.TLSClientConfig.GetClientCertificate == nil { switch {
case gotTransport.TLSClientConfig.GetClientCertificate == nil:
t.Error("LoadClient() transport does not define GetClientCertificate") t.Error("LoadClient() transport does not define GetClientCertificate")
} else if !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()) { case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()):
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
} else { default:
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
if err != nil { if err != nil {
t.Errorf("LoadClient() GetClientCertificate error = %v", err) t.Errorf("LoadClient() GetClientCertificate error = %v", err)

7
ca/testdata/ca.json vendored
View file

@ -9,12 +9,11 @@
"logger": {"format": "text"}, "logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.2, "maxVersion": 1.3,
"renegotiation": false, "renegotiation": false,
"cipherSuites": [ "cipherSuites": [
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
] ]
}, },
"authority": { "authority": {

View file

@ -105,7 +105,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
tr := getDefaultTransport(tlsConfig) tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew // Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
@ -154,7 +154,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
// Update renew function with transport // Update renew function with transport
tr := getDefaultTransport(tlsConfig) tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew // Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
@ -195,7 +195,7 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
} }
// buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport. // buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport.
// nolint:unused // nolint:unused,gocritic
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) { func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) {
d := getDefaultDialer() d := getDefaultDialer()
@ -253,6 +253,8 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
return nil, err return nil, err
} }
// nolint:gocritic
// using a new variable for clarity
chain := append(certPEM, caPEM...) chain := append(certPEM, caPEM...)
cert, err := tls.X509KeyPair(chain, keyPEM) cert, err := tls.X509KeyPair(chain, keyPEM)
if err != nil { if err != nil {
@ -277,9 +279,9 @@ func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
// getDefaultDialer returns a new dialer with the default configuration. // getDefaultDialer returns a new dialer with the default configuration.
func getDefaultDialer() *net.Dialer { func getDefaultDialer() *net.Dialer {
// With the KeepAlive parameter set to 0, it will be use Golang's default.
return &net.Dialer{ return &net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
} }
} }

View file

@ -38,10 +38,17 @@ type Options struct {
CertificateChain []*x509.Certificate `json:"-"` CertificateChain []*x509.Certificate `json:"-"`
Signer crypto.Signer `json:"-"` Signer crypto.Signer `json:"-"`
// IsCreator is set to true when we're creating a certificate authority. Is // IsCreator is set to true when we're creating a certificate authority. It
// used to skip some validations when initializing a CertificateAuthority. // is used to skip some validations when initializing a
// CertificateAuthority. This option is used on SoftCAS and CloudCAS.
IsCreator bool `json:"-"` IsCreator bool `json:"-"`
// IsCAGetter is set to true when we're just using the
// CertificateAuthorityGetter interface to retrieve the root certificate. It
// is used to skip some validations when initializing a
// CertificateAuthority. This option is used on StepCAS.
IsCAGetter bool `json:"-"`
// KeyManager is the KMS used to generate keys in SoftCAS. // KeyManager is the KMS used to generate keys in SoftCAS.
KeyManager kms.KeyManager `json:"-"` KeyManager kms.KeyManager `json:"-"`

View file

@ -108,6 +108,9 @@ type GetCertificateAuthorityResponse struct {
RootCertificate *x509.Certificate RootCertificate *x509.Certificate
} }
// CreateKeyRequest is the request used to generate a new key using a KMS.
type CreateKeyRequest = apiv1.CreateKeyRequest
// CreateCertificateAuthorityRequest is the request used to generate a root or // CreateCertificateAuthorityRequest is the request used to generate a root or
// intermediate certificate. // intermediate certificate.
type CreateCertificateAuthorityRequest struct { type CreateCertificateAuthorityRequest struct {
@ -126,7 +129,7 @@ type CreateCertificateAuthorityRequest struct {
// CreateKey defines the KMS CreateKeyRequest to use when creating a new // CreateKey defines the KMS CreateKeyRequest to use when creating a new
// CertificateAuthority. If CreateKey is nil, a default algorithm will be // CertificateAuthority. If CreateKey is nil, a default algorithm will be
// used. // used.
CreateKey *apiv1.CreateKeyRequest CreateKey *CreateKeyRequest
} }
// CreateCertificateAuthorityResponse is the response for // CreateCertificateAuthorityResponse is the response for
@ -136,6 +139,7 @@ type CreateCertificateAuthorityResponse struct {
Name string Name string
Certificate *x509.Certificate Certificate *x509.Certificate
CertificateChain []*x509.Certificate CertificateChain []*x509.Certificate
KeyName string
PublicKey crypto.PublicKey PublicKey crypto.PublicKey
PrivateKey crypto.PrivateKey PrivateKey crypto.PrivateKey
Signer crypto.Signer Signer crypto.Signer

View file

@ -1,6 +1,7 @@
package apiv1 package apiv1
import ( import (
"crypto/x509"
"net/http" "net/http"
"strings" "strings"
) )
@ -26,6 +27,12 @@ type CertificateAuthorityCreator interface {
CreateCertificateAuthority(req *CreateCertificateAuthorityRequest) (*CreateCertificateAuthorityResponse, error) CreateCertificateAuthority(req *CreateCertificateAuthorityRequest) (*CreateCertificateAuthorityResponse, error)
} }
// SignatureAlgorithmGetter is an optional implementation in a crypto.Signer
// that returns the SignatureAlgorithm to use.
type SignatureAlgorithmGetter interface {
SignatureAlgorithm() x509.SignatureAlgorithm
}
// Type represents the CAS type used. // Type represents the CAS type used.
type Type string type Type string

View file

@ -29,9 +29,7 @@ func init() {
}) })
} }
var now = func() time.Time { var now = time.Now
return time.Now()
}
// The actual regular expression that matches a certificate authority is: // The actual regular expression that matches a certificate authority is:
// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$ // ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$

View file

@ -12,7 +12,6 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"os" "os"
"reflect" "reflect"
@ -103,7 +102,7 @@ MHcCAQEEIN51Rgg6YcQVLeCRzumdw4pjM3VWqFIdCbnsV3Up1e/goAoGCCqGSM49
AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk
Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg== Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg==
-----END EC PRIVATE KEY-----` -----END EC PRIVATE KEY-----`
// nolint:unused,deadcode // nolint:unused,deadcode,gocritic
testIntermediateKey = `-----BEGIN EC PRIVATE KEY----- testIntermediateKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49 MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49
AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6 AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6
@ -190,7 +189,7 @@ func (b *badSigner) Public() crypto.PublicKey {
return b.pub return b.pub
} }
func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { func (b *badSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
return nil, fmt.Errorf("💥") return nil, fmt.Errorf("💥")
} }
@ -730,7 +729,7 @@ func TestCloudCAS_RevokeCertificate(t *testing.T) {
func Test_createCertificateID(t *testing.T) { func Test_createCertificateID(t *testing.T) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
setTeeReader(t, buf) setTeeReader(t, buf)
uuid, err := uuid.NewRandomFromReader(rand.Reader) id, err := uuid.NewRandomFromReader(rand.Reader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -741,7 +740,7 @@ func Test_createCertificateID(t *testing.T) {
want string want string
wantErr bool wantErr bool
}{ }{
{"ok", uuid.String(), false}, {"ok", id.String(), false},
{"fail", "", true}, {"fail", "", true},
} }
for _, tt := range tests { for _, tt := range tests {
@ -858,7 +857,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) {
return lis.Dial() return lis.Dial()
})) }))
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn)) client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn))

View file

@ -19,9 +19,7 @@ func init() {
}) })
} }
var now = func() time.Time { var now = time.Now
return time.Now()
}
// SoftCAS implements a Certificate Authority Service using Golang or KMS // SoftCAS implements a Certificate Authority Service using Golang or KMS
// crypto. This is the default CAS used in step-ca. // crypto. This is the default CAS used in step-ca.
@ -68,7 +66,7 @@ func (c *SoftCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
} }
req.Template.Issuer = c.CertificateChain[0].Subject req.Template.Issuer = c.CertificateChain[0].Subject
cert, err := x509util.CreateCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -93,7 +91,7 @@ func (c *SoftCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R
req.Template.NotAfter = t.Add(req.Lifetime) req.Template.NotAfter = t.Add(req.Lifetime)
req.Template.Issuer = c.CertificateChain[0].Subject req.Template.Issuer = c.CertificateChain[0].Subject
cert, err := x509util.CreateCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer) cert, err := createCertificate(req.Template, c.CertificateChain[0], req.Template.PublicKey, c.Signer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -150,12 +148,12 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori
var cert *x509.Certificate var cert *x509.Certificate
switch req.Type { switch req.Type {
case apiv1.RootCA: case apiv1.RootCA:
cert, err = x509util.CreateCertificate(req.Template, req.Template, signer.Public(), signer) cert, err = createCertificate(req.Template, req.Template, signer.Public(), signer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
case apiv1.IntermediateCA: case apiv1.IntermediateCA:
cert, err = x509util.CreateCertificate(req.Template, req.Parent.Certificate, signer.Public(), req.Parent.Signer) cert, err = createCertificate(req.Template, req.Parent.Certificate, signer.Public(), req.Parent.Signer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -174,6 +172,7 @@ func (c *SoftCAS) CreateCertificateAuthority(req *apiv1.CreateCertificateAuthori
Name: cert.Subject.CommonName, Name: cert.Subject.CommonName,
Certificate: cert, Certificate: cert,
CertificateChain: chain, CertificateChain: chain,
KeyName: key.Name,
PublicKey: key.PublicKey, PublicKey: key.PublicKey,
PrivateKey: key.PrivateKey, PrivateKey: key.PrivateKey,
Signer: signer, Signer: signer,
@ -210,3 +209,16 @@ func (c *SoftCAS) createSigner(req *kmsapi.CreateSignerRequest) (crypto.Signer,
} }
return c.KeyManager.CreateSigner(req) return c.KeyManager.CreateSigner(req)
} }
// createCertificate sets the SignatureAlgorithm of the template if necessary
// and calls x509util.CreateCertificate.
func createCertificate(template, parent *x509.Certificate, pub crypto.PublicKey, signer crypto.Signer) (*x509.Certificate, error) {
// Signers can specify the signature algorithm. This is especially important
// when x509.CreateCertificate attempts to validate a RSAPSS signature.
if template.SignatureAlgorithm == 0 {
if sa, ok := signer.(apiv1.SignatureAlgorithmGetter); ok {
template.SignatureAlgorithm = sa.SignatureAlgorithm()
}
}
return x509util.CreateCertificate(template, parent, pub, signer)
}

View file

@ -75,6 +75,15 @@ var (
testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour)) testSignedIntermediateTemplate = mustSign(testIntermediateTemplate, testSignedRootTemplate, testNow, testNow.Add(24*time.Hour))
) )
type signatureAlgorithmSigner struct {
crypto.Signer
algorithm x509.SignatureAlgorithm
}
func (s *signatureAlgorithmSigner) SignatureAlgorithm() x509.SignatureAlgorithm {
return s.algorithm
}
type mockKeyManager struct { type mockKeyManager struct {
signer crypto.Signer signer crypto.Signer
errGetPublicKey error errGetPublicKey error
@ -97,6 +106,7 @@ func (m *mockKeyManager) CreateKey(req *kmsapi.CreateKeyRequest) (*kmsapi.Create
signer = m.signer signer = m.signer
} }
return &kmsapi.CreateKeyResponse{ return &kmsapi.CreateKeyResponse{
Name: req.Name,
PrivateKey: signer, PrivateKey: signer,
PublicKey: signer.Public(), PublicKey: signer.Public(),
}, m.errCreateKey }, m.errCreateKey
@ -124,7 +134,7 @@ func (b *badSigner) Public() crypto.PublicKey {
return testSigner.Public() return testSigner.Public()
} }
func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { func (b *badSigner) Sign(_ io.Reader, _ []byte, _ crypto.SignerOpts) ([]byte, error) {
return nil, fmt.Errorf("💥") return nil, fmt.Errorf("💥")
} }
@ -247,6 +257,13 @@ func TestSoftCAS_CreateCertificate(t *testing.T) {
tmplNoSerial := *testTemplate tmplNoSerial := *testTemplate
tmplNoSerial.SerialNumber = nil tmplNoSerial.SerialNumber = nil
saTemplate := *testSignedTemplate
saTemplate.SignatureAlgorithm = 0
saSigner := &signatureAlgorithmSigner{
Signer: testSigner,
algorithm: x509.PureEd25519,
}
type fields struct { type fields struct {
Issuer *x509.Certificate Issuer *x509.Certificate
Signer crypto.Signer Signer crypto.Signer
@ -267,6 +284,12 @@ func TestSoftCAS_CreateCertificate(t *testing.T) {
Certificate: testSignedTemplate, Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.CreateCertificateRequest{
Template: &saTemplate, Lifetime: 24 * time.Hour,
}}, &apiv1.CreateCertificateResponse{
Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer},
}, false},
{"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{ {"ok with notBefore", fields{testIssuer, testSigner}, args{&apiv1.CreateCertificateRequest{
Template: &tmplNotBefore, Lifetime: 24 * time.Hour, Template: &tmplNotBefore, Lifetime: 24 * time.Hour,
}}, &apiv1.CreateCertificateResponse{ }}, &apiv1.CreateCertificateResponse{
@ -316,6 +339,11 @@ func TestSoftCAS_RenewCertificate(t *testing.T) {
tmplNoSerial := *testTemplate tmplNoSerial := *testTemplate
tmplNoSerial.SerialNumber = nil tmplNoSerial.SerialNumber = nil
saSigner := &signatureAlgorithmSigner{
Signer: testSigner,
algorithm: x509.PureEd25519,
}
type fields struct { type fields struct {
Issuer *x509.Certificate Issuer *x509.Certificate
Signer crypto.Signer Signer crypto.Signer
@ -336,6 +364,12 @@ func TestSoftCAS_RenewCertificate(t *testing.T) {
Certificate: testSignedTemplate, Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer}, CertificateChain: []*x509.Certificate{testIssuer},
}, false}, }, false},
{"ok signature algorithm", fields{testIssuer, saSigner}, args{&apiv1.RenewCertificateRequest{
Template: testTemplate, Lifetime: 24 * time.Hour,
}}, &apiv1.RenewCertificateResponse{
Certificate: testSignedTemplate,
CertificateChain: []*x509.Certificate{testIssuer},
}, false},
{"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true}, {"fail template", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Lifetime: 24 * time.Hour}}, nil, true},
{"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true}, {"fail lifetime", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{Template: testTemplate}}, nil, true},
{"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{ {"fail CreateCertificate", fields{testIssuer, testSigner}, args{&apiv1.RenewCertificateRequest{
@ -425,6 +459,11 @@ func Test_now(t *testing.T) {
func TestSoftCAS_CreateCertificateAuthority(t *testing.T) { func TestSoftCAS_CreateCertificateAuthority(t *testing.T) {
mockNow(t) mockNow(t)
saSigner := &signatureAlgorithmSigner{
Signer: testSigner,
algorithm: x509.PureEd25519,
}
type fields struct { type fields struct {
Issuer *x509.Certificate Issuer *x509.Certificate
Signer crypto.Signer Signer crypto.Signer
@ -467,6 +506,33 @@ func TestSoftCAS_CreateCertificateAuthority(t *testing.T) {
PrivateKey: testSigner, PrivateKey: testSigner,
Signer: testSigner, Signer: testSigner,
}, false}, }, false},
{"ok signature algorithm", fields{nil, nil, &mockKeyManager{signer: saSigner}}, args{&apiv1.CreateCertificateAuthorityRequest{
Type: apiv1.RootCA,
Template: testRootTemplate,
Lifetime: 24 * time.Hour,
}}, &apiv1.CreateCertificateAuthorityResponse{
Name: "Test Root CA",
Certificate: testSignedRootTemplate,
PublicKey: testSignedRootTemplate.PublicKey,
PrivateKey: saSigner,
Signer: saSigner,
}, false},
{"ok createKey", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{
Type: apiv1.RootCA,
Template: testRootTemplate,
Lifetime: 24 * time.Hour,
CreateKey: &kmsapi.CreateKeyRequest{
Name: "root_ca.crt",
SignatureAlgorithm: kmsapi.ECDSAWithSHA256,
},
}}, &apiv1.CreateCertificateAuthorityResponse{
Name: "Test Root CA",
Certificate: testSignedRootTemplate,
PublicKey: testSignedRootTemplate.PublicKey,
KeyName: "root_ca.crt",
PrivateKey: testSigner,
Signer: testSigner,
}, false},
{"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{ {"fail template", fields{nil, nil, &mockKeyManager{}}, args{&apiv1.CreateCertificateAuthorityRequest{
Type: apiv1.RootCA, Type: apiv1.RootCA,
Lifetime: 24 * time.Hour, Lifetime: 24 * time.Hour,

View file

@ -47,10 +47,13 @@ func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) {
return nil, err return nil, err
} }
// Create configured issuer var iss stepIssuer
iss, err := newStepIssuer(caURL, client, opts.CertificateIssuer) // Create configured issuer unless we only want to use GetCertificateAuthority.
if err != nil { // This avoid the request for the password if not provided.
return nil, err if !opts.IsCAGetter {
if iss, err = newStepIssuer(caURL, client, opts.CertificateIssuer); err != nil {
return nil, err
}
} }
return &StepCAS{ return &StepCAS{
@ -87,9 +90,9 @@ func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R
return nil, apiv1.ErrNotImplemented{Message: "stepCAS does not support mTLS renewals"} return nil, apiv1.ErrNotImplemented{Message: "stepCAS does not support mTLS renewals"}
} }
// RevokeCertificate revokes a certificate.
func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
switch { if req.SerialNumber == "" && req.Certificate == nil {
case req.SerialNumber == "" && req.Certificate == nil:
return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required") return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required")
} }

View file

@ -411,6 +411,19 @@ func TestNew(t *testing.T) {
client: client, client: client,
fingerprint: testRootFingerprint, fingerprint: testRootFingerprint,
}, false}, }, false},
{"ok ca getter", args{context.TODO(), apiv1.Options{
IsCAGetter: true,
CertificateAuthority: caURL.String(),
CertificateAuthorityFingerprint: testRootFingerprint,
CertificateIssuer: &apiv1.CertificateIssuer{
Type: "jwk",
Provisioner: "ra@doe.org",
},
}}, &StepCAS{
iss: nil,
client: client,
fingerprint: testRootFingerprint,
}, false},
{"fail authority", args{context.TODO(), apiv1.Options{ {"fail authority", args{context.TODO(), apiv1.Options{
CertificateAuthority: "", CertificateAuthority: "",
CertificateAuthorityFingerprint: testRootFingerprint, CertificateAuthorityFingerprint: testRootFingerprint,

View file

@ -19,9 +19,7 @@ const defaultValidity = 5 * time.Minute
// timeNow returns the current time. // timeNow returns the current time.
// This method is used for unit testing purposes. // This method is used for unit testing purposes.
var timeNow = func() time.Time { var timeNow = time.Now
return time.Now()
}
type x5cIssuer struct { type x5cIssuer struct {
caURL *url.URL caURL *url.URL

View file

@ -22,7 +22,7 @@ func (b noneSigner) Public() crypto.PublicKey {
return []byte(b) return []byte(b)
} }
func (b noneSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { func (b noneSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
return digest, nil return digest, nil
} }

View file

@ -24,13 +24,16 @@ import (
func main() { func main() {
var credentialsFile, region string var credentialsFile, region string
var ssh bool var enableSSH bool
flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the AWS KMS credentials.") flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the AWS KMS credentials.")
flag.StringVar(&region, "region", "", "AWS KMS region name.") flag.StringVar(&region, "region", "", "AWS KMS region name.")
flag.BoolVar(&ssh, "ssh", false, "Create SSH keys.") flag.BoolVar(&enableSSH, "ssh", false, "Create SSH keys.")
flag.Usage = usage flag.Usage = usage
flag.Parse() flag.Parse()
// Initialize windows terminal
ui.Init()
c, err := awskms.New(context.Background(), apiv1.Options{ c, err := awskms.New(context.Background(), apiv1.Options{
Type: string(apiv1.AmazonKMS), Type: string(apiv1.AmazonKMS),
Region: region, Region: region,
@ -44,16 +47,20 @@ func main() {
fatal(err) fatal(err)
} }
if ssh { if enableSSH {
ui.Println() ui.Println()
if err := createSSH(c); err != nil { if err := createSSH(c); err != nil {
fatal(err) fatal(err)
} }
} }
// Reset windows terminal
ui.Reset()
} }
func fatal(err error) { func fatal(err error) {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
ui.Reset()
os.Exit(1) os.Exit(1)
} }
@ -113,7 +120,7 @@ func createX509(c *awskms.KMS) error {
return err return err
} }
if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {
@ -156,7 +163,7 @@ func createX509(c *awskms.KMS) error {
return err return err
} }
if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {
@ -186,7 +193,7 @@ func createSSH(c *awskms.KMS) error {
return err return err
} }
if err = fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { if err := fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
return err return err
} }
@ -207,7 +214,7 @@ func createSSH(c *awskms.KMS) error {
return err return err
} }
if err = fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { if err := fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
return err return err
} }

View file

@ -22,10 +22,12 @@ import (
"go.step.sm/cli-utils/command" "go.step.sm/cli-utils/command"
"go.step.sm/cli-utils/command/version" "go.step.sm/cli-utils/command/version"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/ui"
"go.step.sm/cli-utils/usage" "go.step.sm/cli-utils/usage"
// Enabled kms interfaces. // Enabled kms interfaces.
_ "github.com/smallstep/certificates/kms/awskms" _ "github.com/smallstep/certificates/kms/awskms"
_ "github.com/smallstep/certificates/kms/azurekms"
_ "github.com/smallstep/certificates/kms/cloudkms" _ "github.com/smallstep/certificates/kms/cloudkms"
_ "github.com/smallstep/certificates/kms/softkms" _ "github.com/smallstep/certificates/kms/softkms"
_ "github.com/smallstep/certificates/kms/sshagentkms" _ "github.com/smallstep/certificates/kms/sshagentkms"
@ -52,6 +54,11 @@ func init() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
} }
func exit(code int) {
ui.Reset()
os.Exit(code)
}
// appHelpTemplate contains the modified template for the main app // appHelpTemplate contains the modified template for the main app
var appHelpTemplate = `## NAME var appHelpTemplate = `## NAME
**{{.HelpName}}** -- {{.Usage}} **{{.HelpName}}** -- {{.Usage}}
@ -90,6 +97,9 @@ Please send us a sentence or two, good or bad: **feedback@smallstep.com** or htt
` `
func main() { func main() {
// Initialize windows terminal
ui.Init()
// Override global framework components // Override global framework components
cli.VersionPrinter = func(c *cli.Context) { cli.VersionPrinter = func(c *cli.Context) {
version.Command(c) version.Command(c)
@ -107,7 +117,9 @@ func main() {
app.HelpName = "step-ca" app.HelpName = "step-ca"
app.Version = config.Version() app.Version = config.Version()
app.Usage = "an online certificate authority for secure automated certificate management" app.Usage = "an online certificate authority for secure automated certificate management"
app.UsageText = `**step-ca** <config> [**--password-file**=<file>] [**--issuer-password-file**=<file>] [**--resolver**=<addr>] [**--help**] [**--version**]` app.UsageText = `**step-ca** <config> [**--password-file**=<file>]
[**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>]
[**--issuer-password-file**=<file>] [**--resolver**=<addr>] [**--help**] [**--version**]`
app.Description = `**step-ca** runs the Step Online Certificate Authority app.Description = `**step-ca** runs the Step Online Certificate Authority
(Step CA) using the given configuration. (Step CA) using the given configuration.
See the README.md for more detailed configuration documentation. See the README.md for more detailed configuration documentation.
@ -162,8 +174,10 @@ $ step-ca $STEPPATH/config/ca.json --password-file ./password.txt
} else { } else {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
} }
os.Exit(1) exit(1)
} }
exit(0)
} }
func flagValue(f cli.Flag) reflect.Value { func flagValue(f cli.Flag) reflect.Value {
@ -178,8 +192,8 @@ var placeholderString = regexp.MustCompile(`<.*?>`)
func stringifyFlag(f cli.Flag) string { func stringifyFlag(f cli.Flag) string {
fv := flagValue(f) fv := flagValue(f)
usage := fv.FieldByName("Usage").String() usg := fv.FieldByName("Usage").String()
placeholder := placeholderString.FindString(usage) placeholder := placeholderString.FindString(usg)
if placeholder == "" { if placeholder == "" {
switch f.(type) { switch f.(type) {
case cli.BoolFlag, cli.BoolTFlag: case cli.BoolFlag, cli.BoolTFlag:
@ -187,5 +201,5 @@ func stringifyFlag(f cli.Flag) string {
placeholder = "<value>" placeholder = "<value>"
} }
} }
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usg
} }

Some files were not shown because too many files have changed in this diff Show more